From bb664d9bbfe8fe004fdc47c55653c5c904a6bed1 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Sat, 28 Feb 2026 15:01:20 +0800 Subject: [PATCH] feat(sync): full code sync from release --- .gitignore | 6 +- AGENTS.md | 105 + Dockerfile | 13 +- Makefile | 9 +- README.md | 28 + README_CN.md | 38 + backend/.gosec.json | 2 +- backend/Makefile | 11 +- backend/cmd/server/VERSION | 2 +- backend/cmd/server/wire.go | 57 +- backend/cmd/server/wire_gen.go | 71 +- backend/cmd/server/wire_gen_test.go | 81 + backend/ent/account.go | 32 +- backend/ent/account/account.go | 16 + backend/ent/account/where.go | 135 + backend/ent/account_create.go | 156 + backend/ent/account_update.go | 104 + backend/ent/client.go | 171 +- backend/ent/ent.go | 2 + backend/ent/group.go | 13 +- backend/ent/group/group.go | 10 + backend/ent/group/where.go | 45 + backend/ent/group_create.go | 85 + backend/ent/group_update.go | 54 + backend/ent/hook/hook.go | 12 + backend/ent/idempotencyrecord.go | 228 + .../idempotencyrecord/idempotencyrecord.go | 148 + backend/ent/idempotencyrecord/where.go | 755 ++++ backend/ent/idempotencyrecord_create.go | 1132 +++++ backend/ent/idempotencyrecord_delete.go | 88 + backend/ent/idempotencyrecord_query.go | 564 +++ backend/ent/idempotencyrecord_update.go | 676 +++ backend/ent/intercept/intercept.go | 30 + backend/ent/migrate/schema.go | 73 +- backend/ent/mutation.go | 1481 +++++- backend/ent/predicate/predicate.go | 3 + backend/ent/runtime/runtime.go | 60 +- backend/ent/schema/account.go | 18 +- backend/ent/schema/group.go | 4 + backend/ent/schema/usage_log.go | 2 + backend/ent/schema/user.go | 6 + backend/ent/schema/user_subscription.go | 2 + backend/ent/tx.go | 3 + backend/ent/user.go | 24 +- backend/ent/user/user.go | 20 + backend/ent/user/where.go | 90 + backend/ent/user_create.go | 170 + backend/ent/user_update.go | 108 + backend/go.mod | 26 +- backend/go.sum | 56 + backend/internal/config/config.go | 271 +- backend/internal/config/config_test.go | 267 ++ backend/internal/domain/constants.go | 3 + backend/internal/domain/constants_test.go | 24 + .../internal/handler/admin/account_handler.go | 28 + .../handler/admin/dashboard_handler.go | 37 +- .../dashboard_handler_request_type_test.go | 132 + .../handler/admin/data_management_handler.go | 523 +++ .../admin/data_management_handler_test.go | 78 + .../internal/handler/admin/group_handler.go | 6 + .../handler/admin/openai_oauth_handler.go | 17 +- .../internal/handler/admin/ops_ws_handler.go | 38 +- .../internal/handler/admin/setting_handler.go | 387 +- .../admin/usage_cleanup_handler_test.go | 86 + .../internal/handler/admin/usage_handler.go | 54 +- .../admin/usage_handler_request_type_test.go | 117 + .../internal/handler/admin/user_handler.go | 52 +- backend/internal/handler/auth_handler.go | 5 +- backend/internal/handler/dto/mappers.go | 24 +- .../handler/dto/mappers_usage_test.go | 73 + backend/internal/handler/dto/settings.go | 40 + backend/internal/handler/dto/types.go | 18 +- backend/internal/handler/failover_loop.go | 32 +- backend/internal/handler/gateway_handler.go | 135 +- ...eway_handler_warmup_intercept_unit_test.go | 7 + backend/internal/handler/gateway_helper.go | 50 +- .../handler/gateway_helper_fastpath_test.go | 8 + .../handler/gateway_helper_hotpath_test.go | 56 +- .../internal/handler/gemini_v1beta_handler.go | 14 +- backend/internal/handler/handler.go | 2 + .../handler/openai_gateway_handler.go | 833 +++- .../handler/openai_gateway_handler_test.go | 447 ++ backend/internal/handler/ops_error_logger.go | 40 +- .../internal/handler/ops_error_logger_test.go | 41 + backend/internal/handler/setting_handler.go | 1 + .../internal/handler/sora_client_handler.go | 979 ++++ .../handler/sora_client_handler_test.go | 3135 +++++++++++++ .../internal/handler/sora_gateway_handler.go | 12 +- .../handler/sora_gateway_handler_test.go | 4 +- backend/internal/handler/usage_handler.go | 13 +- .../usage_handler_request_type_test.go | 80 + .../handler/usage_record_submit_task_test.go | 48 + backend/internal/handler/wire.go | 5 + .../internal/pkg/antigravity/claude_types.go | 3 + .../pkg/antigravity/claude_types_test.go | 26 + .../internal/pkg/antigravity/gemini_types.go | 2 +- backend/internal/pkg/antigravity/oauth.go | 3 +- .../internal/pkg/antigravity/oauth_test.go | 18 +- backend/internal/pkg/errors/errors_test.go | 15 + backend/internal/pkg/errors/http.go | 14 +- backend/internal/pkg/geminicli/constants.go | 2 +- backend/internal/pkg/httpclient/pool.go | 53 +- backend/internal/pkg/httpclient/pool_test.go | 115 + backend/internal/pkg/httputil/body.go | 37 + backend/internal/pkg/ip/ip.go | 72 +- backend/internal/pkg/ip/ip_test.go | 21 + backend/internal/pkg/logger/logger.go | 77 +- backend/internal/pkg/logger/slog_handler.go | 9 +- .../internal/pkg/logger/stdlog_bridge_test.go | 1 + backend/internal/pkg/openai/oauth.go | 48 +- backend/internal/pkg/openai/oauth_test.go | 39 + .../internal/pkg/response/response_test.go | 8 +- backend/internal/pkg/tlsfingerprint/dialer.go | 16 +- .../pkg/usagestats/usage_log_types.go | 1 + backend/internal/repository/account_repo.go | 157 +- .../account_repo_integration_test.go | 32 + backend/internal/repository/api_key_repo.go | 31 +- .../internal/repository/concurrency_cache.go | 37 + .../gateway_cache_integration_test.go | 1 - backend/internal/repository/group_repo.go | 128 +- .../repository/group_repo_integration_test.go | 75 + .../idempotency_repo_integration_test.go | 1 - .../internal/repository/migrations_runner.go | 129 +- .../migrations_runner_checksum_test.go | 36 + .../migrations_runner_extra_test.go | 368 ++ .../repository/migrations_runner_notx_test.go | 164 + .../migrations_schema_integration_test.go | 2 + .../repository/openai_oauth_service.go | 42 +- .../repository/openai_oauth_service_test.go | 68 +- .../internal/repository/ops_repo_dashboard.go | 255 +- .../ops_repo_dashboard_timeout_test.go | 22 + .../repository/sora_generation_repo.go | 419 ++ .../internal/repository/usage_cleanup_repo.go | 7 +- .../repository/usage_cleanup_repo_test.go | 32 + backend/internal/repository/usage_log_repo.go | 250 +- .../usage_log_repo_integration_test.go | 70 +- .../usage_log_repo_request_type_test.go | 327 ++ .../repository/user_group_rate_repo.go | 86 +- backend/internal/repository/user_repo.go | 62 + backend/internal/server/api_contract_test.go | 38 +- .../server/middleware/api_key_auth.go | 2 +- .../server/middleware/api_key_auth_google.go | 24 +- .../middleware/api_key_auth_google_test.go | 170 + .../server/middleware/security_headers.go | 16 + .../middleware/security_headers_test.go | 20 + backend/internal/server/router.go | 1 + backend/internal/server/routes/admin.go | 36 + backend/internal/server/routes/gateway.go | 2 + backend/internal/server/routes/sora_client.go | 33 + backend/internal/service/account.go | 262 +- .../account_openai_passthrough_test.go | 158 + backend/internal/service/account_service.go | 51 +- .../internal/service/account_test_service.go | 95 +- .../internal/service/account_usage_service.go | 86 +- .../internal/service/account_wildcard_test.go | 69 + backend/internal/service/admin_service.go | 273 +- .../service/admin_service_bulk_update_test.go | 67 +- .../service/admin_service_list_users_test.go | 106 + .../service/antigravity_gateway_service.go | 3 +- backend/internal/service/api_key.go | 19 +- .../service/api_key_auth_cache_impl.go | 1 + backend/internal/service/api_key_service.go | 15 + backend/internal/service/auth_service.go | 11 + .../auth_service_turnstile_register_test.go | 96 + .../internal/service/billing_cache_service.go | 85 +- ...billing_cache_service_singleflight_test.go | 115 + .../service/billing_cache_service_test.go | 13 + .../service/billing_service_image_test.go | 2 +- .../internal/service/claude_code_validator.go | 2 +- .../internal/service/concurrency_service.go | 52 +- .../service/concurrency_service_test.go | 53 +- backend/internal/service/dashboard_service.go | 8 +- .../internal/service/data_management_grpc.go | 252 ++ .../service/data_management_grpc_test.go | 36 + .../service/data_management_service.go | 99 + .../service/data_management_service_test.go | 37 + backend/internal/service/domain_constants.go | 22 + ...teway_anthropic_apikey_passthrough_test.go | 4 +- backend/internal/service/gateway_beta_test.go | 64 + .../service/gateway_multiplatform_test.go | 8 + backend/internal/service/gateway_service.go | 989 ++++- ...ay_service_selection_failure_stats_test.go | 141 + ...gateway_service_sora_model_support_test.go | 79 + .../gateway_service_sora_scheduling_test.go | 89 + .../service/gateway_waiting_queue_test.go | 12 +- .../service/gemini_messages_compat_service.go | 56 +- backend/internal/service/group.go | 3 + backend/internal/service/model_rate_limit.go | 4 +- backend/internal/service/oauth_service.go | 2 +- .../internal/service/oauth_service_test.go | 20 +- .../service/openai_account_scheduler.go | 909 ++++ ...openai_account_scheduler_benchmark_test.go | 83 + .../service/openai_account_scheduler_test.go | 841 ++++ .../service/openai_client_transport.go | 71 + .../service/openai_client_transport_test.go | 107 + .../service/openai_gateway_service.go | 1267 +++++- .../openai_gateway_service_hotpath_test.go | 16 + .../service/openai_gateway_service_test.go | 50 + ...openai_json_optimization_benchmark_test.go | 357 ++ .../service/openai_oauth_passthrough_test.go | 2 +- .../internal/service/openai_oauth_service.go | 194 +- .../openai_oauth_service_auth_url_test.go | 67 + .../openai_oauth_service_sora_session_test.go | 106 +- .../openai_oauth_service_state_test.go | 6 +- .../service/openai_previous_response_id.go | 37 + .../openai_previous_response_id_test.go | 34 + .../internal/service/openai_sticky_compat.go | 214 + .../service/openai_sticky_compat_test.go | 96 + .../service/openai_tool_continuation.go | 283 +- .../internal/service/openai_tool_corrector.go | 344 +- .../service/openai_tool_corrector_test.go | 9 + .../service/openai_ws_account_sticky_test.go | 190 + backend/internal/service/openai_ws_client.go | 285 ++ .../internal/service/openai_ws_client_test.go | 112 + .../service/openai_ws_fallback_test.go | 251 ++ .../internal/service/openai_ws_forwarder.go | 3955 +++++++++++++++++ .../openai_ws_forwarder_benchmark_test.go | 127 + ..._ws_forwarder_hotpath_optimization_test.go | 73 + ...penai_ws_forwarder_ingress_session_test.go | 2483 +++++++++++ .../openai_ws_forwarder_ingress_test.go | 714 +++ .../openai_ws_forwarder_retry_payload_test.go | 50 + .../openai_ws_forwarder_success_test.go | 1306 ++++++ backend/internal/service/openai_ws_pool.go | 1706 +++++++ .../service/openai_ws_pool_benchmark_test.go | 58 + .../internal/service/openai_ws_pool_test.go | 1709 +++++++ .../openai_ws_protocol_forward_test.go | 1218 +++++ .../service/openai_ws_protocol_resolver.go | 117 + .../openai_ws_protocol_resolver_test.go | 203 + .../internal/service/openai_ws_state_store.go | 440 ++ .../service/openai_ws_state_store_test.go | 235 + backend/internal/service/ops_retry.go | 4 +- .../service/ops_retry_context_test.go | 47 + .../internal/service/ops_upstream_context.go | 5 + backend/internal/service/ratelimit_service.go | 232 +- backend/internal/service/request_metadata.go | 216 + .../internal/service/request_metadata_test.go | 119 + .../service/response_header_filter.go | 13 + .../service/scheduler_snapshot_service.go | 75 +- backend/internal/service/setting_service.go | 624 ++- backend/internal/service/settings_view.go | 42 + backend/internal/service/sora_client.go | 1 + .../internal/service/sora_gateway_service.go | 94 +- .../service/sora_gateway_service_test.go | 32 + backend/internal/service/sora_generation.go | 63 + .../service/sora_generation_service.go | 332 ++ .../service/sora_generation_service_test.go | 875 ++++ .../internal/service/sora_media_storage.go | 58 + backend/internal/service/sora_models.go | 215 + .../internal/service/sora_quota_service.go | 257 ++ .../service/sora_quota_service_test.go | 492 ++ backend/internal/service/sora_s3_storage.go | 392 ++ .../internal/service/sora_s3_storage_test.go | 263 ++ backend/internal/service/sora_sdk_client.go | 224 +- .../service/sora_upstream_forwarder.go | 149 + backend/internal/service/usage_cleanup.go | 1 + .../internal/service/usage_cleanup_service.go | 13 + .../service/usage_cleanup_service_test.go | 47 + backend/internal/service/usage_log.go | 107 +- backend/internal/service/usage_log_test.go | 112 + backend/internal/service/user.go | 4 + backend/internal/service/user_service_test.go | 22 +- backend/internal/service/wire.go | 1 + backend/internal/testutil/stubs.go | 7 + backend/internal/util/logredact/redact.go | 79 +- .../internal/util/logredact/redact_test.go | 45 + .../util/responseheaders/responseheaders.go | 28 +- .../responseheaders/responseheaders_test.go | 4 +- ..._gemini31_flash_image_to_model_mapping.sql | 61 +- .../060_add_usage_log_openai_ws_mode.sql | 2 + .../061_add_usage_log_request_type.sql | 29 + ...duler_and_usage_composite_indexes_notx.sql | 15 + .../migrations/063_add_sora_client_tables.sql | 56 + backend/migrations/README.md | 20 + deploy/.env.example | 10 +- deploy/DATAMANAGEMENTD_CN.md | 78 + deploy/README.md | 13 +- deploy/config.example.yaml | 97 +- deploy/docker-compose.override.yml.example | 13 + deploy/install-datamanagementd.sh | 123 + deploy/sub2api-datamanagementd.service | 22 + frontend/src/api/__tests__/sora.spec.ts | 80 + frontend/src/api/admin/accounts.ts | 17 + frontend/src/api/admin/dashboard.ts | 5 +- frontend/src/api/admin/dataManagement.ts | 332 ++ frontend/src/api/admin/index.ts | 8 +- frontend/src/api/admin/settings.ts | 148 +- frontend/src/api/admin/usage.ts | 5 +- frontend/src/api/sora.ts | 307 ++ .../account/AccountTodayStatsCell.vue | 59 +- .../components/account/AccountUsageCell.vue | 4 +- .../account/BulkEditAccountModal.vue | 16 +- .../components/account/CreateAccountModal.vue | 138 +- .../components/account/EditAccountModal.vue | 81 +- .../account/OAuthAuthorizationFlow.vue | 113 +- .../__tests__/AccountUsageCell.spec.ts | 70 + .../__tests__/BulkEditAccountModal.spec.ts | 72 + .../admin/usage/UsageCleanupDialog.vue | 9 +- .../components/admin/usage/UsageFilters.vue | 11 +- .../src/components/admin/usage/UsageTable.vue | 20 +- .../components/admin/user/UserEditModal.vue | 14 +- frontend/src/components/layout/AppSidebar.vue | 55 +- .../components/sora/SoraDownloadDialog.vue | 217 + .../src/components/sora/SoraGeneratePage.vue | 430 ++ .../src/components/sora/SoraLibraryPage.vue | 576 +++ .../src/components/sora/SoraMediaPreview.vue | 282 ++ .../components/sora/SoraNoStorageWarning.vue | 39 + .../src/components/sora/SoraProgressCard.vue | 609 +++ .../src/components/sora/SoraPromptBar.vue | 738 +++ frontend/src/components/sora/SoraQuotaBar.vue | 87 + .../__tests__/useModelWhitelist.spec.ts | 18 + .../__tests__/useOpenAIOAuth.spec.ts | 49 + frontend/src/composables/useModelWhitelist.ts | 1 + frontend/src/composables/useOpenAIOAuth.ts | 5 + frontend/src/i18n/locales/en.ts | 367 +- frontend/src/i18n/locales/zh.ts | 369 +- frontend/src/main.ts | 11 + frontend/src/router/index.ts | 24 + frontend/src/stores/app.ts | 1 + frontend/src/types/index.ts | 13 + .../src/utils/__tests__/openaiWsMode.spec.ts | 55 + .../utils/__tests__/soraTokenParser.spec.ts | 90 + frontend/src/utils/openaiWsMode.ts | 61 + frontend/src/utils/soraTokenParser.ts | 308 ++ frontend/src/utils/usageRequestType.ts | 33 + frontend/src/views/admin/AccountsView.vue | 97 +- .../src/views/admin/DataManagementView.vue | 514 +++ frontend/src/views/admin/GroupsView.vue | 46 +- frontend/src/views/admin/SettingsView.vue | 27 + frontend/src/views/admin/UsageView.vue | 39 +- frontend/src/views/user/SoraView.vue | 369 ++ frontend/src/views/user/UsageView.vue | 47 +- frontend/vitest.config.ts | 2 + openspec/config.yaml | 20 + openspec/project.md | 31 + .../perf/openai_responses_ws_v2_compare_k6.js | 167 + tools/perf/openai_ws_pooling_compare_k6.js | 123 + tools/perf/openai_ws_v2_perf_suite_k6.js | 216 + tools/sora-test | 192 + 338 files changed, 54513 insertions(+), 2011 deletions(-) create mode 100644 AGENTS.md create mode 100644 backend/cmd/server/wire_gen_test.go create mode 100644 backend/ent/idempotencyrecord.go create mode 100644 backend/ent/idempotencyrecord/idempotencyrecord.go create mode 100644 backend/ent/idempotencyrecord/where.go create mode 100644 backend/ent/idempotencyrecord_create.go create mode 100644 backend/ent/idempotencyrecord_delete.go create mode 100644 backend/ent/idempotencyrecord_query.go create mode 100644 backend/ent/idempotencyrecord_update.go create mode 100644 backend/internal/domain/constants_test.go create mode 100644 backend/internal/handler/admin/dashboard_handler_request_type_test.go create mode 100644 backend/internal/handler/admin/data_management_handler.go create mode 100644 backend/internal/handler/admin/data_management_handler_test.go create mode 100644 backend/internal/handler/admin/usage_handler_request_type_test.go create mode 100644 backend/internal/handler/dto/mappers_usage_test.go create mode 100644 backend/internal/handler/sora_client_handler.go create mode 100644 backend/internal/handler/sora_client_handler_test.go create mode 100644 backend/internal/handler/usage_handler_request_type_test.go create mode 100644 backend/internal/pkg/antigravity/claude_types_test.go create mode 100644 backend/internal/pkg/httpclient/pool_test.go create mode 100644 backend/internal/pkg/httputil/body.go create mode 100644 backend/internal/repository/migrations_runner_checksum_test.go create mode 100644 backend/internal/repository/migrations_runner_extra_test.go create mode 100644 backend/internal/repository/migrations_runner_notx_test.go create mode 100644 backend/internal/repository/ops_repo_dashboard_timeout_test.go create mode 100644 backend/internal/repository/sora_generation_repo.go create mode 100644 backend/internal/repository/usage_log_repo_request_type_test.go create mode 100644 backend/internal/server/routes/sora_client.go create mode 100644 backend/internal/service/admin_service_list_users_test.go create mode 100644 backend/internal/service/auth_service_turnstile_register_test.go create mode 100644 backend/internal/service/billing_cache_service_singleflight_test.go create mode 100644 backend/internal/service/data_management_grpc.go create mode 100644 backend/internal/service/data_management_grpc_test.go create mode 100644 backend/internal/service/data_management_service.go create mode 100644 backend/internal/service/data_management_service_test.go create mode 100644 backend/internal/service/gateway_service_selection_failure_stats_test.go create mode 100644 backend/internal/service/gateway_service_sora_model_support_test.go create mode 100644 backend/internal/service/gateway_service_sora_scheduling_test.go create mode 100644 backend/internal/service/openai_account_scheduler.go create mode 100644 backend/internal/service/openai_account_scheduler_benchmark_test.go create mode 100644 backend/internal/service/openai_account_scheduler_test.go create mode 100644 backend/internal/service/openai_client_transport.go create mode 100644 backend/internal/service/openai_client_transport_test.go create mode 100644 backend/internal/service/openai_json_optimization_benchmark_test.go create mode 100644 backend/internal/service/openai_oauth_service_auth_url_test.go create mode 100644 backend/internal/service/openai_previous_response_id.go create mode 100644 backend/internal/service/openai_previous_response_id_test.go create mode 100644 backend/internal/service/openai_sticky_compat.go create mode 100644 backend/internal/service/openai_sticky_compat_test.go create mode 100644 backend/internal/service/openai_ws_account_sticky_test.go create mode 100644 backend/internal/service/openai_ws_client.go create mode 100644 backend/internal/service/openai_ws_client_test.go create mode 100644 backend/internal/service/openai_ws_fallback_test.go create mode 100644 backend/internal/service/openai_ws_forwarder.go create mode 100644 backend/internal/service/openai_ws_forwarder_benchmark_test.go create mode 100644 backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go create mode 100644 backend/internal/service/openai_ws_forwarder_ingress_session_test.go create mode 100644 backend/internal/service/openai_ws_forwarder_ingress_test.go create mode 100644 backend/internal/service/openai_ws_forwarder_retry_payload_test.go create mode 100644 backend/internal/service/openai_ws_forwarder_success_test.go create mode 100644 backend/internal/service/openai_ws_pool.go create mode 100644 backend/internal/service/openai_ws_pool_benchmark_test.go create mode 100644 backend/internal/service/openai_ws_pool_test.go create mode 100644 backend/internal/service/openai_ws_protocol_forward_test.go create mode 100644 backend/internal/service/openai_ws_protocol_resolver.go create mode 100644 backend/internal/service/openai_ws_protocol_resolver_test.go create mode 100644 backend/internal/service/openai_ws_state_store.go create mode 100644 backend/internal/service/openai_ws_state_store_test.go create mode 100644 backend/internal/service/ops_retry_context_test.go create mode 100644 backend/internal/service/request_metadata.go create mode 100644 backend/internal/service/request_metadata_test.go create mode 100644 backend/internal/service/response_header_filter.go create mode 100644 backend/internal/service/sora_generation.go create mode 100644 backend/internal/service/sora_generation_service.go create mode 100644 backend/internal/service/sora_generation_service_test.go create mode 100644 backend/internal/service/sora_quota_service.go create mode 100644 backend/internal/service/sora_quota_service_test.go create mode 100644 backend/internal/service/sora_s3_storage.go create mode 100644 backend/internal/service/sora_s3_storage_test.go create mode 100644 backend/internal/service/sora_upstream_forwarder.go create mode 100644 backend/internal/service/usage_log_test.go create mode 100644 backend/migrations/060_add_usage_log_openai_ws_mode.sql create mode 100644 backend/migrations/061_add_usage_log_request_type.sql create mode 100644 backend/migrations/062_add_scheduler_and_usage_composite_indexes_notx.sql create mode 100644 backend/migrations/063_add_sora_client_tables.sql create mode 100644 deploy/DATAMANAGEMENTD_CN.md create mode 100755 deploy/install-datamanagementd.sh create mode 100644 deploy/sub2api-datamanagementd.service create mode 100644 frontend/src/api/__tests__/sora.spec.ts create mode 100644 frontend/src/api/admin/dataManagement.ts create mode 100644 frontend/src/api/sora.ts create mode 100644 frontend/src/components/account/__tests__/AccountUsageCell.spec.ts create mode 100644 frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts create mode 100644 frontend/src/components/sora/SoraDownloadDialog.vue create mode 100644 frontend/src/components/sora/SoraGeneratePage.vue create mode 100644 frontend/src/components/sora/SoraLibraryPage.vue create mode 100644 frontend/src/components/sora/SoraMediaPreview.vue create mode 100644 frontend/src/components/sora/SoraNoStorageWarning.vue create mode 100644 frontend/src/components/sora/SoraProgressCard.vue create mode 100644 frontend/src/components/sora/SoraPromptBar.vue create mode 100644 frontend/src/components/sora/SoraQuotaBar.vue create mode 100644 frontend/src/composables/__tests__/useModelWhitelist.spec.ts create mode 100644 frontend/src/composables/__tests__/useOpenAIOAuth.spec.ts create mode 100644 frontend/src/utils/__tests__/openaiWsMode.spec.ts create mode 100644 frontend/src/utils/__tests__/soraTokenParser.spec.ts create mode 100644 frontend/src/utils/openaiWsMode.ts create mode 100644 frontend/src/utils/soraTokenParser.ts create mode 100644 frontend/src/utils/usageRequestType.ts create mode 100644 frontend/src/views/admin/DataManagementView.vue create mode 100644 frontend/src/views/user/SoraView.vue create mode 100644 openspec/config.yaml create mode 100644 openspec/project.md create mode 100644 tools/perf/openai_responses_ws_v2_compare_k6.js create mode 100644 tools/perf/openai_ws_pooling_compare_k6.js create mode 100644 tools/perf/openai_ws_v2_perf_suite_k6.js create mode 100755 tools/sora-test diff --git a/.gitignore b/.gitignore index 515ce84f..297c1d6f 100644 --- a/.gitignore +++ b/.gitignore @@ -116,13 +116,12 @@ backend/.installed # =================== tests CLAUDE.md -AGENTS.md .claude scripts .code-review-state -openspec/ +#openspec/ code-reviews/ -AGENTS.md +#AGENTS.md backend/cmd/server/server deploy/docker-compose.override.yml .gocache/ @@ -132,4 +131,5 @@ docs/* .codex/ frontend/coverage/ aicodex +output/ diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..bb5bb465 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,105 @@ +# Repository Guidelines + +## Project Structure & Module Organization +- `backend/`: Go service. `cmd/server` is the entrypoint, `internal/` contains handlers/services/repositories/server wiring, `ent/` holds Ent schemas and generated ORM code, `migrations/` stores DB migrations, and `internal/web/dist/` is the embedded frontend build output. +- `frontend/`: Vue 3 + TypeScript app. Main folders are `src/api`, `src/components`, `src/views`, `src/stores`, `src/composables`, `src/utils`, and test files in `src/**/__tests__`. +- `deploy/`: Docker and deployment assets (`docker-compose*.yml`, `.env.example`, `config.example.yaml`). +- `openspec/`: Spec-driven change docs (`changes//{proposal,design,tasks}.md`). +- `tools/`: Utility scripts (security/perf checks). + +## Build, Test, and Development Commands +```bash +make build # Build backend + frontend +make test # Backend tests + frontend lint/typecheck +cd backend && make build # Build backend binary +cd backend && make test-unit # Go unit tests +cd backend && make test-integration # Go integration tests +cd backend && make test # go test ./... + golangci-lint +cd frontend && pnpm install --frozen-lockfile +cd frontend && pnpm dev # Vite dev server +cd frontend && pnpm build # Type-check + production build +cd frontend && pnpm test:run # Vitest run +cd frontend && pnpm test:coverage # Vitest + coverage report +python3 tools/secret_scan.py # Secret scan +``` + +## Coding Style & Naming Conventions +- Go: format with `gofmt`; lint with `golangci-lint` (`backend/.golangci.yml`). +- Respect layering: `internal/service` and `internal/handler` must not import `internal/repository`, `gorm`, or `redis` directly (enforced by depguard). +- Frontend: Vue SFC + TypeScript, 2-space indentation, ESLint rules from `frontend/.eslintrc.cjs`. +- Naming: components use `PascalCase.vue`, composables use `useXxx.ts`, Go tests use `*_test.go`, frontend tests use `*.spec.ts`. + +## Go & Frontend Development Standards +- Control branch complexity: `if` nesting must not exceed 3 levels. Refactor with guard clauses, early returns, helper functions, or strategy maps when deeper logic appears. +- JSON hot-path rule: for read-only/partial-field extraction, prefer `gjson` over full `encoding/json` struct unmarshal to reduce allocations and improve latency. +- Exception rule: if full schema validation or typed writes are required, `encoding/json` is allowed, but PR must explain why `gjson` is not suitable. + +### Go Performance Rules +- Optimization workflow rule: benchmark/profile first, then optimize. Use `go test -bench`, `go tool pprof`, and runtime diagnostics before changing hot-path code. +- For hot functions, run escape analysis (`go build -gcflags=all='-m -m'`) and prioritize stack allocation where reasonable. +- Every external I/O path must use `context.Context` with explicit timeout/cancel. +- When creating derived contexts (`WithTimeout` / `WithDeadline`), always `defer cancel()` to release resources. +- Preallocate slices/maps when size can be estimated (`make([]T, 0, n)`, `make(map[K]V, n)`). +- Avoid unnecessary allocations in loops; reuse buffers and prefer `strings.Builder`/`bytes.Buffer`. +- Prohibit N+1 query patterns; batch DB/Redis operations and verify indexes for new query paths. +- For hot-path changes, include benchmark or latency comparison evidence (e.g., `go test -bench` before/after). +- Keep goroutine growth bounded (worker pool/semaphore), and avoid unbounded fan-out. +- Lock minimization rule: if a lock can be avoided, do not use a lock. Prefer ownership transfer (channel), sharding, immutable snapshots, copy-on-write, or atomic operations to reduce contention. +- When locks are unavoidable, keep critical sections minimal, avoid nested locks, and document why lock-free alternatives are not feasible. +- Follow `sync` guidance: prefer channels for higher-level synchronization; use low-level mutex primitives only where necessary. +- Avoid reflection and `interface{}`-heavy conversions in hot paths; use typed structs/functions. +- Use `sync.Pool` only when benchmark proves allocation reduction; remove if no measurable gain. +- Avoid repeated `time.Now()`/`fmt.Sprintf` in tight loops; hoist or cache when possible. +- For stable high-traffic binaries, maintain representative `default.pgo` profiles and keep `go build -pgo=auto` enabled. + +### Data Access & Cache Rules +- Every new/changed SQL query must be checked with `EXPLAIN` (or `EXPLAIN ANALYZE` in staging) and include index rationale in PR. +- Default to keyset pagination for large tables; avoid deep `OFFSET` scans on hot endpoints. +- Query only required columns; prohibit broad `SELECT *` in latency-sensitive paths. +- Keep transactions short; never perform external RPC/network calls inside DB transactions. +- Connection pool must be explicitly tuned and observed via `DB.Stats` (`SetMaxOpenConns`, `SetMaxIdleConns`, `SetConnMaxIdleTime`, `SetConnMaxLifetime`). +- Avoid overly small `MaxOpenConns` that can turn DB access into lock/semaphore bottlenecks. +- Cache keys must be versioned (e.g., `user_usage:v2:{id}`) and TTL should include jitter to avoid thundering herd. +- Use request coalescing (`singleflight` or equivalent) for high-concurrency cache miss paths. + +### Frontend Performance Rules +- Route-level and heavy-module code splitting is required; lazy-load non-critical views/components. +- API requests must support cancellation and deduplication; use debounce/throttle for search-like inputs. +- Minimize unnecessary reactivity: avoid deep watch chains when computed/cache can solve it. +- Prefer stable props and selective rendering controls (`v-once`, `v-memo`) for expensive subtrees when data is static or keyed. +- Large data rendering must use pagination or virtualization (especially tables/lists >200 rows). +- Move expensive CPU work off the main thread (Web Worker) or chunk tasks to avoid UI blocking. +- Keep bundle growth controlled; avoid adding heavy dependencies without clear ROI and alternatives review. +- Avoid expensive inline computations in templates; move to cached `computed` selectors. +- Keep state normalized; avoid duplicated derived state across multiple stores/components. +- Load charts/editors/export libraries on demand only (`dynamic import`) instead of app-entry import. +- Core Web Vitals targets (p75): `LCP <= 2.5s`, `INP <= 200ms`, `CLS <= 0.1`. +- Main-thread task budget: keep individual tasks below ~50ms; split long tasks and yield between chunks. +- Enforce frontend budgets in CI (Lighthouse CI with `budget.json`) for critical routes. + +### Performance Budget & PR Evidence +- Performance budget is mandatory for hot-path PRs: backend p95/p99 latency and CPU/memory must not regress by more than 5% versus baseline. +- Frontend budget: new route-level JS should not increase by more than 30KB gzip without explicit approval. +- For any gateway/protocol hot path, attach a reproducible benchmark command and results (input size, concurrency, before/after table). +- Profiling evidence is required for major optimizations (`pprof`, flamegraph, browser performance trace, or bundle analyzer output). + +### Quality Gate +- Any changed code must include new or updated unit tests. +- Coverage must stay above 85% (global frontend threshold and no regressions for touched backend modules). +- If any rule is intentionally violated, document reason, risk, and mitigation in the PR description. + +## Testing Guidelines +- Backend suites: `go test -tags=unit ./...`, `go test -tags=integration ./...`, and e2e where relevant. +- Frontend uses Vitest (`jsdom`); keep tests near modules (`__tests__`) or as `*.spec.ts`. +- Enforce unit-test and coverage rules defined in `Quality Gate`. +- Before opening a PR, run `make test` plus targeted tests for touched areas. + +## Commit & Pull Request Guidelines +- Follow Conventional Commits: `feat(scope): ...`, `fix(scope): ...`, `chore(scope): ...`, `docs(scope): ...`. +- PRs should include a clear summary, linked issue/spec, commands run for verification, and screenshots/GIFs for UI changes. +- For behavior/API changes, add or update `openspec/changes/...` artifacts. +- If dependencies change, commit `frontend/pnpm-lock.yaml` in the same PR. + +## Security & Configuration Tips +- Use `deploy/.env.example` and `deploy/config.example.yaml` as templates; do not commit real credentials. +- Set stable `JWT_SECRET`, `TOTP_ENCRYPTION_KEY`, and strong database passwords outside local dev. diff --git a/Dockerfile b/Dockerfile index 645465f1..1493e8a7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,7 +8,7 @@ ARG NODE_IMAGE=node:24-alpine ARG GOLANG_IMAGE=golang:1.25.7-alpine -ARG ALPINE_IMAGE=alpine:3.20 +ARG ALPINE_IMAGE=alpine:3.21 ARG GOPROXY=https://goproxy.cn,direct ARG GOSUMDB=sum.golang.google.cn @@ -68,6 +68,7 @@ RUN VERSION_VALUE="${VERSION}" && \ CGO_ENABLED=0 GOOS=linux go build \ -tags embed \ -ldflags="-s -w -X main.Version=${VERSION_VALUE} -X main.Commit=${COMMIT} -X main.Date=${DATE_VALUE} -X main.BuildType=release" \ + -trimpath \ -o /app/sub2api \ ./cmd/server @@ -85,7 +86,6 @@ LABEL org.opencontainers.image.source="https://github.com/Wei-Shaw/sub2api" RUN apk add --no-cache \ ca-certificates \ tzdata \ - curl \ && rm -rf /var/cache/apk/* # Create non-root user @@ -95,11 +95,12 @@ RUN addgroup -g 1000 sub2api && \ # Set working directory WORKDIR /app -# Copy binary from builder -COPY --from=backend-builder /app/sub2api /app/sub2api +# Copy binary/resources with ownership to avoid extra full-layer chown copy +COPY --from=backend-builder --chown=sub2api:sub2api /app/sub2api /app/sub2api +COPY --from=backend-builder --chown=sub2api:sub2api /app/backend/resources /app/resources # Create data directory -RUN mkdir -p /app/data && chown -R sub2api:sub2api /app +RUN mkdir -p /app/data && chown sub2api:sub2api /app/data # Switch to non-root user USER sub2api @@ -109,7 +110,7 @@ EXPOSE 8080 # Health check HEALTHCHECK --interval=30s --timeout=10s --start-period=10s --retries=3 \ - CMD curl -f http://localhost:${SERVER_PORT:-8080}/health || exit 1 + CMD wget -q -T 5 -O /dev/null http://localhost:${SERVER_PORT:-8080}/health || exit 1 # Run the application ENTRYPOINT ["/app/sub2api"] diff --git a/Makefile b/Makefile index b97404eb..fd6a5a9a 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: build build-backend build-frontend test test-backend test-frontend secret-scan +.PHONY: build build-backend build-frontend build-datamanagementd test test-backend test-frontend test-datamanagementd secret-scan # 一键编译前后端 build: build-backend build-frontend @@ -11,6 +11,10 @@ build-backend: build-frontend: @pnpm --dir frontend run build +# 编译 datamanagementd(宿主机数据管理进程) +build-datamanagementd: + @cd datamanagement && go build -o datamanagementd ./cmd/datamanagementd + # 运行测试(后端 + 前端) test: test-backend test-frontend @@ -21,5 +25,8 @@ test-frontend: @pnpm --dir frontend run lint:check @pnpm --dir frontend run typecheck +test-datamanagementd: + @cd datamanagement && go test ./... + secret-scan: @python3 tools/secret_scan.py diff --git a/README.md b/README.md index a5f680bf..8804ec30 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,34 @@ Sub2API is an AI API gateway platform designed to distribute and manage API quot --- +## Codex CLI WebSocket v2 Example + +To enable OpenAI WebSocket Mode v2 in Codex CLI with Sub2API, add the following to `~/.codex/config.toml`: + +```toml +model_provider = "aicodx2api" +model = "gpt-5.3-codex" +review_model = "gpt-5.3-codex" +model_reasoning_effort = "xhigh" +disable_response_storage = true +network_access = "enabled" +windows_wsl_setup_acknowledged = true + +[model_providers.aicodx2api] +name = "aicodx2api" +base_url = "https://api.sub2api.ai" +wire_api = "responses" +supports_websockets = true +requires_openai_auth = true + +[features] +responses_websockets_v2 = true +``` + +After updating the config, restart Codex CLI. + +--- + ## Deployment ### Method 1: Script Installation (Recommended) diff --git a/README_CN.md b/README_CN.md index ea35a19d..22a772b5 100644 --- a/README_CN.md +++ b/README_CN.md @@ -62,6 +62,32 @@ Sub2API 是一个 AI API 网关平台,用于分发和管理 AI 产品订阅( - 当请求包含 `function_call_output` 时,需要携带 `previous_response_id`,或在 `input` 中包含带 `call_id` 的 `tool_call`/`function_call`,或带非空 `id` 且与 `function_call_output.call_id` 匹配的 `item_reference`。 - 若依赖上游历史记录,网关会强制 `store=true` 并需要复用 `previous_response_id`,以避免出现 “No tool call found for function call output” 错误。 +## Codex CLI 开启 OpenAI WebSocket Mode v2 示例配置 + +如需在 Codex CLI 中通过 Sub2API 启用 OpenAI WebSocket Mode v2,可将以下配置写入 `~/.codex/config.toml`: + +```toml +model_provider = "aicodx2api" +model = "gpt-5.3-codex" +review_model = "gpt-5.3-codex" +model_reasoning_effort = "xhigh" +disable_response_storage = true +network_access = "enabled" +windows_wsl_setup_acknowledged = true + +[model_providers.aicodx2api] +name = "aicodx2api" +base_url = "https://api.sub2api.ai" +wire_api = "responses" +supports_websockets = true +requires_openai_auth = true + +[features] +responses_websockets_v2 = true +``` + +配置更新后,重启 Codex CLI 使其生效。 + --- ## 部署方式 @@ -246,6 +272,18 @@ docker-compose -f docker-compose.local.yml logs -f sub2api **推荐:** 使用 `docker-compose.local.yml`(脚本部署)以便更轻松地管理数据。 +#### 启用“数据管理”功能(datamanagementd) + +如需启用管理后台“数据管理”,需要额外部署宿主机数据管理进程 `datamanagementd`。 + +关键点: + +- 主进程固定探测:`/tmp/sub2api-datamanagement.sock` +- 只有该 Socket 可连通时,数据管理功能才会开启 +- Docker 场景需将宿主机 Socket 挂载到容器同路径 + +详细部署步骤见:`deploy/DATAMANAGEMENTD_CN.md` + #### 访问 在浏览器中打开 `http://你的服务器IP:8080` diff --git a/backend/.gosec.json b/backend/.gosec.json index b34e140c..7a8ccb6a 100644 --- a/backend/.gosec.json +++ b/backend/.gosec.json @@ -1,5 +1,5 @@ { "global": { - "exclude": "G704" + "exclude": "G704,G101,G103,G104,G109,G115,G201,G202,G301,G302,G304,G306,G404" } } diff --git a/backend/Makefile b/backend/Makefile index 89db1104..7084ccb9 100644 --- a/backend/Makefile +++ b/backend/Makefile @@ -1,7 +1,14 @@ -.PHONY: build test test-unit test-integration test-e2e +.PHONY: build generate test test-unit test-integration test-e2e + +VERSION ?= $(shell tr -d '\r\n' < ./cmd/server/VERSION) +LDFLAGS ?= -s -w -X main.Version=$(VERSION) build: - go build -o bin/server ./cmd/server + CGO_ENABLED=0 go build -ldflags="$(LDFLAGS)" -trimpath -o bin/server ./cmd/server + +generate: + go generate ./ent + go generate ./cmd/server test: go test ./... diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index c0c68bab..677af349 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.85 +0.1.85.15 diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index 1ba6b184..cbf89ba3 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -7,6 +7,7 @@ import ( "context" "log" "net/http" + "sync" "time" "github.com/Wei-Shaw/sub2api/ent" @@ -84,16 +85,19 @@ func provideCleanup( openaiOAuth *service.OpenAIOAuthService, geminiOAuth *service.GeminiOAuthService, antigravityOAuth *service.AntigravityOAuthService, + openAIGateway *service.OpenAIGatewayService, ) func() { return func() { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - // Cleanup steps in reverse dependency order - cleanupSteps := []struct { + type cleanupStep struct { name string fn func() error - }{ + } + + // 应用层清理步骤可并行执行,基础设施资源(Redis/Ent)最后按顺序关闭。 + parallelSteps := []cleanupStep{ {"OpsScheduledReportService", func() error { if opsScheduledReport != nil { opsScheduledReport.Stop() @@ -206,23 +210,60 @@ func provideCleanup( antigravityOAuth.Stop() return nil }}, + {"OpenAIWSPool", func() error { + if openAIGateway != nil { + openAIGateway.CloseOpenAIWSPool() + } + return nil + }}, + } + + infraSteps := []cleanupStep{ {"Redis", func() error { + if rdb == nil { + return nil + } return rdb.Close() }}, {"Ent", func() error { + if entClient == nil { + return nil + } return entClient.Close() }}, } - for _, step := range cleanupSteps { - if err := step.fn(); err != nil { - log.Printf("[Cleanup] %s failed: %v", step.name, err) - // Continue with remaining cleanup steps even if one fails - } else { + runParallel := func(steps []cleanupStep) { + var wg sync.WaitGroup + for i := range steps { + step := steps[i] + wg.Add(1) + go func() { + defer wg.Done() + if err := step.fn(); err != nil { + log.Printf("[Cleanup] %s failed: %v", step.name, err) + return + } + log.Printf("[Cleanup] %s succeeded", step.name) + }() + } + wg.Wait() + } + + runSequential := func(steps []cleanupStep) { + for i := range steps { + step := steps[i] + if err := step.fn(); err != nil { + log.Printf("[Cleanup] %s failed: %v", step.name, err) + continue + } log.Printf("[Cleanup] %s succeeded", step.name) } } + runParallel(parallelSteps) + runSequential(infraSteps) + // Check if context timed out select { case <-ctx.Done(): diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 888de4d3..3c44aa72 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -19,6 +19,7 @@ import ( "github.com/redis/go-redis/v9" "log" "net/http" + "sync" "time" ) @@ -139,6 +140,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig) accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, compositeTokenCacheInvalidator) adminAnnouncementHandler := admin.NewAnnouncementHandler(announcementService) + dataManagementService := service.NewDataManagementService() + dataManagementHandler := admin.NewDataManagementHandler(dataManagementService) oAuthHandler := admin.NewOAuthHandler(oAuthService) openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService) geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService) @@ -163,7 +166,12 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) - settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService) + soraS3Storage := service.NewSoraS3Storage(settingService) + settingService.SetOnS3UpdateCallback(soraS3Storage.RefreshClient) + soraGenerationRepository := repository.NewSoraGenerationRepository(db) + soraQuotaService := service.NewSoraQuotaService(userRepository, groupRepository, settingService) + soraGenerationService := service.NewSoraGenerationService(soraGenerationRepository, soraS3Storage, soraQuotaService) + settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, soraS3Storage) opsHandler := admin.NewOpsHandler(opsService) updateCache := repository.NewUpdateCache(redisClient) gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig) @@ -184,19 +192,20 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { errorPassthroughCache := repository.NewErrorPassthroughCache(redisClient) errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache) errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService) - adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler) + adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig) openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, usageRecordWorkerPool, errorPassthroughService, configConfig) soraSDKClient := service.ProvideSoraSDKClient(configConfig, httpUpstream, openAITokenProvider, accountRepository, soraAccountRepository) soraMediaStorage := service.ProvideSoraMediaStorage(configConfig) - soraGatewayService := service.NewSoraGatewayService(soraSDKClient, soraMediaStorage, rateLimitService, configConfig) + soraGatewayService := service.NewSoraGatewayService(soraSDKClient, rateLimitService, httpUpstream, configConfig) + soraClientHandler := handler.NewSoraClientHandler(soraGenerationService, soraQuotaService, soraS3Storage, soraGatewayService, gatewayService, soraMediaStorage, apiKeyService) soraGatewayHandler := handler.NewSoraGatewayHandler(gatewayService, soraGatewayService, concurrencyService, billingCacheService, usageRecordWorkerPool, configConfig) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) totpHandler := handler.NewTotpHandler(totpService) idempotencyCoordinator := service.ProvideIdempotencyCoordinator(idempotencyRepository, configConfig) idempotencyCleanupService := service.ProvideIdempotencyCleanupService(idempotencyRepository, configConfig) - handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, handlerSettingHandler, totpHandler, idempotencyCoordinator, idempotencyCleanupService) + handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, soraGatewayHandler, soraClientHandler, handlerSettingHandler, totpHandler, idempotencyCoordinator, idempotencyCleanupService) jwtAuthMiddleware := middleware.NewJWTAuthMiddleware(authService, userService) adminAuthMiddleware := middleware.NewAdminAuthMiddleware(authService, userService, settingService) apiKeyAuthMiddleware := middleware.NewAPIKeyAuthMiddleware(apiKeyService, subscriptionService, configConfig) @@ -211,7 +220,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, soraAccountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, schedulerCache, configConfig) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) - v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) + v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, soraMediaCleanupService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService) application := &Application{ Server: httpServer, Cleanup: v, @@ -258,15 +267,18 @@ func provideCleanup( openaiOAuth *service.OpenAIOAuthService, geminiOAuth *service.GeminiOAuthService, antigravityOAuth *service.AntigravityOAuthService, + openAIGateway *service.OpenAIGatewayService, ) func() { return func() { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() - cleanupSteps := []struct { + type cleanupStep struct { name string fn func() error - }{ + } + + parallelSteps := []cleanupStep{ {"OpsScheduledReportService", func() error { if opsScheduledReport != nil { opsScheduledReport.Stop() @@ -379,23 +391,60 @@ func provideCleanup( antigravityOAuth.Stop() return nil }}, + {"OpenAIWSPool", func() error { + if openAIGateway != nil { + openAIGateway.CloseOpenAIWSPool() + } + return nil + }}, + } + + infraSteps := []cleanupStep{ {"Redis", func() error { + if rdb == nil { + return nil + } return rdb.Close() }}, {"Ent", func() error { + if entClient == nil { + return nil + } return entClient.Close() }}, } - for _, step := range cleanupSteps { - if err := step.fn(); err != nil { - log.Printf("[Cleanup] %s failed: %v", step.name, err) + runParallel := func(steps []cleanupStep) { + var wg sync.WaitGroup + for i := range steps { + step := steps[i] + wg.Add(1) + go func() { + defer wg.Done() + if err := step.fn(); err != nil { + log.Printf("[Cleanup] %s failed: %v", step.name, err) + return + } + log.Printf("[Cleanup] %s succeeded", step.name) + }() + } + wg.Wait() + } - } else { + runSequential := func(steps []cleanupStep) { + for i := range steps { + step := steps[i] + if err := step.fn(); err != nil { + log.Printf("[Cleanup] %s failed: %v", step.name, err) + continue + } log.Printf("[Cleanup] %s succeeded", step.name) } } + runParallel(parallelSteps) + runSequential(infraSteps) + select { case <-ctx.Done(): log.Printf("[Cleanup] Warning: cleanup timed out after 10 seconds") diff --git a/backend/cmd/server/wire_gen_test.go b/backend/cmd/server/wire_gen_test.go new file mode 100644 index 00000000..9fb9888d --- /dev/null +++ b/backend/cmd/server/wire_gen_test.go @@ -0,0 +1,81 @@ +package main + +import ( + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/handler" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestProvideServiceBuildInfo(t *testing.T) { + in := handler.BuildInfo{ + Version: "v-test", + BuildType: "release", + } + out := provideServiceBuildInfo(in) + require.Equal(t, in.Version, out.Version) + require.Equal(t, in.BuildType, out.BuildType) +} + +func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) { + cfg := &config.Config{} + + oauthSvc := service.NewOAuthService(nil, nil) + openAIOAuthSvc := service.NewOpenAIOAuthService(nil, nil) + geminiOAuthSvc := service.NewGeminiOAuthService(nil, nil, nil, nil, cfg) + antigravityOAuthSvc := service.NewAntigravityOAuthService(nil) + + tokenRefreshSvc := service.NewTokenRefreshService( + nil, + oauthSvc, + openAIOAuthSvc, + geminiOAuthSvc, + antigravityOAuthSvc, + nil, + nil, + cfg, + ) + accountExpirySvc := service.NewAccountExpiryService(nil, time.Second) + subscriptionExpirySvc := service.NewSubscriptionExpiryService(nil, time.Second) + pricingSvc := service.NewPricingService(cfg, nil) + emailQueueSvc := service.NewEmailQueueService(nil, 1) + billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, cfg) + idempotencyCleanupSvc := service.NewIdempotencyCleanupService(nil, cfg) + schedulerSnapshotSvc := service.NewSchedulerSnapshotService(nil, nil, nil, nil, cfg) + opsSystemLogSinkSvc := service.NewOpsSystemLogSink(nil) + + cleanup := provideCleanup( + nil, // entClient + nil, // redis + &service.OpsMetricsCollector{}, + &service.OpsAggregationService{}, + &service.OpsAlertEvaluatorService{}, + &service.OpsCleanupService{}, + &service.OpsScheduledReportService{}, + opsSystemLogSinkSvc, + &service.SoraMediaCleanupService{}, + schedulerSnapshotSvc, + tokenRefreshSvc, + accountExpirySvc, + subscriptionExpirySvc, + &service.UsageCleanupService{}, + idempotencyCleanupSvc, + pricingSvc, + emailQueueSvc, + billingCacheSvc, + &service.UsageRecordWorkerPool{}, + &service.SubscriptionService{}, + oauthSvc, + openAIOAuthSvc, + geminiOAuthSvc, + antigravityOAuthSvc, + nil, // openAIGateway + ) + + require.NotPanics(t, func() { + cleanup() + }) +} diff --git a/backend/ent/account.go b/backend/ent/account.go index 038aa7e5..c77002b3 100644 --- a/backend/ent/account.go +++ b/backend/ent/account.go @@ -63,6 +63,10 @@ type Account struct { RateLimitResetAt *time.Time `json:"rate_limit_reset_at,omitempty"` // OverloadUntil holds the value of the "overload_until" field. OverloadUntil *time.Time `json:"overload_until,omitempty"` + // TempUnschedulableUntil holds the value of the "temp_unschedulable_until" field. + TempUnschedulableUntil *time.Time `json:"temp_unschedulable_until,omitempty"` + // TempUnschedulableReason holds the value of the "temp_unschedulable_reason" field. + TempUnschedulableReason *string `json:"temp_unschedulable_reason,omitempty"` // SessionWindowStart holds the value of the "session_window_start" field. SessionWindowStart *time.Time `json:"session_window_start,omitempty"` // SessionWindowEnd holds the value of the "session_window_end" field. @@ -141,9 +145,9 @@ func (*Account) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullFloat64) case account.FieldID, account.FieldProxyID, account.FieldConcurrency, account.FieldPriority: values[i] = new(sql.NullInt64) - case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldSessionWindowStatus: + case account.FieldName, account.FieldNotes, account.FieldPlatform, account.FieldType, account.FieldStatus, account.FieldErrorMessage, account.FieldTempUnschedulableReason, account.FieldSessionWindowStatus: values[i] = new(sql.NullString) - case account.FieldCreatedAt, account.FieldUpdatedAt, account.FieldDeletedAt, account.FieldLastUsedAt, account.FieldExpiresAt, account.FieldRateLimitedAt, account.FieldRateLimitResetAt, account.FieldOverloadUntil, account.FieldSessionWindowStart, account.FieldSessionWindowEnd: + case account.FieldCreatedAt, account.FieldUpdatedAt, account.FieldDeletedAt, account.FieldLastUsedAt, account.FieldExpiresAt, account.FieldRateLimitedAt, account.FieldRateLimitResetAt, account.FieldOverloadUntil, account.FieldTempUnschedulableUntil, account.FieldSessionWindowStart, account.FieldSessionWindowEnd: values[i] = new(sql.NullTime) default: values[i] = new(sql.UnknownType) @@ -311,6 +315,20 @@ func (_m *Account) assignValues(columns []string, values []any) error { _m.OverloadUntil = new(time.Time) *_m.OverloadUntil = value.Time } + case account.FieldTempUnschedulableUntil: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field temp_unschedulable_until", values[i]) + } else if value.Valid { + _m.TempUnschedulableUntil = new(time.Time) + *_m.TempUnschedulableUntil = value.Time + } + case account.FieldTempUnschedulableReason: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field temp_unschedulable_reason", values[i]) + } else if value.Valid { + _m.TempUnschedulableReason = new(string) + *_m.TempUnschedulableReason = value.String + } case account.FieldSessionWindowStart: if value, ok := values[i].(*sql.NullTime); !ok { return fmt.Errorf("unexpected type %T for field session_window_start", values[i]) @@ -472,6 +490,16 @@ func (_m *Account) String() string { builder.WriteString(v.Format(time.ANSIC)) } builder.WriteString(", ") + if v := _m.TempUnschedulableUntil; v != nil { + builder.WriteString("temp_unschedulable_until=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.TempUnschedulableReason; v != nil { + builder.WriteString("temp_unschedulable_reason=") + builder.WriteString(*v) + } + builder.WriteString(", ") if v := _m.SessionWindowStart; v != nil { builder.WriteString("session_window_start=") builder.WriteString(v.Format(time.ANSIC)) diff --git a/backend/ent/account/account.go b/backend/ent/account/account.go index 73c0e8c2..1fc34620 100644 --- a/backend/ent/account/account.go +++ b/backend/ent/account/account.go @@ -59,6 +59,10 @@ const ( FieldRateLimitResetAt = "rate_limit_reset_at" // FieldOverloadUntil holds the string denoting the overload_until field in the database. FieldOverloadUntil = "overload_until" + // FieldTempUnschedulableUntil holds the string denoting the temp_unschedulable_until field in the database. + FieldTempUnschedulableUntil = "temp_unschedulable_until" + // FieldTempUnschedulableReason holds the string denoting the temp_unschedulable_reason field in the database. + FieldTempUnschedulableReason = "temp_unschedulable_reason" // FieldSessionWindowStart holds the string denoting the session_window_start field in the database. FieldSessionWindowStart = "session_window_start" // FieldSessionWindowEnd holds the string denoting the session_window_end field in the database. @@ -128,6 +132,8 @@ var Columns = []string{ FieldRateLimitedAt, FieldRateLimitResetAt, FieldOverloadUntil, + FieldTempUnschedulableUntil, + FieldTempUnschedulableReason, FieldSessionWindowStart, FieldSessionWindowEnd, FieldSessionWindowStatus, @@ -299,6 +305,16 @@ func ByOverloadUntil(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldOverloadUntil, opts...).ToFunc() } +// ByTempUnschedulableUntil orders the results by the temp_unschedulable_until field. +func ByTempUnschedulableUntil(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTempUnschedulableUntil, opts...).ToFunc() +} + +// ByTempUnschedulableReason orders the results by the temp_unschedulable_reason field. +func ByTempUnschedulableReason(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldTempUnschedulableReason, opts...).ToFunc() +} + // BySessionWindowStart orders the results by the session_window_start field. func BySessionWindowStart(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldSessionWindowStart, opts...).ToFunc() diff --git a/backend/ent/account/where.go b/backend/ent/account/where.go index dea1127a..54db1dcb 100644 --- a/backend/ent/account/where.go +++ b/backend/ent/account/where.go @@ -155,6 +155,16 @@ func OverloadUntil(v time.Time) predicate.Account { return predicate.Account(sql.FieldEQ(FieldOverloadUntil, v)) } +// TempUnschedulableUntil applies equality check predicate on the "temp_unschedulable_until" field. It's identical to TempUnschedulableUntilEQ. +func TempUnschedulableUntil(v time.Time) predicate.Account { + return predicate.Account(sql.FieldEQ(FieldTempUnschedulableUntil, v)) +} + +// TempUnschedulableReason applies equality check predicate on the "temp_unschedulable_reason" field. It's identical to TempUnschedulableReasonEQ. +func TempUnschedulableReason(v string) predicate.Account { + return predicate.Account(sql.FieldEQ(FieldTempUnschedulableReason, v)) +} + // SessionWindowStart applies equality check predicate on the "session_window_start" field. It's identical to SessionWindowStartEQ. func SessionWindowStart(v time.Time) predicate.Account { return predicate.Account(sql.FieldEQ(FieldSessionWindowStart, v)) @@ -1130,6 +1140,131 @@ func OverloadUntilNotNil() predicate.Account { return predicate.Account(sql.FieldNotNull(FieldOverloadUntil)) } +// TempUnschedulableUntilEQ applies the EQ predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilEQ(v time.Time) predicate.Account { + return predicate.Account(sql.FieldEQ(FieldTempUnschedulableUntil, v)) +} + +// TempUnschedulableUntilNEQ applies the NEQ predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilNEQ(v time.Time) predicate.Account { + return predicate.Account(sql.FieldNEQ(FieldTempUnschedulableUntil, v)) +} + +// TempUnschedulableUntilIn applies the In predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilIn(vs ...time.Time) predicate.Account { + return predicate.Account(sql.FieldIn(FieldTempUnschedulableUntil, vs...)) +} + +// TempUnschedulableUntilNotIn applies the NotIn predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilNotIn(vs ...time.Time) predicate.Account { + return predicate.Account(sql.FieldNotIn(FieldTempUnschedulableUntil, vs...)) +} + +// TempUnschedulableUntilGT applies the GT predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilGT(v time.Time) predicate.Account { + return predicate.Account(sql.FieldGT(FieldTempUnschedulableUntil, v)) +} + +// TempUnschedulableUntilGTE applies the GTE predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilGTE(v time.Time) predicate.Account { + return predicate.Account(sql.FieldGTE(FieldTempUnschedulableUntil, v)) +} + +// TempUnschedulableUntilLT applies the LT predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilLT(v time.Time) predicate.Account { + return predicate.Account(sql.FieldLT(FieldTempUnschedulableUntil, v)) +} + +// TempUnschedulableUntilLTE applies the LTE predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilLTE(v time.Time) predicate.Account { + return predicate.Account(sql.FieldLTE(FieldTempUnschedulableUntil, v)) +} + +// TempUnschedulableUntilIsNil applies the IsNil predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilIsNil() predicate.Account { + return predicate.Account(sql.FieldIsNull(FieldTempUnschedulableUntil)) +} + +// TempUnschedulableUntilNotNil applies the NotNil predicate on the "temp_unschedulable_until" field. +func TempUnschedulableUntilNotNil() predicate.Account { + return predicate.Account(sql.FieldNotNull(FieldTempUnschedulableUntil)) +} + +// TempUnschedulableReasonEQ applies the EQ predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonEQ(v string) predicate.Account { + return predicate.Account(sql.FieldEQ(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonNEQ applies the NEQ predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonNEQ(v string) predicate.Account { + return predicate.Account(sql.FieldNEQ(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonIn applies the In predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonIn(vs ...string) predicate.Account { + return predicate.Account(sql.FieldIn(FieldTempUnschedulableReason, vs...)) +} + +// TempUnschedulableReasonNotIn applies the NotIn predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonNotIn(vs ...string) predicate.Account { + return predicate.Account(sql.FieldNotIn(FieldTempUnschedulableReason, vs...)) +} + +// TempUnschedulableReasonGT applies the GT predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonGT(v string) predicate.Account { + return predicate.Account(sql.FieldGT(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonGTE applies the GTE predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonGTE(v string) predicate.Account { + return predicate.Account(sql.FieldGTE(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonLT applies the LT predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonLT(v string) predicate.Account { + return predicate.Account(sql.FieldLT(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonLTE applies the LTE predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonLTE(v string) predicate.Account { + return predicate.Account(sql.FieldLTE(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonContains applies the Contains predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonContains(v string) predicate.Account { + return predicate.Account(sql.FieldContains(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonHasPrefix applies the HasPrefix predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonHasPrefix(v string) predicate.Account { + return predicate.Account(sql.FieldHasPrefix(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonHasSuffix applies the HasSuffix predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonHasSuffix(v string) predicate.Account { + return predicate.Account(sql.FieldHasSuffix(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonIsNil applies the IsNil predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonIsNil() predicate.Account { + return predicate.Account(sql.FieldIsNull(FieldTempUnschedulableReason)) +} + +// TempUnschedulableReasonNotNil applies the NotNil predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonNotNil() predicate.Account { + return predicate.Account(sql.FieldNotNull(FieldTempUnschedulableReason)) +} + +// TempUnschedulableReasonEqualFold applies the EqualFold predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonEqualFold(v string) predicate.Account { + return predicate.Account(sql.FieldEqualFold(FieldTempUnschedulableReason, v)) +} + +// TempUnschedulableReasonContainsFold applies the ContainsFold predicate on the "temp_unschedulable_reason" field. +func TempUnschedulableReasonContainsFold(v string) predicate.Account { + return predicate.Account(sql.FieldContainsFold(FieldTempUnschedulableReason, v)) +} + // SessionWindowStartEQ applies the EQ predicate on the "session_window_start" field. func SessionWindowStartEQ(v time.Time) predicate.Account { return predicate.Account(sql.FieldEQ(FieldSessionWindowStart, v)) diff --git a/backend/ent/account_create.go b/backend/ent/account_create.go index 42a561cf..963ffee8 100644 --- a/backend/ent/account_create.go +++ b/backend/ent/account_create.go @@ -293,6 +293,34 @@ func (_c *AccountCreate) SetNillableOverloadUntil(v *time.Time) *AccountCreate { return _c } +// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field. +func (_c *AccountCreate) SetTempUnschedulableUntil(v time.Time) *AccountCreate { + _c.mutation.SetTempUnschedulableUntil(v) + return _c +} + +// SetNillableTempUnschedulableUntil sets the "temp_unschedulable_until" field if the given value is not nil. +func (_c *AccountCreate) SetNillableTempUnschedulableUntil(v *time.Time) *AccountCreate { + if v != nil { + _c.SetTempUnschedulableUntil(*v) + } + return _c +} + +// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field. +func (_c *AccountCreate) SetTempUnschedulableReason(v string) *AccountCreate { + _c.mutation.SetTempUnschedulableReason(v) + return _c +} + +// SetNillableTempUnschedulableReason sets the "temp_unschedulable_reason" field if the given value is not nil. +func (_c *AccountCreate) SetNillableTempUnschedulableReason(v *string) *AccountCreate { + if v != nil { + _c.SetTempUnschedulableReason(*v) + } + return _c +} + // SetSessionWindowStart sets the "session_window_start" field. func (_c *AccountCreate) SetSessionWindowStart(v time.Time) *AccountCreate { _c.mutation.SetSessionWindowStart(v) @@ -639,6 +667,14 @@ func (_c *AccountCreate) createSpec() (*Account, *sqlgraph.CreateSpec) { _spec.SetField(account.FieldOverloadUntil, field.TypeTime, value) _node.OverloadUntil = &value } + if value, ok := _c.mutation.TempUnschedulableUntil(); ok { + _spec.SetField(account.FieldTempUnschedulableUntil, field.TypeTime, value) + _node.TempUnschedulableUntil = &value + } + if value, ok := _c.mutation.TempUnschedulableReason(); ok { + _spec.SetField(account.FieldTempUnschedulableReason, field.TypeString, value) + _node.TempUnschedulableReason = &value + } if value, ok := _c.mutation.SessionWindowStart(); ok { _spec.SetField(account.FieldSessionWindowStart, field.TypeTime, value) _node.SessionWindowStart = &value @@ -1080,6 +1116,42 @@ func (u *AccountUpsert) ClearOverloadUntil() *AccountUpsert { return u } +// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field. +func (u *AccountUpsert) SetTempUnschedulableUntil(v time.Time) *AccountUpsert { + u.Set(account.FieldTempUnschedulableUntil, v) + return u +} + +// UpdateTempUnschedulableUntil sets the "temp_unschedulable_until" field to the value that was provided on create. +func (u *AccountUpsert) UpdateTempUnschedulableUntil() *AccountUpsert { + u.SetExcluded(account.FieldTempUnschedulableUntil) + return u +} + +// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field. +func (u *AccountUpsert) ClearTempUnschedulableUntil() *AccountUpsert { + u.SetNull(account.FieldTempUnschedulableUntil) + return u +} + +// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field. +func (u *AccountUpsert) SetTempUnschedulableReason(v string) *AccountUpsert { + u.Set(account.FieldTempUnschedulableReason, v) + return u +} + +// UpdateTempUnschedulableReason sets the "temp_unschedulable_reason" field to the value that was provided on create. +func (u *AccountUpsert) UpdateTempUnschedulableReason() *AccountUpsert { + u.SetExcluded(account.FieldTempUnschedulableReason) + return u +} + +// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field. +func (u *AccountUpsert) ClearTempUnschedulableReason() *AccountUpsert { + u.SetNull(account.FieldTempUnschedulableReason) + return u +} + // SetSessionWindowStart sets the "session_window_start" field. func (u *AccountUpsert) SetSessionWindowStart(v time.Time) *AccountUpsert { u.Set(account.FieldSessionWindowStart, v) @@ -1557,6 +1629,48 @@ func (u *AccountUpsertOne) ClearOverloadUntil() *AccountUpsertOne { }) } +// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field. +func (u *AccountUpsertOne) SetTempUnschedulableUntil(v time.Time) *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.SetTempUnschedulableUntil(v) + }) +} + +// UpdateTempUnschedulableUntil sets the "temp_unschedulable_until" field to the value that was provided on create. +func (u *AccountUpsertOne) UpdateTempUnschedulableUntil() *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.UpdateTempUnschedulableUntil() + }) +} + +// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field. +func (u *AccountUpsertOne) ClearTempUnschedulableUntil() *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.ClearTempUnschedulableUntil() + }) +} + +// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field. +func (u *AccountUpsertOne) SetTempUnschedulableReason(v string) *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.SetTempUnschedulableReason(v) + }) +} + +// UpdateTempUnschedulableReason sets the "temp_unschedulable_reason" field to the value that was provided on create. +func (u *AccountUpsertOne) UpdateTempUnschedulableReason() *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.UpdateTempUnschedulableReason() + }) +} + +// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field. +func (u *AccountUpsertOne) ClearTempUnschedulableReason() *AccountUpsertOne { + return u.Update(func(s *AccountUpsert) { + s.ClearTempUnschedulableReason() + }) +} + // SetSessionWindowStart sets the "session_window_start" field. func (u *AccountUpsertOne) SetSessionWindowStart(v time.Time) *AccountUpsertOne { return u.Update(func(s *AccountUpsert) { @@ -2209,6 +2323,48 @@ func (u *AccountUpsertBulk) ClearOverloadUntil() *AccountUpsertBulk { }) } +// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field. +func (u *AccountUpsertBulk) SetTempUnschedulableUntil(v time.Time) *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.SetTempUnschedulableUntil(v) + }) +} + +// UpdateTempUnschedulableUntil sets the "temp_unschedulable_until" field to the value that was provided on create. +func (u *AccountUpsertBulk) UpdateTempUnschedulableUntil() *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.UpdateTempUnschedulableUntil() + }) +} + +// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field. +func (u *AccountUpsertBulk) ClearTempUnschedulableUntil() *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.ClearTempUnschedulableUntil() + }) +} + +// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field. +func (u *AccountUpsertBulk) SetTempUnschedulableReason(v string) *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.SetTempUnschedulableReason(v) + }) +} + +// UpdateTempUnschedulableReason sets the "temp_unschedulable_reason" field to the value that was provided on create. +func (u *AccountUpsertBulk) UpdateTempUnschedulableReason() *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.UpdateTempUnschedulableReason() + }) +} + +// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field. +func (u *AccountUpsertBulk) ClearTempUnschedulableReason() *AccountUpsertBulk { + return u.Update(func(s *AccountUpsert) { + s.ClearTempUnschedulableReason() + }) +} + // SetSessionWindowStart sets the "session_window_start" field. func (u *AccountUpsertBulk) SetSessionWindowStart(v time.Time) *AccountUpsertBulk { return u.Update(func(s *AccountUpsert) { diff --git a/backend/ent/account_update.go b/backend/ent/account_update.go index 63fab096..875888e0 100644 --- a/backend/ent/account_update.go +++ b/backend/ent/account_update.go @@ -376,6 +376,46 @@ func (_u *AccountUpdate) ClearOverloadUntil() *AccountUpdate { return _u } +// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field. +func (_u *AccountUpdate) SetTempUnschedulableUntil(v time.Time) *AccountUpdate { + _u.mutation.SetTempUnschedulableUntil(v) + return _u +} + +// SetNillableTempUnschedulableUntil sets the "temp_unschedulable_until" field if the given value is not nil. +func (_u *AccountUpdate) SetNillableTempUnschedulableUntil(v *time.Time) *AccountUpdate { + if v != nil { + _u.SetTempUnschedulableUntil(*v) + } + return _u +} + +// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field. +func (_u *AccountUpdate) ClearTempUnschedulableUntil() *AccountUpdate { + _u.mutation.ClearTempUnschedulableUntil() + return _u +} + +// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field. +func (_u *AccountUpdate) SetTempUnschedulableReason(v string) *AccountUpdate { + _u.mutation.SetTempUnschedulableReason(v) + return _u +} + +// SetNillableTempUnschedulableReason sets the "temp_unschedulable_reason" field if the given value is not nil. +func (_u *AccountUpdate) SetNillableTempUnschedulableReason(v *string) *AccountUpdate { + if v != nil { + _u.SetTempUnschedulableReason(*v) + } + return _u +} + +// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field. +func (_u *AccountUpdate) ClearTempUnschedulableReason() *AccountUpdate { + _u.mutation.ClearTempUnschedulableReason() + return _u +} + // SetSessionWindowStart sets the "session_window_start" field. func (_u *AccountUpdate) SetSessionWindowStart(v time.Time) *AccountUpdate { _u.mutation.SetSessionWindowStart(v) @@ -701,6 +741,18 @@ func (_u *AccountUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.OverloadUntilCleared() { _spec.ClearField(account.FieldOverloadUntil, field.TypeTime) } + if value, ok := _u.mutation.TempUnschedulableUntil(); ok { + _spec.SetField(account.FieldTempUnschedulableUntil, field.TypeTime, value) + } + if _u.mutation.TempUnschedulableUntilCleared() { + _spec.ClearField(account.FieldTempUnschedulableUntil, field.TypeTime) + } + if value, ok := _u.mutation.TempUnschedulableReason(); ok { + _spec.SetField(account.FieldTempUnschedulableReason, field.TypeString, value) + } + if _u.mutation.TempUnschedulableReasonCleared() { + _spec.ClearField(account.FieldTempUnschedulableReason, field.TypeString) + } if value, ok := _u.mutation.SessionWindowStart(); ok { _spec.SetField(account.FieldSessionWindowStart, field.TypeTime, value) } @@ -1215,6 +1267,46 @@ func (_u *AccountUpdateOne) ClearOverloadUntil() *AccountUpdateOne { return _u } +// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field. +func (_u *AccountUpdateOne) SetTempUnschedulableUntil(v time.Time) *AccountUpdateOne { + _u.mutation.SetTempUnschedulableUntil(v) + return _u +} + +// SetNillableTempUnschedulableUntil sets the "temp_unschedulable_until" field if the given value is not nil. +func (_u *AccountUpdateOne) SetNillableTempUnschedulableUntil(v *time.Time) *AccountUpdateOne { + if v != nil { + _u.SetTempUnschedulableUntil(*v) + } + return _u +} + +// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field. +func (_u *AccountUpdateOne) ClearTempUnschedulableUntil() *AccountUpdateOne { + _u.mutation.ClearTempUnschedulableUntil() + return _u +} + +// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field. +func (_u *AccountUpdateOne) SetTempUnschedulableReason(v string) *AccountUpdateOne { + _u.mutation.SetTempUnschedulableReason(v) + return _u +} + +// SetNillableTempUnschedulableReason sets the "temp_unschedulable_reason" field if the given value is not nil. +func (_u *AccountUpdateOne) SetNillableTempUnschedulableReason(v *string) *AccountUpdateOne { + if v != nil { + _u.SetTempUnschedulableReason(*v) + } + return _u +} + +// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field. +func (_u *AccountUpdateOne) ClearTempUnschedulableReason() *AccountUpdateOne { + _u.mutation.ClearTempUnschedulableReason() + return _u +} + // SetSessionWindowStart sets the "session_window_start" field. func (_u *AccountUpdateOne) SetSessionWindowStart(v time.Time) *AccountUpdateOne { _u.mutation.SetSessionWindowStart(v) @@ -1570,6 +1662,18 @@ func (_u *AccountUpdateOne) sqlSave(ctx context.Context) (_node *Account, err er if _u.mutation.OverloadUntilCleared() { _spec.ClearField(account.FieldOverloadUntil, field.TypeTime) } + if value, ok := _u.mutation.TempUnschedulableUntil(); ok { + _spec.SetField(account.FieldTempUnschedulableUntil, field.TypeTime, value) + } + if _u.mutation.TempUnschedulableUntilCleared() { + _spec.ClearField(account.FieldTempUnschedulableUntil, field.TypeTime) + } + if value, ok := _u.mutation.TempUnschedulableReason(); ok { + _spec.SetField(account.FieldTempUnschedulableReason, field.TypeString, value) + } + if _u.mutation.TempUnschedulableReasonCleared() { + _spec.ClearField(account.FieldTempUnschedulableReason, field.TypeString) + } if value, ok := _u.mutation.SessionWindowStart(); ok { _spec.SetField(account.FieldSessionWindowStart, field.TypeTime, value) } diff --git a/backend/ent/client.go b/backend/ent/client.go index 504c1755..7ebbaa32 100644 --- a/backend/ent/client.go +++ b/backend/ent/client.go @@ -22,6 +22,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" @@ -58,6 +59,8 @@ type Client struct { ErrorPassthroughRule *ErrorPassthroughRuleClient // Group is the client for interacting with the Group builders. Group *GroupClient + // IdempotencyRecord is the client for interacting with the IdempotencyRecord builders. + IdempotencyRecord *IdempotencyRecordClient // PromoCode is the client for interacting with the PromoCode builders. PromoCode *PromoCodeClient // PromoCodeUsage is the client for interacting with the PromoCodeUsage builders. @@ -102,6 +105,7 @@ func (c *Client) init() { c.AnnouncementRead = NewAnnouncementReadClient(c.config) c.ErrorPassthroughRule = NewErrorPassthroughRuleClient(c.config) c.Group = NewGroupClient(c.config) + c.IdempotencyRecord = NewIdempotencyRecordClient(c.config) c.PromoCode = NewPromoCodeClient(c.config) c.PromoCodeUsage = NewPromoCodeUsageClient(c.config) c.Proxy = NewProxyClient(c.config) @@ -214,6 +218,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) { AnnouncementRead: NewAnnouncementReadClient(cfg), ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), Group: NewGroupClient(cfg), + IdempotencyRecord: NewIdempotencyRecordClient(cfg), PromoCode: NewPromoCodeClient(cfg), PromoCodeUsage: NewPromoCodeUsageClient(cfg), Proxy: NewProxyClient(cfg), @@ -253,6 +258,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) AnnouncementRead: NewAnnouncementReadClient(cfg), ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), Group: NewGroupClient(cfg), + IdempotencyRecord: NewIdempotencyRecordClient(cfg), PromoCode: NewPromoCodeClient(cfg), PromoCodeUsage: NewPromoCodeUsageClient(cfg), Proxy: NewProxyClient(cfg), @@ -296,10 +302,10 @@ func (c *Client) Close() error { func (c *Client) Use(hooks ...Hook) { for _, n := range []interface{ Use(...Hook) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, - c.ErrorPassthroughRule, c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy, - c.RedeemCode, c.SecuritySecret, c.Setting, c.UsageCleanupTask, c.UsageLog, - c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, - c.UserSubscription, + c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PromoCode, + c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, + c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup, + c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, } { n.Use(hooks...) } @@ -310,10 +316,10 @@ func (c *Client) Use(hooks ...Hook) { func (c *Client) Intercept(interceptors ...Interceptor) { for _, n := range []interface{ Intercept(...Interceptor) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, - c.ErrorPassthroughRule, c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy, - c.RedeemCode, c.SecuritySecret, c.Setting, c.UsageCleanupTask, c.UsageLog, - c.User, c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, - c.UserSubscription, + c.ErrorPassthroughRule, c.Group, c.IdempotencyRecord, c.PromoCode, + c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.SecuritySecret, c.Setting, + c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup, + c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, } { n.Intercept(interceptors...) } @@ -336,6 +342,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { return c.ErrorPassthroughRule.mutate(ctx, m) case *GroupMutation: return c.Group.mutate(ctx, m) + case *IdempotencyRecordMutation: + return c.IdempotencyRecord.mutate(ctx, m) case *PromoCodeMutation: return c.PromoCode.mutate(ctx, m) case *PromoCodeUsageMutation: @@ -1575,6 +1583,139 @@ func (c *GroupClient) mutate(ctx context.Context, m *GroupMutation) (Value, erro } } +// IdempotencyRecordClient is a client for the IdempotencyRecord schema. +type IdempotencyRecordClient struct { + config +} + +// NewIdempotencyRecordClient returns a client for the IdempotencyRecord from the given config. +func NewIdempotencyRecordClient(c config) *IdempotencyRecordClient { + return &IdempotencyRecordClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `idempotencyrecord.Hooks(f(g(h())))`. +func (c *IdempotencyRecordClient) Use(hooks ...Hook) { + c.hooks.IdempotencyRecord = append(c.hooks.IdempotencyRecord, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `idempotencyrecord.Intercept(f(g(h())))`. +func (c *IdempotencyRecordClient) Intercept(interceptors ...Interceptor) { + c.inters.IdempotencyRecord = append(c.inters.IdempotencyRecord, interceptors...) +} + +// Create returns a builder for creating a IdempotencyRecord entity. +func (c *IdempotencyRecordClient) Create() *IdempotencyRecordCreate { + mutation := newIdempotencyRecordMutation(c.config, OpCreate) + return &IdempotencyRecordCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of IdempotencyRecord entities. +func (c *IdempotencyRecordClient) CreateBulk(builders ...*IdempotencyRecordCreate) *IdempotencyRecordCreateBulk { + return &IdempotencyRecordCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *IdempotencyRecordClient) MapCreateBulk(slice any, setFunc func(*IdempotencyRecordCreate, int)) *IdempotencyRecordCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &IdempotencyRecordCreateBulk{err: fmt.Errorf("calling to IdempotencyRecordClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*IdempotencyRecordCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &IdempotencyRecordCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for IdempotencyRecord. +func (c *IdempotencyRecordClient) Update() *IdempotencyRecordUpdate { + mutation := newIdempotencyRecordMutation(c.config, OpUpdate) + return &IdempotencyRecordUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *IdempotencyRecordClient) UpdateOne(_m *IdempotencyRecord) *IdempotencyRecordUpdateOne { + mutation := newIdempotencyRecordMutation(c.config, OpUpdateOne, withIdempotencyRecord(_m)) + return &IdempotencyRecordUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *IdempotencyRecordClient) UpdateOneID(id int64) *IdempotencyRecordUpdateOne { + mutation := newIdempotencyRecordMutation(c.config, OpUpdateOne, withIdempotencyRecordID(id)) + return &IdempotencyRecordUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for IdempotencyRecord. +func (c *IdempotencyRecordClient) Delete() *IdempotencyRecordDelete { + mutation := newIdempotencyRecordMutation(c.config, OpDelete) + return &IdempotencyRecordDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *IdempotencyRecordClient) DeleteOne(_m *IdempotencyRecord) *IdempotencyRecordDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *IdempotencyRecordClient) DeleteOneID(id int64) *IdempotencyRecordDeleteOne { + builder := c.Delete().Where(idempotencyrecord.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &IdempotencyRecordDeleteOne{builder} +} + +// Query returns a query builder for IdempotencyRecord. +func (c *IdempotencyRecordClient) Query() *IdempotencyRecordQuery { + return &IdempotencyRecordQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeIdempotencyRecord}, + inters: c.Interceptors(), + } +} + +// Get returns a IdempotencyRecord entity by its id. +func (c *IdempotencyRecordClient) Get(ctx context.Context, id int64) (*IdempotencyRecord, error) { + return c.Query().Where(idempotencyrecord.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *IdempotencyRecordClient) GetX(ctx context.Context, id int64) *IdempotencyRecord { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// Hooks returns the client hooks. +func (c *IdempotencyRecordClient) Hooks() []Hook { + return c.hooks.IdempotencyRecord +} + +// Interceptors returns the client interceptors. +func (c *IdempotencyRecordClient) Interceptors() []Interceptor { + return c.inters.IdempotencyRecord +} + +func (c *IdempotencyRecordClient) mutate(ctx context.Context, m *IdempotencyRecordMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&IdempotencyRecordCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&IdempotencyRecordUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&IdempotencyRecordUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&IdempotencyRecordDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown IdempotencyRecord mutation op: %q", m.Op()) + } +} + // PromoCodeClient is a client for the PromoCode schema. type PromoCodeClient struct { config @@ -3747,15 +3888,17 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription type ( hooks struct { APIKey, Account, AccountGroup, Announcement, AnnouncementRead, - ErrorPassthroughRule, Group, PromoCode, PromoCodeUsage, Proxy, RedeemCode, - SecuritySecret, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup, - UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook + ErrorPassthroughRule, Group, IdempotencyRecord, PromoCode, PromoCodeUsage, + Proxy, RedeemCode, SecuritySecret, Setting, UsageCleanupTask, UsageLog, User, + UserAllowedGroup, UserAttributeDefinition, UserAttributeValue, + UserSubscription []ent.Hook } inters struct { APIKey, Account, AccountGroup, Announcement, AnnouncementRead, - ErrorPassthroughRule, Group, PromoCode, PromoCodeUsage, Proxy, RedeemCode, - SecuritySecret, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup, - UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor + ErrorPassthroughRule, Group, IdempotencyRecord, PromoCode, PromoCodeUsage, + Proxy, RedeemCode, SecuritySecret, Setting, UsageCleanupTask, UsageLog, User, + UserAllowedGroup, UserAttributeDefinition, UserAttributeValue, + UserSubscription []ent.Interceptor } ) diff --git a/backend/ent/ent.go b/backend/ent/ent.go index c4ec3387..5197e4d8 100644 --- a/backend/ent/ent.go +++ b/backend/ent/ent.go @@ -19,6 +19,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" @@ -99,6 +100,7 @@ func checkColumn(t, c string) error { announcementread.Table: announcementread.ValidColumn, errorpassthroughrule.Table: errorpassthroughrule.ValidColumn, group.Table: group.ValidColumn, + idempotencyrecord.Table: idempotencyrecord.ValidColumn, promocode.Table: promocode.ValidColumn, promocodeusage.Table: promocodeusage.ValidColumn, proxy.Table: proxy.ValidColumn, diff --git a/backend/ent/group.go b/backend/ent/group.go index 79ec5bf5..76c3cae2 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -60,6 +60,8 @@ type Group struct { SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"` // SoraVideoPricePerRequestHd holds the value of the "sora_video_price_per_request_hd" field. SoraVideoPricePerRequestHd *float64 `json:"sora_video_price_per_request_hd,omitempty"` + // SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field. + SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"` // 是否仅允许 Claude Code 客户端 ClaudeCodeOnly bool `json:"claude_code_only,omitempty"` // 非 Claude Code 请求降级使用的分组 ID @@ -188,7 +190,7 @@ func (*Group) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullBool) case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k, group.FieldSoraImagePrice360, group.FieldSoraImagePrice540, group.FieldSoraVideoPricePerRequest, group.FieldSoraVideoPricePerRequestHd: values[i] = new(sql.NullFloat64) - case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder: + case group.FieldID, group.FieldDefaultValidityDays, group.FieldSoraStorageQuotaBytes, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder: values[i] = new(sql.NullInt64) case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType: values[i] = new(sql.NullString) @@ -353,6 +355,12 @@ func (_m *Group) assignValues(columns []string, values []any) error { _m.SoraVideoPricePerRequestHd = new(float64) *_m.SoraVideoPricePerRequestHd = value.Float64 } + case group.FieldSoraStorageQuotaBytes: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field sora_storage_quota_bytes", values[i]) + } else if value.Valid { + _m.SoraStorageQuotaBytes = value.Int64 + } case group.FieldClaudeCodeOnly: if value, ok := values[i].(*sql.NullBool); !ok { return fmt.Errorf("unexpected type %T for field claude_code_only", values[i]) @@ -570,6 +578,9 @@ func (_m *Group) String() string { builder.WriteString(fmt.Sprintf("%v", *v)) } builder.WriteString(", ") + builder.WriteString("sora_storage_quota_bytes=") + builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageQuotaBytes)) + builder.WriteString(", ") builder.WriteString("claude_code_only=") builder.WriteString(fmt.Sprintf("%v", _m.ClaudeCodeOnly)) builder.WriteString(", ") diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index 133123a1..6ac4eea1 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -57,6 +57,8 @@ const ( FieldSoraVideoPricePerRequest = "sora_video_price_per_request" // FieldSoraVideoPricePerRequestHd holds the string denoting the sora_video_price_per_request_hd field in the database. FieldSoraVideoPricePerRequestHd = "sora_video_price_per_request_hd" + // FieldSoraStorageQuotaBytes holds the string denoting the sora_storage_quota_bytes field in the database. + FieldSoraStorageQuotaBytes = "sora_storage_quota_bytes" // FieldClaudeCodeOnly holds the string denoting the claude_code_only field in the database. FieldClaudeCodeOnly = "claude_code_only" // FieldFallbackGroupID holds the string denoting the fallback_group_id field in the database. @@ -169,6 +171,7 @@ var Columns = []string{ FieldSoraImagePrice540, FieldSoraVideoPricePerRequest, FieldSoraVideoPricePerRequestHd, + FieldSoraStorageQuotaBytes, FieldClaudeCodeOnly, FieldFallbackGroupID, FieldFallbackGroupIDOnInvalidRequest, @@ -232,6 +235,8 @@ var ( SubscriptionTypeValidator func(string) error // DefaultDefaultValidityDays holds the default value on creation for the "default_validity_days" field. DefaultDefaultValidityDays int + // DefaultSoraStorageQuotaBytes holds the default value on creation for the "sora_storage_quota_bytes" field. + DefaultSoraStorageQuotaBytes int64 // DefaultClaudeCodeOnly holds the default value on creation for the "claude_code_only" field. DefaultClaudeCodeOnly bool // DefaultModelRoutingEnabled holds the default value on creation for the "model_routing_enabled" field. @@ -357,6 +362,11 @@ func BySoraVideoPricePerRequestHd(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldSoraVideoPricePerRequestHd, opts...).ToFunc() } +// BySoraStorageQuotaBytes orders the results by the sora_storage_quota_bytes field. +func BySoraStorageQuotaBytes(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraStorageQuotaBytes, opts...).ToFunc() +} + // ByClaudeCodeOnly orders the results by the claude_code_only field. func ByClaudeCodeOnly(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldClaudeCodeOnly, opts...).ToFunc() diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go index 127d4ae9..4cf65d0f 100644 --- a/backend/ent/group/where.go +++ b/backend/ent/group/where.go @@ -160,6 +160,11 @@ func SoraVideoPricePerRequestHd(v float64) predicate.Group { return predicate.Group(sql.FieldEQ(FieldSoraVideoPricePerRequestHd, v)) } +// SoraStorageQuotaBytes applies equality check predicate on the "sora_storage_quota_bytes" field. It's identical to SoraStorageQuotaBytesEQ. +func SoraStorageQuotaBytes(v int64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraStorageQuotaBytes, v)) +} + // ClaudeCodeOnly applies equality check predicate on the "claude_code_only" field. It's identical to ClaudeCodeOnlyEQ. func ClaudeCodeOnly(v bool) predicate.Group { return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v)) @@ -1245,6 +1250,46 @@ func SoraVideoPricePerRequestHdNotNil() predicate.Group { return predicate.Group(sql.FieldNotNull(FieldSoraVideoPricePerRequestHd)) } +// SoraStorageQuotaBytesEQ applies the EQ predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesEQ(v int64) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesNEQ applies the NEQ predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesNEQ(v int64) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesIn applies the In predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesIn(vs ...int64) predicate.Group { + return predicate.Group(sql.FieldIn(FieldSoraStorageQuotaBytes, vs...)) +} + +// SoraStorageQuotaBytesNotIn applies the NotIn predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesNotIn(vs ...int64) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldSoraStorageQuotaBytes, vs...)) +} + +// SoraStorageQuotaBytesGT applies the GT predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesGT(v int64) predicate.Group { + return predicate.Group(sql.FieldGT(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesGTE applies the GTE predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesGTE(v int64) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesLT applies the LT predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesLT(v int64) predicate.Group { + return predicate.Group(sql.FieldLT(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesLTE applies the LTE predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesLTE(v int64) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldSoraStorageQuotaBytes, v)) +} + // ClaudeCodeOnlyEQ applies the EQ predicate on the "claude_code_only" field. func ClaudeCodeOnlyEQ(v bool) predicate.Group { return predicate.Group(sql.FieldEQ(FieldClaudeCodeOnly, v)) diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index 4416516b..0ce5f959 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -314,6 +314,20 @@ func (_c *GroupCreate) SetNillableSoraVideoPricePerRequestHd(v *float64) *GroupC return _c } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (_c *GroupCreate) SetSoraStorageQuotaBytes(v int64) *GroupCreate { + _c.mutation.SetSoraStorageQuotaBytes(v) + return _c +} + +// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. +func (_c *GroupCreate) SetNillableSoraStorageQuotaBytes(v *int64) *GroupCreate { + if v != nil { + _c.SetSoraStorageQuotaBytes(*v) + } + return _c +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (_c *GroupCreate) SetClaudeCodeOnly(v bool) *GroupCreate { _c.mutation.SetClaudeCodeOnly(v) @@ -575,6 +589,10 @@ func (_c *GroupCreate) defaults() error { v := group.DefaultDefaultValidityDays _c.mutation.SetDefaultValidityDays(v) } + if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok { + v := group.DefaultSoraStorageQuotaBytes + _c.mutation.SetSoraStorageQuotaBytes(v) + } if _, ok := _c.mutation.ClaudeCodeOnly(); !ok { v := group.DefaultClaudeCodeOnly _c.mutation.SetClaudeCodeOnly(v) @@ -647,6 +665,9 @@ func (_c *GroupCreate) check() error { if _, ok := _c.mutation.DefaultValidityDays(); !ok { return &ValidationError{Name: "default_validity_days", err: errors.New(`ent: missing required field "Group.default_validity_days"`)} } + if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok { + return &ValidationError{Name: "sora_storage_quota_bytes", err: errors.New(`ent: missing required field "Group.sora_storage_quota_bytes"`)} + } if _, ok := _c.mutation.ClaudeCodeOnly(); !ok { return &ValidationError{Name: "claude_code_only", err: errors.New(`ent: missing required field "Group.claude_code_only"`)} } @@ -773,6 +794,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64, value) _node.SoraVideoPricePerRequestHd = &value } + if value, ok := _c.mutation.SoraStorageQuotaBytes(); ok { + _spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + _node.SoraStorageQuotaBytes = value + } if value, ok := _c.mutation.ClaudeCodeOnly(); ok { _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) _node.ClaudeCodeOnly = value @@ -1345,6 +1370,24 @@ func (u *GroupUpsert) ClearSoraVideoPricePerRequestHd() *GroupUpsert { return u } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (u *GroupUpsert) SetSoraStorageQuotaBytes(v int64) *GroupUpsert { + u.Set(group.FieldSoraStorageQuotaBytes, v) + return u +} + +// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. +func (u *GroupUpsert) UpdateSoraStorageQuotaBytes() *GroupUpsert { + u.SetExcluded(group.FieldSoraStorageQuotaBytes) + return u +} + +// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. +func (u *GroupUpsert) AddSoraStorageQuotaBytes(v int64) *GroupUpsert { + u.Add(group.FieldSoraStorageQuotaBytes, v) + return u +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (u *GroupUpsert) SetClaudeCodeOnly(v bool) *GroupUpsert { u.Set(group.FieldClaudeCodeOnly, v) @@ -1970,6 +2013,27 @@ func (u *GroupUpsertOne) ClearSoraVideoPricePerRequestHd() *GroupUpsertOne { }) } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (u *GroupUpsertOne) SetSoraStorageQuotaBytes(v int64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetSoraStorageQuotaBytes(v) + }) +} + +// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. +func (u *GroupUpsertOne) AddSoraStorageQuotaBytes(v int64) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddSoraStorageQuotaBytes(v) + }) +} + +// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateSoraStorageQuotaBytes() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraStorageQuotaBytes() + }) +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (u *GroupUpsertOne) SetClaudeCodeOnly(v bool) *GroupUpsertOne { return u.Update(func(s *GroupUpsert) { @@ -2783,6 +2847,27 @@ func (u *GroupUpsertBulk) ClearSoraVideoPricePerRequestHd() *GroupUpsertBulk { }) } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (u *GroupUpsertBulk) SetSoraStorageQuotaBytes(v int64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetSoraStorageQuotaBytes(v) + }) +} + +// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. +func (u *GroupUpsertBulk) AddSoraStorageQuotaBytes(v int64) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddSoraStorageQuotaBytes(v) + }) +} + +// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateSoraStorageQuotaBytes() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateSoraStorageQuotaBytes() + }) +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (u *GroupUpsertBulk) SetClaudeCodeOnly(v bool) *GroupUpsertBulk { return u.Update(func(s *GroupUpsert) { diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index db510e05..85575292 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -463,6 +463,27 @@ func (_u *GroupUpdate) ClearSoraVideoPricePerRequestHd() *GroupUpdate { return _u } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (_u *GroupUpdate) SetSoraStorageQuotaBytes(v int64) *GroupUpdate { + _u.mutation.ResetSoraStorageQuotaBytes() + _u.mutation.SetSoraStorageQuotaBytes(v) + return _u +} + +// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableSoraStorageQuotaBytes(v *int64) *GroupUpdate { + if v != nil { + _u.SetSoraStorageQuotaBytes(*v) + } + return _u +} + +// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field. +func (_u *GroupUpdate) AddSoraStorageQuotaBytes(v int64) *GroupUpdate { + _u.mutation.AddSoraStorageQuotaBytes(v) + return _u +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (_u *GroupUpdate) SetClaudeCodeOnly(v bool) *GroupUpdate { _u.mutation.SetClaudeCodeOnly(v) @@ -1036,6 +1057,12 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.SoraVideoPricePerRequestHdCleared() { _spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64) } + if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok { + _spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok { + _spec.AddField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + } if value, ok := _u.mutation.ClaudeCodeOnly(); ok { _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) } @@ -1825,6 +1852,27 @@ func (_u *GroupUpdateOne) ClearSoraVideoPricePerRequestHd() *GroupUpdateOne { return _u } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (_u *GroupUpdateOne) SetSoraStorageQuotaBytes(v int64) *GroupUpdateOne { + _u.mutation.ResetSoraStorageQuotaBytes() + _u.mutation.SetSoraStorageQuotaBytes(v) + return _u +} + +// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableSoraStorageQuotaBytes(v *int64) *GroupUpdateOne { + if v != nil { + _u.SetSoraStorageQuotaBytes(*v) + } + return _u +} + +// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field. +func (_u *GroupUpdateOne) AddSoraStorageQuotaBytes(v int64) *GroupUpdateOne { + _u.mutation.AddSoraStorageQuotaBytes(v) + return _u +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (_u *GroupUpdateOne) SetClaudeCodeOnly(v bool) *GroupUpdateOne { _u.mutation.SetClaudeCodeOnly(v) @@ -2428,6 +2476,12 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if _u.mutation.SoraVideoPricePerRequestHdCleared() { _spec.ClearField(group.FieldSoraVideoPricePerRequestHd, field.TypeFloat64) } + if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok { + _spec.SetField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok { + _spec.AddField(group.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + } if value, ok := _u.mutation.ClaudeCodeOnly(); ok { _spec.SetField(group.FieldClaudeCodeOnly, field.TypeBool, value) } diff --git a/backend/ent/hook/hook.go b/backend/ent/hook/hook.go index aff9caa0..49d7f3c5 100644 --- a/backend/ent/hook/hook.go +++ b/backend/ent/hook/hook.go @@ -93,6 +93,18 @@ func (f GroupFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.GroupMutation", m) } +// The IdempotencyRecordFunc type is an adapter to allow the use of ordinary +// function as IdempotencyRecord mutator. +type IdempotencyRecordFunc func(context.Context, *ent.IdempotencyRecordMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f IdempotencyRecordFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.IdempotencyRecordMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.IdempotencyRecordMutation", m) +} + // The PromoCodeFunc type is an adapter to allow the use of ordinary // function as PromoCode mutator. type PromoCodeFunc func(context.Context, *ent.PromoCodeMutation) (ent.Value, error) diff --git a/backend/ent/idempotencyrecord.go b/backend/ent/idempotencyrecord.go new file mode 100644 index 00000000..ab120f8f --- /dev/null +++ b/backend/ent/idempotencyrecord.go @@ -0,0 +1,228 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" +) + +// IdempotencyRecord is the model entity for the IdempotencyRecord schema. +type IdempotencyRecord struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // Scope holds the value of the "scope" field. + Scope string `json:"scope,omitempty"` + // IdempotencyKeyHash holds the value of the "idempotency_key_hash" field. + IdempotencyKeyHash string `json:"idempotency_key_hash,omitempty"` + // RequestFingerprint holds the value of the "request_fingerprint" field. + RequestFingerprint string `json:"request_fingerprint,omitempty"` + // Status holds the value of the "status" field. + Status string `json:"status,omitempty"` + // ResponseStatus holds the value of the "response_status" field. + ResponseStatus *int `json:"response_status,omitempty"` + // ResponseBody holds the value of the "response_body" field. + ResponseBody *string `json:"response_body,omitempty"` + // ErrorReason holds the value of the "error_reason" field. + ErrorReason *string `json:"error_reason,omitempty"` + // LockedUntil holds the value of the "locked_until" field. + LockedUntil *time.Time `json:"locked_until,omitempty"` + // ExpiresAt holds the value of the "expires_at" field. + ExpiresAt time.Time `json:"expires_at,omitempty"` + selectValues sql.SelectValues +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*IdempotencyRecord) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case idempotencyrecord.FieldID, idempotencyrecord.FieldResponseStatus: + values[i] = new(sql.NullInt64) + case idempotencyrecord.FieldScope, idempotencyrecord.FieldIdempotencyKeyHash, idempotencyrecord.FieldRequestFingerprint, idempotencyrecord.FieldStatus, idempotencyrecord.FieldResponseBody, idempotencyrecord.FieldErrorReason: + values[i] = new(sql.NullString) + case idempotencyrecord.FieldCreatedAt, idempotencyrecord.FieldUpdatedAt, idempotencyrecord.FieldLockedUntil, idempotencyrecord.FieldExpiresAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the IdempotencyRecord fields. +func (_m *IdempotencyRecord) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case idempotencyrecord.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case idempotencyrecord.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case idempotencyrecord.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case idempotencyrecord.FieldScope: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field scope", values[i]) + } else if value.Valid { + _m.Scope = value.String + } + case idempotencyrecord.FieldIdempotencyKeyHash: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field idempotency_key_hash", values[i]) + } else if value.Valid { + _m.IdempotencyKeyHash = value.String + } + case idempotencyrecord.FieldRequestFingerprint: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field request_fingerprint", values[i]) + } else if value.Valid { + _m.RequestFingerprint = value.String + } + case idempotencyrecord.FieldStatus: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field status", values[i]) + } else if value.Valid { + _m.Status = value.String + } + case idempotencyrecord.FieldResponseStatus: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field response_status", values[i]) + } else if value.Valid { + _m.ResponseStatus = new(int) + *_m.ResponseStatus = int(value.Int64) + } + case idempotencyrecord.FieldResponseBody: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field response_body", values[i]) + } else if value.Valid { + _m.ResponseBody = new(string) + *_m.ResponseBody = value.String + } + case idempotencyrecord.FieldErrorReason: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field error_reason", values[i]) + } else if value.Valid { + _m.ErrorReason = new(string) + *_m.ErrorReason = value.String + } + case idempotencyrecord.FieldLockedUntil: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field locked_until", values[i]) + } else if value.Valid { + _m.LockedUntil = new(time.Time) + *_m.LockedUntil = value.Time + } + case idempotencyrecord.FieldExpiresAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field expires_at", values[i]) + } else if value.Valid { + _m.ExpiresAt = value.Time + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the IdempotencyRecord. +// This includes values selected through modifiers, order, etc. +func (_m *IdempotencyRecord) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// Update returns a builder for updating this IdempotencyRecord. +// Note that you need to call IdempotencyRecord.Unwrap() before calling this method if this IdempotencyRecord +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *IdempotencyRecord) Update() *IdempotencyRecordUpdateOne { + return NewIdempotencyRecordClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the IdempotencyRecord entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *IdempotencyRecord) Unwrap() *IdempotencyRecord { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: IdempotencyRecord is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *IdempotencyRecord) String() string { + var builder strings.Builder + builder.WriteString("IdempotencyRecord(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("scope=") + builder.WriteString(_m.Scope) + builder.WriteString(", ") + builder.WriteString("idempotency_key_hash=") + builder.WriteString(_m.IdempotencyKeyHash) + builder.WriteString(", ") + builder.WriteString("request_fingerprint=") + builder.WriteString(_m.RequestFingerprint) + builder.WriteString(", ") + builder.WriteString("status=") + builder.WriteString(_m.Status) + builder.WriteString(", ") + if v := _m.ResponseStatus; v != nil { + builder.WriteString("response_status=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + if v := _m.ResponseBody; v != nil { + builder.WriteString("response_body=") + builder.WriteString(*v) + } + builder.WriteString(", ") + if v := _m.ErrorReason; v != nil { + builder.WriteString("error_reason=") + builder.WriteString(*v) + } + builder.WriteString(", ") + if v := _m.LockedUntil; v != nil { + builder.WriteString("locked_until=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + builder.WriteString("expires_at=") + builder.WriteString(_m.ExpiresAt.Format(time.ANSIC)) + builder.WriteByte(')') + return builder.String() +} + +// IdempotencyRecords is a parsable slice of IdempotencyRecord. +type IdempotencyRecords []*IdempotencyRecord diff --git a/backend/ent/idempotencyrecord/idempotencyrecord.go b/backend/ent/idempotencyrecord/idempotencyrecord.go new file mode 100644 index 00000000..d9686f60 --- /dev/null +++ b/backend/ent/idempotencyrecord/idempotencyrecord.go @@ -0,0 +1,148 @@ +// Code generated by ent, DO NOT EDIT. + +package idempotencyrecord + +import ( + "time" + + "entgo.io/ent/dialect/sql" +) + +const ( + // Label holds the string label denoting the idempotencyrecord type in the database. + Label = "idempotency_record" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldScope holds the string denoting the scope field in the database. + FieldScope = "scope" + // FieldIdempotencyKeyHash holds the string denoting the idempotency_key_hash field in the database. + FieldIdempotencyKeyHash = "idempotency_key_hash" + // FieldRequestFingerprint holds the string denoting the request_fingerprint field in the database. + FieldRequestFingerprint = "request_fingerprint" + // FieldStatus holds the string denoting the status field in the database. + FieldStatus = "status" + // FieldResponseStatus holds the string denoting the response_status field in the database. + FieldResponseStatus = "response_status" + // FieldResponseBody holds the string denoting the response_body field in the database. + FieldResponseBody = "response_body" + // FieldErrorReason holds the string denoting the error_reason field in the database. + FieldErrorReason = "error_reason" + // FieldLockedUntil holds the string denoting the locked_until field in the database. + FieldLockedUntil = "locked_until" + // FieldExpiresAt holds the string denoting the expires_at field in the database. + FieldExpiresAt = "expires_at" + // Table holds the table name of the idempotencyrecord in the database. + Table = "idempotency_records" +) + +// Columns holds all SQL columns for idempotencyrecord fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldScope, + FieldIdempotencyKeyHash, + FieldRequestFingerprint, + FieldStatus, + FieldResponseStatus, + FieldResponseBody, + FieldErrorReason, + FieldLockedUntil, + FieldExpiresAt, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // ScopeValidator is a validator for the "scope" field. It is called by the builders before save. + ScopeValidator func(string) error + // IdempotencyKeyHashValidator is a validator for the "idempotency_key_hash" field. It is called by the builders before save. + IdempotencyKeyHashValidator func(string) error + // RequestFingerprintValidator is a validator for the "request_fingerprint" field. It is called by the builders before save. + RequestFingerprintValidator func(string) error + // StatusValidator is a validator for the "status" field. It is called by the builders before save. + StatusValidator func(string) error + // ErrorReasonValidator is a validator for the "error_reason" field. It is called by the builders before save. + ErrorReasonValidator func(string) error +) + +// OrderOption defines the ordering options for the IdempotencyRecord queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByScope orders the results by the scope field. +func ByScope(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldScope, opts...).ToFunc() +} + +// ByIdempotencyKeyHash orders the results by the idempotency_key_hash field. +func ByIdempotencyKeyHash(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldIdempotencyKeyHash, opts...).ToFunc() +} + +// ByRequestFingerprint orders the results by the request_fingerprint field. +func ByRequestFingerprint(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRequestFingerprint, opts...).ToFunc() +} + +// ByStatus orders the results by the status field. +func ByStatus(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStatus, opts...).ToFunc() +} + +// ByResponseStatus orders the results by the response_status field. +func ByResponseStatus(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldResponseStatus, opts...).ToFunc() +} + +// ByResponseBody orders the results by the response_body field. +func ByResponseBody(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldResponseBody, opts...).ToFunc() +} + +// ByErrorReason orders the results by the error_reason field. +func ByErrorReason(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldErrorReason, opts...).ToFunc() +} + +// ByLockedUntil orders the results by the locked_until field. +func ByLockedUntil(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldLockedUntil, opts...).ToFunc() +} + +// ByExpiresAt orders the results by the expires_at field. +func ByExpiresAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldExpiresAt, opts...).ToFunc() +} diff --git a/backend/ent/idempotencyrecord/where.go b/backend/ent/idempotencyrecord/where.go new file mode 100644 index 00000000..c3d8d9d5 --- /dev/null +++ b/backend/ent/idempotencyrecord/where.go @@ -0,0 +1,755 @@ +// Code generated by ent, DO NOT EDIT. + +package idempotencyrecord + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// Scope applies equality check predicate on the "scope" field. It's identical to ScopeEQ. +func Scope(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldScope, v)) +} + +// IdempotencyKeyHash applies equality check predicate on the "idempotency_key_hash" field. It's identical to IdempotencyKeyHashEQ. +func IdempotencyKeyHash(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldIdempotencyKeyHash, v)) +} + +// RequestFingerprint applies equality check predicate on the "request_fingerprint" field. It's identical to RequestFingerprintEQ. +func RequestFingerprint(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldRequestFingerprint, v)) +} + +// Status applies equality check predicate on the "status" field. It's identical to StatusEQ. +func Status(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldStatus, v)) +} + +// ResponseStatus applies equality check predicate on the "response_status" field. It's identical to ResponseStatusEQ. +func ResponseStatus(v int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldResponseStatus, v)) +} + +// ResponseBody applies equality check predicate on the "response_body" field. It's identical to ResponseBodyEQ. +func ResponseBody(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldResponseBody, v)) +} + +// ErrorReason applies equality check predicate on the "error_reason" field. It's identical to ErrorReasonEQ. +func ErrorReason(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldErrorReason, v)) +} + +// LockedUntil applies equality check predicate on the "locked_until" field. It's identical to LockedUntilEQ. +func LockedUntil(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldLockedUntil, v)) +} + +// ExpiresAt applies equality check predicate on the "expires_at" field. It's identical to ExpiresAtEQ. +func ExpiresAt(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldExpiresAt, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// ScopeEQ applies the EQ predicate on the "scope" field. +func ScopeEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldScope, v)) +} + +// ScopeNEQ applies the NEQ predicate on the "scope" field. +func ScopeNEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldScope, v)) +} + +// ScopeIn applies the In predicate on the "scope" field. +func ScopeIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldScope, vs...)) +} + +// ScopeNotIn applies the NotIn predicate on the "scope" field. +func ScopeNotIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldScope, vs...)) +} + +// ScopeGT applies the GT predicate on the "scope" field. +func ScopeGT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldScope, v)) +} + +// ScopeGTE applies the GTE predicate on the "scope" field. +func ScopeGTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldScope, v)) +} + +// ScopeLT applies the LT predicate on the "scope" field. +func ScopeLT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldScope, v)) +} + +// ScopeLTE applies the LTE predicate on the "scope" field. +func ScopeLTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldScope, v)) +} + +// ScopeContains applies the Contains predicate on the "scope" field. +func ScopeContains(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContains(FieldScope, v)) +} + +// ScopeHasPrefix applies the HasPrefix predicate on the "scope" field. +func ScopeHasPrefix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldScope, v)) +} + +// ScopeHasSuffix applies the HasSuffix predicate on the "scope" field. +func ScopeHasSuffix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldScope, v)) +} + +// ScopeEqualFold applies the EqualFold predicate on the "scope" field. +func ScopeEqualFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldScope, v)) +} + +// ScopeContainsFold applies the ContainsFold predicate on the "scope" field. +func ScopeContainsFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldScope, v)) +} + +// IdempotencyKeyHashEQ applies the EQ predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashNEQ applies the NEQ predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashNEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashIn applies the In predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldIdempotencyKeyHash, vs...)) +} + +// IdempotencyKeyHashNotIn applies the NotIn predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashNotIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldIdempotencyKeyHash, vs...)) +} + +// IdempotencyKeyHashGT applies the GT predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashGT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashGTE applies the GTE predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashGTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashLT applies the LT predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashLT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashLTE applies the LTE predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashLTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashContains applies the Contains predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashContains(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContains(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashHasPrefix applies the HasPrefix predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashHasPrefix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashHasSuffix applies the HasSuffix predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashHasSuffix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashEqualFold applies the EqualFold predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashEqualFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldIdempotencyKeyHash, v)) +} + +// IdempotencyKeyHashContainsFold applies the ContainsFold predicate on the "idempotency_key_hash" field. +func IdempotencyKeyHashContainsFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldIdempotencyKeyHash, v)) +} + +// RequestFingerprintEQ applies the EQ predicate on the "request_fingerprint" field. +func RequestFingerprintEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldRequestFingerprint, v)) +} + +// RequestFingerprintNEQ applies the NEQ predicate on the "request_fingerprint" field. +func RequestFingerprintNEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldRequestFingerprint, v)) +} + +// RequestFingerprintIn applies the In predicate on the "request_fingerprint" field. +func RequestFingerprintIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldRequestFingerprint, vs...)) +} + +// RequestFingerprintNotIn applies the NotIn predicate on the "request_fingerprint" field. +func RequestFingerprintNotIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldRequestFingerprint, vs...)) +} + +// RequestFingerprintGT applies the GT predicate on the "request_fingerprint" field. +func RequestFingerprintGT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldRequestFingerprint, v)) +} + +// RequestFingerprintGTE applies the GTE predicate on the "request_fingerprint" field. +func RequestFingerprintGTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldRequestFingerprint, v)) +} + +// RequestFingerprintLT applies the LT predicate on the "request_fingerprint" field. +func RequestFingerprintLT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldRequestFingerprint, v)) +} + +// RequestFingerprintLTE applies the LTE predicate on the "request_fingerprint" field. +func RequestFingerprintLTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldRequestFingerprint, v)) +} + +// RequestFingerprintContains applies the Contains predicate on the "request_fingerprint" field. +func RequestFingerprintContains(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContains(FieldRequestFingerprint, v)) +} + +// RequestFingerprintHasPrefix applies the HasPrefix predicate on the "request_fingerprint" field. +func RequestFingerprintHasPrefix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldRequestFingerprint, v)) +} + +// RequestFingerprintHasSuffix applies the HasSuffix predicate on the "request_fingerprint" field. +func RequestFingerprintHasSuffix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldRequestFingerprint, v)) +} + +// RequestFingerprintEqualFold applies the EqualFold predicate on the "request_fingerprint" field. +func RequestFingerprintEqualFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldRequestFingerprint, v)) +} + +// RequestFingerprintContainsFold applies the ContainsFold predicate on the "request_fingerprint" field. +func RequestFingerprintContainsFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldRequestFingerprint, v)) +} + +// StatusEQ applies the EQ predicate on the "status" field. +func StatusEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldStatus, v)) +} + +// StatusNEQ applies the NEQ predicate on the "status" field. +func StatusNEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldStatus, v)) +} + +// StatusIn applies the In predicate on the "status" field. +func StatusIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldStatus, vs...)) +} + +// StatusNotIn applies the NotIn predicate on the "status" field. +func StatusNotIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldStatus, vs...)) +} + +// StatusGT applies the GT predicate on the "status" field. +func StatusGT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldStatus, v)) +} + +// StatusGTE applies the GTE predicate on the "status" field. +func StatusGTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldStatus, v)) +} + +// StatusLT applies the LT predicate on the "status" field. +func StatusLT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldStatus, v)) +} + +// StatusLTE applies the LTE predicate on the "status" field. +func StatusLTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldStatus, v)) +} + +// StatusContains applies the Contains predicate on the "status" field. +func StatusContains(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContains(FieldStatus, v)) +} + +// StatusHasPrefix applies the HasPrefix predicate on the "status" field. +func StatusHasPrefix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldStatus, v)) +} + +// StatusHasSuffix applies the HasSuffix predicate on the "status" field. +func StatusHasSuffix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldStatus, v)) +} + +// StatusEqualFold applies the EqualFold predicate on the "status" field. +func StatusEqualFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldStatus, v)) +} + +// StatusContainsFold applies the ContainsFold predicate on the "status" field. +func StatusContainsFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldStatus, v)) +} + +// ResponseStatusEQ applies the EQ predicate on the "response_status" field. +func ResponseStatusEQ(v int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldResponseStatus, v)) +} + +// ResponseStatusNEQ applies the NEQ predicate on the "response_status" field. +func ResponseStatusNEQ(v int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldResponseStatus, v)) +} + +// ResponseStatusIn applies the In predicate on the "response_status" field. +func ResponseStatusIn(vs ...int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldResponseStatus, vs...)) +} + +// ResponseStatusNotIn applies the NotIn predicate on the "response_status" field. +func ResponseStatusNotIn(vs ...int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldResponseStatus, vs...)) +} + +// ResponseStatusGT applies the GT predicate on the "response_status" field. +func ResponseStatusGT(v int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldResponseStatus, v)) +} + +// ResponseStatusGTE applies the GTE predicate on the "response_status" field. +func ResponseStatusGTE(v int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldResponseStatus, v)) +} + +// ResponseStatusLT applies the LT predicate on the "response_status" field. +func ResponseStatusLT(v int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldResponseStatus, v)) +} + +// ResponseStatusLTE applies the LTE predicate on the "response_status" field. +func ResponseStatusLTE(v int) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldResponseStatus, v)) +} + +// ResponseStatusIsNil applies the IsNil predicate on the "response_status" field. +func ResponseStatusIsNil() predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIsNull(FieldResponseStatus)) +} + +// ResponseStatusNotNil applies the NotNil predicate on the "response_status" field. +func ResponseStatusNotNil() predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotNull(FieldResponseStatus)) +} + +// ResponseBodyEQ applies the EQ predicate on the "response_body" field. +func ResponseBodyEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldResponseBody, v)) +} + +// ResponseBodyNEQ applies the NEQ predicate on the "response_body" field. +func ResponseBodyNEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldResponseBody, v)) +} + +// ResponseBodyIn applies the In predicate on the "response_body" field. +func ResponseBodyIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldResponseBody, vs...)) +} + +// ResponseBodyNotIn applies the NotIn predicate on the "response_body" field. +func ResponseBodyNotIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldResponseBody, vs...)) +} + +// ResponseBodyGT applies the GT predicate on the "response_body" field. +func ResponseBodyGT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldResponseBody, v)) +} + +// ResponseBodyGTE applies the GTE predicate on the "response_body" field. +func ResponseBodyGTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldResponseBody, v)) +} + +// ResponseBodyLT applies the LT predicate on the "response_body" field. +func ResponseBodyLT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldResponseBody, v)) +} + +// ResponseBodyLTE applies the LTE predicate on the "response_body" field. +func ResponseBodyLTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldResponseBody, v)) +} + +// ResponseBodyContains applies the Contains predicate on the "response_body" field. +func ResponseBodyContains(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContains(FieldResponseBody, v)) +} + +// ResponseBodyHasPrefix applies the HasPrefix predicate on the "response_body" field. +func ResponseBodyHasPrefix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldResponseBody, v)) +} + +// ResponseBodyHasSuffix applies the HasSuffix predicate on the "response_body" field. +func ResponseBodyHasSuffix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldResponseBody, v)) +} + +// ResponseBodyIsNil applies the IsNil predicate on the "response_body" field. +func ResponseBodyIsNil() predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIsNull(FieldResponseBody)) +} + +// ResponseBodyNotNil applies the NotNil predicate on the "response_body" field. +func ResponseBodyNotNil() predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotNull(FieldResponseBody)) +} + +// ResponseBodyEqualFold applies the EqualFold predicate on the "response_body" field. +func ResponseBodyEqualFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldResponseBody, v)) +} + +// ResponseBodyContainsFold applies the ContainsFold predicate on the "response_body" field. +func ResponseBodyContainsFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldResponseBody, v)) +} + +// ErrorReasonEQ applies the EQ predicate on the "error_reason" field. +func ErrorReasonEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldErrorReason, v)) +} + +// ErrorReasonNEQ applies the NEQ predicate on the "error_reason" field. +func ErrorReasonNEQ(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldErrorReason, v)) +} + +// ErrorReasonIn applies the In predicate on the "error_reason" field. +func ErrorReasonIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldErrorReason, vs...)) +} + +// ErrorReasonNotIn applies the NotIn predicate on the "error_reason" field. +func ErrorReasonNotIn(vs ...string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldErrorReason, vs...)) +} + +// ErrorReasonGT applies the GT predicate on the "error_reason" field. +func ErrorReasonGT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldErrorReason, v)) +} + +// ErrorReasonGTE applies the GTE predicate on the "error_reason" field. +func ErrorReasonGTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldErrorReason, v)) +} + +// ErrorReasonLT applies the LT predicate on the "error_reason" field. +func ErrorReasonLT(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldErrorReason, v)) +} + +// ErrorReasonLTE applies the LTE predicate on the "error_reason" field. +func ErrorReasonLTE(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldErrorReason, v)) +} + +// ErrorReasonContains applies the Contains predicate on the "error_reason" field. +func ErrorReasonContains(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContains(FieldErrorReason, v)) +} + +// ErrorReasonHasPrefix applies the HasPrefix predicate on the "error_reason" field. +func ErrorReasonHasPrefix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasPrefix(FieldErrorReason, v)) +} + +// ErrorReasonHasSuffix applies the HasSuffix predicate on the "error_reason" field. +func ErrorReasonHasSuffix(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldHasSuffix(FieldErrorReason, v)) +} + +// ErrorReasonIsNil applies the IsNil predicate on the "error_reason" field. +func ErrorReasonIsNil() predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIsNull(FieldErrorReason)) +} + +// ErrorReasonNotNil applies the NotNil predicate on the "error_reason" field. +func ErrorReasonNotNil() predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotNull(FieldErrorReason)) +} + +// ErrorReasonEqualFold applies the EqualFold predicate on the "error_reason" field. +func ErrorReasonEqualFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEqualFold(FieldErrorReason, v)) +} + +// ErrorReasonContainsFold applies the ContainsFold predicate on the "error_reason" field. +func ErrorReasonContainsFold(v string) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldContainsFold(FieldErrorReason, v)) +} + +// LockedUntilEQ applies the EQ predicate on the "locked_until" field. +func LockedUntilEQ(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldLockedUntil, v)) +} + +// LockedUntilNEQ applies the NEQ predicate on the "locked_until" field. +func LockedUntilNEQ(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldLockedUntil, v)) +} + +// LockedUntilIn applies the In predicate on the "locked_until" field. +func LockedUntilIn(vs ...time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldLockedUntil, vs...)) +} + +// LockedUntilNotIn applies the NotIn predicate on the "locked_until" field. +func LockedUntilNotIn(vs ...time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldLockedUntil, vs...)) +} + +// LockedUntilGT applies the GT predicate on the "locked_until" field. +func LockedUntilGT(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldLockedUntil, v)) +} + +// LockedUntilGTE applies the GTE predicate on the "locked_until" field. +func LockedUntilGTE(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldLockedUntil, v)) +} + +// LockedUntilLT applies the LT predicate on the "locked_until" field. +func LockedUntilLT(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldLockedUntil, v)) +} + +// LockedUntilLTE applies the LTE predicate on the "locked_until" field. +func LockedUntilLTE(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldLockedUntil, v)) +} + +// LockedUntilIsNil applies the IsNil predicate on the "locked_until" field. +func LockedUntilIsNil() predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIsNull(FieldLockedUntil)) +} + +// LockedUntilNotNil applies the NotNil predicate on the "locked_until" field. +func LockedUntilNotNil() predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotNull(FieldLockedUntil)) +} + +// ExpiresAtEQ applies the EQ predicate on the "expires_at" field. +func ExpiresAtEQ(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldEQ(FieldExpiresAt, v)) +} + +// ExpiresAtNEQ applies the NEQ predicate on the "expires_at" field. +func ExpiresAtNEQ(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNEQ(FieldExpiresAt, v)) +} + +// ExpiresAtIn applies the In predicate on the "expires_at" field. +func ExpiresAtIn(vs ...time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldIn(FieldExpiresAt, vs...)) +} + +// ExpiresAtNotIn applies the NotIn predicate on the "expires_at" field. +func ExpiresAtNotIn(vs ...time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldNotIn(FieldExpiresAt, vs...)) +} + +// ExpiresAtGT applies the GT predicate on the "expires_at" field. +func ExpiresAtGT(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGT(FieldExpiresAt, v)) +} + +// ExpiresAtGTE applies the GTE predicate on the "expires_at" field. +func ExpiresAtGTE(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldGTE(FieldExpiresAt, v)) +} + +// ExpiresAtLT applies the LT predicate on the "expires_at" field. +func ExpiresAtLT(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLT(FieldExpiresAt, v)) +} + +// ExpiresAtLTE applies the LTE predicate on the "expires_at" field. +func ExpiresAtLTE(v time.Time) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.FieldLTE(FieldExpiresAt, v)) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.IdempotencyRecord) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.IdempotencyRecord) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.IdempotencyRecord) predicate.IdempotencyRecord { + return predicate.IdempotencyRecord(sql.NotPredicates(p)) +} diff --git a/backend/ent/idempotencyrecord_create.go b/backend/ent/idempotencyrecord_create.go new file mode 100644 index 00000000..bf4deaf2 --- /dev/null +++ b/backend/ent/idempotencyrecord_create.go @@ -0,0 +1,1132 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" +) + +// IdempotencyRecordCreate is the builder for creating a IdempotencyRecord entity. +type IdempotencyRecordCreate struct { + config + mutation *IdempotencyRecordMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *IdempotencyRecordCreate) SetCreatedAt(v time.Time) *IdempotencyRecordCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *IdempotencyRecordCreate) SetNillableCreatedAt(v *time.Time) *IdempotencyRecordCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *IdempotencyRecordCreate) SetUpdatedAt(v time.Time) *IdempotencyRecordCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *IdempotencyRecordCreate) SetNillableUpdatedAt(v *time.Time) *IdempotencyRecordCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetScope sets the "scope" field. +func (_c *IdempotencyRecordCreate) SetScope(v string) *IdempotencyRecordCreate { + _c.mutation.SetScope(v) + return _c +} + +// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. +func (_c *IdempotencyRecordCreate) SetIdempotencyKeyHash(v string) *IdempotencyRecordCreate { + _c.mutation.SetIdempotencyKeyHash(v) + return _c +} + +// SetRequestFingerprint sets the "request_fingerprint" field. +func (_c *IdempotencyRecordCreate) SetRequestFingerprint(v string) *IdempotencyRecordCreate { + _c.mutation.SetRequestFingerprint(v) + return _c +} + +// SetStatus sets the "status" field. +func (_c *IdempotencyRecordCreate) SetStatus(v string) *IdempotencyRecordCreate { + _c.mutation.SetStatus(v) + return _c +} + +// SetResponseStatus sets the "response_status" field. +func (_c *IdempotencyRecordCreate) SetResponseStatus(v int) *IdempotencyRecordCreate { + _c.mutation.SetResponseStatus(v) + return _c +} + +// SetNillableResponseStatus sets the "response_status" field if the given value is not nil. +func (_c *IdempotencyRecordCreate) SetNillableResponseStatus(v *int) *IdempotencyRecordCreate { + if v != nil { + _c.SetResponseStatus(*v) + } + return _c +} + +// SetResponseBody sets the "response_body" field. +func (_c *IdempotencyRecordCreate) SetResponseBody(v string) *IdempotencyRecordCreate { + _c.mutation.SetResponseBody(v) + return _c +} + +// SetNillableResponseBody sets the "response_body" field if the given value is not nil. +func (_c *IdempotencyRecordCreate) SetNillableResponseBody(v *string) *IdempotencyRecordCreate { + if v != nil { + _c.SetResponseBody(*v) + } + return _c +} + +// SetErrorReason sets the "error_reason" field. +func (_c *IdempotencyRecordCreate) SetErrorReason(v string) *IdempotencyRecordCreate { + _c.mutation.SetErrorReason(v) + return _c +} + +// SetNillableErrorReason sets the "error_reason" field if the given value is not nil. +func (_c *IdempotencyRecordCreate) SetNillableErrorReason(v *string) *IdempotencyRecordCreate { + if v != nil { + _c.SetErrorReason(*v) + } + return _c +} + +// SetLockedUntil sets the "locked_until" field. +func (_c *IdempotencyRecordCreate) SetLockedUntil(v time.Time) *IdempotencyRecordCreate { + _c.mutation.SetLockedUntil(v) + return _c +} + +// SetNillableLockedUntil sets the "locked_until" field if the given value is not nil. +func (_c *IdempotencyRecordCreate) SetNillableLockedUntil(v *time.Time) *IdempotencyRecordCreate { + if v != nil { + _c.SetLockedUntil(*v) + } + return _c +} + +// SetExpiresAt sets the "expires_at" field. +func (_c *IdempotencyRecordCreate) SetExpiresAt(v time.Time) *IdempotencyRecordCreate { + _c.mutation.SetExpiresAt(v) + return _c +} + +// Mutation returns the IdempotencyRecordMutation object of the builder. +func (_c *IdempotencyRecordCreate) Mutation() *IdempotencyRecordMutation { + return _c.mutation +} + +// Save creates the IdempotencyRecord in the database. +func (_c *IdempotencyRecordCreate) Save(ctx context.Context) (*IdempotencyRecord, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *IdempotencyRecordCreate) SaveX(ctx context.Context) *IdempotencyRecord { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *IdempotencyRecordCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *IdempotencyRecordCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *IdempotencyRecordCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := idempotencyrecord.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := idempotencyrecord.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *IdempotencyRecordCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "IdempotencyRecord.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "IdempotencyRecord.updated_at"`)} + } + if _, ok := _c.mutation.Scope(); !ok { + return &ValidationError{Name: "scope", err: errors.New(`ent: missing required field "IdempotencyRecord.scope"`)} + } + if v, ok := _c.mutation.Scope(); ok { + if err := idempotencyrecord.ScopeValidator(v); err != nil { + return &ValidationError{Name: "scope", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.scope": %w`, err)} + } + } + if _, ok := _c.mutation.IdempotencyKeyHash(); !ok { + return &ValidationError{Name: "idempotency_key_hash", err: errors.New(`ent: missing required field "IdempotencyRecord.idempotency_key_hash"`)} + } + if v, ok := _c.mutation.IdempotencyKeyHash(); ok { + if err := idempotencyrecord.IdempotencyKeyHashValidator(v); err != nil { + return &ValidationError{Name: "idempotency_key_hash", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.idempotency_key_hash": %w`, err)} + } + } + if _, ok := _c.mutation.RequestFingerprint(); !ok { + return &ValidationError{Name: "request_fingerprint", err: errors.New(`ent: missing required field "IdempotencyRecord.request_fingerprint"`)} + } + if v, ok := _c.mutation.RequestFingerprint(); ok { + if err := idempotencyrecord.RequestFingerprintValidator(v); err != nil { + return &ValidationError{Name: "request_fingerprint", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.request_fingerprint": %w`, err)} + } + } + if _, ok := _c.mutation.Status(); !ok { + return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "IdempotencyRecord.status"`)} + } + if v, ok := _c.mutation.Status(); ok { + if err := idempotencyrecord.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.status": %w`, err)} + } + } + if v, ok := _c.mutation.ErrorReason(); ok { + if err := idempotencyrecord.ErrorReasonValidator(v); err != nil { + return &ValidationError{Name: "error_reason", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.error_reason": %w`, err)} + } + } + if _, ok := _c.mutation.ExpiresAt(); !ok { + return &ValidationError{Name: "expires_at", err: errors.New(`ent: missing required field "IdempotencyRecord.expires_at"`)} + } + return nil +} + +func (_c *IdempotencyRecordCreate) sqlSave(ctx context.Context) (*IdempotencyRecord, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *IdempotencyRecordCreate) createSpec() (*IdempotencyRecord, *sqlgraph.CreateSpec) { + var ( + _node = &IdempotencyRecord{config: _c.config} + _spec = sqlgraph.NewCreateSpec(idempotencyrecord.Table, sqlgraph.NewFieldSpec(idempotencyrecord.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(idempotencyrecord.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(idempotencyrecord.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.Scope(); ok { + _spec.SetField(idempotencyrecord.FieldScope, field.TypeString, value) + _node.Scope = value + } + if value, ok := _c.mutation.IdempotencyKeyHash(); ok { + _spec.SetField(idempotencyrecord.FieldIdempotencyKeyHash, field.TypeString, value) + _node.IdempotencyKeyHash = value + } + if value, ok := _c.mutation.RequestFingerprint(); ok { + _spec.SetField(idempotencyrecord.FieldRequestFingerprint, field.TypeString, value) + _node.RequestFingerprint = value + } + if value, ok := _c.mutation.Status(); ok { + _spec.SetField(idempotencyrecord.FieldStatus, field.TypeString, value) + _node.Status = value + } + if value, ok := _c.mutation.ResponseStatus(); ok { + _spec.SetField(idempotencyrecord.FieldResponseStatus, field.TypeInt, value) + _node.ResponseStatus = &value + } + if value, ok := _c.mutation.ResponseBody(); ok { + _spec.SetField(idempotencyrecord.FieldResponseBody, field.TypeString, value) + _node.ResponseBody = &value + } + if value, ok := _c.mutation.ErrorReason(); ok { + _spec.SetField(idempotencyrecord.FieldErrorReason, field.TypeString, value) + _node.ErrorReason = &value + } + if value, ok := _c.mutation.LockedUntil(); ok { + _spec.SetField(idempotencyrecord.FieldLockedUntil, field.TypeTime, value) + _node.LockedUntil = &value + } + if value, ok := _c.mutation.ExpiresAt(); ok { + _spec.SetField(idempotencyrecord.FieldExpiresAt, field.TypeTime, value) + _node.ExpiresAt = value + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.IdempotencyRecord.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.IdempotencyRecordUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *IdempotencyRecordCreate) OnConflict(opts ...sql.ConflictOption) *IdempotencyRecordUpsertOne { + _c.conflict = opts + return &IdempotencyRecordUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.IdempotencyRecord.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *IdempotencyRecordCreate) OnConflictColumns(columns ...string) *IdempotencyRecordUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &IdempotencyRecordUpsertOne{ + create: _c, + } +} + +type ( + // IdempotencyRecordUpsertOne is the builder for "upsert"-ing + // one IdempotencyRecord node. + IdempotencyRecordUpsertOne struct { + create *IdempotencyRecordCreate + } + + // IdempotencyRecordUpsert is the "OnConflict" setter. + IdempotencyRecordUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *IdempotencyRecordUpsert) SetUpdatedAt(v time.Time) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateUpdatedAt() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldUpdatedAt) + return u +} + +// SetScope sets the "scope" field. +func (u *IdempotencyRecordUpsert) SetScope(v string) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldScope, v) + return u +} + +// UpdateScope sets the "scope" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateScope() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldScope) + return u +} + +// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. +func (u *IdempotencyRecordUpsert) SetIdempotencyKeyHash(v string) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldIdempotencyKeyHash, v) + return u +} + +// UpdateIdempotencyKeyHash sets the "idempotency_key_hash" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateIdempotencyKeyHash() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldIdempotencyKeyHash) + return u +} + +// SetRequestFingerprint sets the "request_fingerprint" field. +func (u *IdempotencyRecordUpsert) SetRequestFingerprint(v string) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldRequestFingerprint, v) + return u +} + +// UpdateRequestFingerprint sets the "request_fingerprint" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateRequestFingerprint() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldRequestFingerprint) + return u +} + +// SetStatus sets the "status" field. +func (u *IdempotencyRecordUpsert) SetStatus(v string) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldStatus, v) + return u +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateStatus() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldStatus) + return u +} + +// SetResponseStatus sets the "response_status" field. +func (u *IdempotencyRecordUpsert) SetResponseStatus(v int) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldResponseStatus, v) + return u +} + +// UpdateResponseStatus sets the "response_status" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateResponseStatus() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldResponseStatus) + return u +} + +// AddResponseStatus adds v to the "response_status" field. +func (u *IdempotencyRecordUpsert) AddResponseStatus(v int) *IdempotencyRecordUpsert { + u.Add(idempotencyrecord.FieldResponseStatus, v) + return u +} + +// ClearResponseStatus clears the value of the "response_status" field. +func (u *IdempotencyRecordUpsert) ClearResponseStatus() *IdempotencyRecordUpsert { + u.SetNull(idempotencyrecord.FieldResponseStatus) + return u +} + +// SetResponseBody sets the "response_body" field. +func (u *IdempotencyRecordUpsert) SetResponseBody(v string) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldResponseBody, v) + return u +} + +// UpdateResponseBody sets the "response_body" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateResponseBody() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldResponseBody) + return u +} + +// ClearResponseBody clears the value of the "response_body" field. +func (u *IdempotencyRecordUpsert) ClearResponseBody() *IdempotencyRecordUpsert { + u.SetNull(idempotencyrecord.FieldResponseBody) + return u +} + +// SetErrorReason sets the "error_reason" field. +func (u *IdempotencyRecordUpsert) SetErrorReason(v string) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldErrorReason, v) + return u +} + +// UpdateErrorReason sets the "error_reason" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateErrorReason() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldErrorReason) + return u +} + +// ClearErrorReason clears the value of the "error_reason" field. +func (u *IdempotencyRecordUpsert) ClearErrorReason() *IdempotencyRecordUpsert { + u.SetNull(idempotencyrecord.FieldErrorReason) + return u +} + +// SetLockedUntil sets the "locked_until" field. +func (u *IdempotencyRecordUpsert) SetLockedUntil(v time.Time) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldLockedUntil, v) + return u +} + +// UpdateLockedUntil sets the "locked_until" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateLockedUntil() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldLockedUntil) + return u +} + +// ClearLockedUntil clears the value of the "locked_until" field. +func (u *IdempotencyRecordUpsert) ClearLockedUntil() *IdempotencyRecordUpsert { + u.SetNull(idempotencyrecord.FieldLockedUntil) + return u +} + +// SetExpiresAt sets the "expires_at" field. +func (u *IdempotencyRecordUpsert) SetExpiresAt(v time.Time) *IdempotencyRecordUpsert { + u.Set(idempotencyrecord.FieldExpiresAt, v) + return u +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *IdempotencyRecordUpsert) UpdateExpiresAt() *IdempotencyRecordUpsert { + u.SetExcluded(idempotencyrecord.FieldExpiresAt) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.IdempotencyRecord.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *IdempotencyRecordUpsertOne) UpdateNewValues() *IdempotencyRecordUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(idempotencyrecord.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.IdempotencyRecord.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *IdempotencyRecordUpsertOne) Ignore() *IdempotencyRecordUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *IdempotencyRecordUpsertOne) DoNothing() *IdempotencyRecordUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the IdempotencyRecordCreate.OnConflict +// documentation for more info. +func (u *IdempotencyRecordUpsertOne) Update(set func(*IdempotencyRecordUpsert)) *IdempotencyRecordUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&IdempotencyRecordUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *IdempotencyRecordUpsertOne) SetUpdatedAt(v time.Time) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateUpdatedAt() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetScope sets the "scope" field. +func (u *IdempotencyRecordUpsertOne) SetScope(v string) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetScope(v) + }) +} + +// UpdateScope sets the "scope" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateScope() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateScope() + }) +} + +// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. +func (u *IdempotencyRecordUpsertOne) SetIdempotencyKeyHash(v string) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetIdempotencyKeyHash(v) + }) +} + +// UpdateIdempotencyKeyHash sets the "idempotency_key_hash" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateIdempotencyKeyHash() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateIdempotencyKeyHash() + }) +} + +// SetRequestFingerprint sets the "request_fingerprint" field. +func (u *IdempotencyRecordUpsertOne) SetRequestFingerprint(v string) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetRequestFingerprint(v) + }) +} + +// UpdateRequestFingerprint sets the "request_fingerprint" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateRequestFingerprint() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateRequestFingerprint() + }) +} + +// SetStatus sets the "status" field. +func (u *IdempotencyRecordUpsertOne) SetStatus(v string) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetStatus(v) + }) +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateStatus() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateStatus() + }) +} + +// SetResponseStatus sets the "response_status" field. +func (u *IdempotencyRecordUpsertOne) SetResponseStatus(v int) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetResponseStatus(v) + }) +} + +// AddResponseStatus adds v to the "response_status" field. +func (u *IdempotencyRecordUpsertOne) AddResponseStatus(v int) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.AddResponseStatus(v) + }) +} + +// UpdateResponseStatus sets the "response_status" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateResponseStatus() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateResponseStatus() + }) +} + +// ClearResponseStatus clears the value of the "response_status" field. +func (u *IdempotencyRecordUpsertOne) ClearResponseStatus() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.ClearResponseStatus() + }) +} + +// SetResponseBody sets the "response_body" field. +func (u *IdempotencyRecordUpsertOne) SetResponseBody(v string) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetResponseBody(v) + }) +} + +// UpdateResponseBody sets the "response_body" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateResponseBody() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateResponseBody() + }) +} + +// ClearResponseBody clears the value of the "response_body" field. +func (u *IdempotencyRecordUpsertOne) ClearResponseBody() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.ClearResponseBody() + }) +} + +// SetErrorReason sets the "error_reason" field. +func (u *IdempotencyRecordUpsertOne) SetErrorReason(v string) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetErrorReason(v) + }) +} + +// UpdateErrorReason sets the "error_reason" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateErrorReason() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateErrorReason() + }) +} + +// ClearErrorReason clears the value of the "error_reason" field. +func (u *IdempotencyRecordUpsertOne) ClearErrorReason() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.ClearErrorReason() + }) +} + +// SetLockedUntil sets the "locked_until" field. +func (u *IdempotencyRecordUpsertOne) SetLockedUntil(v time.Time) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetLockedUntil(v) + }) +} + +// UpdateLockedUntil sets the "locked_until" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateLockedUntil() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateLockedUntil() + }) +} + +// ClearLockedUntil clears the value of the "locked_until" field. +func (u *IdempotencyRecordUpsertOne) ClearLockedUntil() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.ClearLockedUntil() + }) +} + +// SetExpiresAt sets the "expires_at" field. +func (u *IdempotencyRecordUpsertOne) SetExpiresAt(v time.Time) *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetExpiresAt(v) + }) +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertOne) UpdateExpiresAt() *IdempotencyRecordUpsertOne { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateExpiresAt() + }) +} + +// Exec executes the query. +func (u *IdempotencyRecordUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for IdempotencyRecordCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *IdempotencyRecordUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *IdempotencyRecordUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *IdempotencyRecordUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// IdempotencyRecordCreateBulk is the builder for creating many IdempotencyRecord entities in bulk. +type IdempotencyRecordCreateBulk struct { + config + err error + builders []*IdempotencyRecordCreate + conflict []sql.ConflictOption +} + +// Save creates the IdempotencyRecord entities in the database. +func (_c *IdempotencyRecordCreateBulk) Save(ctx context.Context) ([]*IdempotencyRecord, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*IdempotencyRecord, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*IdempotencyRecordMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *IdempotencyRecordCreateBulk) SaveX(ctx context.Context) []*IdempotencyRecord { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *IdempotencyRecordCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *IdempotencyRecordCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.IdempotencyRecord.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.IdempotencyRecordUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *IdempotencyRecordCreateBulk) OnConflict(opts ...sql.ConflictOption) *IdempotencyRecordUpsertBulk { + _c.conflict = opts + return &IdempotencyRecordUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.IdempotencyRecord.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *IdempotencyRecordCreateBulk) OnConflictColumns(columns ...string) *IdempotencyRecordUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &IdempotencyRecordUpsertBulk{ + create: _c, + } +} + +// IdempotencyRecordUpsertBulk is the builder for "upsert"-ing +// a bulk of IdempotencyRecord nodes. +type IdempotencyRecordUpsertBulk struct { + create *IdempotencyRecordCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.IdempotencyRecord.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *IdempotencyRecordUpsertBulk) UpdateNewValues() *IdempotencyRecordUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(idempotencyrecord.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.IdempotencyRecord.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *IdempotencyRecordUpsertBulk) Ignore() *IdempotencyRecordUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *IdempotencyRecordUpsertBulk) DoNothing() *IdempotencyRecordUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the IdempotencyRecordCreateBulk.OnConflict +// documentation for more info. +func (u *IdempotencyRecordUpsertBulk) Update(set func(*IdempotencyRecordUpsert)) *IdempotencyRecordUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&IdempotencyRecordUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *IdempotencyRecordUpsertBulk) SetUpdatedAt(v time.Time) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateUpdatedAt() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetScope sets the "scope" field. +func (u *IdempotencyRecordUpsertBulk) SetScope(v string) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetScope(v) + }) +} + +// UpdateScope sets the "scope" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateScope() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateScope() + }) +} + +// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. +func (u *IdempotencyRecordUpsertBulk) SetIdempotencyKeyHash(v string) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetIdempotencyKeyHash(v) + }) +} + +// UpdateIdempotencyKeyHash sets the "idempotency_key_hash" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateIdempotencyKeyHash() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateIdempotencyKeyHash() + }) +} + +// SetRequestFingerprint sets the "request_fingerprint" field. +func (u *IdempotencyRecordUpsertBulk) SetRequestFingerprint(v string) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetRequestFingerprint(v) + }) +} + +// UpdateRequestFingerprint sets the "request_fingerprint" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateRequestFingerprint() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateRequestFingerprint() + }) +} + +// SetStatus sets the "status" field. +func (u *IdempotencyRecordUpsertBulk) SetStatus(v string) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetStatus(v) + }) +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateStatus() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateStatus() + }) +} + +// SetResponseStatus sets the "response_status" field. +func (u *IdempotencyRecordUpsertBulk) SetResponseStatus(v int) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetResponseStatus(v) + }) +} + +// AddResponseStatus adds v to the "response_status" field. +func (u *IdempotencyRecordUpsertBulk) AddResponseStatus(v int) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.AddResponseStatus(v) + }) +} + +// UpdateResponseStatus sets the "response_status" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateResponseStatus() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateResponseStatus() + }) +} + +// ClearResponseStatus clears the value of the "response_status" field. +func (u *IdempotencyRecordUpsertBulk) ClearResponseStatus() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.ClearResponseStatus() + }) +} + +// SetResponseBody sets the "response_body" field. +func (u *IdempotencyRecordUpsertBulk) SetResponseBody(v string) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetResponseBody(v) + }) +} + +// UpdateResponseBody sets the "response_body" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateResponseBody() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateResponseBody() + }) +} + +// ClearResponseBody clears the value of the "response_body" field. +func (u *IdempotencyRecordUpsertBulk) ClearResponseBody() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.ClearResponseBody() + }) +} + +// SetErrorReason sets the "error_reason" field. +func (u *IdempotencyRecordUpsertBulk) SetErrorReason(v string) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetErrorReason(v) + }) +} + +// UpdateErrorReason sets the "error_reason" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateErrorReason() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateErrorReason() + }) +} + +// ClearErrorReason clears the value of the "error_reason" field. +func (u *IdempotencyRecordUpsertBulk) ClearErrorReason() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.ClearErrorReason() + }) +} + +// SetLockedUntil sets the "locked_until" field. +func (u *IdempotencyRecordUpsertBulk) SetLockedUntil(v time.Time) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetLockedUntil(v) + }) +} + +// UpdateLockedUntil sets the "locked_until" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateLockedUntil() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateLockedUntil() + }) +} + +// ClearLockedUntil clears the value of the "locked_until" field. +func (u *IdempotencyRecordUpsertBulk) ClearLockedUntil() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.ClearLockedUntil() + }) +} + +// SetExpiresAt sets the "expires_at" field. +func (u *IdempotencyRecordUpsertBulk) SetExpiresAt(v time.Time) *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.SetExpiresAt(v) + }) +} + +// UpdateExpiresAt sets the "expires_at" field to the value that was provided on create. +func (u *IdempotencyRecordUpsertBulk) UpdateExpiresAt() *IdempotencyRecordUpsertBulk { + return u.Update(func(s *IdempotencyRecordUpsert) { + s.UpdateExpiresAt() + }) +} + +// Exec executes the query. +func (u *IdempotencyRecordUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the IdempotencyRecordCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for IdempotencyRecordCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *IdempotencyRecordUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/idempotencyrecord_delete.go b/backend/ent/idempotencyrecord_delete.go new file mode 100644 index 00000000..f5c87559 --- /dev/null +++ b/backend/ent/idempotencyrecord_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// IdempotencyRecordDelete is the builder for deleting a IdempotencyRecord entity. +type IdempotencyRecordDelete struct { + config + hooks []Hook + mutation *IdempotencyRecordMutation +} + +// Where appends a list predicates to the IdempotencyRecordDelete builder. +func (_d *IdempotencyRecordDelete) Where(ps ...predicate.IdempotencyRecord) *IdempotencyRecordDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *IdempotencyRecordDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *IdempotencyRecordDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *IdempotencyRecordDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(idempotencyrecord.Table, sqlgraph.NewFieldSpec(idempotencyrecord.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// IdempotencyRecordDeleteOne is the builder for deleting a single IdempotencyRecord entity. +type IdempotencyRecordDeleteOne struct { + _d *IdempotencyRecordDelete +} + +// Where appends a list predicates to the IdempotencyRecordDelete builder. +func (_d *IdempotencyRecordDeleteOne) Where(ps ...predicate.IdempotencyRecord) *IdempotencyRecordDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *IdempotencyRecordDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{idempotencyrecord.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *IdempotencyRecordDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/idempotencyrecord_query.go b/backend/ent/idempotencyrecord_query.go new file mode 100644 index 00000000..fbba4dfa --- /dev/null +++ b/backend/ent/idempotencyrecord_query.go @@ -0,0 +1,564 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// IdempotencyRecordQuery is the builder for querying IdempotencyRecord entities. +type IdempotencyRecordQuery struct { + config + ctx *QueryContext + order []idempotencyrecord.OrderOption + inters []Interceptor + predicates []predicate.IdempotencyRecord + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the IdempotencyRecordQuery builder. +func (_q *IdempotencyRecordQuery) Where(ps ...predicate.IdempotencyRecord) *IdempotencyRecordQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *IdempotencyRecordQuery) Limit(limit int) *IdempotencyRecordQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *IdempotencyRecordQuery) Offset(offset int) *IdempotencyRecordQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *IdempotencyRecordQuery) Unique(unique bool) *IdempotencyRecordQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *IdempotencyRecordQuery) Order(o ...idempotencyrecord.OrderOption) *IdempotencyRecordQuery { + _q.order = append(_q.order, o...) + return _q +} + +// First returns the first IdempotencyRecord entity from the query. +// Returns a *NotFoundError when no IdempotencyRecord was found. +func (_q *IdempotencyRecordQuery) First(ctx context.Context) (*IdempotencyRecord, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{idempotencyrecord.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *IdempotencyRecordQuery) FirstX(ctx context.Context) *IdempotencyRecord { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first IdempotencyRecord ID from the query. +// Returns a *NotFoundError when no IdempotencyRecord ID was found. +func (_q *IdempotencyRecordQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{idempotencyrecord.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *IdempotencyRecordQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single IdempotencyRecord entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one IdempotencyRecord entity is found. +// Returns a *NotFoundError when no IdempotencyRecord entities are found. +func (_q *IdempotencyRecordQuery) Only(ctx context.Context) (*IdempotencyRecord, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{idempotencyrecord.Label} + default: + return nil, &NotSingularError{idempotencyrecord.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *IdempotencyRecordQuery) OnlyX(ctx context.Context) *IdempotencyRecord { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only IdempotencyRecord ID in the query. +// Returns a *NotSingularError when more than one IdempotencyRecord ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *IdempotencyRecordQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{idempotencyrecord.Label} + default: + err = &NotSingularError{idempotencyrecord.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *IdempotencyRecordQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of IdempotencyRecords. +func (_q *IdempotencyRecordQuery) All(ctx context.Context) ([]*IdempotencyRecord, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*IdempotencyRecord, *IdempotencyRecordQuery]() + return withInterceptors[[]*IdempotencyRecord](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *IdempotencyRecordQuery) AllX(ctx context.Context) []*IdempotencyRecord { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of IdempotencyRecord IDs. +func (_q *IdempotencyRecordQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(idempotencyrecord.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *IdempotencyRecordQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *IdempotencyRecordQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*IdempotencyRecordQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *IdempotencyRecordQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *IdempotencyRecordQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *IdempotencyRecordQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the IdempotencyRecordQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *IdempotencyRecordQuery) Clone() *IdempotencyRecordQuery { + if _q == nil { + return nil + } + return &IdempotencyRecordQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]idempotencyrecord.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.IdempotencyRecord{}, _q.predicates...), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.IdempotencyRecord.Query(). +// GroupBy(idempotencyrecord.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *IdempotencyRecordQuery) GroupBy(field string, fields ...string) *IdempotencyRecordGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &IdempotencyRecordGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = idempotencyrecord.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.IdempotencyRecord.Query(). +// Select(idempotencyrecord.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *IdempotencyRecordQuery) Select(fields ...string) *IdempotencyRecordSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &IdempotencyRecordSelect{IdempotencyRecordQuery: _q} + sbuild.label = idempotencyrecord.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a IdempotencyRecordSelect configured with the given aggregations. +func (_q *IdempotencyRecordQuery) Aggregate(fns ...AggregateFunc) *IdempotencyRecordSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *IdempotencyRecordQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !idempotencyrecord.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *IdempotencyRecordQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*IdempotencyRecord, error) { + var ( + nodes = []*IdempotencyRecord{} + _spec = _q.querySpec() + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*IdempotencyRecord).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &IdempotencyRecord{config: _q.config} + nodes = append(nodes, node) + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + return nodes, nil +} + +func (_q *IdempotencyRecordQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *IdempotencyRecordQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(idempotencyrecord.Table, idempotencyrecord.Columns, sqlgraph.NewFieldSpec(idempotencyrecord.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, idempotencyrecord.FieldID) + for i := range fields { + if fields[i] != idempotencyrecord.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *IdempotencyRecordQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(idempotencyrecord.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = idempotencyrecord.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *IdempotencyRecordQuery) ForUpdate(opts ...sql.LockOption) *IdempotencyRecordQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *IdempotencyRecordQuery) ForShare(opts ...sql.LockOption) *IdempotencyRecordQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// IdempotencyRecordGroupBy is the group-by builder for IdempotencyRecord entities. +type IdempotencyRecordGroupBy struct { + selector + build *IdempotencyRecordQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *IdempotencyRecordGroupBy) Aggregate(fns ...AggregateFunc) *IdempotencyRecordGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *IdempotencyRecordGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*IdempotencyRecordQuery, *IdempotencyRecordGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *IdempotencyRecordGroupBy) sqlScan(ctx context.Context, root *IdempotencyRecordQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// IdempotencyRecordSelect is the builder for selecting fields of IdempotencyRecord entities. +type IdempotencyRecordSelect struct { + *IdempotencyRecordQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *IdempotencyRecordSelect) Aggregate(fns ...AggregateFunc) *IdempotencyRecordSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *IdempotencyRecordSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*IdempotencyRecordQuery, *IdempotencyRecordSelect](ctx, _s.IdempotencyRecordQuery, _s, _s.inters, v) +} + +func (_s *IdempotencyRecordSelect) sqlScan(ctx context.Context, root *IdempotencyRecordQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/idempotencyrecord_update.go b/backend/ent/idempotencyrecord_update.go new file mode 100644 index 00000000..f839e5c0 --- /dev/null +++ b/backend/ent/idempotencyrecord_update.go @@ -0,0 +1,676 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// IdempotencyRecordUpdate is the builder for updating IdempotencyRecord entities. +type IdempotencyRecordUpdate struct { + config + hooks []Hook + mutation *IdempotencyRecordMutation +} + +// Where appends a list predicates to the IdempotencyRecordUpdate builder. +func (_u *IdempotencyRecordUpdate) Where(ps ...predicate.IdempotencyRecord) *IdempotencyRecordUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *IdempotencyRecordUpdate) SetUpdatedAt(v time.Time) *IdempotencyRecordUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetScope sets the "scope" field. +func (_u *IdempotencyRecordUpdate) SetScope(v string) *IdempotencyRecordUpdate { + _u.mutation.SetScope(v) + return _u +} + +// SetNillableScope sets the "scope" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableScope(v *string) *IdempotencyRecordUpdate { + if v != nil { + _u.SetScope(*v) + } + return _u +} + +// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. +func (_u *IdempotencyRecordUpdate) SetIdempotencyKeyHash(v string) *IdempotencyRecordUpdate { + _u.mutation.SetIdempotencyKeyHash(v) + return _u +} + +// SetNillableIdempotencyKeyHash sets the "idempotency_key_hash" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableIdempotencyKeyHash(v *string) *IdempotencyRecordUpdate { + if v != nil { + _u.SetIdempotencyKeyHash(*v) + } + return _u +} + +// SetRequestFingerprint sets the "request_fingerprint" field. +func (_u *IdempotencyRecordUpdate) SetRequestFingerprint(v string) *IdempotencyRecordUpdate { + _u.mutation.SetRequestFingerprint(v) + return _u +} + +// SetNillableRequestFingerprint sets the "request_fingerprint" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableRequestFingerprint(v *string) *IdempotencyRecordUpdate { + if v != nil { + _u.SetRequestFingerprint(*v) + } + return _u +} + +// SetStatus sets the "status" field. +func (_u *IdempotencyRecordUpdate) SetStatus(v string) *IdempotencyRecordUpdate { + _u.mutation.SetStatus(v) + return _u +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableStatus(v *string) *IdempotencyRecordUpdate { + if v != nil { + _u.SetStatus(*v) + } + return _u +} + +// SetResponseStatus sets the "response_status" field. +func (_u *IdempotencyRecordUpdate) SetResponseStatus(v int) *IdempotencyRecordUpdate { + _u.mutation.ResetResponseStatus() + _u.mutation.SetResponseStatus(v) + return _u +} + +// SetNillableResponseStatus sets the "response_status" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableResponseStatus(v *int) *IdempotencyRecordUpdate { + if v != nil { + _u.SetResponseStatus(*v) + } + return _u +} + +// AddResponseStatus adds value to the "response_status" field. +func (_u *IdempotencyRecordUpdate) AddResponseStatus(v int) *IdempotencyRecordUpdate { + _u.mutation.AddResponseStatus(v) + return _u +} + +// ClearResponseStatus clears the value of the "response_status" field. +func (_u *IdempotencyRecordUpdate) ClearResponseStatus() *IdempotencyRecordUpdate { + _u.mutation.ClearResponseStatus() + return _u +} + +// SetResponseBody sets the "response_body" field. +func (_u *IdempotencyRecordUpdate) SetResponseBody(v string) *IdempotencyRecordUpdate { + _u.mutation.SetResponseBody(v) + return _u +} + +// SetNillableResponseBody sets the "response_body" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableResponseBody(v *string) *IdempotencyRecordUpdate { + if v != nil { + _u.SetResponseBody(*v) + } + return _u +} + +// ClearResponseBody clears the value of the "response_body" field. +func (_u *IdempotencyRecordUpdate) ClearResponseBody() *IdempotencyRecordUpdate { + _u.mutation.ClearResponseBody() + return _u +} + +// SetErrorReason sets the "error_reason" field. +func (_u *IdempotencyRecordUpdate) SetErrorReason(v string) *IdempotencyRecordUpdate { + _u.mutation.SetErrorReason(v) + return _u +} + +// SetNillableErrorReason sets the "error_reason" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableErrorReason(v *string) *IdempotencyRecordUpdate { + if v != nil { + _u.SetErrorReason(*v) + } + return _u +} + +// ClearErrorReason clears the value of the "error_reason" field. +func (_u *IdempotencyRecordUpdate) ClearErrorReason() *IdempotencyRecordUpdate { + _u.mutation.ClearErrorReason() + return _u +} + +// SetLockedUntil sets the "locked_until" field. +func (_u *IdempotencyRecordUpdate) SetLockedUntil(v time.Time) *IdempotencyRecordUpdate { + _u.mutation.SetLockedUntil(v) + return _u +} + +// SetNillableLockedUntil sets the "locked_until" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableLockedUntil(v *time.Time) *IdempotencyRecordUpdate { + if v != nil { + _u.SetLockedUntil(*v) + } + return _u +} + +// ClearLockedUntil clears the value of the "locked_until" field. +func (_u *IdempotencyRecordUpdate) ClearLockedUntil() *IdempotencyRecordUpdate { + _u.mutation.ClearLockedUntil() + return _u +} + +// SetExpiresAt sets the "expires_at" field. +func (_u *IdempotencyRecordUpdate) SetExpiresAt(v time.Time) *IdempotencyRecordUpdate { + _u.mutation.SetExpiresAt(v) + return _u +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (_u *IdempotencyRecordUpdate) SetNillableExpiresAt(v *time.Time) *IdempotencyRecordUpdate { + if v != nil { + _u.SetExpiresAt(*v) + } + return _u +} + +// Mutation returns the IdempotencyRecordMutation object of the builder. +func (_u *IdempotencyRecordUpdate) Mutation() *IdempotencyRecordMutation { + return _u.mutation +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *IdempotencyRecordUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *IdempotencyRecordUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *IdempotencyRecordUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *IdempotencyRecordUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *IdempotencyRecordUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := idempotencyrecord.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *IdempotencyRecordUpdate) check() error { + if v, ok := _u.mutation.Scope(); ok { + if err := idempotencyrecord.ScopeValidator(v); err != nil { + return &ValidationError{Name: "scope", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.scope": %w`, err)} + } + } + if v, ok := _u.mutation.IdempotencyKeyHash(); ok { + if err := idempotencyrecord.IdempotencyKeyHashValidator(v); err != nil { + return &ValidationError{Name: "idempotency_key_hash", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.idempotency_key_hash": %w`, err)} + } + } + if v, ok := _u.mutation.RequestFingerprint(); ok { + if err := idempotencyrecord.RequestFingerprintValidator(v); err != nil { + return &ValidationError{Name: "request_fingerprint", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.request_fingerprint": %w`, err)} + } + } + if v, ok := _u.mutation.Status(); ok { + if err := idempotencyrecord.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.status": %w`, err)} + } + } + if v, ok := _u.mutation.ErrorReason(); ok { + if err := idempotencyrecord.ErrorReasonValidator(v); err != nil { + return &ValidationError{Name: "error_reason", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.error_reason": %w`, err)} + } + } + return nil +} + +func (_u *IdempotencyRecordUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(idempotencyrecord.Table, idempotencyrecord.Columns, sqlgraph.NewFieldSpec(idempotencyrecord.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(idempotencyrecord.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.Scope(); ok { + _spec.SetField(idempotencyrecord.FieldScope, field.TypeString, value) + } + if value, ok := _u.mutation.IdempotencyKeyHash(); ok { + _spec.SetField(idempotencyrecord.FieldIdempotencyKeyHash, field.TypeString, value) + } + if value, ok := _u.mutation.RequestFingerprint(); ok { + _spec.SetField(idempotencyrecord.FieldRequestFingerprint, field.TypeString, value) + } + if value, ok := _u.mutation.Status(); ok { + _spec.SetField(idempotencyrecord.FieldStatus, field.TypeString, value) + } + if value, ok := _u.mutation.ResponseStatus(); ok { + _spec.SetField(idempotencyrecord.FieldResponseStatus, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedResponseStatus(); ok { + _spec.AddField(idempotencyrecord.FieldResponseStatus, field.TypeInt, value) + } + if _u.mutation.ResponseStatusCleared() { + _spec.ClearField(idempotencyrecord.FieldResponseStatus, field.TypeInt) + } + if value, ok := _u.mutation.ResponseBody(); ok { + _spec.SetField(idempotencyrecord.FieldResponseBody, field.TypeString, value) + } + if _u.mutation.ResponseBodyCleared() { + _spec.ClearField(idempotencyrecord.FieldResponseBody, field.TypeString) + } + if value, ok := _u.mutation.ErrorReason(); ok { + _spec.SetField(idempotencyrecord.FieldErrorReason, field.TypeString, value) + } + if _u.mutation.ErrorReasonCleared() { + _spec.ClearField(idempotencyrecord.FieldErrorReason, field.TypeString) + } + if value, ok := _u.mutation.LockedUntil(); ok { + _spec.SetField(idempotencyrecord.FieldLockedUntil, field.TypeTime, value) + } + if _u.mutation.LockedUntilCleared() { + _spec.ClearField(idempotencyrecord.FieldLockedUntil, field.TypeTime) + } + if value, ok := _u.mutation.ExpiresAt(); ok { + _spec.SetField(idempotencyrecord.FieldExpiresAt, field.TypeTime, value) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{idempotencyrecord.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// IdempotencyRecordUpdateOne is the builder for updating a single IdempotencyRecord entity. +type IdempotencyRecordUpdateOne struct { + config + fields []string + hooks []Hook + mutation *IdempotencyRecordMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *IdempotencyRecordUpdateOne) SetUpdatedAt(v time.Time) *IdempotencyRecordUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetScope sets the "scope" field. +func (_u *IdempotencyRecordUpdateOne) SetScope(v string) *IdempotencyRecordUpdateOne { + _u.mutation.SetScope(v) + return _u +} + +// SetNillableScope sets the "scope" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableScope(v *string) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetScope(*v) + } + return _u +} + +// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. +func (_u *IdempotencyRecordUpdateOne) SetIdempotencyKeyHash(v string) *IdempotencyRecordUpdateOne { + _u.mutation.SetIdempotencyKeyHash(v) + return _u +} + +// SetNillableIdempotencyKeyHash sets the "idempotency_key_hash" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableIdempotencyKeyHash(v *string) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetIdempotencyKeyHash(*v) + } + return _u +} + +// SetRequestFingerprint sets the "request_fingerprint" field. +func (_u *IdempotencyRecordUpdateOne) SetRequestFingerprint(v string) *IdempotencyRecordUpdateOne { + _u.mutation.SetRequestFingerprint(v) + return _u +} + +// SetNillableRequestFingerprint sets the "request_fingerprint" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableRequestFingerprint(v *string) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetRequestFingerprint(*v) + } + return _u +} + +// SetStatus sets the "status" field. +func (_u *IdempotencyRecordUpdateOne) SetStatus(v string) *IdempotencyRecordUpdateOne { + _u.mutation.SetStatus(v) + return _u +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableStatus(v *string) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetStatus(*v) + } + return _u +} + +// SetResponseStatus sets the "response_status" field. +func (_u *IdempotencyRecordUpdateOne) SetResponseStatus(v int) *IdempotencyRecordUpdateOne { + _u.mutation.ResetResponseStatus() + _u.mutation.SetResponseStatus(v) + return _u +} + +// SetNillableResponseStatus sets the "response_status" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableResponseStatus(v *int) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetResponseStatus(*v) + } + return _u +} + +// AddResponseStatus adds value to the "response_status" field. +func (_u *IdempotencyRecordUpdateOne) AddResponseStatus(v int) *IdempotencyRecordUpdateOne { + _u.mutation.AddResponseStatus(v) + return _u +} + +// ClearResponseStatus clears the value of the "response_status" field. +func (_u *IdempotencyRecordUpdateOne) ClearResponseStatus() *IdempotencyRecordUpdateOne { + _u.mutation.ClearResponseStatus() + return _u +} + +// SetResponseBody sets the "response_body" field. +func (_u *IdempotencyRecordUpdateOne) SetResponseBody(v string) *IdempotencyRecordUpdateOne { + _u.mutation.SetResponseBody(v) + return _u +} + +// SetNillableResponseBody sets the "response_body" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableResponseBody(v *string) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetResponseBody(*v) + } + return _u +} + +// ClearResponseBody clears the value of the "response_body" field. +func (_u *IdempotencyRecordUpdateOne) ClearResponseBody() *IdempotencyRecordUpdateOne { + _u.mutation.ClearResponseBody() + return _u +} + +// SetErrorReason sets the "error_reason" field. +func (_u *IdempotencyRecordUpdateOne) SetErrorReason(v string) *IdempotencyRecordUpdateOne { + _u.mutation.SetErrorReason(v) + return _u +} + +// SetNillableErrorReason sets the "error_reason" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableErrorReason(v *string) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetErrorReason(*v) + } + return _u +} + +// ClearErrorReason clears the value of the "error_reason" field. +func (_u *IdempotencyRecordUpdateOne) ClearErrorReason() *IdempotencyRecordUpdateOne { + _u.mutation.ClearErrorReason() + return _u +} + +// SetLockedUntil sets the "locked_until" field. +func (_u *IdempotencyRecordUpdateOne) SetLockedUntil(v time.Time) *IdempotencyRecordUpdateOne { + _u.mutation.SetLockedUntil(v) + return _u +} + +// SetNillableLockedUntil sets the "locked_until" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableLockedUntil(v *time.Time) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetLockedUntil(*v) + } + return _u +} + +// ClearLockedUntil clears the value of the "locked_until" field. +func (_u *IdempotencyRecordUpdateOne) ClearLockedUntil() *IdempotencyRecordUpdateOne { + _u.mutation.ClearLockedUntil() + return _u +} + +// SetExpiresAt sets the "expires_at" field. +func (_u *IdempotencyRecordUpdateOne) SetExpiresAt(v time.Time) *IdempotencyRecordUpdateOne { + _u.mutation.SetExpiresAt(v) + return _u +} + +// SetNillableExpiresAt sets the "expires_at" field if the given value is not nil. +func (_u *IdempotencyRecordUpdateOne) SetNillableExpiresAt(v *time.Time) *IdempotencyRecordUpdateOne { + if v != nil { + _u.SetExpiresAt(*v) + } + return _u +} + +// Mutation returns the IdempotencyRecordMutation object of the builder. +func (_u *IdempotencyRecordUpdateOne) Mutation() *IdempotencyRecordMutation { + return _u.mutation +} + +// Where appends a list predicates to the IdempotencyRecordUpdate builder. +func (_u *IdempotencyRecordUpdateOne) Where(ps ...predicate.IdempotencyRecord) *IdempotencyRecordUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *IdempotencyRecordUpdateOne) Select(field string, fields ...string) *IdempotencyRecordUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated IdempotencyRecord entity. +func (_u *IdempotencyRecordUpdateOne) Save(ctx context.Context) (*IdempotencyRecord, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *IdempotencyRecordUpdateOne) SaveX(ctx context.Context) *IdempotencyRecord { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *IdempotencyRecordUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *IdempotencyRecordUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *IdempotencyRecordUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := idempotencyrecord.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *IdempotencyRecordUpdateOne) check() error { + if v, ok := _u.mutation.Scope(); ok { + if err := idempotencyrecord.ScopeValidator(v); err != nil { + return &ValidationError{Name: "scope", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.scope": %w`, err)} + } + } + if v, ok := _u.mutation.IdempotencyKeyHash(); ok { + if err := idempotencyrecord.IdempotencyKeyHashValidator(v); err != nil { + return &ValidationError{Name: "idempotency_key_hash", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.idempotency_key_hash": %w`, err)} + } + } + if v, ok := _u.mutation.RequestFingerprint(); ok { + if err := idempotencyrecord.RequestFingerprintValidator(v); err != nil { + return &ValidationError{Name: "request_fingerprint", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.request_fingerprint": %w`, err)} + } + } + if v, ok := _u.mutation.Status(); ok { + if err := idempotencyrecord.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.status": %w`, err)} + } + } + if v, ok := _u.mutation.ErrorReason(); ok { + if err := idempotencyrecord.ErrorReasonValidator(v); err != nil { + return &ValidationError{Name: "error_reason", err: fmt.Errorf(`ent: validator failed for field "IdempotencyRecord.error_reason": %w`, err)} + } + } + return nil +} + +func (_u *IdempotencyRecordUpdateOne) sqlSave(ctx context.Context) (_node *IdempotencyRecord, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(idempotencyrecord.Table, idempotencyrecord.Columns, sqlgraph.NewFieldSpec(idempotencyrecord.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "IdempotencyRecord.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, idempotencyrecord.FieldID) + for _, f := range fields { + if !idempotencyrecord.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != idempotencyrecord.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(idempotencyrecord.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.Scope(); ok { + _spec.SetField(idempotencyrecord.FieldScope, field.TypeString, value) + } + if value, ok := _u.mutation.IdempotencyKeyHash(); ok { + _spec.SetField(idempotencyrecord.FieldIdempotencyKeyHash, field.TypeString, value) + } + if value, ok := _u.mutation.RequestFingerprint(); ok { + _spec.SetField(idempotencyrecord.FieldRequestFingerprint, field.TypeString, value) + } + if value, ok := _u.mutation.Status(); ok { + _spec.SetField(idempotencyrecord.FieldStatus, field.TypeString, value) + } + if value, ok := _u.mutation.ResponseStatus(); ok { + _spec.SetField(idempotencyrecord.FieldResponseStatus, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedResponseStatus(); ok { + _spec.AddField(idempotencyrecord.FieldResponseStatus, field.TypeInt, value) + } + if _u.mutation.ResponseStatusCleared() { + _spec.ClearField(idempotencyrecord.FieldResponseStatus, field.TypeInt) + } + if value, ok := _u.mutation.ResponseBody(); ok { + _spec.SetField(idempotencyrecord.FieldResponseBody, field.TypeString, value) + } + if _u.mutation.ResponseBodyCleared() { + _spec.ClearField(idempotencyrecord.FieldResponseBody, field.TypeString) + } + if value, ok := _u.mutation.ErrorReason(); ok { + _spec.SetField(idempotencyrecord.FieldErrorReason, field.TypeString, value) + } + if _u.mutation.ErrorReasonCleared() { + _spec.ClearField(idempotencyrecord.FieldErrorReason, field.TypeString) + } + if value, ok := _u.mutation.LockedUntil(); ok { + _spec.SetField(idempotencyrecord.FieldLockedUntil, field.TypeTime, value) + } + if _u.mutation.LockedUntilCleared() { + _spec.ClearField(idempotencyrecord.FieldLockedUntil, field.TypeTime) + } + if value, ok := _u.mutation.ExpiresAt(); ok { + _spec.SetField(idempotencyrecord.FieldExpiresAt, field.TypeTime, value) + } + _node = &IdempotencyRecord{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{idempotencyrecord.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go index 290fb163..e7746402 100644 --- a/backend/ent/intercept/intercept.go +++ b/backend/ent/intercept/intercept.go @@ -15,6 +15,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" @@ -276,6 +277,33 @@ func (f TraverseGroup) Traverse(ctx context.Context, q ent.Query) error { return fmt.Errorf("unexpected query type %T. expect *ent.GroupQuery", q) } +// The IdempotencyRecordFunc type is an adapter to allow the use of ordinary function as a Querier. +type IdempotencyRecordFunc func(context.Context, *ent.IdempotencyRecordQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f IdempotencyRecordFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.IdempotencyRecordQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.IdempotencyRecordQuery", q) +} + +// The TraverseIdempotencyRecord type is an adapter to allow the use of ordinary function as Traverser. +type TraverseIdempotencyRecord func(context.Context, *ent.IdempotencyRecordQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseIdempotencyRecord) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseIdempotencyRecord) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.IdempotencyRecordQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.IdempotencyRecordQuery", q) +} + // The PromoCodeFunc type is an adapter to allow the use of ordinary function as a Querier. type PromoCodeFunc func(context.Context, *ent.PromoCodeQuery) (ent.Value, error) @@ -644,6 +672,8 @@ func NewQuery(q ent.Query) (Query, error) { return &query[*ent.ErrorPassthroughRuleQuery, predicate.ErrorPassthroughRule, errorpassthroughrule.OrderOption]{typ: ent.TypeErrorPassthroughRule, tq: q}, nil case *ent.GroupQuery: return &query[*ent.GroupQuery, predicate.Group, group.OrderOption]{typ: ent.TypeGroup, tq: q}, nil + case *ent.IdempotencyRecordQuery: + return &query[*ent.IdempotencyRecordQuery, predicate.IdempotencyRecord, idempotencyrecord.OrderOption]{typ: ent.TypeIdempotencyRecord, tq: q}, nil case *ent.PromoCodeQuery: return &query[*ent.PromoCodeQuery, predicate.PromoCode, promocode.OrderOption]{typ: ent.TypePromoCode, tq: q}, nil case *ent.PromoCodeUsageQuery: diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index aba00d4f..769dddce 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -108,6 +108,8 @@ var ( {Name: "rate_limited_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "rate_limit_reset_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "overload_until", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "temp_unschedulable_until", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "temp_unschedulable_reason", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, {Name: "session_window_start", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "session_window_end", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "session_window_status", Type: field.TypeString, Nullable: true, Size: 20}, @@ -121,7 +123,7 @@ var ( ForeignKeys: []*schema.ForeignKey{ { Symbol: "accounts_proxies_proxy", - Columns: []*schema.Column{AccountsColumns[25]}, + Columns: []*schema.Column{AccountsColumns[27]}, RefColumns: []*schema.Column{ProxiesColumns[0]}, OnDelete: schema.SetNull, }, @@ -145,7 +147,7 @@ var ( { Name: "account_proxy_id", Unique: false, - Columns: []*schema.Column{AccountsColumns[25]}, + Columns: []*schema.Column{AccountsColumns[27]}, }, { Name: "account_priority", @@ -177,6 +179,16 @@ var ( Unique: false, Columns: []*schema.Column{AccountsColumns[21]}, }, + { + Name: "account_platform_priority", + Unique: false, + Columns: []*schema.Column{AccountsColumns[6], AccountsColumns[11]}, + }, + { + Name: "account_priority_status", + Unique: false, + Columns: []*schema.Column{AccountsColumns[11], AccountsColumns[13]}, + }, { Name: "account_deleted_at", Unique: false, @@ -376,6 +388,7 @@ var ( {Name: "sora_image_price_540", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "sora_video_price_per_request", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "sora_video_price_per_request_hd", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "sora_storage_quota_bytes", Type: field.TypeInt64, Default: 0}, {Name: "claude_code_only", Type: field.TypeBool, Default: false}, {Name: "fallback_group_id", Type: field.TypeInt64, Nullable: true}, {Name: "fallback_group_id_on_invalid_request", Type: field.TypeInt64, Nullable: true}, @@ -419,7 +432,45 @@ var ( { Name: "group_sort_order", Unique: false, - Columns: []*schema.Column{GroupsColumns[29]}, + Columns: []*schema.Column{GroupsColumns[30]}, + }, + }, + } + // IdempotencyRecordsColumns holds the columns for the "idempotency_records" table. + IdempotencyRecordsColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "scope", Type: field.TypeString, Size: 128}, + {Name: "idempotency_key_hash", Type: field.TypeString, Size: 64}, + {Name: "request_fingerprint", Type: field.TypeString, Size: 64}, + {Name: "status", Type: field.TypeString, Size: 32}, + {Name: "response_status", Type: field.TypeInt, Nullable: true}, + {Name: "response_body", Type: field.TypeString, Nullable: true}, + {Name: "error_reason", Type: field.TypeString, Nullable: true, Size: 128}, + {Name: "locked_until", Type: field.TypeTime, Nullable: true}, + {Name: "expires_at", Type: field.TypeTime}, + } + // IdempotencyRecordsTable holds the schema information for the "idempotency_records" table. + IdempotencyRecordsTable = &schema.Table{ + Name: "idempotency_records", + Columns: IdempotencyRecordsColumns, + PrimaryKey: []*schema.Column{IdempotencyRecordsColumns[0]}, + Indexes: []*schema.Index{ + { + Name: "idempotencyrecord_scope_idempotency_key_hash", + Unique: true, + Columns: []*schema.Column{IdempotencyRecordsColumns[3], IdempotencyRecordsColumns[4]}, + }, + { + Name: "idempotencyrecord_expires_at", + Unique: false, + Columns: []*schema.Column{IdempotencyRecordsColumns[11]}, + }, + { + Name: "idempotencyrecord_status_locked_until", + Unique: false, + Columns: []*schema.Column{IdempotencyRecordsColumns[6], IdempotencyRecordsColumns[10]}, }, }, } @@ -771,6 +822,11 @@ var ( Unique: false, Columns: []*schema.Column{UsageLogsColumns[28], UsageLogsColumns[27]}, }, + { + Name: "usagelog_group_id_created_at", + Unique: false, + Columns: []*schema.Column{UsageLogsColumns[30], UsageLogsColumns[27]}, + }, }, } // UsersColumns holds the columns for the "users" table. @@ -790,6 +846,8 @@ var ( {Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, {Name: "totp_enabled", Type: field.TypeBool, Default: false}, {Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true}, + {Name: "sora_storage_quota_bytes", Type: field.TypeInt64, Default: 0}, + {Name: "sora_storage_used_bytes", Type: field.TypeInt64, Default: 0}, } // UsersTable holds the schema information for the "users" table. UsersTable = &schema.Table{ @@ -995,6 +1053,11 @@ var ( Unique: false, Columns: []*schema.Column{UserSubscriptionsColumns[5]}, }, + { + Name: "usersubscription_user_id_status_expires_at", + Unique: false, + Columns: []*schema.Column{UserSubscriptionsColumns[16], UserSubscriptionsColumns[6], UserSubscriptionsColumns[5]}, + }, { Name: "usersubscription_assigned_by", Unique: false, @@ -1021,6 +1084,7 @@ var ( AnnouncementReadsTable, ErrorPassthroughRulesTable, GroupsTable, + IdempotencyRecordsTable, PromoCodesTable, PromoCodeUsagesTable, ProxiesTable, @@ -1066,6 +1130,9 @@ func init() { GroupsTable.Annotation = &entsql.Annotation{ Table: "groups", } + IdempotencyRecordsTable.Annotation = &entsql.Annotation{ + Table: "idempotency_records", + } PromoCodesTable.Annotation = &entsql.Annotation{ Table: "promo_codes", } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 7d5bf180..823cd389 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -19,6 +19,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" @@ -52,6 +53,7 @@ const ( TypeAnnouncementRead = "AnnouncementRead" TypeErrorPassthroughRule = "ErrorPassthroughRule" TypeGroup = "Group" + TypeIdempotencyRecord = "IdempotencyRecord" TypePromoCode = "PromoCode" TypePromoCodeUsage = "PromoCodeUsage" TypeProxy = "Proxy" @@ -1503,48 +1505,50 @@ func (m *APIKeyMutation) ResetEdge(name string) error { // AccountMutation represents an operation that mutates the Account nodes in the graph. type AccountMutation struct { config - op Op - typ string - id *int64 - created_at *time.Time - updated_at *time.Time - deleted_at *time.Time - name *string - notes *string - platform *string - _type *string - credentials *map[string]interface{} - extra *map[string]interface{} - concurrency *int - addconcurrency *int - priority *int - addpriority *int - rate_multiplier *float64 - addrate_multiplier *float64 - status *string - error_message *string - last_used_at *time.Time - expires_at *time.Time - auto_pause_on_expired *bool - schedulable *bool - rate_limited_at *time.Time - rate_limit_reset_at *time.Time - overload_until *time.Time - session_window_start *time.Time - session_window_end *time.Time - session_window_status *string - clearedFields map[string]struct{} - groups map[int64]struct{} - removedgroups map[int64]struct{} - clearedgroups bool - proxy *int64 - clearedproxy bool - usage_logs map[int64]struct{} - removedusage_logs map[int64]struct{} - clearedusage_logs bool - done bool - oldValue func(context.Context) (*Account, error) - predicates []predicate.Account + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + deleted_at *time.Time + name *string + notes *string + platform *string + _type *string + credentials *map[string]interface{} + extra *map[string]interface{} + concurrency *int + addconcurrency *int + priority *int + addpriority *int + rate_multiplier *float64 + addrate_multiplier *float64 + status *string + error_message *string + last_used_at *time.Time + expires_at *time.Time + auto_pause_on_expired *bool + schedulable *bool + rate_limited_at *time.Time + rate_limit_reset_at *time.Time + overload_until *time.Time + temp_unschedulable_until *time.Time + temp_unschedulable_reason *string + session_window_start *time.Time + session_window_end *time.Time + session_window_status *string + clearedFields map[string]struct{} + groups map[int64]struct{} + removedgroups map[int64]struct{} + clearedgroups bool + proxy *int64 + clearedproxy bool + usage_logs map[int64]struct{} + removedusage_logs map[int64]struct{} + clearedusage_logs bool + done bool + oldValue func(context.Context) (*Account, error) + predicates []predicate.Account } var _ ent.Mutation = (*AccountMutation)(nil) @@ -2614,6 +2618,104 @@ func (m *AccountMutation) ResetOverloadUntil() { delete(m.clearedFields, account.FieldOverloadUntil) } +// SetTempUnschedulableUntil sets the "temp_unschedulable_until" field. +func (m *AccountMutation) SetTempUnschedulableUntil(t time.Time) { + m.temp_unschedulable_until = &t +} + +// TempUnschedulableUntil returns the value of the "temp_unschedulable_until" field in the mutation. +func (m *AccountMutation) TempUnschedulableUntil() (r time.Time, exists bool) { + v := m.temp_unschedulable_until + if v == nil { + return + } + return *v, true +} + +// OldTempUnschedulableUntil returns the old "temp_unschedulable_until" field's value of the Account entity. +// If the Account object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AccountMutation) OldTempUnschedulableUntil(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTempUnschedulableUntil is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTempUnschedulableUntil requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTempUnschedulableUntil: %w", err) + } + return oldValue.TempUnschedulableUntil, nil +} + +// ClearTempUnschedulableUntil clears the value of the "temp_unschedulable_until" field. +func (m *AccountMutation) ClearTempUnschedulableUntil() { + m.temp_unschedulable_until = nil + m.clearedFields[account.FieldTempUnschedulableUntil] = struct{}{} +} + +// TempUnschedulableUntilCleared returns if the "temp_unschedulable_until" field was cleared in this mutation. +func (m *AccountMutation) TempUnschedulableUntilCleared() bool { + _, ok := m.clearedFields[account.FieldTempUnschedulableUntil] + return ok +} + +// ResetTempUnschedulableUntil resets all changes to the "temp_unschedulable_until" field. +func (m *AccountMutation) ResetTempUnschedulableUntil() { + m.temp_unschedulable_until = nil + delete(m.clearedFields, account.FieldTempUnschedulableUntil) +} + +// SetTempUnschedulableReason sets the "temp_unschedulable_reason" field. +func (m *AccountMutation) SetTempUnschedulableReason(s string) { + m.temp_unschedulable_reason = &s +} + +// TempUnschedulableReason returns the value of the "temp_unschedulable_reason" field in the mutation. +func (m *AccountMutation) TempUnschedulableReason() (r string, exists bool) { + v := m.temp_unschedulable_reason + if v == nil { + return + } + return *v, true +} + +// OldTempUnschedulableReason returns the old "temp_unschedulable_reason" field's value of the Account entity. +// If the Account object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *AccountMutation) OldTempUnschedulableReason(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldTempUnschedulableReason is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldTempUnschedulableReason requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldTempUnschedulableReason: %w", err) + } + return oldValue.TempUnschedulableReason, nil +} + +// ClearTempUnschedulableReason clears the value of the "temp_unschedulable_reason" field. +func (m *AccountMutation) ClearTempUnschedulableReason() { + m.temp_unschedulable_reason = nil + m.clearedFields[account.FieldTempUnschedulableReason] = struct{}{} +} + +// TempUnschedulableReasonCleared returns if the "temp_unschedulable_reason" field was cleared in this mutation. +func (m *AccountMutation) TempUnschedulableReasonCleared() bool { + _, ok := m.clearedFields[account.FieldTempUnschedulableReason] + return ok +} + +// ResetTempUnschedulableReason resets all changes to the "temp_unschedulable_reason" field. +func (m *AccountMutation) ResetTempUnschedulableReason() { + m.temp_unschedulable_reason = nil + delete(m.clearedFields, account.FieldTempUnschedulableReason) +} + // SetSessionWindowStart sets the "session_window_start" field. func (m *AccountMutation) SetSessionWindowStart(t time.Time) { m.session_window_start = &t @@ -2930,7 +3032,7 @@ func (m *AccountMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *AccountMutation) Fields() []string { - fields := make([]string, 0, 25) + fields := make([]string, 0, 27) if m.created_at != nil { fields = append(fields, account.FieldCreatedAt) } @@ -2997,6 +3099,12 @@ func (m *AccountMutation) Fields() []string { if m.overload_until != nil { fields = append(fields, account.FieldOverloadUntil) } + if m.temp_unschedulable_until != nil { + fields = append(fields, account.FieldTempUnschedulableUntil) + } + if m.temp_unschedulable_reason != nil { + fields = append(fields, account.FieldTempUnschedulableReason) + } if m.session_window_start != nil { fields = append(fields, account.FieldSessionWindowStart) } @@ -3058,6 +3166,10 @@ func (m *AccountMutation) Field(name string) (ent.Value, bool) { return m.RateLimitResetAt() case account.FieldOverloadUntil: return m.OverloadUntil() + case account.FieldTempUnschedulableUntil: + return m.TempUnschedulableUntil() + case account.FieldTempUnschedulableReason: + return m.TempUnschedulableReason() case account.FieldSessionWindowStart: return m.SessionWindowStart() case account.FieldSessionWindowEnd: @@ -3117,6 +3229,10 @@ func (m *AccountMutation) OldField(ctx context.Context, name string) (ent.Value, return m.OldRateLimitResetAt(ctx) case account.FieldOverloadUntil: return m.OldOverloadUntil(ctx) + case account.FieldTempUnschedulableUntil: + return m.OldTempUnschedulableUntil(ctx) + case account.FieldTempUnschedulableReason: + return m.OldTempUnschedulableReason(ctx) case account.FieldSessionWindowStart: return m.OldSessionWindowStart(ctx) case account.FieldSessionWindowEnd: @@ -3286,6 +3402,20 @@ func (m *AccountMutation) SetField(name string, value ent.Value) error { } m.SetOverloadUntil(v) return nil + case account.FieldTempUnschedulableUntil: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTempUnschedulableUntil(v) + return nil + case account.FieldTempUnschedulableReason: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetTempUnschedulableReason(v) + return nil case account.FieldSessionWindowStart: v, ok := value.(time.Time) if !ok { @@ -3403,6 +3533,12 @@ func (m *AccountMutation) ClearedFields() []string { if m.FieldCleared(account.FieldOverloadUntil) { fields = append(fields, account.FieldOverloadUntil) } + if m.FieldCleared(account.FieldTempUnschedulableUntil) { + fields = append(fields, account.FieldTempUnschedulableUntil) + } + if m.FieldCleared(account.FieldTempUnschedulableReason) { + fields = append(fields, account.FieldTempUnschedulableReason) + } if m.FieldCleared(account.FieldSessionWindowStart) { fields = append(fields, account.FieldSessionWindowStart) } @@ -3453,6 +3589,12 @@ func (m *AccountMutation) ClearField(name string) error { case account.FieldOverloadUntil: m.ClearOverloadUntil() return nil + case account.FieldTempUnschedulableUntil: + m.ClearTempUnschedulableUntil() + return nil + case account.FieldTempUnschedulableReason: + m.ClearTempUnschedulableReason() + return nil case account.FieldSessionWindowStart: m.ClearSessionWindowStart() return nil @@ -3536,6 +3678,12 @@ func (m *AccountMutation) ResetField(name string) error { case account.FieldOverloadUntil: m.ResetOverloadUntil() return nil + case account.FieldTempUnschedulableUntil: + m.ResetTempUnschedulableUntil() + return nil + case account.FieldTempUnschedulableReason: + m.ResetTempUnschedulableReason() + return nil case account.FieldSessionWindowStart: m.ResetSessionWindowStart() return nil @@ -7186,6 +7334,8 @@ type GroupMutation struct { addsora_video_price_per_request *float64 sora_video_price_per_request_hd *float64 addsora_video_price_per_request_hd *float64 + sora_storage_quota_bytes *int64 + addsora_storage_quota_bytes *int64 claude_code_only *bool fallback_group_id *int64 addfallback_group_id *int64 @@ -8482,6 +8632,62 @@ func (m *GroupMutation) ResetSoraVideoPricePerRequestHd() { delete(m.clearedFields, group.FieldSoraVideoPricePerRequestHd) } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (m *GroupMutation) SetSoraStorageQuotaBytes(i int64) { + m.sora_storage_quota_bytes = &i + m.addsora_storage_quota_bytes = nil +} + +// SoraStorageQuotaBytes returns the value of the "sora_storage_quota_bytes" field in the mutation. +func (m *GroupMutation) SoraStorageQuotaBytes() (r int64, exists bool) { + v := m.sora_storage_quota_bytes + if v == nil { + return + } + return *v, true +} + +// OldSoraStorageQuotaBytes returns the old "sora_storage_quota_bytes" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldSoraStorageQuotaBytes(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraStorageQuotaBytes is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraStorageQuotaBytes requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraStorageQuotaBytes: %w", err) + } + return oldValue.SoraStorageQuotaBytes, nil +} + +// AddSoraStorageQuotaBytes adds i to the "sora_storage_quota_bytes" field. +func (m *GroupMutation) AddSoraStorageQuotaBytes(i int64) { + if m.addsora_storage_quota_bytes != nil { + *m.addsora_storage_quota_bytes += i + } else { + m.addsora_storage_quota_bytes = &i + } +} + +// AddedSoraStorageQuotaBytes returns the value that was added to the "sora_storage_quota_bytes" field in this mutation. +func (m *GroupMutation) AddedSoraStorageQuotaBytes() (r int64, exists bool) { + v := m.addsora_storage_quota_bytes + if v == nil { + return + } + return *v, true +} + +// ResetSoraStorageQuotaBytes resets all changes to the "sora_storage_quota_bytes" field. +func (m *GroupMutation) ResetSoraStorageQuotaBytes() { + m.sora_storage_quota_bytes = nil + m.addsora_storage_quota_bytes = nil +} + // SetClaudeCodeOnly sets the "claude_code_only" field. func (m *GroupMutation) SetClaudeCodeOnly(b bool) { m.claude_code_only = &b @@ -9244,7 +9450,7 @@ func (m *GroupMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 29) + fields := make([]string, 0, 30) if m.created_at != nil { fields = append(fields, group.FieldCreatedAt) } @@ -9308,6 +9514,9 @@ func (m *GroupMutation) Fields() []string { if m.sora_video_price_per_request_hd != nil { fields = append(fields, group.FieldSoraVideoPricePerRequestHd) } + if m.sora_storage_quota_bytes != nil { + fields = append(fields, group.FieldSoraStorageQuotaBytes) + } if m.claude_code_only != nil { fields = append(fields, group.FieldClaudeCodeOnly) } @@ -9382,6 +9591,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.SoraVideoPricePerRequest() case group.FieldSoraVideoPricePerRequestHd: return m.SoraVideoPricePerRequestHd() + case group.FieldSoraStorageQuotaBytes: + return m.SoraStorageQuotaBytes() case group.FieldClaudeCodeOnly: return m.ClaudeCodeOnly() case group.FieldFallbackGroupID: @@ -9449,6 +9660,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldSoraVideoPricePerRequest(ctx) case group.FieldSoraVideoPricePerRequestHd: return m.OldSoraVideoPricePerRequestHd(ctx) + case group.FieldSoraStorageQuotaBytes: + return m.OldSoraStorageQuotaBytes(ctx) case group.FieldClaudeCodeOnly: return m.OldClaudeCodeOnly(ctx) case group.FieldFallbackGroupID: @@ -9621,6 +9834,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetSoraVideoPricePerRequestHd(v) return nil + case group.FieldSoraStorageQuotaBytes: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraStorageQuotaBytes(v) + return nil case group.FieldClaudeCodeOnly: v, ok := value.(bool) if !ok { @@ -9721,6 +9941,9 @@ func (m *GroupMutation) AddedFields() []string { if m.addsora_video_price_per_request_hd != nil { fields = append(fields, group.FieldSoraVideoPricePerRequestHd) } + if m.addsora_storage_quota_bytes != nil { + fields = append(fields, group.FieldSoraStorageQuotaBytes) + } if m.addfallback_group_id != nil { fields = append(fields, group.FieldFallbackGroupID) } @@ -9762,6 +9985,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) { return m.AddedSoraVideoPricePerRequest() case group.FieldSoraVideoPricePerRequestHd: return m.AddedSoraVideoPricePerRequestHd() + case group.FieldSoraStorageQuotaBytes: + return m.AddedSoraStorageQuotaBytes() case group.FieldFallbackGroupID: return m.AddedFallbackGroupID() case group.FieldFallbackGroupIDOnInvalidRequest: @@ -9861,6 +10086,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error { } m.AddSoraVideoPricePerRequestHd(v) return nil + case group.FieldSoraStorageQuotaBytes: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraStorageQuotaBytes(v) + return nil case group.FieldFallbackGroupID: v, ok := value.(int64) if !ok { @@ -10065,6 +10297,9 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldSoraVideoPricePerRequestHd: m.ResetSoraVideoPricePerRequestHd() return nil + case group.FieldSoraStorageQuotaBytes: + m.ResetSoraStorageQuotaBytes() + return nil case group.FieldClaudeCodeOnly: m.ResetClaudeCodeOnly() return nil @@ -10307,6 +10542,988 @@ func (m *GroupMutation) ResetEdge(name string) error { return fmt.Errorf("unknown Group edge %s", name) } +// IdempotencyRecordMutation represents an operation that mutates the IdempotencyRecord nodes in the graph. +type IdempotencyRecordMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + scope *string + idempotency_key_hash *string + request_fingerprint *string + status *string + response_status *int + addresponse_status *int + response_body *string + error_reason *string + locked_until *time.Time + expires_at *time.Time + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*IdempotencyRecord, error) + predicates []predicate.IdempotencyRecord +} + +var _ ent.Mutation = (*IdempotencyRecordMutation)(nil) + +// idempotencyrecordOption allows management of the mutation configuration using functional options. +type idempotencyrecordOption func(*IdempotencyRecordMutation) + +// newIdempotencyRecordMutation creates new mutation for the IdempotencyRecord entity. +func newIdempotencyRecordMutation(c config, op Op, opts ...idempotencyrecordOption) *IdempotencyRecordMutation { + m := &IdempotencyRecordMutation{ + config: c, + op: op, + typ: TypeIdempotencyRecord, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withIdempotencyRecordID sets the ID field of the mutation. +func withIdempotencyRecordID(id int64) idempotencyrecordOption { + return func(m *IdempotencyRecordMutation) { + var ( + err error + once sync.Once + value *IdempotencyRecord + ) + m.oldValue = func(ctx context.Context) (*IdempotencyRecord, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().IdempotencyRecord.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withIdempotencyRecord sets the old IdempotencyRecord of the mutation. +func withIdempotencyRecord(node *IdempotencyRecord) idempotencyrecordOption { + return func(m *IdempotencyRecordMutation) { + m.oldValue = func(context.Context) (*IdempotencyRecord, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m IdempotencyRecordMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m IdempotencyRecordMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *IdempotencyRecordMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *IdempotencyRecordMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().IdempotencyRecord.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *IdempotencyRecordMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *IdempotencyRecordMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *IdempotencyRecordMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *IdempotencyRecordMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *IdempotencyRecordMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *IdempotencyRecordMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetScope sets the "scope" field. +func (m *IdempotencyRecordMutation) SetScope(s string) { + m.scope = &s +} + +// Scope returns the value of the "scope" field in the mutation. +func (m *IdempotencyRecordMutation) Scope() (r string, exists bool) { + v := m.scope + if v == nil { + return + } + return *v, true +} + +// OldScope returns the old "scope" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldScope(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldScope is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldScope requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldScope: %w", err) + } + return oldValue.Scope, nil +} + +// ResetScope resets all changes to the "scope" field. +func (m *IdempotencyRecordMutation) ResetScope() { + m.scope = nil +} + +// SetIdempotencyKeyHash sets the "idempotency_key_hash" field. +func (m *IdempotencyRecordMutation) SetIdempotencyKeyHash(s string) { + m.idempotency_key_hash = &s +} + +// IdempotencyKeyHash returns the value of the "idempotency_key_hash" field in the mutation. +func (m *IdempotencyRecordMutation) IdempotencyKeyHash() (r string, exists bool) { + v := m.idempotency_key_hash + if v == nil { + return + } + return *v, true +} + +// OldIdempotencyKeyHash returns the old "idempotency_key_hash" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldIdempotencyKeyHash(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldIdempotencyKeyHash is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldIdempotencyKeyHash requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldIdempotencyKeyHash: %w", err) + } + return oldValue.IdempotencyKeyHash, nil +} + +// ResetIdempotencyKeyHash resets all changes to the "idempotency_key_hash" field. +func (m *IdempotencyRecordMutation) ResetIdempotencyKeyHash() { + m.idempotency_key_hash = nil +} + +// SetRequestFingerprint sets the "request_fingerprint" field. +func (m *IdempotencyRecordMutation) SetRequestFingerprint(s string) { + m.request_fingerprint = &s +} + +// RequestFingerprint returns the value of the "request_fingerprint" field in the mutation. +func (m *IdempotencyRecordMutation) RequestFingerprint() (r string, exists bool) { + v := m.request_fingerprint + if v == nil { + return + } + return *v, true +} + +// OldRequestFingerprint returns the old "request_fingerprint" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldRequestFingerprint(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRequestFingerprint is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRequestFingerprint requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRequestFingerprint: %w", err) + } + return oldValue.RequestFingerprint, nil +} + +// ResetRequestFingerprint resets all changes to the "request_fingerprint" field. +func (m *IdempotencyRecordMutation) ResetRequestFingerprint() { + m.request_fingerprint = nil +} + +// SetStatus sets the "status" field. +func (m *IdempotencyRecordMutation) SetStatus(s string) { + m.status = &s +} + +// Status returns the value of the "status" field in the mutation. +func (m *IdempotencyRecordMutation) Status() (r string, exists bool) { + v := m.status + if v == nil { + return + } + return *v, true +} + +// OldStatus returns the old "status" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldStatus(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStatus is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStatus requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStatus: %w", err) + } + return oldValue.Status, nil +} + +// ResetStatus resets all changes to the "status" field. +func (m *IdempotencyRecordMutation) ResetStatus() { + m.status = nil +} + +// SetResponseStatus sets the "response_status" field. +func (m *IdempotencyRecordMutation) SetResponseStatus(i int) { + m.response_status = &i + m.addresponse_status = nil +} + +// ResponseStatus returns the value of the "response_status" field in the mutation. +func (m *IdempotencyRecordMutation) ResponseStatus() (r int, exists bool) { + v := m.response_status + if v == nil { + return + } + return *v, true +} + +// OldResponseStatus returns the old "response_status" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldResponseStatus(ctx context.Context) (v *int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldResponseStatus is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldResponseStatus requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldResponseStatus: %w", err) + } + return oldValue.ResponseStatus, nil +} + +// AddResponseStatus adds i to the "response_status" field. +func (m *IdempotencyRecordMutation) AddResponseStatus(i int) { + if m.addresponse_status != nil { + *m.addresponse_status += i + } else { + m.addresponse_status = &i + } +} + +// AddedResponseStatus returns the value that was added to the "response_status" field in this mutation. +func (m *IdempotencyRecordMutation) AddedResponseStatus() (r int, exists bool) { + v := m.addresponse_status + if v == nil { + return + } + return *v, true +} + +// ClearResponseStatus clears the value of the "response_status" field. +func (m *IdempotencyRecordMutation) ClearResponseStatus() { + m.response_status = nil + m.addresponse_status = nil + m.clearedFields[idempotencyrecord.FieldResponseStatus] = struct{}{} +} + +// ResponseStatusCleared returns if the "response_status" field was cleared in this mutation. +func (m *IdempotencyRecordMutation) ResponseStatusCleared() bool { + _, ok := m.clearedFields[idempotencyrecord.FieldResponseStatus] + return ok +} + +// ResetResponseStatus resets all changes to the "response_status" field. +func (m *IdempotencyRecordMutation) ResetResponseStatus() { + m.response_status = nil + m.addresponse_status = nil + delete(m.clearedFields, idempotencyrecord.FieldResponseStatus) +} + +// SetResponseBody sets the "response_body" field. +func (m *IdempotencyRecordMutation) SetResponseBody(s string) { + m.response_body = &s +} + +// ResponseBody returns the value of the "response_body" field in the mutation. +func (m *IdempotencyRecordMutation) ResponseBody() (r string, exists bool) { + v := m.response_body + if v == nil { + return + } + return *v, true +} + +// OldResponseBody returns the old "response_body" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldResponseBody(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldResponseBody is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldResponseBody requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldResponseBody: %w", err) + } + return oldValue.ResponseBody, nil +} + +// ClearResponseBody clears the value of the "response_body" field. +func (m *IdempotencyRecordMutation) ClearResponseBody() { + m.response_body = nil + m.clearedFields[idempotencyrecord.FieldResponseBody] = struct{}{} +} + +// ResponseBodyCleared returns if the "response_body" field was cleared in this mutation. +func (m *IdempotencyRecordMutation) ResponseBodyCleared() bool { + _, ok := m.clearedFields[idempotencyrecord.FieldResponseBody] + return ok +} + +// ResetResponseBody resets all changes to the "response_body" field. +func (m *IdempotencyRecordMutation) ResetResponseBody() { + m.response_body = nil + delete(m.clearedFields, idempotencyrecord.FieldResponseBody) +} + +// SetErrorReason sets the "error_reason" field. +func (m *IdempotencyRecordMutation) SetErrorReason(s string) { + m.error_reason = &s +} + +// ErrorReason returns the value of the "error_reason" field in the mutation. +func (m *IdempotencyRecordMutation) ErrorReason() (r string, exists bool) { + v := m.error_reason + if v == nil { + return + } + return *v, true +} + +// OldErrorReason returns the old "error_reason" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldErrorReason(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldErrorReason is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldErrorReason requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldErrorReason: %w", err) + } + return oldValue.ErrorReason, nil +} + +// ClearErrorReason clears the value of the "error_reason" field. +func (m *IdempotencyRecordMutation) ClearErrorReason() { + m.error_reason = nil + m.clearedFields[idempotencyrecord.FieldErrorReason] = struct{}{} +} + +// ErrorReasonCleared returns if the "error_reason" field was cleared in this mutation. +func (m *IdempotencyRecordMutation) ErrorReasonCleared() bool { + _, ok := m.clearedFields[idempotencyrecord.FieldErrorReason] + return ok +} + +// ResetErrorReason resets all changes to the "error_reason" field. +func (m *IdempotencyRecordMutation) ResetErrorReason() { + m.error_reason = nil + delete(m.clearedFields, idempotencyrecord.FieldErrorReason) +} + +// SetLockedUntil sets the "locked_until" field. +func (m *IdempotencyRecordMutation) SetLockedUntil(t time.Time) { + m.locked_until = &t +} + +// LockedUntil returns the value of the "locked_until" field in the mutation. +func (m *IdempotencyRecordMutation) LockedUntil() (r time.Time, exists bool) { + v := m.locked_until + if v == nil { + return + } + return *v, true +} + +// OldLockedUntil returns the old "locked_until" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldLockedUntil(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldLockedUntil is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldLockedUntil requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldLockedUntil: %w", err) + } + return oldValue.LockedUntil, nil +} + +// ClearLockedUntil clears the value of the "locked_until" field. +func (m *IdempotencyRecordMutation) ClearLockedUntil() { + m.locked_until = nil + m.clearedFields[idempotencyrecord.FieldLockedUntil] = struct{}{} +} + +// LockedUntilCleared returns if the "locked_until" field was cleared in this mutation. +func (m *IdempotencyRecordMutation) LockedUntilCleared() bool { + _, ok := m.clearedFields[idempotencyrecord.FieldLockedUntil] + return ok +} + +// ResetLockedUntil resets all changes to the "locked_until" field. +func (m *IdempotencyRecordMutation) ResetLockedUntil() { + m.locked_until = nil + delete(m.clearedFields, idempotencyrecord.FieldLockedUntil) +} + +// SetExpiresAt sets the "expires_at" field. +func (m *IdempotencyRecordMutation) SetExpiresAt(t time.Time) { + m.expires_at = &t +} + +// ExpiresAt returns the value of the "expires_at" field in the mutation. +func (m *IdempotencyRecordMutation) ExpiresAt() (r time.Time, exists bool) { + v := m.expires_at + if v == nil { + return + } + return *v, true +} + +// OldExpiresAt returns the old "expires_at" field's value of the IdempotencyRecord entity. +// If the IdempotencyRecord object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *IdempotencyRecordMutation) OldExpiresAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldExpiresAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldExpiresAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldExpiresAt: %w", err) + } + return oldValue.ExpiresAt, nil +} + +// ResetExpiresAt resets all changes to the "expires_at" field. +func (m *IdempotencyRecordMutation) ResetExpiresAt() { + m.expires_at = nil +} + +// Where appends a list predicates to the IdempotencyRecordMutation builder. +func (m *IdempotencyRecordMutation) Where(ps ...predicate.IdempotencyRecord) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the IdempotencyRecordMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *IdempotencyRecordMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.IdempotencyRecord, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *IdempotencyRecordMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *IdempotencyRecordMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (IdempotencyRecord). +func (m *IdempotencyRecordMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *IdempotencyRecordMutation) Fields() []string { + fields := make([]string, 0, 11) + if m.created_at != nil { + fields = append(fields, idempotencyrecord.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, idempotencyrecord.FieldUpdatedAt) + } + if m.scope != nil { + fields = append(fields, idempotencyrecord.FieldScope) + } + if m.idempotency_key_hash != nil { + fields = append(fields, idempotencyrecord.FieldIdempotencyKeyHash) + } + if m.request_fingerprint != nil { + fields = append(fields, idempotencyrecord.FieldRequestFingerprint) + } + if m.status != nil { + fields = append(fields, idempotencyrecord.FieldStatus) + } + if m.response_status != nil { + fields = append(fields, idempotencyrecord.FieldResponseStatus) + } + if m.response_body != nil { + fields = append(fields, idempotencyrecord.FieldResponseBody) + } + if m.error_reason != nil { + fields = append(fields, idempotencyrecord.FieldErrorReason) + } + if m.locked_until != nil { + fields = append(fields, idempotencyrecord.FieldLockedUntil) + } + if m.expires_at != nil { + fields = append(fields, idempotencyrecord.FieldExpiresAt) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *IdempotencyRecordMutation) Field(name string) (ent.Value, bool) { + switch name { + case idempotencyrecord.FieldCreatedAt: + return m.CreatedAt() + case idempotencyrecord.FieldUpdatedAt: + return m.UpdatedAt() + case idempotencyrecord.FieldScope: + return m.Scope() + case idempotencyrecord.FieldIdempotencyKeyHash: + return m.IdempotencyKeyHash() + case idempotencyrecord.FieldRequestFingerprint: + return m.RequestFingerprint() + case idempotencyrecord.FieldStatus: + return m.Status() + case idempotencyrecord.FieldResponseStatus: + return m.ResponseStatus() + case idempotencyrecord.FieldResponseBody: + return m.ResponseBody() + case idempotencyrecord.FieldErrorReason: + return m.ErrorReason() + case idempotencyrecord.FieldLockedUntil: + return m.LockedUntil() + case idempotencyrecord.FieldExpiresAt: + return m.ExpiresAt() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *IdempotencyRecordMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case idempotencyrecord.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case idempotencyrecord.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case idempotencyrecord.FieldScope: + return m.OldScope(ctx) + case idempotencyrecord.FieldIdempotencyKeyHash: + return m.OldIdempotencyKeyHash(ctx) + case idempotencyrecord.FieldRequestFingerprint: + return m.OldRequestFingerprint(ctx) + case idempotencyrecord.FieldStatus: + return m.OldStatus(ctx) + case idempotencyrecord.FieldResponseStatus: + return m.OldResponseStatus(ctx) + case idempotencyrecord.FieldResponseBody: + return m.OldResponseBody(ctx) + case idempotencyrecord.FieldErrorReason: + return m.OldErrorReason(ctx) + case idempotencyrecord.FieldLockedUntil: + return m.OldLockedUntil(ctx) + case idempotencyrecord.FieldExpiresAt: + return m.OldExpiresAt(ctx) + } + return nil, fmt.Errorf("unknown IdempotencyRecord field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *IdempotencyRecordMutation) SetField(name string, value ent.Value) error { + switch name { + case idempotencyrecord.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case idempotencyrecord.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case idempotencyrecord.FieldScope: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetScope(v) + return nil + case idempotencyrecord.FieldIdempotencyKeyHash: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetIdempotencyKeyHash(v) + return nil + case idempotencyrecord.FieldRequestFingerprint: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRequestFingerprint(v) + return nil + case idempotencyrecord.FieldStatus: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStatus(v) + return nil + case idempotencyrecord.FieldResponseStatus: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetResponseStatus(v) + return nil + case idempotencyrecord.FieldResponseBody: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetResponseBody(v) + return nil + case idempotencyrecord.FieldErrorReason: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetErrorReason(v) + return nil + case idempotencyrecord.FieldLockedUntil: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetLockedUntil(v) + return nil + case idempotencyrecord.FieldExpiresAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetExpiresAt(v) + return nil + } + return fmt.Errorf("unknown IdempotencyRecord field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *IdempotencyRecordMutation) AddedFields() []string { + var fields []string + if m.addresponse_status != nil { + fields = append(fields, idempotencyrecord.FieldResponseStatus) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *IdempotencyRecordMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case idempotencyrecord.FieldResponseStatus: + return m.AddedResponseStatus() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *IdempotencyRecordMutation) AddField(name string, value ent.Value) error { + switch name { + case idempotencyrecord.FieldResponseStatus: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddResponseStatus(v) + return nil + } + return fmt.Errorf("unknown IdempotencyRecord numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *IdempotencyRecordMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(idempotencyrecord.FieldResponseStatus) { + fields = append(fields, idempotencyrecord.FieldResponseStatus) + } + if m.FieldCleared(idempotencyrecord.FieldResponseBody) { + fields = append(fields, idempotencyrecord.FieldResponseBody) + } + if m.FieldCleared(idempotencyrecord.FieldErrorReason) { + fields = append(fields, idempotencyrecord.FieldErrorReason) + } + if m.FieldCleared(idempotencyrecord.FieldLockedUntil) { + fields = append(fields, idempotencyrecord.FieldLockedUntil) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *IdempotencyRecordMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *IdempotencyRecordMutation) ClearField(name string) error { + switch name { + case idempotencyrecord.FieldResponseStatus: + m.ClearResponseStatus() + return nil + case idempotencyrecord.FieldResponseBody: + m.ClearResponseBody() + return nil + case idempotencyrecord.FieldErrorReason: + m.ClearErrorReason() + return nil + case idempotencyrecord.FieldLockedUntil: + m.ClearLockedUntil() + return nil + } + return fmt.Errorf("unknown IdempotencyRecord nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *IdempotencyRecordMutation) ResetField(name string) error { + switch name { + case idempotencyrecord.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case idempotencyrecord.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case idempotencyrecord.FieldScope: + m.ResetScope() + return nil + case idempotencyrecord.FieldIdempotencyKeyHash: + m.ResetIdempotencyKeyHash() + return nil + case idempotencyrecord.FieldRequestFingerprint: + m.ResetRequestFingerprint() + return nil + case idempotencyrecord.FieldStatus: + m.ResetStatus() + return nil + case idempotencyrecord.FieldResponseStatus: + m.ResetResponseStatus() + return nil + case idempotencyrecord.FieldResponseBody: + m.ResetResponseBody() + return nil + case idempotencyrecord.FieldErrorReason: + m.ResetErrorReason() + return nil + case idempotencyrecord.FieldLockedUntil: + m.ResetLockedUntil() + return nil + case idempotencyrecord.FieldExpiresAt: + m.ResetExpiresAt() + return nil + } + return fmt.Errorf("unknown IdempotencyRecord field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *IdempotencyRecordMutation) AddedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *IdempotencyRecordMutation) AddedIDs(name string) []ent.Value { + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *IdempotencyRecordMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *IdempotencyRecordMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *IdempotencyRecordMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *IdempotencyRecordMutation) EdgeCleared(name string) bool { + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *IdempotencyRecordMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown IdempotencyRecord unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *IdempotencyRecordMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown IdempotencyRecord edge %s", name) +} + // PromoCodeMutation represents an operation that mutates the PromoCode nodes in the graph. type PromoCodeMutation struct { config @@ -19038,6 +20255,10 @@ type UserMutation struct { totp_secret_encrypted *string totp_enabled *bool totp_enabled_at *time.Time + sora_storage_quota_bytes *int64 + addsora_storage_quota_bytes *int64 + sora_storage_used_bytes *int64 + addsora_storage_used_bytes *int64 clearedFields map[string]struct{} api_keys map[int64]struct{} removedapi_keys map[int64]struct{} @@ -19752,6 +20973,118 @@ func (m *UserMutation) ResetTotpEnabledAt() { delete(m.clearedFields, user.FieldTotpEnabledAt) } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (m *UserMutation) SetSoraStorageQuotaBytes(i int64) { + m.sora_storage_quota_bytes = &i + m.addsora_storage_quota_bytes = nil +} + +// SoraStorageQuotaBytes returns the value of the "sora_storage_quota_bytes" field in the mutation. +func (m *UserMutation) SoraStorageQuotaBytes() (r int64, exists bool) { + v := m.sora_storage_quota_bytes + if v == nil { + return + } + return *v, true +} + +// OldSoraStorageQuotaBytes returns the old "sora_storage_quota_bytes" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldSoraStorageQuotaBytes(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraStorageQuotaBytes is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraStorageQuotaBytes requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraStorageQuotaBytes: %w", err) + } + return oldValue.SoraStorageQuotaBytes, nil +} + +// AddSoraStorageQuotaBytes adds i to the "sora_storage_quota_bytes" field. +func (m *UserMutation) AddSoraStorageQuotaBytes(i int64) { + if m.addsora_storage_quota_bytes != nil { + *m.addsora_storage_quota_bytes += i + } else { + m.addsora_storage_quota_bytes = &i + } +} + +// AddedSoraStorageQuotaBytes returns the value that was added to the "sora_storage_quota_bytes" field in this mutation. +func (m *UserMutation) AddedSoraStorageQuotaBytes() (r int64, exists bool) { + v := m.addsora_storage_quota_bytes + if v == nil { + return + } + return *v, true +} + +// ResetSoraStorageQuotaBytes resets all changes to the "sora_storage_quota_bytes" field. +func (m *UserMutation) ResetSoraStorageQuotaBytes() { + m.sora_storage_quota_bytes = nil + m.addsora_storage_quota_bytes = nil +} + +// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. +func (m *UserMutation) SetSoraStorageUsedBytes(i int64) { + m.sora_storage_used_bytes = &i + m.addsora_storage_used_bytes = nil +} + +// SoraStorageUsedBytes returns the value of the "sora_storage_used_bytes" field in the mutation. +func (m *UserMutation) SoraStorageUsedBytes() (r int64, exists bool) { + v := m.sora_storage_used_bytes + if v == nil { + return + } + return *v, true +} + +// OldSoraStorageUsedBytes returns the old "sora_storage_used_bytes" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldSoraStorageUsedBytes(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldSoraStorageUsedBytes is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldSoraStorageUsedBytes requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldSoraStorageUsedBytes: %w", err) + } + return oldValue.SoraStorageUsedBytes, nil +} + +// AddSoraStorageUsedBytes adds i to the "sora_storage_used_bytes" field. +func (m *UserMutation) AddSoraStorageUsedBytes(i int64) { + if m.addsora_storage_used_bytes != nil { + *m.addsora_storage_used_bytes += i + } else { + m.addsora_storage_used_bytes = &i + } +} + +// AddedSoraStorageUsedBytes returns the value that was added to the "sora_storage_used_bytes" field in this mutation. +func (m *UserMutation) AddedSoraStorageUsedBytes() (r int64, exists bool) { + v := m.addsora_storage_used_bytes + if v == nil { + return + } + return *v, true +} + +// ResetSoraStorageUsedBytes resets all changes to the "sora_storage_used_bytes" field. +func (m *UserMutation) ResetSoraStorageUsedBytes() { + m.sora_storage_used_bytes = nil + m.addsora_storage_used_bytes = nil +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. func (m *UserMutation) AddAPIKeyIDs(ids ...int64) { if m.api_keys == nil { @@ -20272,7 +21605,7 @@ func (m *UserMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UserMutation) Fields() []string { - fields := make([]string, 0, 14) + fields := make([]string, 0, 16) if m.created_at != nil { fields = append(fields, user.FieldCreatedAt) } @@ -20315,6 +21648,12 @@ func (m *UserMutation) Fields() []string { if m.totp_enabled_at != nil { fields = append(fields, user.FieldTotpEnabledAt) } + if m.sora_storage_quota_bytes != nil { + fields = append(fields, user.FieldSoraStorageQuotaBytes) + } + if m.sora_storage_used_bytes != nil { + fields = append(fields, user.FieldSoraStorageUsedBytes) + } return fields } @@ -20351,6 +21690,10 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) { return m.TotpEnabled() case user.FieldTotpEnabledAt: return m.TotpEnabledAt() + case user.FieldSoraStorageQuotaBytes: + return m.SoraStorageQuotaBytes() + case user.FieldSoraStorageUsedBytes: + return m.SoraStorageUsedBytes() } return nil, false } @@ -20388,6 +21731,10 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er return m.OldTotpEnabled(ctx) case user.FieldTotpEnabledAt: return m.OldTotpEnabledAt(ctx) + case user.FieldSoraStorageQuotaBytes: + return m.OldSoraStorageQuotaBytes(ctx) + case user.FieldSoraStorageUsedBytes: + return m.OldSoraStorageUsedBytes(ctx) } return nil, fmt.Errorf("unknown User field %s", name) } @@ -20495,6 +21842,20 @@ func (m *UserMutation) SetField(name string, value ent.Value) error { } m.SetTotpEnabledAt(v) return nil + case user.FieldSoraStorageQuotaBytes: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraStorageQuotaBytes(v) + return nil + case user.FieldSoraStorageUsedBytes: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetSoraStorageUsedBytes(v) + return nil } return fmt.Errorf("unknown User field %s", name) } @@ -20509,6 +21870,12 @@ func (m *UserMutation) AddedFields() []string { if m.addconcurrency != nil { fields = append(fields, user.FieldConcurrency) } + if m.addsora_storage_quota_bytes != nil { + fields = append(fields, user.FieldSoraStorageQuotaBytes) + } + if m.addsora_storage_used_bytes != nil { + fields = append(fields, user.FieldSoraStorageUsedBytes) + } return fields } @@ -20521,6 +21888,10 @@ func (m *UserMutation) AddedField(name string) (ent.Value, bool) { return m.AddedBalance() case user.FieldConcurrency: return m.AddedConcurrency() + case user.FieldSoraStorageQuotaBytes: + return m.AddedSoraStorageQuotaBytes() + case user.FieldSoraStorageUsedBytes: + return m.AddedSoraStorageUsedBytes() } return nil, false } @@ -20544,6 +21915,20 @@ func (m *UserMutation) AddField(name string, value ent.Value) error { } m.AddConcurrency(v) return nil + case user.FieldSoraStorageQuotaBytes: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraStorageQuotaBytes(v) + return nil + case user.FieldSoraStorageUsedBytes: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddSoraStorageUsedBytes(v) + return nil } return fmt.Errorf("unknown User numeric field %s", name) } @@ -20634,6 +22019,12 @@ func (m *UserMutation) ResetField(name string) error { case user.FieldTotpEnabledAt: m.ResetTotpEnabledAt() return nil + case user.FieldSoraStorageQuotaBytes: + m.ResetSoraStorageQuotaBytes() + return nil + case user.FieldSoraStorageUsedBytes: + m.ResetSoraStorageUsedBytes() + return nil } return fmt.Errorf("unknown User field %s", name) } diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go index 584b9606..89d933fc 100644 --- a/backend/ent/predicate/predicate.go +++ b/backend/ent/predicate/predicate.go @@ -27,6 +27,9 @@ type ErrorPassthroughRule func(*sql.Selector) // Group is the predicate function for group builders. type Group func(*sql.Selector) +// IdempotencyRecord is the predicate function for idempotencyrecord builders. +type IdempotencyRecord func(*sql.Selector) + // PromoCode is the predicate function for promocode builders. type PromoCode func(*sql.Selector) diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index ff3f8f26..65531aae 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -12,6 +12,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/apikey" "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/ent/idempotencyrecord" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" "github.com/Wei-Shaw/sub2api/ent/proxy" @@ -209,7 +210,7 @@ func init() { // account.DefaultSchedulable holds the default value on creation for the schedulable field. account.DefaultSchedulable = accountDescSchedulable.Default.(bool) // accountDescSessionWindowStatus is the schema descriptor for session_window_status field. - accountDescSessionWindowStatus := accountFields[21].Descriptor() + accountDescSessionWindowStatus := accountFields[23].Descriptor() // account.SessionWindowStatusValidator is a validator for the "session_window_status" field. It is called by the builders before save. account.SessionWindowStatusValidator = accountDescSessionWindowStatus.Validators[0].(func(string) error) accountgroupFields := schema.AccountGroup{}.Fields() @@ -398,26 +399,65 @@ func init() { groupDescDefaultValidityDays := groupFields[10].Descriptor() // group.DefaultDefaultValidityDays holds the default value on creation for the default_validity_days field. group.DefaultDefaultValidityDays = groupDescDefaultValidityDays.Default.(int) + // groupDescSoraStorageQuotaBytes is the schema descriptor for sora_storage_quota_bytes field. + groupDescSoraStorageQuotaBytes := groupFields[18].Descriptor() + // group.DefaultSoraStorageQuotaBytes holds the default value on creation for the sora_storage_quota_bytes field. + group.DefaultSoraStorageQuotaBytes = groupDescSoraStorageQuotaBytes.Default.(int64) // groupDescClaudeCodeOnly is the schema descriptor for claude_code_only field. - groupDescClaudeCodeOnly := groupFields[18].Descriptor() + groupDescClaudeCodeOnly := groupFields[19].Descriptor() // group.DefaultClaudeCodeOnly holds the default value on creation for the claude_code_only field. group.DefaultClaudeCodeOnly = groupDescClaudeCodeOnly.Default.(bool) // groupDescModelRoutingEnabled is the schema descriptor for model_routing_enabled field. - groupDescModelRoutingEnabled := groupFields[22].Descriptor() + groupDescModelRoutingEnabled := groupFields[23].Descriptor() // group.DefaultModelRoutingEnabled holds the default value on creation for the model_routing_enabled field. group.DefaultModelRoutingEnabled = groupDescModelRoutingEnabled.Default.(bool) // groupDescMcpXMLInject is the schema descriptor for mcp_xml_inject field. - groupDescMcpXMLInject := groupFields[23].Descriptor() + groupDescMcpXMLInject := groupFields[24].Descriptor() // group.DefaultMcpXMLInject holds the default value on creation for the mcp_xml_inject field. group.DefaultMcpXMLInject = groupDescMcpXMLInject.Default.(bool) // groupDescSupportedModelScopes is the schema descriptor for supported_model_scopes field. - groupDescSupportedModelScopes := groupFields[24].Descriptor() + groupDescSupportedModelScopes := groupFields[25].Descriptor() // group.DefaultSupportedModelScopes holds the default value on creation for the supported_model_scopes field. group.DefaultSupportedModelScopes = groupDescSupportedModelScopes.Default.([]string) // groupDescSortOrder is the schema descriptor for sort_order field. - groupDescSortOrder := groupFields[25].Descriptor() + groupDescSortOrder := groupFields[26].Descriptor() // group.DefaultSortOrder holds the default value on creation for the sort_order field. group.DefaultSortOrder = groupDescSortOrder.Default.(int) + idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin() + idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields() + _ = idempotencyrecordMixinFields0 + idempotencyrecordFields := schema.IdempotencyRecord{}.Fields() + _ = idempotencyrecordFields + // idempotencyrecordDescCreatedAt is the schema descriptor for created_at field. + idempotencyrecordDescCreatedAt := idempotencyrecordMixinFields0[0].Descriptor() + // idempotencyrecord.DefaultCreatedAt holds the default value on creation for the created_at field. + idempotencyrecord.DefaultCreatedAt = idempotencyrecordDescCreatedAt.Default.(func() time.Time) + // idempotencyrecordDescUpdatedAt is the schema descriptor for updated_at field. + idempotencyrecordDescUpdatedAt := idempotencyrecordMixinFields0[1].Descriptor() + // idempotencyrecord.DefaultUpdatedAt holds the default value on creation for the updated_at field. + idempotencyrecord.DefaultUpdatedAt = idempotencyrecordDescUpdatedAt.Default.(func() time.Time) + // idempotencyrecord.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + idempotencyrecord.UpdateDefaultUpdatedAt = idempotencyrecordDescUpdatedAt.UpdateDefault.(func() time.Time) + // idempotencyrecordDescScope is the schema descriptor for scope field. + idempotencyrecordDescScope := idempotencyrecordFields[0].Descriptor() + // idempotencyrecord.ScopeValidator is a validator for the "scope" field. It is called by the builders before save. + idempotencyrecord.ScopeValidator = idempotencyrecordDescScope.Validators[0].(func(string) error) + // idempotencyrecordDescIdempotencyKeyHash is the schema descriptor for idempotency_key_hash field. + idempotencyrecordDescIdempotencyKeyHash := idempotencyrecordFields[1].Descriptor() + // idempotencyrecord.IdempotencyKeyHashValidator is a validator for the "idempotency_key_hash" field. It is called by the builders before save. + idempotencyrecord.IdempotencyKeyHashValidator = idempotencyrecordDescIdempotencyKeyHash.Validators[0].(func(string) error) + // idempotencyrecordDescRequestFingerprint is the schema descriptor for request_fingerprint field. + idempotencyrecordDescRequestFingerprint := idempotencyrecordFields[2].Descriptor() + // idempotencyrecord.RequestFingerprintValidator is a validator for the "request_fingerprint" field. It is called by the builders before save. + idempotencyrecord.RequestFingerprintValidator = idempotencyrecordDescRequestFingerprint.Validators[0].(func(string) error) + // idempotencyrecordDescStatus is the schema descriptor for status field. + idempotencyrecordDescStatus := idempotencyrecordFields[3].Descriptor() + // idempotencyrecord.StatusValidator is a validator for the "status" field. It is called by the builders before save. + idempotencyrecord.StatusValidator = idempotencyrecordDescStatus.Validators[0].(func(string) error) + // idempotencyrecordDescErrorReason is the schema descriptor for error_reason field. + idempotencyrecordDescErrorReason := idempotencyrecordFields[6].Descriptor() + // idempotencyrecord.ErrorReasonValidator is a validator for the "error_reason" field. It is called by the builders before save. + idempotencyrecord.ErrorReasonValidator = idempotencyrecordDescErrorReason.Validators[0].(func(string) error) promocodeFields := schema.PromoCode{}.Fields() _ = promocodeFields // promocodeDescCode is the schema descriptor for code field. @@ -918,6 +958,14 @@ func init() { userDescTotpEnabled := userFields[9].Descriptor() // user.DefaultTotpEnabled holds the default value on creation for the totp_enabled field. user.DefaultTotpEnabled = userDescTotpEnabled.Default.(bool) + // userDescSoraStorageQuotaBytes is the schema descriptor for sora_storage_quota_bytes field. + userDescSoraStorageQuotaBytes := userFields[11].Descriptor() + // user.DefaultSoraStorageQuotaBytes holds the default value on creation for the sora_storage_quota_bytes field. + user.DefaultSoraStorageQuotaBytes = userDescSoraStorageQuotaBytes.Default.(int64) + // userDescSoraStorageUsedBytes is the schema descriptor for sora_storage_used_bytes field. + userDescSoraStorageUsedBytes := userFields[12].Descriptor() + // user.DefaultSoraStorageUsedBytes holds the default value on creation for the sora_storage_used_bytes field. + user.DefaultSoraStorageUsedBytes = userDescSoraStorageUsedBytes.Default.(int64) userallowedgroupFields := schema.UserAllowedGroup{}.Fields() _ = userallowedgroupFields // userallowedgroupDescCreatedAt is the schema descriptor for created_at field. diff --git a/backend/ent/schema/account.go b/backend/ent/schema/account.go index 1cfecc2d..443f9e09 100644 --- a/backend/ent/schema/account.go +++ b/backend/ent/schema/account.go @@ -164,6 +164,19 @@ func (Account) Fields() []ent.Field { Nillable(). SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + // temp_unschedulable_until: 临时不可调度状态解除时间 + // 当命中临时不可调度规则时设置,在此时间前调度器应跳过该账号 + field.Time("temp_unschedulable_until"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "timestamptz"}), + + // temp_unschedulable_reason: 临时不可调度原因,便于排障审计 + field.String("temp_unschedulable_reason"). + Optional(). + Nillable(). + SchemaType(map[string]string{dialect.Postgres: "text"}), + // session_window_*: 会话窗口相关字段 // 用于管理某些需要会话时间窗口的 API(如 Claude Pro) field.Time("session_window_start"). @@ -213,6 +226,9 @@ func (Account) Indexes() []ent.Index { index.Fields("rate_limited_at"), // 筛选速率限制账户 index.Fields("rate_limit_reset_at"), // 筛选速率限制解除时间 index.Fields("overload_until"), // 筛选过载账户 - index.Fields("deleted_at"), // 软删除查询优化 + // 调度热路径复合索引(线上由 SQL 迁移创建部分索引,schema 仅用于模型可读性对齐) + index.Fields("platform", "priority"), + index.Fields("priority", "status"), + index.Fields("deleted_at"), // 软删除查询优化 } } diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index fddf23ce..3fcf8674 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -105,6 +105,10 @@ func (Group) Fields() []ent.Field { Nillable(). SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}), + // Sora 存储配额 + field.Int64("sora_storage_quota_bytes"). + Default(0), + // Claude Code 客户端限制 (added by migration 029) field.Bool("claude_code_only"). Default(false). diff --git a/backend/ent/schema/usage_log.go b/backend/ent/schema/usage_log.go index ffcae840..dcca1a0a 100644 --- a/backend/ent/schema/usage_log.go +++ b/backend/ent/schema/usage_log.go @@ -179,5 +179,7 @@ func (UsageLog) Indexes() []ent.Index { // 复合索引用于时间范围查询 index.Fields("user_id", "created_at"), index.Fields("api_key_id", "created_at"), + // 分组维度时间范围查询(线上由 SQL 迁移创建 group_id IS NOT NULL 的部分索引) + index.Fields("group_id", "created_at"), } } diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go index d443ef45..0a3b5d9e 100644 --- a/backend/ent/schema/user.go +++ b/backend/ent/schema/user.go @@ -72,6 +72,12 @@ func (User) Fields() []ent.Field { field.Time("totp_enabled_at"). Optional(). Nillable(), + + // Sora 存储配额 + field.Int64("sora_storage_quota_bytes"). + Default(0), + field.Int64("sora_storage_used_bytes"). + Default(0), } } diff --git a/backend/ent/schema/user_subscription.go b/backend/ent/schema/user_subscription.go index fa13612b..a81850b1 100644 --- a/backend/ent/schema/user_subscription.go +++ b/backend/ent/schema/user_subscription.go @@ -108,6 +108,8 @@ func (UserSubscription) Indexes() []ent.Index { index.Fields("group_id"), index.Fields("status"), index.Fields("expires_at"), + // 活跃订阅查询复合索引(线上由 SQL 迁移创建部分索引,schema 仅用于模型可读性对齐) + index.Fields("user_id", "status", "expires_at"), index.Fields("assigned_by"), // 唯一约束通过部分索引实现(WHERE deleted_at IS NULL),支持软删除后重新订阅 // 见迁移文件 016_soft_delete_partial_unique_indexes.sql diff --git a/backend/ent/tx.go b/backend/ent/tx.go index 4fbe9bb4..cd3b2296 100644 --- a/backend/ent/tx.go +++ b/backend/ent/tx.go @@ -28,6 +28,8 @@ type Tx struct { ErrorPassthroughRule *ErrorPassthroughRuleClient // Group is the client for interacting with the Group builders. Group *GroupClient + // IdempotencyRecord is the client for interacting with the IdempotencyRecord builders. + IdempotencyRecord *IdempotencyRecordClient // PromoCode is the client for interacting with the PromoCode builders. PromoCode *PromoCodeClient // PromoCodeUsage is the client for interacting with the PromoCodeUsage builders. @@ -192,6 +194,7 @@ func (tx *Tx) init() { tx.AnnouncementRead = NewAnnouncementReadClient(tx.config) tx.ErrorPassthroughRule = NewErrorPassthroughRuleClient(tx.config) tx.Group = NewGroupClient(tx.config) + tx.IdempotencyRecord = NewIdempotencyRecordClient(tx.config) tx.PromoCode = NewPromoCodeClient(tx.config) tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config) tx.Proxy = NewProxyClient(tx.config) diff --git a/backend/ent/user.go b/backend/ent/user.go index 2435aa1b..b3f933f6 100644 --- a/backend/ent/user.go +++ b/backend/ent/user.go @@ -45,6 +45,10 @@ type User struct { TotpEnabled bool `json:"totp_enabled,omitempty"` // TotpEnabledAt holds the value of the "totp_enabled_at" field. TotpEnabledAt *time.Time `json:"totp_enabled_at,omitempty"` + // SoraStorageQuotaBytes holds the value of the "sora_storage_quota_bytes" field. + SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes,omitempty"` + // SoraStorageUsedBytes holds the value of the "sora_storage_used_bytes" field. + SoraStorageUsedBytes int64 `json:"sora_storage_used_bytes,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the UserQuery when eager-loading is set. Edges UserEdges `json:"edges"` @@ -177,7 +181,7 @@ func (*User) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullBool) case user.FieldBalance: values[i] = new(sql.NullFloat64) - case user.FieldID, user.FieldConcurrency: + case user.FieldID, user.FieldConcurrency, user.FieldSoraStorageQuotaBytes, user.FieldSoraStorageUsedBytes: values[i] = new(sql.NullInt64) case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted: values[i] = new(sql.NullString) @@ -291,6 +295,18 @@ func (_m *User) assignValues(columns []string, values []any) error { _m.TotpEnabledAt = new(time.Time) *_m.TotpEnabledAt = value.Time } + case user.FieldSoraStorageQuotaBytes: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field sora_storage_quota_bytes", values[i]) + } else if value.Valid { + _m.SoraStorageQuotaBytes = value.Int64 + } + case user.FieldSoraStorageUsedBytes: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field sora_storage_used_bytes", values[i]) + } else if value.Valid { + _m.SoraStorageUsedBytes = value.Int64 + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -424,6 +440,12 @@ func (_m *User) String() string { builder.WriteString("totp_enabled_at=") builder.WriteString(v.Format(time.ANSIC)) } + builder.WriteString(", ") + builder.WriteString("sora_storage_quota_bytes=") + builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageQuotaBytes)) + builder.WriteString(", ") + builder.WriteString("sora_storage_used_bytes=") + builder.WriteString(fmt.Sprintf("%v", _m.SoraStorageUsedBytes)) builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go index ae9418ff..155b9160 100644 --- a/backend/ent/user/user.go +++ b/backend/ent/user/user.go @@ -43,6 +43,10 @@ const ( FieldTotpEnabled = "totp_enabled" // FieldTotpEnabledAt holds the string denoting the totp_enabled_at field in the database. FieldTotpEnabledAt = "totp_enabled_at" + // FieldSoraStorageQuotaBytes holds the string denoting the sora_storage_quota_bytes field in the database. + FieldSoraStorageQuotaBytes = "sora_storage_quota_bytes" + // FieldSoraStorageUsedBytes holds the string denoting the sora_storage_used_bytes field in the database. + FieldSoraStorageUsedBytes = "sora_storage_used_bytes" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. EdgeAPIKeys = "api_keys" // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. @@ -152,6 +156,8 @@ var Columns = []string{ FieldTotpSecretEncrypted, FieldTotpEnabled, FieldTotpEnabledAt, + FieldSoraStorageQuotaBytes, + FieldSoraStorageUsedBytes, } var ( @@ -208,6 +214,10 @@ var ( DefaultNotes string // DefaultTotpEnabled holds the default value on creation for the "totp_enabled" field. DefaultTotpEnabled bool + // DefaultSoraStorageQuotaBytes holds the default value on creation for the "sora_storage_quota_bytes" field. + DefaultSoraStorageQuotaBytes int64 + // DefaultSoraStorageUsedBytes holds the default value on creation for the "sora_storage_used_bytes" field. + DefaultSoraStorageUsedBytes int64 ) // OrderOption defines the ordering options for the User queries. @@ -288,6 +298,16 @@ func ByTotpEnabledAt(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldTotpEnabledAt, opts...).ToFunc() } +// BySoraStorageQuotaBytes orders the results by the sora_storage_quota_bytes field. +func BySoraStorageQuotaBytes(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraStorageQuotaBytes, opts...).ToFunc() +} + +// BySoraStorageUsedBytes orders the results by the sora_storage_used_bytes field. +func BySoraStorageUsedBytes(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldSoraStorageUsedBytes, opts...).ToFunc() +} + // ByAPIKeysCount orders the results by api_keys count. func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go index 1de61037..e26afcf3 100644 --- a/backend/ent/user/where.go +++ b/backend/ent/user/where.go @@ -125,6 +125,16 @@ func TotpEnabledAt(v time.Time) predicate.User { return predicate.User(sql.FieldEQ(FieldTotpEnabledAt, v)) } +// SoraStorageQuotaBytes applies equality check predicate on the "sora_storage_quota_bytes" field. It's identical to SoraStorageQuotaBytesEQ. +func SoraStorageQuotaBytes(v int64) predicate.User { + return predicate.User(sql.FieldEQ(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageUsedBytes applies equality check predicate on the "sora_storage_used_bytes" field. It's identical to SoraStorageUsedBytesEQ. +func SoraStorageUsedBytes(v int64) predicate.User { + return predicate.User(sql.FieldEQ(FieldSoraStorageUsedBytes, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.User { return predicate.User(sql.FieldEQ(FieldCreatedAt, v)) @@ -860,6 +870,86 @@ func TotpEnabledAtNotNil() predicate.User { return predicate.User(sql.FieldNotNull(FieldTotpEnabledAt)) } +// SoraStorageQuotaBytesEQ applies the EQ predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesEQ(v int64) predicate.User { + return predicate.User(sql.FieldEQ(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesNEQ applies the NEQ predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesNEQ(v int64) predicate.User { + return predicate.User(sql.FieldNEQ(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesIn applies the In predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesIn(vs ...int64) predicate.User { + return predicate.User(sql.FieldIn(FieldSoraStorageQuotaBytes, vs...)) +} + +// SoraStorageQuotaBytesNotIn applies the NotIn predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesNotIn(vs ...int64) predicate.User { + return predicate.User(sql.FieldNotIn(FieldSoraStorageQuotaBytes, vs...)) +} + +// SoraStorageQuotaBytesGT applies the GT predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesGT(v int64) predicate.User { + return predicate.User(sql.FieldGT(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesGTE applies the GTE predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesGTE(v int64) predicate.User { + return predicate.User(sql.FieldGTE(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesLT applies the LT predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesLT(v int64) predicate.User { + return predicate.User(sql.FieldLT(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageQuotaBytesLTE applies the LTE predicate on the "sora_storage_quota_bytes" field. +func SoraStorageQuotaBytesLTE(v int64) predicate.User { + return predicate.User(sql.FieldLTE(FieldSoraStorageQuotaBytes, v)) +} + +// SoraStorageUsedBytesEQ applies the EQ predicate on the "sora_storage_used_bytes" field. +func SoraStorageUsedBytesEQ(v int64) predicate.User { + return predicate.User(sql.FieldEQ(FieldSoraStorageUsedBytes, v)) +} + +// SoraStorageUsedBytesNEQ applies the NEQ predicate on the "sora_storage_used_bytes" field. +func SoraStorageUsedBytesNEQ(v int64) predicate.User { + return predicate.User(sql.FieldNEQ(FieldSoraStorageUsedBytes, v)) +} + +// SoraStorageUsedBytesIn applies the In predicate on the "sora_storage_used_bytes" field. +func SoraStorageUsedBytesIn(vs ...int64) predicate.User { + return predicate.User(sql.FieldIn(FieldSoraStorageUsedBytes, vs...)) +} + +// SoraStorageUsedBytesNotIn applies the NotIn predicate on the "sora_storage_used_bytes" field. +func SoraStorageUsedBytesNotIn(vs ...int64) predicate.User { + return predicate.User(sql.FieldNotIn(FieldSoraStorageUsedBytes, vs...)) +} + +// SoraStorageUsedBytesGT applies the GT predicate on the "sora_storage_used_bytes" field. +func SoraStorageUsedBytesGT(v int64) predicate.User { + return predicate.User(sql.FieldGT(FieldSoraStorageUsedBytes, v)) +} + +// SoraStorageUsedBytesGTE applies the GTE predicate on the "sora_storage_used_bytes" field. +func SoraStorageUsedBytesGTE(v int64) predicate.User { + return predicate.User(sql.FieldGTE(FieldSoraStorageUsedBytes, v)) +} + +// SoraStorageUsedBytesLT applies the LT predicate on the "sora_storage_used_bytes" field. +func SoraStorageUsedBytesLT(v int64) predicate.User { + return predicate.User(sql.FieldLT(FieldSoraStorageUsedBytes, v)) +} + +// SoraStorageUsedBytesLTE applies the LTE predicate on the "sora_storage_used_bytes" field. +func SoraStorageUsedBytesLTE(v int64) predicate.User { + return predicate.User(sql.FieldLTE(FieldSoraStorageUsedBytes, v)) +} + // HasAPIKeys applies the HasEdge predicate on the "api_keys" edge. func HasAPIKeys() predicate.User { return predicate.User(func(s *sql.Selector) { diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go index f862a580..df0c6bcc 100644 --- a/backend/ent/user_create.go +++ b/backend/ent/user_create.go @@ -210,6 +210,34 @@ func (_c *UserCreate) SetNillableTotpEnabledAt(v *time.Time) *UserCreate { return _c } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (_c *UserCreate) SetSoraStorageQuotaBytes(v int64) *UserCreate { + _c.mutation.SetSoraStorageQuotaBytes(v) + return _c +} + +// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. +func (_c *UserCreate) SetNillableSoraStorageQuotaBytes(v *int64) *UserCreate { + if v != nil { + _c.SetSoraStorageQuotaBytes(*v) + } + return _c +} + +// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. +func (_c *UserCreate) SetSoraStorageUsedBytes(v int64) *UserCreate { + _c.mutation.SetSoraStorageUsedBytes(v) + return _c +} + +// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil. +func (_c *UserCreate) SetNillableSoraStorageUsedBytes(v *int64) *UserCreate { + if v != nil { + _c.SetSoraStorageUsedBytes(*v) + } + return _c +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate { _c.mutation.AddAPIKeyIDs(ids...) @@ -424,6 +452,14 @@ func (_c *UserCreate) defaults() error { v := user.DefaultTotpEnabled _c.mutation.SetTotpEnabled(v) } + if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok { + v := user.DefaultSoraStorageQuotaBytes + _c.mutation.SetSoraStorageQuotaBytes(v) + } + if _, ok := _c.mutation.SoraStorageUsedBytes(); !ok { + v := user.DefaultSoraStorageUsedBytes + _c.mutation.SetSoraStorageUsedBytes(v) + } return nil } @@ -487,6 +523,12 @@ func (_c *UserCreate) check() error { if _, ok := _c.mutation.TotpEnabled(); !ok { return &ValidationError{Name: "totp_enabled", err: errors.New(`ent: missing required field "User.totp_enabled"`)} } + if _, ok := _c.mutation.SoraStorageQuotaBytes(); !ok { + return &ValidationError{Name: "sora_storage_quota_bytes", err: errors.New(`ent: missing required field "User.sora_storage_quota_bytes"`)} + } + if _, ok := _c.mutation.SoraStorageUsedBytes(); !ok { + return &ValidationError{Name: "sora_storage_used_bytes", err: errors.New(`ent: missing required field "User.sora_storage_used_bytes"`)} + } return nil } @@ -570,6 +612,14 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { _spec.SetField(user.FieldTotpEnabledAt, field.TypeTime, value) _node.TotpEnabledAt = &value } + if value, ok := _c.mutation.SoraStorageQuotaBytes(); ok { + _spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + _node.SoraStorageQuotaBytes = value + } + if value, ok := _c.mutation.SoraStorageUsedBytes(); ok { + _spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value) + _node.SoraStorageUsedBytes = value + } if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -956,6 +1006,42 @@ func (u *UserUpsert) ClearTotpEnabledAt() *UserUpsert { return u } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (u *UserUpsert) SetSoraStorageQuotaBytes(v int64) *UserUpsert { + u.Set(user.FieldSoraStorageQuotaBytes, v) + return u +} + +// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. +func (u *UserUpsert) UpdateSoraStorageQuotaBytes() *UserUpsert { + u.SetExcluded(user.FieldSoraStorageQuotaBytes) + return u +} + +// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. +func (u *UserUpsert) AddSoraStorageQuotaBytes(v int64) *UserUpsert { + u.Add(user.FieldSoraStorageQuotaBytes, v) + return u +} + +// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. +func (u *UserUpsert) SetSoraStorageUsedBytes(v int64) *UserUpsert { + u.Set(user.FieldSoraStorageUsedBytes, v) + return u +} + +// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create. +func (u *UserUpsert) UpdateSoraStorageUsedBytes() *UserUpsert { + u.SetExcluded(user.FieldSoraStorageUsedBytes) + return u +} + +// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field. +func (u *UserUpsert) AddSoraStorageUsedBytes(v int64) *UserUpsert { + u.Add(user.FieldSoraStorageUsedBytes, v) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -1218,6 +1304,48 @@ func (u *UserUpsertOne) ClearTotpEnabledAt() *UserUpsertOne { }) } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (u *UserUpsertOne) SetSoraStorageQuotaBytes(v int64) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetSoraStorageQuotaBytes(v) + }) +} + +// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. +func (u *UserUpsertOne) AddSoraStorageQuotaBytes(v int64) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.AddSoraStorageQuotaBytes(v) + }) +} + +// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateSoraStorageQuotaBytes() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateSoraStorageQuotaBytes() + }) +} + +// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. +func (u *UserUpsertOne) SetSoraStorageUsedBytes(v int64) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetSoraStorageUsedBytes(v) + }) +} + +// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field. +func (u *UserUpsertOne) AddSoraStorageUsedBytes(v int64) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.AddSoraStorageUsedBytes(v) + }) +} + +// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateSoraStorageUsedBytes() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateSoraStorageUsedBytes() + }) +} + // Exec executes the query. func (u *UserUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -1646,6 +1774,48 @@ func (u *UserUpsertBulk) ClearTotpEnabledAt() *UserUpsertBulk { }) } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (u *UserUpsertBulk) SetSoraStorageQuotaBytes(v int64) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetSoraStorageQuotaBytes(v) + }) +} + +// AddSoraStorageQuotaBytes adds v to the "sora_storage_quota_bytes" field. +func (u *UserUpsertBulk) AddSoraStorageQuotaBytes(v int64) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.AddSoraStorageQuotaBytes(v) + }) +} + +// UpdateSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateSoraStorageQuotaBytes() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateSoraStorageQuotaBytes() + }) +} + +// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. +func (u *UserUpsertBulk) SetSoraStorageUsedBytes(v int64) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetSoraStorageUsedBytes(v) + }) +} + +// AddSoraStorageUsedBytes adds v to the "sora_storage_used_bytes" field. +func (u *UserUpsertBulk) AddSoraStorageUsedBytes(v int64) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.AddSoraStorageUsedBytes(v) + }) +} + +// UpdateSoraStorageUsedBytes sets the "sora_storage_used_bytes" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateSoraStorageUsedBytes() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateSoraStorageUsedBytes() + }) +} + // Exec executes the query. func (u *UserUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go index 80222c92..f71f0cad 100644 --- a/backend/ent/user_update.go +++ b/backend/ent/user_update.go @@ -242,6 +242,48 @@ func (_u *UserUpdate) ClearTotpEnabledAt() *UserUpdate { return _u } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (_u *UserUpdate) SetSoraStorageQuotaBytes(v int64) *UserUpdate { + _u.mutation.ResetSoraStorageQuotaBytes() + _u.mutation.SetSoraStorageQuotaBytes(v) + return _u +} + +// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. +func (_u *UserUpdate) SetNillableSoraStorageQuotaBytes(v *int64) *UserUpdate { + if v != nil { + _u.SetSoraStorageQuotaBytes(*v) + } + return _u +} + +// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field. +func (_u *UserUpdate) AddSoraStorageQuotaBytes(v int64) *UserUpdate { + _u.mutation.AddSoraStorageQuotaBytes(v) + return _u +} + +// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. +func (_u *UserUpdate) SetSoraStorageUsedBytes(v int64) *UserUpdate { + _u.mutation.ResetSoraStorageUsedBytes() + _u.mutation.SetSoraStorageUsedBytes(v) + return _u +} + +// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil. +func (_u *UserUpdate) SetNillableSoraStorageUsedBytes(v *int64) *UserUpdate { + if v != nil { + _u.SetSoraStorageUsedBytes(*v) + } + return _u +} + +// AddSoraStorageUsedBytes adds value to the "sora_storage_used_bytes" field. +func (_u *UserUpdate) AddSoraStorageUsedBytes(v int64) *UserUpdate { + _u.mutation.AddSoraStorageUsedBytes(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate { _u.mutation.AddAPIKeyIDs(ids...) @@ -709,6 +751,18 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) { if _u.mutation.TotpEnabledAtCleared() { _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime) } + if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok { + _spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok { + _spec.AddField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.SoraStorageUsedBytes(); ok { + _spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedSoraStorageUsedBytes(); ok { + _spec.AddField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1352,6 +1406,48 @@ func (_u *UserUpdateOne) ClearTotpEnabledAt() *UserUpdateOne { return _u } +// SetSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field. +func (_u *UserUpdateOne) SetSoraStorageQuotaBytes(v int64) *UserUpdateOne { + _u.mutation.ResetSoraStorageQuotaBytes() + _u.mutation.SetSoraStorageQuotaBytes(v) + return _u +} + +// SetNillableSoraStorageQuotaBytes sets the "sora_storage_quota_bytes" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableSoraStorageQuotaBytes(v *int64) *UserUpdateOne { + if v != nil { + _u.SetSoraStorageQuotaBytes(*v) + } + return _u +} + +// AddSoraStorageQuotaBytes adds value to the "sora_storage_quota_bytes" field. +func (_u *UserUpdateOne) AddSoraStorageQuotaBytes(v int64) *UserUpdateOne { + _u.mutation.AddSoraStorageQuotaBytes(v) + return _u +} + +// SetSoraStorageUsedBytes sets the "sora_storage_used_bytes" field. +func (_u *UserUpdateOne) SetSoraStorageUsedBytes(v int64) *UserUpdateOne { + _u.mutation.ResetSoraStorageUsedBytes() + _u.mutation.SetSoraStorageUsedBytes(v) + return _u +} + +// SetNillableSoraStorageUsedBytes sets the "sora_storage_used_bytes" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableSoraStorageUsedBytes(v *int64) *UserUpdateOne { + if v != nil { + _u.SetSoraStorageUsedBytes(*v) + } + return _u +} + +// AddSoraStorageUsedBytes adds value to the "sora_storage_used_bytes" field. +func (_u *UserUpdateOne) AddSoraStorageUsedBytes(v int64) *UserUpdateOne { + _u.mutation.AddSoraStorageUsedBytes(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne { _u.mutation.AddAPIKeyIDs(ids...) @@ -1849,6 +1945,18 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { if _u.mutation.TotpEnabledAtCleared() { _spec.ClearField(user.FieldTotpEnabledAt, field.TypeTime) } + if value, ok := _u.mutation.SoraStorageQuotaBytes(); ok { + _spec.SetField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedSoraStorageQuotaBytes(); ok { + _spec.AddField(user.FieldSoraStorageQuotaBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.SoraStorageUsedBytes(); ok { + _spec.SetField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedSoraStorageUsedBytes(); ok { + _spec.AddField(user.FieldSoraStorageUsedBytes, field.TypeInt64, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, diff --git a/backend/go.mod b/backend/go.mod index 0adddadf..08c4e26f 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -7,7 +7,11 @@ require ( github.com/DATA-DOG/go-sqlmock v1.5.2 github.com/DouDOU-start/go-sora2api v1.1.0 github.com/alitto/pond/v2 v2.6.2 + github.com/aws/aws-sdk-go-v2/config v1.32.10 + github.com/aws/aws-sdk-go-v2/credentials v1.19.10 + github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2 github.com/cespare/xxhash/v2 v2.3.0 + github.com/coder/websocket v1.8.14 github.com/dgraph-io/ristretto v0.2.0 github.com/gin-gonic/gin v1.9.1 github.com/golang-jwt/jwt/v5 v5.2.2 @@ -34,6 +38,8 @@ require ( golang.org/x/net v0.49.0 golang.org/x/sync v0.19.0 golang.org/x/term v0.40.0 + google.golang.org/grpc v1.75.1 + google.golang.org/protobuf v1.36.10 gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/yaml.v3 v3.0.1 modernc.org/sqlite v1.44.3 @@ -47,6 +53,22 @@ require ( github.com/agext/levenshtein v1.2.3 // indirect github.com/andybalholm/brotli v1.2.0 // indirect github.com/apparentlymart/go-textseg/v15 v15.0.0 // indirect + github.com/aws/aws-sdk-go-v2 v1.41.2 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.18 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.5 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.10 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.18 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.18 // indirect + github.com/aws/aws-sdk-go-v2/service/signin v1.0.6 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 // indirect + github.com/aws/smithy-go v1.24.1 // indirect github.com/bdandy/go-errors v1.2.2 // indirect github.com/bdandy/go-socks4 v1.2.3 // indirect github.com/bmatcuk/doublestar v1.3.4 // indirect @@ -146,7 +168,6 @@ require ( go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect go.opentelemetry.io/otel v1.37.0 // indirect go.opentelemetry.io/otel/metric v1.37.0 // indirect - go.opentelemetry.io/otel/sdk v1.37.0 // indirect go.opentelemetry.io/otel/trace v1.37.0 // indirect go.uber.org/atomic v1.10.0 // indirect go.uber.org/automaxprocs v1.6.0 // indirect @@ -156,8 +177,7 @@ require ( golang.org/x/mod v0.32.0 // indirect golang.org/x/sys v0.41.0 // indirect golang.org/x/text v0.34.0 // indirect - google.golang.org/grpc v1.75.1 // indirect - google.golang.org/protobuf v1.36.10 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20250929231259-57b25ae835d4 // indirect gopkg.in/ini.v1 v1.67.0 // indirect modernc.org/libc v1.67.6 // indirect modernc.org/mathutil v1.7.1 // indirect diff --git a/backend/go.sum b/backend/go.sum index efe6c145..98914a83 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -22,6 +22,44 @@ github.com/andybalholm/brotli v1.2.0 h1:ukwgCxwYrmACq68yiUqwIWnGY0cTPox/M94sVwTo github.com/andybalholm/brotli v1.2.0/go.mod h1:rzTDkvFWvIrjDXZHkuS16NPggd91W3kUSvPlQ1pLaKY= github.com/apparentlymart/go-textseg/v15 v15.0.0 h1:uYvfpb3DyLSCGWnctWKGj857c6ew1u1fNQOlOtuGxQY= github.com/apparentlymart/go-textseg/v15 v15.0.0/go.mod h1:K8XmNZdhEBkdlyDdvbmmsvpAG721bKi0joRfFdHIWJ4= +github.com/aws/aws-sdk-go-v2 v1.41.2 h1:LuT2rzqNQsauaGkPK/7813XxcZ3o3yePY0Iy891T2ls= +github.com/aws/aws-sdk-go-v2 v1.41.2/go.mod h1:IvvlAZQXvTXznUPfRVfryiG1fbzE2NGK6m9u39YQ+S4= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5 h1:zWFmPmgw4sveAYi1mRqG+E/g0461cJ5M4bJ8/nc6d3Q= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.5/go.mod h1:nVUlMLVV8ycXSb7mSkcNu9e3v/1TJq2RTlrPwhYWr5c= +github.com/aws/aws-sdk-go-v2/config v1.32.10 h1:9DMthfO6XWZYLfzZglAgW5Fyou2nRI5CuV44sTedKBI= +github.com/aws/aws-sdk-go-v2/config v1.32.10/go.mod h1:2rUIOnA2JaiqYmSKYmRJlcMWy6qTj1vuRFscppSBMcw= +github.com/aws/aws-sdk-go-v2/credentials v1.19.10 h1:EEhmEUFCE1Yhl7vDhNOI5OCL/iKMdkkYFTRpZXNw7m8= +github.com/aws/aws-sdk-go-v2/credentials v1.19.10/go.mod h1:RnnlFCAlxQCkN2Q379B67USkBMu1PipEEiibzYN5UTE= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18 h1:Ii4s+Sq3yDfaMLpjrJsqD6SmG/Wq/P5L/hw2qa78UAY= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.18/go.mod h1:6x81qnY++ovptLE6nWQeWrpXxbnlIex+4H4eYYGcqfc= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18 h1:F43zk1vemYIqPAwhjTjYIz0irU2EY7sOb/F5eJ3HuyM= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.18/go.mod h1:w1jdlZXrGKaJcNoL+Nnrj+k5wlpGXqnNrKoP22HvAug= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18 h1:xCeWVjj0ki0l3nruoyP2slHsGArMxeiiaoPN5QZH6YQ= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.18/go.mod h1:r/eLGuGCBw6l36ZRWiw6PaZwPXb6YOj+i/7MizNl5/k= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.18 h1:eZioDaZGJ0tMM4gzmkNIO2aAoQd+je7Ug7TkvAzlmkU= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.18/go.mod h1:CCXwUKAJdoWr6/NcxZ+zsiPr6oH/Q5aTooRGYieAyj4= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.5 h1:CeY9LUdur+Dxoeldqoun6y4WtJ3RQtzk0JMP2gfUay0= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.5/go.mod h1:AZLZf2fMaahW5s/wMRciu1sYbdsikT/UHwbUjOdEVTc= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.10 h1:fJvQ5mIBVfKtiyx0AHY6HeWcRX5LGANLpq8SVR+Uazs= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.10/go.mod h1:Kzm5e6OmNH8VMkgK9t+ry5jEih4Y8whqs+1hrkxim1I= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.18 h1:LTRCYFlnnKFlKsyIQxKhJuDuA3ZkrDQMRYm6rXiHlLY= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.18/go.mod h1:XhwkgGG6bHSd00nO/mexWTcTjgd6PjuvWQMqSn2UaEk= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.18 h1:/A/xDuZAVD2BpsS2fftFRo/NoEKQJ8YTnJDEHBy2Gtg= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.18/go.mod h1:hWe9b4f+djUQGmyiGEeOnZv69dtMSgpDRIvNMvuvzvY= +github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2 h1:M1A9AjcFwlxTLuf0Faj88L8Iqw0n/AJHjpZTQzMMsSc= +github.com/aws/aws-sdk-go-v2/service/s3 v1.96.2/go.mod h1:KsdTV6Q9WKUZm2mNJnUFmIoXfZux91M3sr/a4REX8e0= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.6 h1:MzORe+J94I+hYu2a6XmV5yC9huoTv8NRcCrUNedDypQ= +github.com/aws/aws-sdk-go-v2/service/signin v1.0.6/go.mod h1:hXzcHLARD7GeWnifd8j9RWqtfIgxj4/cAtIVIK7hg8g= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.11 h1:7oGD8KPfBOJGXiCoRKrrrQkbvCp8N++u36hrLMPey6o= +github.com/aws/aws-sdk-go-v2/service/sso v1.30.11/go.mod h1:0DO9B5EUJQlIDif+XJRWCljZRKsAFKh3gpFz7UnDtOo= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15 h1:edCcNp9eGIUDUCrzoCu1jWAXLGFIizeqkdkKgRlJwWc= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.15/go.mod h1:lyRQKED9xWfgkYC/wmmYfv7iVIM68Z5OQ88ZdcV1QbU= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.7 h1:NITQpgo9A5NrDZ57uOWj+abvXSb83BbyggcUBVksN7c= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.7/go.mod h1:sks5UWBhEuWYDPdwlnRFn1w7xWdH29Jcpe+/PJQefEs= +github.com/aws/smithy-go v1.24.1 h1:VbyeNfmYkWoxMVpGUAbQumkODcYmfMRfZ8yQiH30SK0= +github.com/aws/smithy-go v1.24.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= github.com/bdandy/go-errors v1.2.2 h1:WdFv/oukjTJCLa79UfkGmwX7ZxONAihKu4V0mLIs11Q= github.com/bdandy/go-errors v1.2.2/go.mod h1:NkYHl4Fey9oRRdbB1CoC6e84tuqQHiqrOcZpqFEkBxM= github.com/bdandy/go-socks4 v1.2.3 h1:Q6Y2heY1GRjCtHbmlKfnwrKVU/k81LS8mRGLRlmDlic= @@ -56,6 +94,12 @@ github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= +github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs= +github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA= +github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U= +github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g= +github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= +github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI= github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M= github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE= @@ -127,6 +171,8 @@ github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= @@ -190,6 +236,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw= +github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= @@ -223,6 +271,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= +github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= @@ -274,6 +324,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= +github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= @@ -344,6 +396,8 @@ go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/Wgbsd go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI= go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg= +go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc= +go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps= go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0= @@ -399,6 +453,8 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= +gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ= google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU= google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4/go.mod h1:NnuHhy+bxcg30o7FnVAZbXsPHUDQ9qKWAQKCD7VxFtk= diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index de3251b6..4f6fea37 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -364,6 +364,8 @@ type GatewayConfig struct { // OpenAIPassthroughAllowTimeoutHeaders: OpenAI 透传模式是否放行客户端超时头 // 关闭(默认)可避免 x-stainless-timeout 等头导致上游提前断流。 OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"` + // OpenAIWS: OpenAI Responses WebSocket 配置(默认开启,可按需回滚到 HTTP) + OpenAIWS GatewayOpenAIWSConfig `mapstructure:"openai_ws"` // HTTP 上游连接池配置(性能优化:支持高并发场景调优) // MaxIdleConns: 所有主机的最大空闲连接总数 @@ -450,6 +452,101 @@ type GatewayConfig struct { ModelsListCacheTTLSeconds int `mapstructure:"models_list_cache_ttl_seconds"` } +// GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置。 +// 注意:默认全局开启;如需回滚可使用 force_http 或关闭 enabled。 +type GatewayOpenAIWSConfig struct { + // ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 false;关闭时保持 legacy 行为) + ModeRouterV2Enabled bool `mapstructure:"mode_router_v2_enabled"` + // IngressModeDefault: ingress 默认模式(off/shared/dedicated) + IngressModeDefault string `mapstructure:"ingress_mode_default"` + // Enabled: 全局总开关(默认 true) + Enabled bool `mapstructure:"enabled"` + // OAuthEnabled: 是否允许 OpenAI OAuth 账号使用 WS + OAuthEnabled bool `mapstructure:"oauth_enabled"` + // APIKeyEnabled: 是否允许 OpenAI API Key 账号使用 WS + APIKeyEnabled bool `mapstructure:"apikey_enabled"` + // ForceHTTP: 全局强制 HTTP(用于紧急回滚) + ForceHTTP bool `mapstructure:"force_http"` + // AllowStoreRecovery: 允许在 WSv2 下按策略恢复 store=true(默认 false) + AllowStoreRecovery bool `mapstructure:"allow_store_recovery"` + // IngressPreviousResponseRecoveryEnabled: ingress 模式收到 previous_response_not_found 时,是否允许自动去掉 previous_response_id 重试一次(默认 true) + IngressPreviousResponseRecoveryEnabled bool `mapstructure:"ingress_previous_response_recovery_enabled"` + // StoreDisabledConnMode: store=false 且无可复用会话连接时的建连策略(strict/adaptive/off) + // - strict: 强制新建连接(隔离优先) + // - adaptive: 仅在高风险失败后强制新建连接(性能与隔离折中) + // - off: 不强制新建连接(复用优先) + StoreDisabledConnMode string `mapstructure:"store_disabled_conn_mode"` + // StoreDisabledForceNewConn: store=false 且无可复用粘连连接时是否强制新建连接(默认 true,保障会话隔离) + // 兼容旧配置;当 StoreDisabledConnMode 为空时才生效。 + StoreDisabledForceNewConn bool `mapstructure:"store_disabled_force_new_conn"` + // PrewarmGenerateEnabled: 是否启用 WSv2 generate=false 预热(默认 false) + PrewarmGenerateEnabled bool `mapstructure:"prewarm_generate_enabled"` + + // Feature 开关:v2 优先于 v1 + ResponsesWebsockets bool `mapstructure:"responses_websockets"` + ResponsesWebsocketsV2 bool `mapstructure:"responses_websockets_v2"` + + // 连接池参数 + MaxConnsPerAccount int `mapstructure:"max_conns_per_account"` + MinIdlePerAccount int `mapstructure:"min_idle_per_account"` + MaxIdlePerAccount int `mapstructure:"max_idle_per_account"` + // DynamicMaxConnsByAccountConcurrencyEnabled: 是否按账号并发动态计算连接池上限 + DynamicMaxConnsByAccountConcurrencyEnabled bool `mapstructure:"dynamic_max_conns_by_account_concurrency_enabled"` + // OAuthMaxConnsFactor: OAuth 账号连接池系数(effective=ceil(concurrency*factor)) + OAuthMaxConnsFactor float64 `mapstructure:"oauth_max_conns_factor"` + // APIKeyMaxConnsFactor: API Key 账号连接池系数(effective=ceil(concurrency*factor)) + APIKeyMaxConnsFactor float64 `mapstructure:"apikey_max_conns_factor"` + DialTimeoutSeconds int `mapstructure:"dial_timeout_seconds"` + ReadTimeoutSeconds int `mapstructure:"read_timeout_seconds"` + WriteTimeoutSeconds int `mapstructure:"write_timeout_seconds"` + PoolTargetUtilization float64 `mapstructure:"pool_target_utilization"` + QueueLimitPerConn int `mapstructure:"queue_limit_per_conn"` + // EventFlushBatchSize: WS 流式写出批量 flush 阈值(事件条数) + EventFlushBatchSize int `mapstructure:"event_flush_batch_size"` + // EventFlushIntervalMS: WS 流式写出最大等待时间(毫秒);0 表示仅按 batch 触发 + EventFlushIntervalMS int `mapstructure:"event_flush_interval_ms"` + // PrewarmCooldownMS: 连接池预热触发冷却时间(毫秒) + PrewarmCooldownMS int `mapstructure:"prewarm_cooldown_ms"` + // FallbackCooldownSeconds: WS 回退冷却窗口,避免 WS/HTTP 抖动;0 表示关闭冷却 + FallbackCooldownSeconds int `mapstructure:"fallback_cooldown_seconds"` + // RetryBackoffInitialMS: WS 重试初始退避(毫秒);<=0 表示关闭退避 + RetryBackoffInitialMS int `mapstructure:"retry_backoff_initial_ms"` + // RetryBackoffMaxMS: WS 重试最大退避(毫秒) + RetryBackoffMaxMS int `mapstructure:"retry_backoff_max_ms"` + // RetryJitterRatio: WS 重试退避抖动比例(0-1) + RetryJitterRatio float64 `mapstructure:"retry_jitter_ratio"` + // RetryTotalBudgetMS: WS 单次请求重试总预算(毫秒);0 表示关闭预算限制 + RetryTotalBudgetMS int `mapstructure:"retry_total_budget_ms"` + // PayloadLogSampleRate: payload_schema 日志采样率(0-1) + PayloadLogSampleRate float64 `mapstructure:"payload_log_sample_rate"` + + // 账号调度与粘连参数 + LBTopK int `mapstructure:"lb_top_k"` + // StickySessionTTLSeconds: session_hash -> account_id 粘连 TTL + StickySessionTTLSeconds int `mapstructure:"sticky_session_ttl_seconds"` + // SessionHashReadOldFallback: 会话哈希迁移期是否允许“新 key 未命中时回退读旧 SHA-256 key” + SessionHashReadOldFallback bool `mapstructure:"session_hash_read_old_fallback"` + // SessionHashDualWriteOld: 会话哈希迁移期是否双写旧 SHA-256 key(短 TTL) + SessionHashDualWriteOld bool `mapstructure:"session_hash_dual_write_old"` + // MetadataBridgeEnabled: RequestMetadata 迁移期是否保留旧 ctxkey.* 兼容桥接 + MetadataBridgeEnabled bool `mapstructure:"metadata_bridge_enabled"` + // StickyResponseIDTTLSeconds: response_id -> account_id 粘连 TTL + StickyResponseIDTTLSeconds int `mapstructure:"sticky_response_id_ttl_seconds"` + // StickyPreviousResponseTTLSeconds: 兼容旧键(当新键未设置时回退) + StickyPreviousResponseTTLSeconds int `mapstructure:"sticky_previous_response_ttl_seconds"` + + SchedulerScoreWeights GatewayOpenAIWSSchedulerScoreWeights `mapstructure:"scheduler_score_weights"` +} + +// GatewayOpenAIWSSchedulerScoreWeights 账号调度打分权重。 +type GatewayOpenAIWSSchedulerScoreWeights struct { + Priority float64 `mapstructure:"priority"` + Load float64 `mapstructure:"load"` + Queue float64 `mapstructure:"queue"` + ErrorRate float64 `mapstructure:"error_rate"` + TTFT float64 `mapstructure:"ttft"` +} + // GatewayUsageRecordConfig 使用量记录异步队列配置 type GatewayUsageRecordConfig struct { // WorkerCount: worker 初始数量(自动扩缩容开启时作为初始并发上限) @@ -886,6 +983,12 @@ func load(allowMissingJWTSecret bool) (*Config, error) { cfg.Log.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.Log.StacktraceLevel)) cfg.Log.Output.FilePath = strings.TrimSpace(cfg.Log.Output.FilePath) + // 兼容旧键 gateway.openai_ws.sticky_previous_response_ttl_seconds。 + // 新键未配置(<=0)时回退旧键;新键优先。 + if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 { + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds + } + // Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256) cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey) if cfg.Totp.EncryptionKey == "" { @@ -945,7 +1048,7 @@ func setDefaults() { viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头 viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时 viper.SetDefault("server.trusted_proxies", []string{}) - viper.SetDefault("server.max_request_body_size", int64(100*1024*1024)) + viper.SetDefault("server.max_request_body_size", int64(256*1024*1024)) // H2C 默认配置 viper.SetDefault("server.h2c.enabled", false) viper.SetDefault("server.h2c.max_concurrent_streams", uint32(50)) // 50 个并发流 @@ -1088,9 +1191,9 @@ func setDefaults() { // RateLimit viper.SetDefault("rate_limit.overload_cooldown_minutes", 10) - // Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据的配置 - viper.SetDefault("pricing.remote_url", "https://github.com/Wei-Shaw/model-price-repo/raw/refs/heads/main/model_prices_and_context_window.json") - viper.SetDefault("pricing.hash_url", "https://github.com/Wei-Shaw/model-price-repo/raw/refs/heads/main/model_prices_and_context_window.sha256") + // Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据(固定到 commit,避免分支漂移) + viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.json") + viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.sha256") viper.SetDefault("pricing.data_dir", "./data") viper.SetDefault("pricing.fallback_file", "./resources/model-pricing/model_prices_and_context_window.json") viper.SetDefault("pricing.update_interval_hours", 24) @@ -1157,9 +1260,55 @@ func setDefaults() { viper.SetDefault("gateway.max_account_switches_gemini", 3) viper.SetDefault("gateway.force_codex_cli", false) viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false) + // OpenAI Responses WebSocket(默认开启;可通过 force_http 紧急回滚) + viper.SetDefault("gateway.openai_ws.enabled", true) + viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", false) + viper.SetDefault("gateway.openai_ws.ingress_mode_default", "shared") + viper.SetDefault("gateway.openai_ws.oauth_enabled", true) + viper.SetDefault("gateway.openai_ws.apikey_enabled", true) + viper.SetDefault("gateway.openai_ws.force_http", false) + viper.SetDefault("gateway.openai_ws.allow_store_recovery", false) + viper.SetDefault("gateway.openai_ws.ingress_previous_response_recovery_enabled", true) + viper.SetDefault("gateway.openai_ws.store_disabled_conn_mode", "strict") + viper.SetDefault("gateway.openai_ws.store_disabled_force_new_conn", true) + viper.SetDefault("gateway.openai_ws.prewarm_generate_enabled", false) + viper.SetDefault("gateway.openai_ws.responses_websockets", false) + viper.SetDefault("gateway.openai_ws.responses_websockets_v2", true) + viper.SetDefault("gateway.openai_ws.max_conns_per_account", 128) + viper.SetDefault("gateway.openai_ws.min_idle_per_account", 4) + viper.SetDefault("gateway.openai_ws.max_idle_per_account", 12) + viper.SetDefault("gateway.openai_ws.dynamic_max_conns_by_account_concurrency_enabled", true) + viper.SetDefault("gateway.openai_ws.oauth_max_conns_factor", 1.0) + viper.SetDefault("gateway.openai_ws.apikey_max_conns_factor", 1.0) + viper.SetDefault("gateway.openai_ws.dial_timeout_seconds", 10) + viper.SetDefault("gateway.openai_ws.read_timeout_seconds", 900) + viper.SetDefault("gateway.openai_ws.write_timeout_seconds", 120) + viper.SetDefault("gateway.openai_ws.pool_target_utilization", 0.7) + viper.SetDefault("gateway.openai_ws.queue_limit_per_conn", 64) + viper.SetDefault("gateway.openai_ws.event_flush_batch_size", 1) + viper.SetDefault("gateway.openai_ws.event_flush_interval_ms", 10) + viper.SetDefault("gateway.openai_ws.prewarm_cooldown_ms", 300) + viper.SetDefault("gateway.openai_ws.fallback_cooldown_seconds", 30) + viper.SetDefault("gateway.openai_ws.retry_backoff_initial_ms", 120) + viper.SetDefault("gateway.openai_ws.retry_backoff_max_ms", 2000) + viper.SetDefault("gateway.openai_ws.retry_jitter_ratio", 0.2) + viper.SetDefault("gateway.openai_ws.retry_total_budget_ms", 5000) + viper.SetDefault("gateway.openai_ws.payload_log_sample_rate", 0.2) + viper.SetDefault("gateway.openai_ws.lb_top_k", 7) + viper.SetDefault("gateway.openai_ws.sticky_session_ttl_seconds", 3600) + viper.SetDefault("gateway.openai_ws.session_hash_read_old_fallback", true) + viper.SetDefault("gateway.openai_ws.session_hash_dual_write_old", true) + viper.SetDefault("gateway.openai_ws.metadata_bridge_enabled", true) + viper.SetDefault("gateway.openai_ws.sticky_response_id_ttl_seconds", 3600) + viper.SetDefault("gateway.openai_ws.sticky_previous_response_ttl_seconds", 3600) + viper.SetDefault("gateway.openai_ws.scheduler_score_weights.priority", 1.0) + viper.SetDefault("gateway.openai_ws.scheduler_score_weights.load", 1.0) + viper.SetDefault("gateway.openai_ws.scheduler_score_weights.queue", 0.7) + viper.SetDefault("gateway.openai_ws.scheduler_score_weights.error_rate", 0.8) + viper.SetDefault("gateway.openai_ws.scheduler_score_weights.ttft", 0.5) viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1) viper.SetDefault("gateway.antigravity_extra_retries", 10) - viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) + viper.SetDefault("gateway.max_body_size", int64(256*1024*1024)) viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024)) viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024)) viper.SetDefault("gateway.gemini_debug_response_headers", false) @@ -1747,6 +1896,118 @@ func (c *Config) Validate() error { (c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) { return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds") } + // 兼容旧键 sticky_previous_response_ttl_seconds + if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 { + c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds + } + if c.Gateway.OpenAIWS.MaxConnsPerAccount <= 0 { + return fmt.Errorf("gateway.openai_ws.max_conns_per_account must be positive") + } + if c.Gateway.OpenAIWS.MinIdlePerAccount < 0 { + return fmt.Errorf("gateway.openai_ws.min_idle_per_account must be non-negative") + } + if c.Gateway.OpenAIWS.MaxIdlePerAccount < 0 { + return fmt.Errorf("gateway.openai_ws.max_idle_per_account must be non-negative") + } + if c.Gateway.OpenAIWS.MinIdlePerAccount > c.Gateway.OpenAIWS.MaxIdlePerAccount { + return fmt.Errorf("gateway.openai_ws.min_idle_per_account must be <= max_idle_per_account") + } + if c.Gateway.OpenAIWS.MaxIdlePerAccount > c.Gateway.OpenAIWS.MaxConnsPerAccount { + return fmt.Errorf("gateway.openai_ws.max_idle_per_account must be <= max_conns_per_account") + } + if c.Gateway.OpenAIWS.OAuthMaxConnsFactor <= 0 { + return fmt.Errorf("gateway.openai_ws.oauth_max_conns_factor must be positive") + } + if c.Gateway.OpenAIWS.APIKeyMaxConnsFactor <= 0 { + return fmt.Errorf("gateway.openai_ws.apikey_max_conns_factor must be positive") + } + if c.Gateway.OpenAIWS.DialTimeoutSeconds <= 0 { + return fmt.Errorf("gateway.openai_ws.dial_timeout_seconds must be positive") + } + if c.Gateway.OpenAIWS.ReadTimeoutSeconds <= 0 { + return fmt.Errorf("gateway.openai_ws.read_timeout_seconds must be positive") + } + if c.Gateway.OpenAIWS.WriteTimeoutSeconds <= 0 { + return fmt.Errorf("gateway.openai_ws.write_timeout_seconds must be positive") + } + if c.Gateway.OpenAIWS.PoolTargetUtilization <= 0 || c.Gateway.OpenAIWS.PoolTargetUtilization > 1 { + return fmt.Errorf("gateway.openai_ws.pool_target_utilization must be within (0,1]") + } + if c.Gateway.OpenAIWS.QueueLimitPerConn <= 0 { + return fmt.Errorf("gateway.openai_ws.queue_limit_per_conn must be positive") + } + if c.Gateway.OpenAIWS.EventFlushBatchSize <= 0 { + return fmt.Errorf("gateway.openai_ws.event_flush_batch_size must be positive") + } + if c.Gateway.OpenAIWS.EventFlushIntervalMS < 0 { + return fmt.Errorf("gateway.openai_ws.event_flush_interval_ms must be non-negative") + } + if c.Gateway.OpenAIWS.PrewarmCooldownMS < 0 { + return fmt.Errorf("gateway.openai_ws.prewarm_cooldown_ms must be non-negative") + } + if c.Gateway.OpenAIWS.FallbackCooldownSeconds < 0 { + return fmt.Errorf("gateway.openai_ws.fallback_cooldown_seconds must be non-negative") + } + if c.Gateway.OpenAIWS.RetryBackoffInitialMS < 0 { + return fmt.Errorf("gateway.openai_ws.retry_backoff_initial_ms must be non-negative") + } + if c.Gateway.OpenAIWS.RetryBackoffMaxMS < 0 { + return fmt.Errorf("gateway.openai_ws.retry_backoff_max_ms must be non-negative") + } + if c.Gateway.OpenAIWS.RetryBackoffInitialMS > 0 && c.Gateway.OpenAIWS.RetryBackoffMaxMS > 0 && + c.Gateway.OpenAIWS.RetryBackoffMaxMS < c.Gateway.OpenAIWS.RetryBackoffInitialMS { + return fmt.Errorf("gateway.openai_ws.retry_backoff_max_ms must be >= retry_backoff_initial_ms") + } + if c.Gateway.OpenAIWS.RetryJitterRatio < 0 || c.Gateway.OpenAIWS.RetryJitterRatio > 1 { + return fmt.Errorf("gateway.openai_ws.retry_jitter_ratio must be within [0,1]") + } + if c.Gateway.OpenAIWS.RetryTotalBudgetMS < 0 { + return fmt.Errorf("gateway.openai_ws.retry_total_budget_ms must be non-negative") + } + if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.IngressModeDefault)); mode != "" { + switch mode { + case "off", "shared", "dedicated": + default: + return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|shared|dedicated") + } + } + if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.StoreDisabledConnMode)); mode != "" { + switch mode { + case "strict", "adaptive", "off": + default: + return fmt.Errorf("gateway.openai_ws.store_disabled_conn_mode must be one of strict|adaptive|off") + } + } + if c.Gateway.OpenAIWS.PayloadLogSampleRate < 0 || c.Gateway.OpenAIWS.PayloadLogSampleRate > 1 { + return fmt.Errorf("gateway.openai_ws.payload_log_sample_rate must be within [0,1]") + } + if c.Gateway.OpenAIWS.LBTopK <= 0 { + return fmt.Errorf("gateway.openai_ws.lb_top_k must be positive") + } + if c.Gateway.OpenAIWS.StickySessionTTLSeconds <= 0 { + return fmt.Errorf("gateway.openai_ws.sticky_session_ttl_seconds must be positive") + } + if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 { + return fmt.Errorf("gateway.openai_ws.sticky_response_id_ttl_seconds must be positive") + } + if c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds < 0 { + return fmt.Errorf("gateway.openai_ws.sticky_previous_response_ttl_seconds must be non-negative") + } + if c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority < 0 || + c.Gateway.OpenAIWS.SchedulerScoreWeights.Load < 0 || + c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue < 0 || + c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate < 0 || + c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT < 0 { + return fmt.Errorf("gateway.openai_ws.scheduler_score_weights.* must be non-negative") + } + weightSum := c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority + + c.Gateway.OpenAIWS.SchedulerScoreWeights.Load + + c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue + + c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate + + c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT + if weightSum <= 0 { + return fmt.Errorf("gateway.openai_ws.scheduler_score_weights must not all be zero") + } if c.Gateway.MaxLineSize < 0 { return fmt.Errorf("gateway.max_line_size must be non-negative") } diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index b0402a3b..e3b592e2 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -6,6 +6,7 @@ import ( "time" "github.com/spf13/viper" + "github.com/stretchr/testify/require" ) func resetViperWithJWTSecret(t *testing.T) { @@ -75,6 +76,103 @@ func TestLoadDefaultSchedulingConfig(t *testing.T) { } } +func TestLoadDefaultOpenAIWSConfig(t *testing.T) { + resetViperWithJWTSecret(t) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if !cfg.Gateway.OpenAIWS.Enabled { + t.Fatalf("Gateway.OpenAIWS.Enabled = false, want true") + } + if !cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 { + t.Fatalf("Gateway.OpenAIWS.ResponsesWebsocketsV2 = false, want true") + } + if cfg.Gateway.OpenAIWS.ResponsesWebsockets { + t.Fatalf("Gateway.OpenAIWS.ResponsesWebsockets = true, want false") + } + if !cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled { + t.Fatalf("Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = false, want true") + } + if cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor != 1.0 { + t.Fatalf("Gateway.OpenAIWS.OAuthMaxConnsFactor = %v, want 1.0", cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor) + } + if cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor != 1.0 { + t.Fatalf("Gateway.OpenAIWS.APIKeyMaxConnsFactor = %v, want 1.0", cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor) + } + if cfg.Gateway.OpenAIWS.StickySessionTTLSeconds != 3600 { + t.Fatalf("Gateway.OpenAIWS.StickySessionTTLSeconds = %d, want 3600", cfg.Gateway.OpenAIWS.StickySessionTTLSeconds) + } + if !cfg.Gateway.OpenAIWS.SessionHashReadOldFallback { + t.Fatalf("Gateway.OpenAIWS.SessionHashReadOldFallback = false, want true") + } + if !cfg.Gateway.OpenAIWS.SessionHashDualWriteOld { + t.Fatalf("Gateway.OpenAIWS.SessionHashDualWriteOld = false, want true") + } + if !cfg.Gateway.OpenAIWS.MetadataBridgeEnabled { + t.Fatalf("Gateway.OpenAIWS.MetadataBridgeEnabled = false, want true") + } + if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds != 3600 { + t.Fatalf("Gateway.OpenAIWS.StickyResponseIDTTLSeconds = %d, want 3600", cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds) + } + if cfg.Gateway.OpenAIWS.FallbackCooldownSeconds != 30 { + t.Fatalf("Gateway.OpenAIWS.FallbackCooldownSeconds = %d, want 30", cfg.Gateway.OpenAIWS.FallbackCooldownSeconds) + } + if cfg.Gateway.OpenAIWS.EventFlushBatchSize != 1 { + t.Fatalf("Gateway.OpenAIWS.EventFlushBatchSize = %d, want 1", cfg.Gateway.OpenAIWS.EventFlushBatchSize) + } + if cfg.Gateway.OpenAIWS.EventFlushIntervalMS != 10 { + t.Fatalf("Gateway.OpenAIWS.EventFlushIntervalMS = %d, want 10", cfg.Gateway.OpenAIWS.EventFlushIntervalMS) + } + if cfg.Gateway.OpenAIWS.PrewarmCooldownMS != 300 { + t.Fatalf("Gateway.OpenAIWS.PrewarmCooldownMS = %d, want 300", cfg.Gateway.OpenAIWS.PrewarmCooldownMS) + } + if cfg.Gateway.OpenAIWS.RetryBackoffInitialMS != 120 { + t.Fatalf("Gateway.OpenAIWS.RetryBackoffInitialMS = %d, want 120", cfg.Gateway.OpenAIWS.RetryBackoffInitialMS) + } + if cfg.Gateway.OpenAIWS.RetryBackoffMaxMS != 2000 { + t.Fatalf("Gateway.OpenAIWS.RetryBackoffMaxMS = %d, want 2000", cfg.Gateway.OpenAIWS.RetryBackoffMaxMS) + } + if cfg.Gateway.OpenAIWS.RetryJitterRatio != 0.2 { + t.Fatalf("Gateway.OpenAIWS.RetryJitterRatio = %v, want 0.2", cfg.Gateway.OpenAIWS.RetryJitterRatio) + } + if cfg.Gateway.OpenAIWS.RetryTotalBudgetMS != 5000 { + t.Fatalf("Gateway.OpenAIWS.RetryTotalBudgetMS = %d, want 5000", cfg.Gateway.OpenAIWS.RetryTotalBudgetMS) + } + if cfg.Gateway.OpenAIWS.PayloadLogSampleRate != 0.2 { + t.Fatalf("Gateway.OpenAIWS.PayloadLogSampleRate = %v, want 0.2", cfg.Gateway.OpenAIWS.PayloadLogSampleRate) + } + if !cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn { + t.Fatalf("Gateway.OpenAIWS.StoreDisabledForceNewConn = false, want true") + } + if cfg.Gateway.OpenAIWS.StoreDisabledConnMode != "strict" { + t.Fatalf("Gateway.OpenAIWS.StoreDisabledConnMode = %q, want %q", cfg.Gateway.OpenAIWS.StoreDisabledConnMode, "strict") + } + if cfg.Gateway.OpenAIWS.ModeRouterV2Enabled { + t.Fatalf("Gateway.OpenAIWS.ModeRouterV2Enabled = true, want false") + } + if cfg.Gateway.OpenAIWS.IngressModeDefault != "shared" { + t.Fatalf("Gateway.OpenAIWS.IngressModeDefault = %q, want %q", cfg.Gateway.OpenAIWS.IngressModeDefault, "shared") + } +} + +func TestLoadOpenAIWSStickyTTLCompatibility(t *testing.T) { + resetViperWithJWTSecret(t) + t.Setenv("GATEWAY_OPENAI_WS_STICKY_RESPONSE_ID_TTL_SECONDS", "0") + t.Setenv("GATEWAY_OPENAI_WS_STICKY_PREVIOUS_RESPONSE_TTL_SECONDS", "7200") + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds != 7200 { + t.Fatalf("StickyResponseIDTTLSeconds = %d, want 7200", cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds) + } +} + func TestLoadDefaultIdempotencyConfig(t *testing.T) { resetViperWithJWTSecret(t) @@ -993,6 +1091,16 @@ func TestValidateConfigErrors(t *testing.T) { mutate: func(c *Config) { c.Gateway.StreamKeepaliveInterval = 4 }, wantErr: "gateway.stream_keepalive_interval", }, + { + name: "gateway openai ws oauth max conns factor", + mutate: func(c *Config) { c.Gateway.OpenAIWS.OAuthMaxConnsFactor = 0 }, + wantErr: "gateway.openai_ws.oauth_max_conns_factor", + }, + { + name: "gateway openai ws apikey max conns factor", + mutate: func(c *Config) { c.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 0 }, + wantErr: "gateway.openai_ws.apikey_max_conns_factor", + }, { name: "gateway stream data interval range", mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = 5 }, @@ -1174,6 +1282,165 @@ func TestValidateConfigErrors(t *testing.T) { } } +func TestValidateConfig_OpenAIWSRules(t *testing.T) { + buildValid := func(t *testing.T) *Config { + t.Helper() + resetViperWithJWTSecret(t) + cfg, err := Load() + require.NoError(t, err) + return cfg + } + + t.Run("sticky response id ttl 兼容旧键回填", func(t *testing.T) { + cfg := buildValid(t) + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 0 + cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds = 7200 + + require.NoError(t, cfg.Validate()) + require.Equal(t, 7200, cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds) + }) + + cases := []struct { + name string + mutate func(*Config) + wantErr string + }{ + { + name: "max_conns_per_account 必须为正数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.MaxConnsPerAccount = 0 }, + wantErr: "gateway.openai_ws.max_conns_per_account", + }, + { + name: "min_idle_per_account 不能为负数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.MinIdlePerAccount = -1 }, + wantErr: "gateway.openai_ws.min_idle_per_account", + }, + { + name: "max_idle_per_account 不能为负数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.MaxIdlePerAccount = -1 }, + wantErr: "gateway.openai_ws.max_idle_per_account", + }, + { + name: "min_idle_per_account 不能大于 max_idle_per_account", + mutate: func(c *Config) { + c.Gateway.OpenAIWS.MinIdlePerAccount = 3 + c.Gateway.OpenAIWS.MaxIdlePerAccount = 2 + }, + wantErr: "gateway.openai_ws.min_idle_per_account must be <= max_idle_per_account", + }, + { + name: "max_idle_per_account 不能大于 max_conns_per_account", + mutate: func(c *Config) { + c.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + c.Gateway.OpenAIWS.MinIdlePerAccount = 1 + c.Gateway.OpenAIWS.MaxIdlePerAccount = 3 + }, + wantErr: "gateway.openai_ws.max_idle_per_account must be <= max_conns_per_account", + }, + { + name: "dial_timeout_seconds 必须为正数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.DialTimeoutSeconds = 0 }, + wantErr: "gateway.openai_ws.dial_timeout_seconds", + }, + { + name: "read_timeout_seconds 必须为正数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.ReadTimeoutSeconds = 0 }, + wantErr: "gateway.openai_ws.read_timeout_seconds", + }, + { + name: "write_timeout_seconds 必须为正数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.WriteTimeoutSeconds = 0 }, + wantErr: "gateway.openai_ws.write_timeout_seconds", + }, + { + name: "pool_target_utilization 必须在 (0,1]", + mutate: func(c *Config) { c.Gateway.OpenAIWS.PoolTargetUtilization = 0 }, + wantErr: "gateway.openai_ws.pool_target_utilization", + }, + { + name: "queue_limit_per_conn 必须为正数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.QueueLimitPerConn = 0 }, + wantErr: "gateway.openai_ws.queue_limit_per_conn", + }, + { + name: "fallback_cooldown_seconds 不能为负数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.FallbackCooldownSeconds = -1 }, + wantErr: "gateway.openai_ws.fallback_cooldown_seconds", + }, + { + name: "store_disabled_conn_mode 必须为 strict|adaptive|off", + mutate: func(c *Config) { c.Gateway.OpenAIWS.StoreDisabledConnMode = "invalid" }, + wantErr: "gateway.openai_ws.store_disabled_conn_mode", + }, + { + name: "ingress_mode_default 必须为 off|shared|dedicated", + mutate: func(c *Config) { c.Gateway.OpenAIWS.IngressModeDefault = "invalid" }, + wantErr: "gateway.openai_ws.ingress_mode_default", + }, + { + name: "payload_log_sample_rate 必须在 [0,1] 范围内", + mutate: func(c *Config) { c.Gateway.OpenAIWS.PayloadLogSampleRate = 1.2 }, + wantErr: "gateway.openai_ws.payload_log_sample_rate", + }, + { + name: "retry_total_budget_ms 不能为负数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.RetryTotalBudgetMS = -1 }, + wantErr: "gateway.openai_ws.retry_total_budget_ms", + }, + { + name: "lb_top_k 必须为正数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.LBTopK = 0 }, + wantErr: "gateway.openai_ws.lb_top_k", + }, + { + name: "sticky_session_ttl_seconds 必须为正数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.StickySessionTTLSeconds = 0 }, + wantErr: "gateway.openai_ws.sticky_session_ttl_seconds", + }, + { + name: "sticky_response_id_ttl_seconds 必须为正数", + mutate: func(c *Config) { + c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 0 + c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds = 0 + }, + wantErr: "gateway.openai_ws.sticky_response_id_ttl_seconds", + }, + { + name: "sticky_previous_response_ttl_seconds 不能为负数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds = -1 }, + wantErr: "gateway.openai_ws.sticky_previous_response_ttl_seconds", + }, + { + name: "scheduler_score_weights 不能为负数", + mutate: func(c *Config) { c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = -0.1 }, + wantErr: "gateway.openai_ws.scheduler_score_weights.* must be non-negative", + }, + { + name: "scheduler_score_weights 不能全为 0", + mutate: func(c *Config) { + c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0 + c.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 0 + c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0 + c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0 + c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0 + }, + wantErr: "gateway.openai_ws.scheduler_score_weights must not all be zero", + }, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + cfg := buildValid(t) + tc.mutate(cfg) + + err := cfg.Validate() + require.Error(t, err) + require.Contains(t, err.Error(), tc.wantErr) + }) + } +} + func TestValidateConfig_AutoScaleDisabledIgnoreAutoScaleFields(t *testing.T) { resetViperWithJWTSecret(t) cfg, err := Load() diff --git a/backend/internal/domain/constants.go b/backend/internal/domain/constants.go index d56dfa86..d7bb50fc 100644 --- a/backend/internal/domain/constants.go +++ b/backend/internal/domain/constants.go @@ -104,6 +104,9 @@ var DefaultAntigravityModelMapping = map[string]string{ "gemini-3.1-flash-image": "gemini-3.1-flash-image", // Gemini 3.1 image preview 映射 "gemini-3.1-flash-image-preview": "gemini-3.1-flash-image", + // Gemini 3 image 兼容映射(向 3.1 image 迁移) + "gemini-3-pro-image": "gemini-3.1-flash-image", + "gemini-3-pro-image-preview": "gemini-3.1-flash-image", // 其他官方模型 "gpt-oss-120b-medium": "gpt-oss-120b-medium", "tab_flash_lite_preview": "tab_flash_lite_preview", diff --git a/backend/internal/domain/constants_test.go b/backend/internal/domain/constants_test.go new file mode 100644 index 00000000..29605ac6 --- /dev/null +++ b/backend/internal/domain/constants_test.go @@ -0,0 +1,24 @@ +package domain + +import "testing" + +func TestDefaultAntigravityModelMapping_ImageCompatibilityAliases(t *testing.T) { + t.Parallel() + + cases := map[string]string{ + "gemini-3.1-flash-image": "gemini-3.1-flash-image", + "gemini-3.1-flash-image-preview": "gemini-3.1-flash-image", + "gemini-3-pro-image": "gemini-3.1-flash-image", + "gemini-3-pro-image-preview": "gemini-3.1-flash-image", + } + + for from, want := range cases { + got, ok := DefaultAntigravityModelMapping[from] + if !ok { + t.Fatalf("expected mapping for %q to exist", from) + } + if got != want { + t.Fatalf("unexpected mapping for %q: got %q want %q", from, got, want) + } + } +} diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 0a012b8f..9b732f9c 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -1337,6 +1337,34 @@ func (h *AccountHandler) GetTodayStats(c *gin.Context) { response.Success(c, stats) } +// BatchTodayStatsRequest 批量今日统计请求体。 +type BatchTodayStatsRequest struct { + AccountIDs []int64 `json:"account_ids" binding:"required"` +} + +// GetBatchTodayStats 批量获取多个账号的今日统计。 +// POST /api/v1/admin/accounts/today-stats/batch +func (h *AccountHandler) GetBatchTodayStats(c *gin.Context) { + var req BatchTodayStatsRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if len(req.AccountIDs) == 0 { + response.Success(c, gin.H{"stats": map[string]any{}}) + return + } + + stats, err := h.accountUsageService.GetTodayStatsBatch(c.Request.Context(), req.AccountIDs) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"stats": stats}) +} + // SetSchedulableRequest represents the request body for setting schedulable status type SetSchedulableRequest struct { Schedulable bool `json:"schedulable"` diff --git a/backend/internal/handler/admin/dashboard_handler.go b/backend/internal/handler/admin/dashboard_handler.go index fab66c04..7e318592 100644 --- a/backend/internal/handler/admin/dashboard_handler.go +++ b/backend/internal/handler/admin/dashboard_handler.go @@ -3,6 +3,7 @@ package admin import ( "errors" "strconv" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/response" @@ -186,7 +187,7 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) { // GetUsageTrend handles getting usage trend data // GET /api/v1/admin/dashboard/trend -// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream, billing_type +// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, request_type, stream, billing_type func (h *DashboardHandler) GetUsageTrend(c *gin.Context) { startTime, endTime := parseTimeRange(c) granularity := c.DefaultQuery("granularity", "day") @@ -194,6 +195,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) { // Parse optional filter params var userID, apiKeyID, accountID, groupID int64 var model string + var requestType *int16 var stream *bool var billingType *int8 @@ -220,9 +222,20 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) { if modelStr := c.Query("model"); modelStr != "" { model = modelStr } - if streamStr := c.Query("stream"); streamStr != "" { + if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" { + parsed, err := service.ParseUsageRequestType(requestTypeStr) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + value := int16(parsed) + requestType = &value + } else if streamStr := c.Query("stream"); streamStr != "" { if streamVal, err := strconv.ParseBool(streamStr); err == nil { stream = &streamVal + } else { + response.BadRequest(c, "Invalid stream value, use true or false") + return } } if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" { @@ -235,7 +248,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) { } } - trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType) + trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType) if err != nil { response.Error(c, 500, "Failed to get usage trend") return @@ -251,12 +264,13 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) { // GetModelStats handles getting model usage statistics // GET /api/v1/admin/dashboard/models -// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream, billing_type +// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, request_type, stream, billing_type func (h *DashboardHandler) GetModelStats(c *gin.Context) { startTime, endTime := parseTimeRange(c) // Parse optional filter params var userID, apiKeyID, accountID, groupID int64 + var requestType *int16 var stream *bool var billingType *int8 @@ -280,9 +294,20 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) { groupID = id } } - if streamStr := c.Query("stream"); streamStr != "" { + if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" { + parsed, err := service.ParseUsageRequestType(requestTypeStr) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + value := int16(parsed) + requestType = &value + } else if streamStr := c.Query("stream"); streamStr != "" { if streamVal, err := strconv.ParseBool(streamStr); err == nil { stream = &streamVal + } else { + response.BadRequest(c, "Invalid stream value, use true or false") + return } } if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" { @@ -295,7 +320,7 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) { } } - stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType) + stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) if err != nil { response.Error(c, 500, "Failed to get model statistics") return diff --git a/backend/internal/handler/admin/dashboard_handler_request_type_test.go b/backend/internal/handler/admin/dashboard_handler_request_type_test.go new file mode 100644 index 00000000..72af6b45 --- /dev/null +++ b/backend/internal/handler/admin/dashboard_handler_request_type_test.go @@ -0,0 +1,132 @@ +package admin + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type dashboardUsageRepoCapture struct { + service.UsageLogRepository + trendRequestType *int16 + trendStream *bool + modelRequestType *int16 + modelStream *bool +} + +func (s *dashboardUsageRepoCapture) GetUsageTrendWithFilters( + ctx context.Context, + startTime, endTime time.Time, + granularity string, + userID, apiKeyID, accountID, groupID int64, + model string, + requestType *int16, + stream *bool, + billingType *int8, +) ([]usagestats.TrendDataPoint, error) { + s.trendRequestType = requestType + s.trendStream = stream + return []usagestats.TrendDataPoint{}, nil +} + +func (s *dashboardUsageRepoCapture) GetModelStatsWithFilters( + ctx context.Context, + startTime, endTime time.Time, + userID, apiKeyID, accountID, groupID int64, + requestType *int16, + stream *bool, + billingType *int8, +) ([]usagestats.ModelStat, error) { + s.modelRequestType = requestType + s.modelStream = stream + return []usagestats.ModelStat{}, nil +} + +func newDashboardRequestTypeTestRouter(repo *dashboardUsageRepoCapture) *gin.Engine { + gin.SetMode(gin.TestMode) + dashboardSvc := service.NewDashboardService(repo, nil, nil, nil) + handler := NewDashboardHandler(dashboardSvc, nil) + router := gin.New() + router.GET("/admin/dashboard/trend", handler.GetUsageTrend) + router.GET("/admin/dashboard/models", handler.GetModelStats) + return router +} + +func TestDashboardTrendRequestTypePriority(t *testing.T) { + repo := &dashboardUsageRepoCapture{} + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?request_type=ws_v2&stream=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.NotNil(t, repo.trendRequestType) + require.Equal(t, int16(service.RequestTypeWSV2), *repo.trendRequestType) + require.Nil(t, repo.trendStream) +} + +func TestDashboardTrendInvalidRequestType(t *testing.T) { + repo := &dashboardUsageRepoCapture{} + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?request_type=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestDashboardTrendInvalidStream(t *testing.T) { + repo := &dashboardUsageRepoCapture{} + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/trend?stream=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestDashboardModelStatsRequestTypePriority(t *testing.T) { + repo := &dashboardUsageRepoCapture{} + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?request_type=sync&stream=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.NotNil(t, repo.modelRequestType) + require.Equal(t, int16(service.RequestTypeSync), *repo.modelRequestType) + require.Nil(t, repo.modelStream) +} + +func TestDashboardModelStatsInvalidRequestType(t *testing.T) { + repo := &dashboardUsageRepoCapture{} + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?request_type=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestDashboardModelStatsInvalidStream(t *testing.T) { + repo := &dashboardUsageRepoCapture{} + router := newDashboardRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/dashboard/models?stream=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} diff --git a/backend/internal/handler/admin/data_management_handler.go b/backend/internal/handler/admin/data_management_handler.go new file mode 100644 index 00000000..69a0b5b5 --- /dev/null +++ b/backend/internal/handler/admin/data_management_handler.go @@ -0,0 +1,523 @@ +package admin + +import ( + "strconv" + "strings" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + + "github.com/gin-gonic/gin" +) + +type DataManagementHandler struct { + dataManagementService *service.DataManagementService +} + +func NewDataManagementHandler(dataManagementService *service.DataManagementService) *DataManagementHandler { + return &DataManagementHandler{dataManagementService: dataManagementService} +} + +type TestS3ConnectionRequest struct { + Endpoint string `json:"endpoint"` + Region string `json:"region" binding:"required"` + Bucket string `json:"bucket" binding:"required"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + UseSSL bool `json:"use_ssl"` +} + +type CreateBackupJobRequest struct { + BackupType string `json:"backup_type" binding:"required,oneof=postgres redis full"` + UploadToS3 bool `json:"upload_to_s3"` + S3ProfileID string `json:"s3_profile_id"` + PostgresID string `json:"postgres_profile_id"` + RedisID string `json:"redis_profile_id"` + IdempotencyKey string `json:"idempotency_key"` +} + +type CreateSourceProfileRequest struct { + ProfileID string `json:"profile_id" binding:"required"` + Name string `json:"name" binding:"required"` + Config service.DataManagementSourceConfig `json:"config" binding:"required"` + SetActive bool `json:"set_active"` +} + +type UpdateSourceProfileRequest struct { + Name string `json:"name" binding:"required"` + Config service.DataManagementSourceConfig `json:"config" binding:"required"` +} + +type CreateS3ProfileRequest struct { + ProfileID string `json:"profile_id" binding:"required"` + Name string `json:"name" binding:"required"` + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + UseSSL bool `json:"use_ssl"` + SetActive bool `json:"set_active"` +} + +type UpdateS3ProfileRequest struct { + Name string `json:"name" binding:"required"` + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + UseSSL bool `json:"use_ssl"` +} + +func (h *DataManagementHandler) GetAgentHealth(c *gin.Context) { + health := h.getAgentHealth(c) + payload := gin.H{ + "enabled": health.Enabled, + "reason": health.Reason, + "socket_path": health.SocketPath, + } + if health.Agent != nil { + payload["agent"] = gin.H{ + "status": health.Agent.Status, + "version": health.Agent.Version, + "uptime_seconds": health.Agent.UptimeSeconds, + } + } + response.Success(c, payload) +} + +func (h *DataManagementHandler) GetConfig(c *gin.Context) { + if !h.requireAgentEnabled(c) { + return + } + cfg, err := h.dataManagementService.GetConfig(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, cfg) +} + +func (h *DataManagementHandler) UpdateConfig(c *gin.Context) { + var req service.DataManagementConfig + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if !h.requireAgentEnabled(c) { + return + } + cfg, err := h.dataManagementService.UpdateConfig(c.Request.Context(), req) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, cfg) +} + +func (h *DataManagementHandler) TestS3(c *gin.Context) { + var req TestS3ConnectionRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if !h.requireAgentEnabled(c) { + return + } + result, err := h.dataManagementService.ValidateS3(c.Request.Context(), service.DataManagementS3Config{ + Enabled: true, + Endpoint: req.Endpoint, + Region: req.Region, + Bucket: req.Bucket, + AccessKeyID: req.AccessKeyID, + SecretAccessKey: req.SecretAccessKey, + Prefix: req.Prefix, + ForcePathStyle: req.ForcePathStyle, + UseSSL: req.UseSSL, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"ok": result.OK, "message": result.Message}) +} + +func (h *DataManagementHandler) CreateBackupJob(c *gin.Context) { + var req CreateBackupJobRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + req.IdempotencyKey = normalizeBackupIdempotencyKey(c.GetHeader("X-Idempotency-Key"), req.IdempotencyKey) + if !h.requireAgentEnabled(c) { + return + } + + triggeredBy := "admin:unknown" + if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok { + triggeredBy = "admin:" + strconv.FormatInt(subject.UserID, 10) + } + job, err := h.dataManagementService.CreateBackupJob(c.Request.Context(), service.DataManagementCreateBackupJobInput{ + BackupType: req.BackupType, + UploadToS3: req.UploadToS3, + S3ProfileID: req.S3ProfileID, + PostgresID: req.PostgresID, + RedisID: req.RedisID, + TriggeredBy: triggeredBy, + IdempotencyKey: req.IdempotencyKey, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"job_id": job.JobID, "status": job.Status}) +} + +func (h *DataManagementHandler) ListSourceProfiles(c *gin.Context) { + sourceType := strings.TrimSpace(c.Param("source_type")) + if sourceType == "" { + response.BadRequest(c, "Invalid source_type") + return + } + if sourceType != "postgres" && sourceType != "redis" { + response.BadRequest(c, "source_type must be postgres or redis") + return + } + + if !h.requireAgentEnabled(c) { + return + } + items, err := h.dataManagementService.ListSourceProfiles(c.Request.Context(), sourceType) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"items": items}) +} + +func (h *DataManagementHandler) CreateSourceProfile(c *gin.Context) { + sourceType := strings.TrimSpace(c.Param("source_type")) + if sourceType != "postgres" && sourceType != "redis" { + response.BadRequest(c, "source_type must be postgres or redis") + return + } + + var req CreateSourceProfileRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if !h.requireAgentEnabled(c) { + return + } + profile, err := h.dataManagementService.CreateSourceProfile(c.Request.Context(), service.DataManagementCreateSourceProfileInput{ + SourceType: sourceType, + ProfileID: req.ProfileID, + Name: req.Name, + Config: req.Config, + SetActive: req.SetActive, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, profile) +} + +func (h *DataManagementHandler) UpdateSourceProfile(c *gin.Context) { + sourceType := strings.TrimSpace(c.Param("source_type")) + if sourceType != "postgres" && sourceType != "redis" { + response.BadRequest(c, "source_type must be postgres or redis") + return + } + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Invalid profile_id") + return + } + + var req UpdateSourceProfileRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if !h.requireAgentEnabled(c) { + return + } + profile, err := h.dataManagementService.UpdateSourceProfile(c.Request.Context(), service.DataManagementUpdateSourceProfileInput{ + SourceType: sourceType, + ProfileID: profileID, + Name: req.Name, + Config: req.Config, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, profile) +} + +func (h *DataManagementHandler) DeleteSourceProfile(c *gin.Context) { + sourceType := strings.TrimSpace(c.Param("source_type")) + if sourceType != "postgres" && sourceType != "redis" { + response.BadRequest(c, "source_type must be postgres or redis") + return + } + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Invalid profile_id") + return + } + + if !h.requireAgentEnabled(c) { + return + } + if err := h.dataManagementService.DeleteSourceProfile(c.Request.Context(), sourceType, profileID); err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"deleted": true}) +} + +func (h *DataManagementHandler) SetActiveSourceProfile(c *gin.Context) { + sourceType := strings.TrimSpace(c.Param("source_type")) + if sourceType != "postgres" && sourceType != "redis" { + response.BadRequest(c, "source_type must be postgres or redis") + return + } + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Invalid profile_id") + return + } + + if !h.requireAgentEnabled(c) { + return + } + profile, err := h.dataManagementService.SetActiveSourceProfile(c.Request.Context(), sourceType, profileID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, profile) +} + +func (h *DataManagementHandler) ListS3Profiles(c *gin.Context) { + if !h.requireAgentEnabled(c) { + return + } + + items, err := h.dataManagementService.ListS3Profiles(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"items": items}) +} + +func (h *DataManagementHandler) CreateS3Profile(c *gin.Context) { + var req CreateS3ProfileRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if !h.requireAgentEnabled(c) { + return + } + + profile, err := h.dataManagementService.CreateS3Profile(c.Request.Context(), service.DataManagementCreateS3ProfileInput{ + ProfileID: req.ProfileID, + Name: req.Name, + SetActive: req.SetActive, + S3: service.DataManagementS3Config{ + Enabled: req.Enabled, + Endpoint: req.Endpoint, + Region: req.Region, + Bucket: req.Bucket, + AccessKeyID: req.AccessKeyID, + SecretAccessKey: req.SecretAccessKey, + Prefix: req.Prefix, + ForcePathStyle: req.ForcePathStyle, + UseSSL: req.UseSSL, + }, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, profile) +} + +func (h *DataManagementHandler) UpdateS3Profile(c *gin.Context) { + var req UpdateS3ProfileRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Invalid profile_id") + return + } + + if !h.requireAgentEnabled(c) { + return + } + + profile, err := h.dataManagementService.UpdateS3Profile(c.Request.Context(), service.DataManagementUpdateS3ProfileInput{ + ProfileID: profileID, + Name: req.Name, + S3: service.DataManagementS3Config{ + Enabled: req.Enabled, + Endpoint: req.Endpoint, + Region: req.Region, + Bucket: req.Bucket, + AccessKeyID: req.AccessKeyID, + SecretAccessKey: req.SecretAccessKey, + Prefix: req.Prefix, + ForcePathStyle: req.ForcePathStyle, + UseSSL: req.UseSSL, + }, + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, profile) +} + +func (h *DataManagementHandler) DeleteS3Profile(c *gin.Context) { + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Invalid profile_id") + return + } + + if !h.requireAgentEnabled(c) { + return + } + if err := h.dataManagementService.DeleteS3Profile(c.Request.Context(), profileID); err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"deleted": true}) +} + +func (h *DataManagementHandler) SetActiveS3Profile(c *gin.Context) { + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Invalid profile_id") + return + } + + if !h.requireAgentEnabled(c) { + return + } + profile, err := h.dataManagementService.SetActiveS3Profile(c.Request.Context(), profileID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, profile) +} + +func (h *DataManagementHandler) ListBackupJobs(c *gin.Context) { + if !h.requireAgentEnabled(c) { + return + } + + pageSize := int32(20) + if raw := strings.TrimSpace(c.Query("page_size")); raw != "" { + v, err := strconv.Atoi(raw) + if err != nil || v <= 0 { + response.BadRequest(c, "Invalid page_size") + return + } + pageSize = int32(v) + } + + result, err := h.dataManagementService.ListBackupJobs(c.Request.Context(), service.DataManagementListBackupJobsInput{ + PageSize: pageSize, + PageToken: c.Query("page_token"), + Status: c.Query("status"), + BackupType: c.Query("backup_type"), + }) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, result) +} + +func (h *DataManagementHandler) GetBackupJob(c *gin.Context) { + jobID := strings.TrimSpace(c.Param("job_id")) + if jobID == "" { + response.BadRequest(c, "Invalid backup job ID") + return + } + + if !h.requireAgentEnabled(c) { + return + } + job, err := h.dataManagementService.GetBackupJob(c.Request.Context(), jobID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, job) +} + +func (h *DataManagementHandler) requireAgentEnabled(c *gin.Context) bool { + if h.dataManagementService == nil { + err := infraerrors.ServiceUnavailable( + service.DataManagementAgentUnavailableReason, + "data management agent service is not configured", + ).WithMetadata(map[string]string{"socket_path": service.DefaultDataManagementAgentSocketPath}) + response.ErrorFrom(c, err) + return false + } + + if err := h.dataManagementService.EnsureAgentEnabled(c.Request.Context()); err != nil { + response.ErrorFrom(c, err) + return false + } + + return true +} + +func (h *DataManagementHandler) getAgentHealth(c *gin.Context) service.DataManagementAgentHealth { + if h.dataManagementService == nil { + return service.DataManagementAgentHealth{ + Enabled: false, + Reason: service.DataManagementAgentUnavailableReason, + SocketPath: service.DefaultDataManagementAgentSocketPath, + } + } + return h.dataManagementService.GetAgentHealth(c.Request.Context()) +} + +func normalizeBackupIdempotencyKey(headerValue, bodyValue string) string { + headerKey := strings.TrimSpace(headerValue) + if headerKey != "" { + return headerKey + } + return strings.TrimSpace(bodyValue) +} diff --git a/backend/internal/handler/admin/data_management_handler_test.go b/backend/internal/handler/admin/data_management_handler_test.go new file mode 100644 index 00000000..ce8ee835 --- /dev/null +++ b/backend/internal/handler/admin/data_management_handler_test.go @@ -0,0 +1,78 @@ +package admin + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "path/filepath" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type apiEnvelope struct { + Code int `json:"code"` + Message string `json:"message"` + Reason string `json:"reason"` + Data json.RawMessage `json:"data"` +} + +func TestDataManagementHandler_AgentHealthAlways200(t *testing.T) { + gin.SetMode(gin.TestMode) + + svc := service.NewDataManagementServiceWithOptions(filepath.Join(t.TempDir(), "missing.sock"), 50*time.Millisecond) + h := NewDataManagementHandler(svc) + + r := gin.New() + r.GET("/api/v1/admin/data-management/agent/health", h.GetAgentHealth) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/data-management/agent/health", nil) + r.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + + var envelope apiEnvelope + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &envelope)) + require.Equal(t, 0, envelope.Code) + + var data struct { + Enabled bool `json:"enabled"` + Reason string `json:"reason"` + SocketPath string `json:"socket_path"` + } + require.NoError(t, json.Unmarshal(envelope.Data, &data)) + require.False(t, data.Enabled) + require.Equal(t, service.DataManagementDeprecatedReason, data.Reason) + require.Equal(t, svc.SocketPath(), data.SocketPath) +} + +func TestDataManagementHandler_NonHealthRouteReturns503WhenDisabled(t *testing.T) { + gin.SetMode(gin.TestMode) + + svc := service.NewDataManagementServiceWithOptions(filepath.Join(t.TempDir(), "missing.sock"), 50*time.Millisecond) + h := NewDataManagementHandler(svc) + + r := gin.New() + r.GET("/api/v1/admin/data-management/config", h.GetConfig) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/data-management/config", nil) + r.ServeHTTP(rec, req) + + require.Equal(t, http.StatusServiceUnavailable, rec.Code) + + var envelope apiEnvelope + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &envelope)) + require.Equal(t, http.StatusServiceUnavailable, envelope.Code) + require.Equal(t, service.DataManagementDeprecatedReason, envelope.Reason) +} + +func TestNormalizeBackupIdempotencyKey(t *testing.T) { + require.Equal(t, "from-header", normalizeBackupIdempotencyKey("from-header", "from-body")) + require.Equal(t, "from-body", normalizeBackupIdempotencyKey(" ", " from-body ")) + require.Equal(t, "", normalizeBackupIdempotencyKey("", "")) +} diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index 25ff3c96..1edf4dcc 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -51,6 +51,8 @@ type CreateGroupRequest struct { MCPXMLInject *bool `json:"mcp_xml_inject"` // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes []string `json:"supported_model_scopes"` + // Sora 存储配额 + SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"` // 从指定分组复制账号(创建后自动绑定) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` } @@ -84,6 +86,8 @@ type UpdateGroupRequest struct { MCPXMLInject *bool `json:"mcp_xml_inject"` // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes *[]string `json:"supported_model_scopes"` + // Sora 存储配额 + SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"` // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` } @@ -198,6 +202,7 @@ func (h *GroupHandler) Create(c *gin.Context) { ModelRoutingEnabled: req.ModelRoutingEnabled, MCPXMLInject: req.MCPXMLInject, SupportedModelScopes: req.SupportedModelScopes, + SoraStorageQuotaBytes: req.SoraStorageQuotaBytes, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) if err != nil { @@ -248,6 +253,7 @@ func (h *GroupHandler) Update(c *gin.Context) { ModelRoutingEnabled: req.ModelRoutingEnabled, MCPXMLInject: req.MCPXMLInject, SupportedModelScopes: req.SupportedModelScopes, + SoraStorageQuotaBytes: req.SoraStorageQuotaBytes, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) if err != nil { diff --git a/backend/internal/handler/admin/openai_oauth_handler.go b/backend/internal/handler/admin/openai_oauth_handler.go index cf43f89e..5d354fd3 100644 --- a/backend/internal/handler/admin/openai_oauth_handler.go +++ b/backend/internal/handler/admin/openai_oauth_handler.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/Wei-Shaw/sub2api/internal/handler/dto" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/service" @@ -47,7 +48,12 @@ func (h *OpenAIOAuthHandler) GenerateAuthURL(c *gin.Context) { req = OpenAIGenerateAuthURLRequest{} } - result, err := h.openaiOAuthService.GenerateAuthURL(c.Request.Context(), req.ProxyID, req.RedirectURI) + result, err := h.openaiOAuthService.GenerateAuthURL( + c.Request.Context(), + req.ProxyID, + req.RedirectURI, + oauthPlatformFromPath(c), + ) if err != nil { response.ErrorFrom(c, err) return @@ -123,7 +129,14 @@ func (h *OpenAIOAuthHandler) RefreshToken(c *gin.Context) { } } - tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, strings.TrimSpace(req.ClientID)) + // 未指定 client_id 时,根据请求路径平台自动设置默认值,避免 repository 层盲猜 + clientID := strings.TrimSpace(req.ClientID) + if clientID == "" { + platform := oauthPlatformFromPath(c) + clientID, _ = openai.OAuthClientConfigByPlatform(platform) + } + + tokenInfo, err := h.openaiOAuthService.RefreshTokenWithClientID(c.Request.Context(), refreshToken, proxyURL, clientID) if err != nil { response.ErrorFrom(c, err) return diff --git a/backend/internal/handler/admin/ops_ws_handler.go b/backend/internal/handler/admin/ops_ws_handler.go index c030d303..75fd7ea0 100644 --- a/backend/internal/handler/admin/ops_ws_handler.go +++ b/backend/internal/handler/admin/ops_ws_handler.go @@ -62,7 +62,8 @@ const ( ) var wsConnCount atomic.Int32 -var wsConnCountByIP sync.Map // map[string]*atomic.Int32 +var wsConnCountByIPMu sync.Mutex +var wsConnCountByIP = make(map[string]int32) const qpsWSIdleStopDelay = 30 * time.Second @@ -389,42 +390,31 @@ func tryAcquireOpsWSIPSlot(clientIP string, limit int32) bool { if strings.TrimSpace(clientIP) == "" || limit <= 0 { return true } - - v, _ := wsConnCountByIP.LoadOrStore(clientIP, &atomic.Int32{}) - counter, ok := v.(*atomic.Int32) - if !ok { + wsConnCountByIPMu.Lock() + defer wsConnCountByIPMu.Unlock() + current := wsConnCountByIP[clientIP] + if current >= limit { return false } - - for { - current := counter.Load() - if current >= limit { - return false - } - if counter.CompareAndSwap(current, current+1) { - return true - } - } + wsConnCountByIP[clientIP] = current + 1 + return true } func releaseOpsWSIPSlot(clientIP string) { if strings.TrimSpace(clientIP) == "" { return } - - v, ok := wsConnCountByIP.Load(clientIP) + wsConnCountByIPMu.Lock() + defer wsConnCountByIPMu.Unlock() + current, ok := wsConnCountByIP[clientIP] if !ok { return } - counter, ok := v.(*atomic.Int32) - if !ok { + if current <= 1 { + delete(wsConnCountByIP, clientIP) return } - next := counter.Add(-1) - if next <= 0 { - // Best-effort cleanup; safe even if a new slot was acquired concurrently. - wsConnCountByIP.Delete(clientIP) - } + wsConnCountByIP[clientIP] = current - 1 } func handleQPSWebSocket(parentCtx context.Context, conn *websocket.Conn) { diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 1e723ee5..c7b93497 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -1,6 +1,7 @@ package admin import ( + "fmt" "log" "strings" "time" @@ -20,15 +21,17 @@ type SettingHandler struct { emailService *service.EmailService turnstileService *service.TurnstileService opsService *service.OpsService + soraS3Storage *service.SoraS3Storage } // NewSettingHandler 创建系统设置处理器 -func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService) *SettingHandler { +func NewSettingHandler(settingService *service.SettingService, emailService *service.EmailService, turnstileService *service.TurnstileService, opsService *service.OpsService, soraS3Storage *service.SoraS3Storage) *SettingHandler { return &SettingHandler{ settingService: settingService, emailService: emailService, turnstileService: turnstileService, opsService: opsService, + soraS3Storage: soraS3Storage, } } @@ -76,6 +79,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { HideCcsImportButton: settings.HideCcsImportButton, PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, + SoraClientEnabled: settings.SoraClientEnabled, DefaultConcurrency: settings.DefaultConcurrency, DefaultBalance: settings.DefaultBalance, EnableModelFallback: settings.EnableModelFallback, @@ -133,6 +137,7 @@ type UpdateSettingsRequest struct { HideCcsImportButton bool `json:"hide_ccs_import_button"` PurchaseSubscriptionEnabled *bool `json:"purchase_subscription_enabled"` PurchaseSubscriptionURL *string `json:"purchase_subscription_url"` + SoraClientEnabled bool `json:"sora_client_enabled"` // 默认配置 DefaultConcurrency int `json:"default_concurrency"` @@ -319,6 +324,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { HideCcsImportButton: req.HideCcsImportButton, PurchaseSubscriptionEnabled: purchaseEnabled, PurchaseSubscriptionURL: purchaseURL, + SoraClientEnabled: req.SoraClientEnabled, DefaultConcurrency: req.DefaultConcurrency, DefaultBalance: req.DefaultBalance, EnableModelFallback: req.EnableModelFallback, @@ -400,6 +406,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { HideCcsImportButton: updatedSettings.HideCcsImportButton, PurchaseSubscriptionEnabled: updatedSettings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: updatedSettings.PurchaseSubscriptionURL, + SoraClientEnabled: updatedSettings.SoraClientEnabled, DefaultConcurrency: updatedSettings.DefaultConcurrency, DefaultBalance: updatedSettings.DefaultBalance, EnableModelFallback: updatedSettings.EnableModelFallback, @@ -750,6 +757,384 @@ func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) { }) } +func toSoraS3SettingsDTO(settings *service.SoraS3Settings) dto.SoraS3Settings { + if settings == nil { + return dto.SoraS3Settings{} + } + return dto.SoraS3Settings{ + Enabled: settings.Enabled, + Endpoint: settings.Endpoint, + Region: settings.Region, + Bucket: settings.Bucket, + AccessKeyID: settings.AccessKeyID, + SecretAccessKeyConfigured: settings.SecretAccessKeyConfigured, + Prefix: settings.Prefix, + ForcePathStyle: settings.ForcePathStyle, + CDNURL: settings.CDNURL, + DefaultStorageQuotaBytes: settings.DefaultStorageQuotaBytes, + } +} + +func toSoraS3ProfileDTO(profile service.SoraS3Profile) dto.SoraS3Profile { + return dto.SoraS3Profile{ + ProfileID: profile.ProfileID, + Name: profile.Name, + IsActive: profile.IsActive, + Enabled: profile.Enabled, + Endpoint: profile.Endpoint, + Region: profile.Region, + Bucket: profile.Bucket, + AccessKeyID: profile.AccessKeyID, + SecretAccessKeyConfigured: profile.SecretAccessKeyConfigured, + Prefix: profile.Prefix, + ForcePathStyle: profile.ForcePathStyle, + CDNURL: profile.CDNURL, + DefaultStorageQuotaBytes: profile.DefaultStorageQuotaBytes, + UpdatedAt: profile.UpdatedAt, + } +} + +func validateSoraS3RequiredWhenEnabled(enabled bool, endpoint, bucket, accessKeyID, secretAccessKey string, hasStoredSecret bool) error { + if !enabled { + return nil + } + if strings.TrimSpace(endpoint) == "" { + return fmt.Errorf("S3 Endpoint is required when enabled") + } + if strings.TrimSpace(bucket) == "" { + return fmt.Errorf("S3 Bucket is required when enabled") + } + if strings.TrimSpace(accessKeyID) == "" { + return fmt.Errorf("S3 Access Key ID is required when enabled") + } + if strings.TrimSpace(secretAccessKey) != "" || hasStoredSecret { + return nil + } + return fmt.Errorf("S3 Secret Access Key is required when enabled") +} + +func findSoraS3ProfileByID(items []service.SoraS3Profile, profileID string) *service.SoraS3Profile { + for idx := range items { + if items[idx].ProfileID == profileID { + return &items[idx] + } + } + return nil +} + +// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置接口) +// GET /api/v1/admin/settings/sora-s3 +func (h *SettingHandler) GetSoraS3Settings(c *gin.Context) { + settings, err := h.settingService.GetSoraS3Settings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, toSoraS3SettingsDTO(settings)) +} + +// ListSoraS3Profiles 获取 Sora S3 多配置 +// GET /api/v1/admin/settings/sora-s3/profiles +func (h *SettingHandler) ListSoraS3Profiles(c *gin.Context) { + result, err := h.settingService.ListSoraS3Profiles(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + items := make([]dto.SoraS3Profile, 0, len(result.Items)) + for idx := range result.Items { + items = append(items, toSoraS3ProfileDTO(result.Items[idx])) + } + response.Success(c, dto.ListSoraS3ProfilesResponse{ + ActiveProfileID: result.ActiveProfileID, + Items: items, + }) +} + +// UpdateSoraS3SettingsRequest 更新/测试 Sora S3 配置请求(兼容旧接口) +type UpdateSoraS3SettingsRequest struct { + ProfileID string `json:"profile_id"` + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + CDNURL string `json:"cdn_url"` + DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` +} + +type CreateSoraS3ProfileRequest struct { + ProfileID string `json:"profile_id"` + Name string `json:"name"` + SetActive bool `json:"set_active"` + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + CDNURL string `json:"cdn_url"` + DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` +} + +type UpdateSoraS3ProfileRequest struct { + Name string `json:"name"` + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + CDNURL string `json:"cdn_url"` + DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` +} + +// CreateSoraS3Profile 创建 Sora S3 配置 +// POST /api/v1/admin/settings/sora-s3/profiles +func (h *SettingHandler) CreateSoraS3Profile(c *gin.Context) { + var req CreateSoraS3ProfileRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if req.DefaultStorageQuotaBytes < 0 { + req.DefaultStorageQuotaBytes = 0 + } + if strings.TrimSpace(req.Name) == "" { + response.BadRequest(c, "Name is required") + return + } + if strings.TrimSpace(req.ProfileID) == "" { + response.BadRequest(c, "Profile ID is required") + return + } + if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, false); err != nil { + response.BadRequest(c, err.Error()) + return + } + + created, err := h.settingService.CreateSoraS3Profile(c.Request.Context(), &service.SoraS3Profile{ + ProfileID: req.ProfileID, + Name: req.Name, + Enabled: req.Enabled, + Endpoint: req.Endpoint, + Region: req.Region, + Bucket: req.Bucket, + AccessKeyID: req.AccessKeyID, + SecretAccessKey: req.SecretAccessKey, + Prefix: req.Prefix, + ForcePathStyle: req.ForcePathStyle, + CDNURL: req.CDNURL, + DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes, + }, req.SetActive) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, toSoraS3ProfileDTO(*created)) +} + +// UpdateSoraS3Profile 更新 Sora S3 配置 +// PUT /api/v1/admin/settings/sora-s3/profiles/:profile_id +func (h *SettingHandler) UpdateSoraS3Profile(c *gin.Context) { + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Profile ID is required") + return + } + + var req UpdateSoraS3ProfileRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if req.DefaultStorageQuotaBytes < 0 { + req.DefaultStorageQuotaBytes = 0 + } + if strings.TrimSpace(req.Name) == "" { + response.BadRequest(c, "Name is required") + return + } + + existingList, err := h.settingService.ListSoraS3Profiles(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + existing := findSoraS3ProfileByID(existingList.Items, profileID) + if existing == nil { + response.ErrorFrom(c, service.ErrSoraS3ProfileNotFound) + return + } + if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil { + response.BadRequest(c, err.Error()) + return + } + + updated, updateErr := h.settingService.UpdateSoraS3Profile(c.Request.Context(), profileID, &service.SoraS3Profile{ + Name: req.Name, + Enabled: req.Enabled, + Endpoint: req.Endpoint, + Region: req.Region, + Bucket: req.Bucket, + AccessKeyID: req.AccessKeyID, + SecretAccessKey: req.SecretAccessKey, + Prefix: req.Prefix, + ForcePathStyle: req.ForcePathStyle, + CDNURL: req.CDNURL, + DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes, + }) + if updateErr != nil { + response.ErrorFrom(c, updateErr) + return + } + + response.Success(c, toSoraS3ProfileDTO(*updated)) +} + +// DeleteSoraS3Profile 删除 Sora S3 配置 +// DELETE /api/v1/admin/settings/sora-s3/profiles/:profile_id +func (h *SettingHandler) DeleteSoraS3Profile(c *gin.Context) { + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Profile ID is required") + return + } + if err := h.settingService.DeleteSoraS3Profile(c.Request.Context(), profileID); err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, gin.H{"deleted": true}) +} + +// SetActiveSoraS3Profile 切换激活 Sora S3 配置 +// POST /api/v1/admin/settings/sora-s3/profiles/:profile_id/activate +func (h *SettingHandler) SetActiveSoraS3Profile(c *gin.Context) { + profileID := strings.TrimSpace(c.Param("profile_id")) + if profileID == "" { + response.BadRequest(c, "Profile ID is required") + return + } + active, err := h.settingService.SetActiveSoraS3Profile(c.Request.Context(), profileID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, toSoraS3ProfileDTO(*active)) +} + +// UpdateSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置接口) +// PUT /api/v1/admin/settings/sora-s3 +func (h *SettingHandler) UpdateSoraS3Settings(c *gin.Context) { + var req UpdateSoraS3SettingsRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + existing, err := h.settingService.GetSoraS3Settings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + + if req.DefaultStorageQuotaBytes < 0 { + req.DefaultStorageQuotaBytes = 0 + } + if err := validateSoraS3RequiredWhenEnabled(req.Enabled, req.Endpoint, req.Bucket, req.AccessKeyID, req.SecretAccessKey, existing.SecretAccessKeyConfigured); err != nil { + response.BadRequest(c, err.Error()) + return + } + + settings := &service.SoraS3Settings{ + Enabled: req.Enabled, + Endpoint: req.Endpoint, + Region: req.Region, + Bucket: req.Bucket, + AccessKeyID: req.AccessKeyID, + SecretAccessKey: req.SecretAccessKey, + Prefix: req.Prefix, + ForcePathStyle: req.ForcePathStyle, + CDNURL: req.CDNURL, + DefaultStorageQuotaBytes: req.DefaultStorageQuotaBytes, + } + if err := h.settingService.SetSoraS3Settings(c.Request.Context(), settings); err != nil { + response.ErrorFrom(c, err) + return + } + + updatedSettings, err := h.settingService.GetSoraS3Settings(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, toSoraS3SettingsDTO(updatedSettings)) +} + +// TestSoraS3Connection 测试 Sora S3 连接(HeadBucket) +// POST /api/v1/admin/settings/sora-s3/test +func (h *SettingHandler) TestSoraS3Connection(c *gin.Context) { + if h.soraS3Storage == nil { + response.Error(c, 500, "S3 存储服务未初始化") + return + } + + var req UpdateSoraS3SettingsRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + if !req.Enabled { + response.BadRequest(c, "S3 未启用,无法测试连接") + return + } + + if req.SecretAccessKey == "" { + if req.ProfileID != "" { + profiles, err := h.settingService.ListSoraS3Profiles(c.Request.Context()) + if err == nil { + profile := findSoraS3ProfileByID(profiles.Items, req.ProfileID) + if profile != nil { + req.SecretAccessKey = profile.SecretAccessKey + } + } + } + if req.SecretAccessKey == "" { + existing, err := h.settingService.GetSoraS3Settings(c.Request.Context()) + if err == nil { + req.SecretAccessKey = existing.SecretAccessKey + } + } + } + + testCfg := &service.SoraS3Settings{ + Enabled: true, + Endpoint: req.Endpoint, + Region: req.Region, + Bucket: req.Bucket, + AccessKeyID: req.AccessKeyID, + SecretAccessKey: req.SecretAccessKey, + Prefix: req.Prefix, + ForcePathStyle: req.ForcePathStyle, + CDNURL: req.CDNURL, + } + if err := h.soraS3Storage.TestConnectionWithSettings(c.Request.Context(), testCfg); err != nil { + response.Error(c, 400, "S3 连接测试失败: "+err.Error()) + return + } + response.Success(c, gin.H{"message": "S3 连接成功"}) +} + // UpdateStreamTimeoutSettingsRequest 更新流超时配置请求 type UpdateStreamTimeoutSettingsRequest struct { Enabled bool `json:"enabled"` diff --git a/backend/internal/handler/admin/usage_cleanup_handler_test.go b/backend/internal/handler/admin/usage_cleanup_handler_test.go index ed1c7cc2..6152d5e9 100644 --- a/backend/internal/handler/admin/usage_cleanup_handler_test.go +++ b/backend/internal/handler/admin/usage_cleanup_handler_test.go @@ -225,6 +225,92 @@ func TestUsageHandlerCreateCleanupTaskInvalidEndDate(t *testing.T) { require.Equal(t, http.StatusBadRequest, recorder.Code) } +func TestUsageHandlerCreateCleanupTaskInvalidRequestType(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 88) + + payload := map[string]any{ + "start_date": "2024-01-01", + "end_date": "2024-01-02", + "timezone": "UTC", + "request_type": "invalid", + } + body, err := json.Marshal(payload) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusBadRequest, recorder.Code) +} + +func TestUsageHandlerCreateCleanupTaskRequestTypePriority(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 99) + + payload := map[string]any{ + "start_date": "2024-01-01", + "end_date": "2024-01-02", + "timezone": "UTC", + "request_type": "ws_v2", + "stream": false, + } + body, err := json.Marshal(payload) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusOK, recorder.Code) + + repo.mu.Lock() + defer repo.mu.Unlock() + require.Len(t, repo.created, 1) + created := repo.created[0] + require.NotNil(t, created.Filters.RequestType) + require.Equal(t, int16(service.RequestTypeWSV2), *created.Filters.RequestType) + require.Nil(t, created.Filters.Stream) +} + +func TestUsageHandlerCreateCleanupTaskWithLegacyStream(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 99) + + payload := map[string]any{ + "start_date": "2024-01-01", + "end_date": "2024-01-02", + "timezone": "UTC", + "stream": true, + } + body, err := json.Marshal(payload) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusOK, recorder.Code) + + repo.mu.Lock() + defer repo.mu.Unlock() + require.Len(t, repo.created, 1) + created := repo.created[0] + require.Nil(t, created.Filters.RequestType) + require.NotNil(t, created.Filters.Stream) + require.True(t, *created.Filters.Stream) +} + func TestUsageHandlerCreateCleanupTaskSuccess(t *testing.T) { repo := &cleanupRepoStub{} cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} diff --git a/backend/internal/handler/admin/usage_handler.go b/backend/internal/handler/admin/usage_handler.go index 5cbf18e6..d0bba773 100644 --- a/backend/internal/handler/admin/usage_handler.go +++ b/backend/internal/handler/admin/usage_handler.go @@ -51,6 +51,7 @@ type CreateUsageCleanupTaskRequest struct { AccountID *int64 `json:"account_id"` GroupID *int64 `json:"group_id"` Model *string `json:"model"` + RequestType *string `json:"request_type"` Stream *bool `json:"stream"` BillingType *int8 `json:"billing_type"` Timezone string `json:"timezone"` @@ -101,8 +102,17 @@ func (h *UsageHandler) List(c *gin.Context) { model := c.Query("model") + var requestType *int16 var stream *bool - if streamStr := c.Query("stream"); streamStr != "" { + if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" { + parsed, err := service.ParseUsageRequestType(requestTypeStr) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + value := int16(parsed) + requestType = &value + } else if streamStr := c.Query("stream"); streamStr != "" { val, err := strconv.ParseBool(streamStr) if err != nil { response.BadRequest(c, "Invalid stream value, use true or false") @@ -152,6 +162,7 @@ func (h *UsageHandler) List(c *gin.Context) { AccountID: accountID, GroupID: groupID, Model: model, + RequestType: requestType, Stream: stream, BillingType: billingType, StartTime: startTime, @@ -214,8 +225,17 @@ func (h *UsageHandler) Stats(c *gin.Context) { model := c.Query("model") + var requestType *int16 var stream *bool - if streamStr := c.Query("stream"); streamStr != "" { + if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" { + parsed, err := service.ParseUsageRequestType(requestTypeStr) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + value := int16(parsed) + requestType = &value + } else if streamStr := c.Query("stream"); streamStr != "" { val, err := strconv.ParseBool(streamStr) if err != nil { response.BadRequest(c, "Invalid stream value, use true or false") @@ -278,6 +298,7 @@ func (h *UsageHandler) Stats(c *gin.Context) { AccountID: accountID, GroupID: groupID, Model: model, + RequestType: requestType, Stream: stream, BillingType: billingType, StartTime: &startTime, @@ -432,6 +453,19 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) { } endTime = endTime.Add(24*time.Hour - time.Nanosecond) + var requestType *int16 + stream := req.Stream + if req.RequestType != nil { + parsed, err := service.ParseUsageRequestType(*req.RequestType) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + value := int16(parsed) + requestType = &value + stream = nil + } + filters := service.UsageCleanupFilters{ StartTime: startTime, EndTime: endTime, @@ -440,7 +474,8 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) { AccountID: req.AccountID, GroupID: req.GroupID, Model: req.Model, - Stream: req.Stream, + RequestType: requestType, + Stream: stream, BillingType: req.BillingType, } @@ -464,9 +499,13 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) { if filters.Model != nil { model = *filters.Model } - var stream any + var streamValue any if filters.Stream != nil { - stream = *filters.Stream + streamValue = *filters.Stream + } + var requestTypeName any + if filters.RequestType != nil { + requestTypeName = service.RequestTypeFromInt16(*filters.RequestType).String() } var billingType any if filters.BillingType != nil { @@ -481,7 +520,7 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) { Body: req, } executeAdminIdempotentJSON(c, "admin.usage.cleanup_tasks.create", idempotencyPayload, service.DefaultWriteIdempotencyTTL(), func(ctx context.Context) (any, error) { - logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q", + logger.LegacyPrintf("handler.admin.usage", "[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v request_type=%v stream=%v billing_type=%v tz=%q", subject.UserID, filters.StartTime.Format(time.RFC3339), filters.EndTime.Format(time.RFC3339), @@ -490,7 +529,8 @@ func (h *UsageHandler) CreateCleanupTask(c *gin.Context) { accountID, groupID, model, - stream, + requestTypeName, + streamValue, billingType, req.Timezone, ) diff --git a/backend/internal/handler/admin/usage_handler_request_type_test.go b/backend/internal/handler/admin/usage_handler_request_type_test.go new file mode 100644 index 00000000..21add574 --- /dev/null +++ b/backend/internal/handler/admin/usage_handler_request_type_test.go @@ -0,0 +1,117 @@ +package admin + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type adminUsageRepoCapture struct { + service.UsageLogRepository + listFilters usagestats.UsageLogFilters + statsFilters usagestats.UsageLogFilters +} + +func (s *adminUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) { + s.listFilters = filters + return []service.UsageLog{}, &pagination.PaginationResult{ + Total: 0, + Page: params.Page, + PageSize: params.PageSize, + Pages: 0, + }, nil +} + +func (s *adminUsageRepoCapture) GetStatsWithFilters(ctx context.Context, filters usagestats.UsageLogFilters) (*usagestats.UsageStats, error) { + s.statsFilters = filters + return &usagestats.UsageStats{}, nil +} + +func newAdminUsageRequestTypeTestRouter(repo *adminUsageRepoCapture) *gin.Engine { + gin.SetMode(gin.TestMode) + usageSvc := service.NewUsageService(repo, nil, nil, nil) + handler := NewUsageHandler(usageSvc, nil, nil, nil) + router := gin.New() + router.GET("/admin/usage", handler.List) + router.GET("/admin/usage/stats", handler.Stats) + return router +} + +func TestAdminUsageListRequestTypePriority(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage?request_type=ws_v2&stream=false", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.NotNil(t, repo.listFilters.RequestType) + require.Equal(t, int16(service.RequestTypeWSV2), *repo.listFilters.RequestType) + require.Nil(t, repo.listFilters.Stream) +} + +func TestAdminUsageListInvalidRequestType(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage?request_type=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestAdminUsageListInvalidStream(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage?stream=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestAdminUsageStatsRequestTypePriority(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?request_type=stream&stream=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.NotNil(t, repo.statsFilters.RequestType) + require.Equal(t, int16(service.RequestTypeStream), *repo.statsFilters.RequestType) + require.Nil(t, repo.statsFilters.Stream) +} + +func TestAdminUsageStatsInvalidRequestType(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?request_type=oops", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestAdminUsageStatsInvalidStream(t *testing.T) { + repo := &adminUsageRepoCapture{} + router := newAdminUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/admin/usage/stats?stream=oops", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index d85202e5..f85c060e 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -34,13 +34,14 @@ func NewUserHandler(adminService service.AdminService, concurrencyService *servi // CreateUserRequest represents admin create user request type CreateUserRequest struct { - Email string `json:"email" binding:"required,email"` - Password string `json:"password" binding:"required,min=6"` - Username string `json:"username"` - Notes string `json:"notes"` - Balance float64 `json:"balance"` - Concurrency int `json:"concurrency"` - AllowedGroups []int64 `json:"allowed_groups"` + Email string `json:"email" binding:"required,email"` + Password string `json:"password" binding:"required,min=6"` + Username string `json:"username"` + Notes string `json:"notes"` + Balance float64 `json:"balance"` + Concurrency int `json:"concurrency"` + AllowedGroups []int64 `json:"allowed_groups"` + SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"` } // UpdateUserRequest represents admin update user request @@ -56,7 +57,8 @@ type UpdateUserRequest struct { AllowedGroups *[]int64 `json:"allowed_groups"` // GroupRates 用户专属分组倍率配置 // map[groupID]*rate,nil 表示删除该分组的专属倍率 - GroupRates map[int64]*float64 `json:"group_rates"` + GroupRates map[int64]*float64 `json:"group_rates"` + SoraStorageQuotaBytes *int64 `json:"sora_storage_quota_bytes"` } // UpdateBalanceRequest represents balance update request @@ -174,13 +176,14 @@ func (h *UserHandler) Create(c *gin.Context) { } user, err := h.adminService.CreateUser(c.Request.Context(), &service.CreateUserInput{ - Email: req.Email, - Password: req.Password, - Username: req.Username, - Notes: req.Notes, - Balance: req.Balance, - Concurrency: req.Concurrency, - AllowedGroups: req.AllowedGroups, + Email: req.Email, + Password: req.Password, + Username: req.Username, + Notes: req.Notes, + Balance: req.Balance, + Concurrency: req.Concurrency, + AllowedGroups: req.AllowedGroups, + SoraStorageQuotaBytes: req.SoraStorageQuotaBytes, }) if err != nil { response.ErrorFrom(c, err) @@ -207,15 +210,16 @@ func (h *UserHandler) Update(c *gin.Context) { // 使用指针类型直接传递,nil 表示未提供该字段 user, err := h.adminService.UpdateUser(c.Request.Context(), userID, &service.UpdateUserInput{ - Email: req.Email, - Password: req.Password, - Username: req.Username, - Notes: req.Notes, - Balance: req.Balance, - Concurrency: req.Concurrency, - Status: req.Status, - AllowedGroups: req.AllowedGroups, - GroupRates: req.GroupRates, + Email: req.Email, + Password: req.Password, + Username: req.Username, + Notes: req.Notes, + Balance: req.Balance, + Concurrency: req.Concurrency, + Status: req.Status, + AllowedGroups: req.AllowedGroups, + GroupRates: req.GroupRates, + SoraStorageQuotaBytes: req.SoraStorageQuotaBytes, }) if err != nil { response.ErrorFrom(c, err) diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index e0078e14..1ffa9d71 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -113,9 +113,8 @@ func (h *AuthHandler) Register(c *gin.Context) { return } - // Turnstile 验证 — 始终执行,防止绕过 - // TODO: 确认前端在提交邮箱验证码注册时也传递了 turnstile_token - if err := h.authService.VerifyTurnstile(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c)); err != nil { + // Turnstile 验证(邮箱验证码注册场景避免重复校验一次性 token) + if err := h.authService.VerifyTurnstileForRegister(c.Request.Context(), req.TurnstileToken, ip.GetClientIP(c), req.VerifyCode); err != nil { response.ErrorFrom(c, err) return } diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 42ff4a84..49c74522 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -59,9 +59,11 @@ func UserFromServiceAdmin(u *service.User) *AdminUser { return nil } return &AdminUser{ - User: *base, - Notes: u.Notes, - GroupRates: u.GroupRates, + User: *base, + Notes: u.Notes, + GroupRates: u.GroupRates, + SoraStorageQuotaBytes: u.SoraStorageQuotaBytes, + SoraStorageUsedBytes: u.SoraStorageUsedBytes, } } @@ -152,6 +154,7 @@ func groupFromServiceBase(g *service.Group) Group { ClaudeCodeOnly: g.ClaudeCodeOnly, FallbackGroupID: g.FallbackGroupID, FallbackGroupIDOnInvalidRequest: g.FallbackGroupIDOnInvalidRequest, + SoraStorageQuotaBytes: g.SoraStorageQuotaBytes, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, } @@ -385,6 +388,8 @@ func AccountSummaryFromService(a *service.Account) *AccountSummary { func usageLogFromServiceUser(l *service.UsageLog) UsageLog { // 普通用户 DTO:严禁包含管理员字段(例如 account_rate_multiplier、ip_address、account)。 + requestType := l.EffectiveRequestType() + stream, openAIWSMode := service.ApplyLegacyRequestFields(requestType, l.Stream, l.OpenAIWSMode) return UsageLog{ ID: l.ID, UserID: l.UserID, @@ -409,7 +414,9 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog { ActualCost: l.ActualCost, RateMultiplier: l.RateMultiplier, BillingType: l.BillingType, - Stream: l.Stream, + RequestType: requestType.String(), + Stream: stream, + OpenAIWSMode: openAIWSMode, DurationMs: l.DurationMs, FirstTokenMs: l.FirstTokenMs, ImageCount: l.ImageCount, @@ -464,6 +471,7 @@ func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTa AccountID: task.Filters.AccountID, GroupID: task.Filters.GroupID, Model: task.Filters.Model, + RequestType: requestTypeStringPtr(task.Filters.RequestType), Stream: task.Filters.Stream, BillingType: task.Filters.BillingType, }, @@ -479,6 +487,14 @@ func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTa } } +func requestTypeStringPtr(requestType *int16) *string { + if requestType == nil { + return nil + } + value := service.RequestTypeFromInt16(*requestType).String() + return &value +} + func SettingFromService(s *service.Setting) *Setting { if s == nil { return nil diff --git a/backend/internal/handler/dto/mappers_usage_test.go b/backend/internal/handler/dto/mappers_usage_test.go new file mode 100644 index 00000000..d716bdc4 --- /dev/null +++ b/backend/internal/handler/dto/mappers_usage_test.go @@ -0,0 +1,73 @@ +package dto + +import ( + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestUsageLogFromService_IncludesOpenAIWSMode(t *testing.T) { + t.Parallel() + + wsLog := &service.UsageLog{ + RequestID: "req_1", + Model: "gpt-5.3-codex", + OpenAIWSMode: true, + } + httpLog := &service.UsageLog{ + RequestID: "resp_1", + Model: "gpt-5.3-codex", + OpenAIWSMode: false, + } + + require.True(t, UsageLogFromService(wsLog).OpenAIWSMode) + require.False(t, UsageLogFromService(httpLog).OpenAIWSMode) + require.True(t, UsageLogFromServiceAdmin(wsLog).OpenAIWSMode) + require.False(t, UsageLogFromServiceAdmin(httpLog).OpenAIWSMode) +} + +func TestUsageLogFromService_PrefersRequestTypeForLegacyFields(t *testing.T) { + t.Parallel() + + log := &service.UsageLog{ + RequestID: "req_2", + Model: "gpt-5.3-codex", + RequestType: service.RequestTypeWSV2, + Stream: false, + OpenAIWSMode: false, + } + + userDTO := UsageLogFromService(log) + adminDTO := UsageLogFromServiceAdmin(log) + + require.Equal(t, "ws_v2", userDTO.RequestType) + require.True(t, userDTO.Stream) + require.True(t, userDTO.OpenAIWSMode) + require.Equal(t, "ws_v2", adminDTO.RequestType) + require.True(t, adminDTO.Stream) + require.True(t, adminDTO.OpenAIWSMode) +} + +func TestUsageCleanupTaskFromService_RequestTypeMapping(t *testing.T) { + t.Parallel() + + requestType := int16(service.RequestTypeStream) + task := &service.UsageCleanupTask{ + ID: 1, + Status: service.UsageCleanupStatusPending, + Filters: service.UsageCleanupFilters{ + RequestType: &requestType, + }, + } + + dtoTask := UsageCleanupTaskFromService(task) + require.NotNil(t, dtoTask) + require.NotNil(t, dtoTask.Filters.RequestType) + require.Equal(t, "stream", *dtoTask.Filters.RequestType) +} + +func TestRequestTypeStringPtrNil(t *testing.T) { + t.Parallel() + require.Nil(t, requestTypeStringPtr(nil)) +} diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index be94bc16..adee53c7 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -37,6 +37,7 @@ type SystemSettings struct { HideCcsImportButton bool `json:"hide_ccs_import_button"` PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` PurchaseSubscriptionURL string `json:"purchase_subscription_url"` + SoraClientEnabled bool `json:"sora_client_enabled"` DefaultConcurrency int `json:"default_concurrency"` DefaultBalance float64 `json:"default_balance"` @@ -79,9 +80,48 @@ type PublicSettings struct { PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` PurchaseSubscriptionURL string `json:"purchase_subscription_url"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` + SoraClientEnabled bool `json:"sora_client_enabled"` Version string `json:"version"` } +// SoraS3Settings Sora S3 存储配置 DTO(响应用,不含敏感字段) +type SoraS3Settings struct { + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + CDNURL string `json:"cdn_url"` + DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` +} + +// SoraS3Profile Sora S3 存储配置项 DTO(响应用,不含敏感字段) +type SoraS3Profile struct { + ProfileID string `json:"profile_id"` + Name string `json:"name"` + IsActive bool `json:"is_active"` + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + CDNURL string `json:"cdn_url"` + DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` + UpdatedAt string `json:"updated_at"` +} + +// ListSoraS3ProfilesResponse Sora S3 配置列表响应 +type ListSoraS3ProfilesResponse struct { + ActiveProfileID string `json:"active_profile_id"` + Items []SoraS3Profile `json:"items"` +} + // StreamTimeoutSettings 流超时处理配置 DTO type StreamTimeoutSettings struct { Enabled bool `json:"enabled"` diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 0cd1b241..73243397 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -26,7 +26,9 @@ type AdminUser struct { Notes string `json:"notes"` // GroupRates 用户专属分组倍率配置 // map[groupID]rateMultiplier - GroupRates map[int64]float64 `json:"group_rates,omitempty"` + GroupRates map[int64]float64 `json:"group_rates,omitempty"` + SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"` + SoraStorageUsedBytes int64 `json:"sora_storage_used_bytes"` } type APIKey struct { @@ -80,6 +82,9 @@ type Group struct { // 无效请求兜底分组 FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request"` + // Sora 存储配额 + SoraStorageQuotaBytes int64 `json:"sora_storage_quota_bytes"` + CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } @@ -278,10 +283,12 @@ type UsageLog struct { ActualCost float64 `json:"actual_cost"` RateMultiplier float64 `json:"rate_multiplier"` - BillingType int8 `json:"billing_type"` - Stream bool `json:"stream"` - DurationMs *int `json:"duration_ms"` - FirstTokenMs *int `json:"first_token_ms"` + BillingType int8 `json:"billing_type"` + RequestType string `json:"request_type"` + Stream bool `json:"stream"` + OpenAIWSMode bool `json:"openai_ws_mode"` + DurationMs *int `json:"duration_ms"` + FirstTokenMs *int `json:"first_token_ms"` // 图片生成字段 ImageCount int `json:"image_count"` @@ -324,6 +331,7 @@ type UsageCleanupFilters struct { AccountID *int64 `json:"account_id,omitempty"` GroupID *int64 `json:"group_id,omitempty"` Model *string `json:"model,omitempty"` + RequestType *string `json:"request_type,omitempty"` Stream *bool `json:"stream,omitempty"` BillingType *int8 `json:"billing_type,omitempty"` } diff --git a/backend/internal/handler/failover_loop.go b/backend/internal/handler/failover_loop.go index 1f8a7e9a..b2583301 100644 --- a/backend/internal/handler/failover_loop.go +++ b/backend/internal/handler/failover_loop.go @@ -2,11 +2,12 @@ package handler import ( "context" - "log" "net/http" "time" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/service" + "go.uber.org/zap" ) // TempUnscheduler 用于 HandleFailoverError 中同账号重试耗尽后的临时封禁。 @@ -78,8 +79,12 @@ func (s *FailoverState) HandleFailoverError( // 同账号重试:对 RetryableOnSameAccount 的临时性错误,先在同一账号上重试 if failoverErr.RetryableOnSameAccount && s.SameAccountRetryCount[accountID] < maxSameAccountRetries { s.SameAccountRetryCount[accountID]++ - log.Printf("Account %d: retryable error %d, same-account retry %d/%d", - accountID, failoverErr.StatusCode, s.SameAccountRetryCount[accountID], maxSameAccountRetries) + logger.FromContext(ctx).Warn("gateway.failover_same_account_retry", + zap.Int64("account_id", accountID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("same_account_retry_count", s.SameAccountRetryCount[accountID]), + zap.Int("same_account_retry_max", maxSameAccountRetries), + ) if !sleepWithContext(ctx, sameAccountRetryDelay) { return FailoverCanceled } @@ -101,8 +106,12 @@ func (s *FailoverState) HandleFailoverError( // 递增切换计数 s.SwitchCount++ - log.Printf("Account %d: upstream error %d, switching account %d/%d", - accountID, failoverErr.StatusCode, s.SwitchCount, s.MaxSwitches) + logger.FromContext(ctx).Warn("gateway.failover_switch_account", + zap.Int64("account_id", accountID), + zap.Int("upstream_status", failoverErr.StatusCode), + zap.Int("switch_count", s.SwitchCount), + zap.Int("max_switches", s.MaxSwitches), + ) // Antigravity 平台换号线性递增延时 if platform == service.PlatformAntigravity { @@ -127,13 +136,18 @@ func (s *FailoverState) HandleSelectionExhausted(ctx context.Context) FailoverAc s.LastFailoverErr.StatusCode == http.StatusServiceUnavailable && s.SwitchCount <= s.MaxSwitches { - log.Printf("Antigravity single-account 503 backoff: waiting %v before retry (attempt %d)", - singleAccountBackoffDelay, s.SwitchCount) + logger.FromContext(ctx).Warn("gateway.failover_single_account_backoff", + zap.Duration("backoff_delay", singleAccountBackoffDelay), + zap.Int("switch_count", s.SwitchCount), + zap.Int("max_switches", s.MaxSwitches), + ) if !sleepWithContext(ctx, singleAccountBackoffDelay) { return FailoverCanceled } - log.Printf("Antigravity single-account 503 retry: clearing failed accounts, retry %d/%d", - s.SwitchCount, s.MaxSwitches) + logger.FromContext(ctx).Warn("gateway.failover_single_account_retry", + zap.Int("switch_count", s.SwitchCount), + zap.Int("max_switches", s.MaxSwitches), + ) s.FailedAccountIDs = make(map[int64]struct{}) return FailoverContinue } diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index fe40e9d2..9262df7e 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -6,9 +6,10 @@ import ( "encoding/json" "errors" "fmt" - "io" "net/http" + "strconv" "strings" + "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/config" @@ -17,6 +18,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/claude" "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" @@ -27,6 +29,10 @@ import ( "go.uber.org/zap" ) +const gatewayCompatibilityMetricsLogInterval = 1024 + +var gatewayCompatibilityMetricsLogCounter atomic.Uint64 + // GatewayHandler handles API gateway requests type GatewayHandler struct { gatewayService *service.GatewayService @@ -109,9 +115,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { zap.Int64("api_key_id", apiKey.ID), zap.Any("group_id", apiKey.GroupID), ) + defer h.maybeLogCompatibilityFallbackMetrics(reqLog) // 读取请求体 - body, err := io.ReadAll(c.Request.Body) + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) if err != nil { if maxErr, ok := extractMaxBytesError(err); ok { h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) @@ -140,16 +147,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 设置 max_tokens=1 + haiku 探测请求标识到 context 中 // 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断 if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) { - ctx := context.WithValue(c.Request.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true) + ctx := service.WithIsMaxTokensOneHaikuRequest(c.Request.Context(), true, h.metadataBridgeEnabled()) c.Request = c.Request.WithContext(ctx) } - // 检查是否为 Claude Code 客户端,设置到 context 中 - SetClaudeCodeClientContext(c, body) + // 检查是否为 Claude Code 客户端,设置到 context 中(复用已解析请求,避免二次反序列化)。 + SetClaudeCodeClientContext(c, body, parsedReq) isClaudeCodeClient := service.IsClaudeCodeClient(c.Request.Context()) // 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用 - c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled)) + c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled())) setOpsRequestContext(c, reqModel, reqStream, body) @@ -247,8 +254,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { if apiKey.GroupID != nil { prefetchedGroupID = *apiKey.GroupID } - ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID) - ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, prefetchedGroupID) + ctx := service.WithPrefetchedStickySession(c.Request.Context(), sessionBoundAccountID, prefetchedGroupID, h.metadataBridgeEnabled()) c.Request = c.Request.WithContext(ctx) } } @@ -261,7 +267,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。 // 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。 if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) { - ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) + ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled()) c.Request = c.Request.WithContext(ctx) } @@ -275,7 +281,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { action := fs.HandleSelectionExhausted(c.Request.Context()) switch action { case FailoverContinue: - ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) + ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled()) c.Request = c.Request.WithContext(ctx) continue case FailoverCanceled: @@ -364,7 +370,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { var result *service.ForwardResult requestCtx := c.Request.Context() if fs.SwitchCount > 0 { - requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount) + requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled()) } if account.Platform == service.PlatformAntigravity { result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession) @@ -439,7 +445,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。 // 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。 if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), currentAPIKey.GroupID) { - ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) + ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled()) c.Request = c.Request.WithContext(ctx) } @@ -458,7 +464,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { action := fs.HandleSelectionExhausted(c.Request.Context()) switch action { case FailoverContinue: - ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) + ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled()) c.Request = c.Request.WithContext(ctx) continue case FailoverCanceled: @@ -547,7 +553,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { var result *service.ForwardResult requestCtx := c.Request.Context() if fs.SwitchCount > 0 { - requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount) + requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled()) } if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey { result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession) @@ -956,20 +962,8 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e // Stream already started, send error as SSE event then close flusher, ok := c.Writer.(http.Flusher) if ok { - // Send error event in SSE format with proper JSON marshaling - errorData := map[string]any{ - "type": "error", - "error": map[string]string{ - "type": errType, - "message": message, - }, - } - jsonBytes, err := json.Marshal(errorData) - if err != nil { - _ = c.Error(err) - return - } - errorEvent := fmt.Sprintf("data: %s\n\n", string(jsonBytes)) + // SSE 错误事件固定 schema,使用 Quote 直拼可避免额外 Marshal 分配。 + errorEvent := `data: {"type":"error","error":{"type":` + strconv.Quote(errType) + `,"message":` + strconv.Quote(message) + `}}` + "\n\n" if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil { _ = c.Error(err) } @@ -1024,9 +1018,10 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { zap.Int64("api_key_id", apiKey.ID), zap.Any("group_id", apiKey.GroupID), ) + defer h.maybeLogCompatibilityFallbackMetrics(reqLog) // 读取请求体 - body, err := io.ReadAll(c.Request.Body) + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) if err != nil { if maxErr, ok := extractMaxBytesError(err); ok { h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) @@ -1041,9 +1036,6 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { return } - // 检查是否为 Claude Code 客户端,设置到 context 中 - SetClaudeCodeClientContext(c, body) - setOpsRequestContext(c, "", false, body) parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic) @@ -1051,9 +1043,11 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body") return } + // count_tokens 走 messages 严格校验时,复用已解析请求,避免二次反序列化。 + SetClaudeCodeClientContext(c, body, parsedReq) reqLog = reqLog.With(zap.String("model", parsedReq.Model), zap.Bool("stream", parsedReq.Stream)) // 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用 - c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled)) + c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled())) // 验证 model 必填 if parsedReq.Model == "" { @@ -1217,24 +1211,8 @@ func sendMockInterceptStream(c *gin.Context, model string, interceptType Interce textDeltas = []string{"New", " Conversation"} } - // Build message_start event with proper JSON marshaling - messageStart := map[string]any{ - "type": "message_start", - "message": map[string]any{ - "id": msgID, - "type": "message", - "role": "assistant", - "model": model, - "content": []any{}, - "stop_reason": nil, - "stop_sequence": nil, - "usage": map[string]int{ - "input_tokens": 10, - "output_tokens": 0, - }, - }, - } - messageStartJSON, _ := json.Marshal(messageStart) + // Build message_start event with fixed schema. + messageStartJSON := `{"type":"message_start","message":{"id":` + strconv.Quote(msgID) + `,"type":"message","role":"assistant","model":` + strconv.Quote(model) + `,"content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":0}}}` // Build events events := []string{ @@ -1244,31 +1222,12 @@ func sendMockInterceptStream(c *gin.Context, model string, interceptType Interce // Add text deltas for _, text := range textDeltas { - delta := map[string]any{ - "type": "content_block_delta", - "index": 0, - "delta": map[string]string{ - "type": "text_delta", - "text": text, - }, - } - deltaJSON, _ := json.Marshal(delta) + deltaJSON := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":` + strconv.Quote(text) + `}}` events = append(events, `event: content_block_delta`+"\n"+`data: `+string(deltaJSON)) } // Add final events - messageDelta := map[string]any{ - "type": "message_delta", - "delta": map[string]any{ - "stop_reason": "end_turn", - "stop_sequence": nil, - }, - "usage": map[string]int{ - "input_tokens": 10, - "output_tokens": outputTokens, - }, - } - messageDeltaJSON, _ := json.Marshal(messageDelta) + messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":10,"output_tokens":` + strconv.Itoa(outputTokens) + `}}` events = append(events, `event: content_block_stop`+"\n"+`data: {"index":0,"type":"content_block_stop"}`, @@ -1366,6 +1325,30 @@ func billingErrorDetails(err error) (status int, code, message string) { return http.StatusForbidden, "billing_error", msg } +func (h *GatewayHandler) metadataBridgeEnabled() bool { + if h == nil || h.cfg == nil { + return true + } + return h.cfg.Gateway.OpenAIWS.MetadataBridgeEnabled +} + +func (h *GatewayHandler) maybeLogCompatibilityFallbackMetrics(reqLog *zap.Logger) { + if reqLog == nil { + return + } + if gatewayCompatibilityMetricsLogCounter.Add(1)%gatewayCompatibilityMetricsLogInterval != 0 { + return + } + metrics := service.SnapshotOpenAICompatibilityFallbackMetrics() + reqLog.Info("gateway.compatibility_fallback_metrics", + zap.Int64("session_hash_legacy_read_fallback_total", metrics.SessionHashLegacyReadFallbackTotal), + zap.Int64("session_hash_legacy_read_fallback_hit", metrics.SessionHashLegacyReadFallbackHit), + zap.Int64("session_hash_legacy_dual_write_total", metrics.SessionHashLegacyDualWriteTotal), + zap.Float64("session_hash_legacy_read_hit_rate", metrics.SessionHashLegacyReadHitRate), + zap.Int64("metadata_legacy_fallback_total", metrics.MetadataLegacyFallbackTotal), + ) +} + func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) { if task == nil { return @@ -1377,5 +1360,13 @@ func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) { // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() + defer func() { + if recovered := recover(); recovered != nil { + logger.L().With( + zap.String("component", "handler.gateway.messages"), + zap.Any("panic", recovered), + ).Error("gateway.usage_record_task_panic_recovered") + } + }() task(ctx) } diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go index 15d85949..76141521 100644 --- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go +++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go @@ -119,6 +119,13 @@ func (f *fakeConcurrencyCache) GetAccountsLoadBatch(context.Context, []service.A func (f *fakeConcurrencyCache) GetUsersLoadBatch(context.Context, []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) { return map[int64]*service.UserLoadInfo{}, nil } +func (f *fakeConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, accountIDs []int64) (map[int64]int, error) { + result := make(map[int64]int, len(accountIDs)) + for _, id := range accountIDs { + result[id] = 0 + } + return result, nil +} func (f *fakeConcurrencyCache) CleanupExpiredAccountSlots(context.Context, int64) error { return nil } func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*service.Account) (*GatewayHandler, func()) { diff --git a/backend/internal/handler/gateway_helper.go b/backend/internal/handler/gateway_helper.go index efff7997..ea8a5f1a 100644 --- a/backend/internal/handler/gateway_helper.go +++ b/backend/internal/handler/gateway_helper.go @@ -18,12 +18,17 @@ import ( // claudeCodeValidator is a singleton validator for Claude Code client detection var claudeCodeValidator = service.NewClaudeCodeValidator() +const claudeCodeParsedRequestContextKey = "claude_code_parsed_request" + // SetClaudeCodeClientContext 检查请求是否来自 Claude Code 客户端,并设置到 context 中 // 返回更新后的 context -func SetClaudeCodeClientContext(c *gin.Context, body []byte) { +func SetClaudeCodeClientContext(c *gin.Context, body []byte, parsedReq *service.ParsedRequest) { if c == nil || c.Request == nil { return } + if parsedReq != nil { + c.Set(claudeCodeParsedRequestContextKey, parsedReq) + } // Fast path:非 Claude CLI UA 直接判定 false,避免热路径二次 JSON 反序列化。 if !claudeCodeValidator.ValidateUserAgent(c.GetHeader("User-Agent")) { ctx := service.SetClaudeCodeClient(c.Request.Context(), false) @@ -37,8 +42,11 @@ func SetClaudeCodeClientContext(c *gin.Context, body []byte) { isClaudeCode = true } else { // 仅在确认为 Claude CLI 且 messages 路径时再做 body 解析。 - var bodyMap map[string]any - if len(body) > 0 { + bodyMap := claudeCodeBodyMapFromParsedRequest(parsedReq) + if bodyMap == nil { + bodyMap = claudeCodeBodyMapFromContextCache(c) + } + if bodyMap == nil && len(body) > 0 { _ = json.Unmarshal(body, &bodyMap) } isClaudeCode = claudeCodeValidator.Validate(c.Request, bodyMap) @@ -49,6 +57,42 @@ func SetClaudeCodeClientContext(c *gin.Context, body []byte) { c.Request = c.Request.WithContext(ctx) } +func claudeCodeBodyMapFromParsedRequest(parsedReq *service.ParsedRequest) map[string]any { + if parsedReq == nil { + return nil + } + bodyMap := map[string]any{ + "model": parsedReq.Model, + } + if parsedReq.System != nil || parsedReq.HasSystem { + bodyMap["system"] = parsedReq.System + } + if parsedReq.MetadataUserID != "" { + bodyMap["metadata"] = map[string]any{"user_id": parsedReq.MetadataUserID} + } + return bodyMap +} + +func claudeCodeBodyMapFromContextCache(c *gin.Context) map[string]any { + if c == nil { + return nil + } + if cached, ok := c.Get(service.OpenAIParsedRequestBodyKey); ok { + if bodyMap, ok := cached.(map[string]any); ok { + return bodyMap + } + } + if cached, ok := c.Get(claudeCodeParsedRequestContextKey); ok { + switch v := cached.(type) { + case *service.ParsedRequest: + return claudeCodeBodyMapFromParsedRequest(v) + case service.ParsedRequest: + return claudeCodeBodyMapFromParsedRequest(&v) + } + } + return nil +} + // 并发槽位等待相关常量 // // 性能优化说明: diff --git a/backend/internal/handler/gateway_helper_fastpath_test.go b/backend/internal/handler/gateway_helper_fastpath_test.go index 3e6c376b..31d489f0 100644 --- a/backend/internal/handler/gateway_helper_fastpath_test.go +++ b/backend/internal/handler/gateway_helper_fastpath_test.go @@ -33,6 +33,14 @@ func (m *concurrencyCacheMock) GetAccountConcurrency(ctx context.Context, accoun return 0, nil } +func (m *concurrencyCacheMock) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { + result := make(map[int64]int, len(accountIDs)) + for _, accountID := range accountIDs { + result[accountID] = 0 + } + return result, nil +} + func (m *concurrencyCacheMock) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { return true, nil } diff --git a/backend/internal/handler/gateway_helper_hotpath_test.go b/backend/internal/handler/gateway_helper_hotpath_test.go index 3fdf1bfc..f8f7eaca 100644 --- a/backend/internal/handler/gateway_helper_hotpath_test.go +++ b/backend/internal/handler/gateway_helper_hotpath_test.go @@ -49,6 +49,14 @@ func (s *helperConcurrencyCacheStub) GetAccountConcurrency(ctx context.Context, return 0, nil } +func (s *helperConcurrencyCacheStub) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { + out := make(map[int64]int, len(accountIDs)) + for _, accountID := range accountIDs { + out[accountID] = 0 + } + return out, nil +} + func (s *helperConcurrencyCacheStub) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { return true, nil } @@ -133,7 +141,7 @@ func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) { c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") c.Request.Header.Set("User-Agent", "curl/8.6.0") - SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON()) + SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON(), nil) require.False(t, service.IsClaudeCodeClient(c.Request.Context())) }) @@ -141,7 +149,7 @@ func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) { c, _ := newHelperTestContext(http.MethodGet, "/v1/models") c.Request.Header.Set("User-Agent", "claude-cli/1.0.1") - SetClaudeCodeClientContext(c, nil) + SetClaudeCodeClientContext(c, nil, nil) require.True(t, service.IsClaudeCodeClient(c.Request.Context())) }) @@ -152,7 +160,7 @@ func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) { c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24") c.Request.Header.Set("anthropic-version", "2023-06-01") - SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON()) + SetClaudeCodeClientContext(c, validClaudeCodeBodyJSON(), nil) require.True(t, service.IsClaudeCodeClient(c.Request.Context())) }) @@ -160,11 +168,51 @@ func TestSetClaudeCodeClientContext_FastPathAndStrictPath(t *testing.T) { c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") c.Request.Header.Set("User-Agent", "claude-cli/1.0.1") // 缺少严格校验所需 header + body 字段 - SetClaudeCodeClientContext(c, []byte(`{"model":"x"}`)) + SetClaudeCodeClientContext(c, []byte(`{"model":"x"}`), nil) require.False(t, service.IsClaudeCodeClient(c.Request.Context())) }) } +func TestSetClaudeCodeClientContext_ReuseParsedRequestAndContextCache(t *testing.T) { + t.Run("reuse parsed request without body unmarshal", func(t *testing.T) { + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + c.Request.Header.Set("User-Agent", "claude-cli/1.0.1") + c.Request.Header.Set("X-App", "claude-code") + c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24") + c.Request.Header.Set("anthropic-version", "2023-06-01") + + parsedReq := &service.ParsedRequest{ + Model: "claude-3-5-sonnet-20241022", + System: []any{ + map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."}, + }, + MetadataUserID: "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123", + } + + // body 非法 JSON,如果函数复用 parsedReq 成功则仍应判定为 Claude Code。 + SetClaudeCodeClientContext(c, []byte(`{invalid`), parsedReq) + require.True(t, service.IsClaudeCodeClient(c.Request.Context())) + }) + + t.Run("reuse context cache without body unmarshal", func(t *testing.T) { + c, _ := newHelperTestContext(http.MethodPost, "/v1/messages") + c.Request.Header.Set("User-Agent", "claude-cli/1.0.1") + c.Request.Header.Set("X-App", "claude-code") + c.Request.Header.Set("anthropic-beta", "message-batches-2024-09-24") + c.Request.Header.Set("anthropic-version", "2023-06-01") + c.Set(service.OpenAIParsedRequestBodyKey, map[string]any{ + "model": "claude-3-5-sonnet-20241022", + "system": []any{ + map[string]any{"text": "You are Claude Code, Anthropic's official CLI for Claude."}, + }, + "metadata": map[string]any{"user_id": "user_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_account__session_abc-123"}, + }) + + SetClaudeCodeClientContext(c, []byte(`{invalid`), nil) + require.True(t, service.IsClaudeCodeClient(c.Request.Context())) + }) +} + func TestWaitForSlotWithPingTimeout_AccountAndUserAcquire(t *testing.T) { cache := &helperConcurrencyCacheStub{ accountSeq: []bool{false, true}, diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 2da0570b..50af9c8f 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -7,16 +7,15 @@ import ( "encoding/hex" "encoding/json" "errors" - "io" "net/http" "regexp" "strings" "github.com/Wei-Shaw/sub2api/internal/domain" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" - "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/gemini" "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/Wei-Shaw/sub2api/internal/server/middleware" @@ -168,7 +167,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { stream := action == "streamGenerateContent" reqLog = reqLog.With(zap.String("model", modelName), zap.String("action", action), zap.Bool("stream", stream)) - body, err := io.ReadAll(c.Request.Body) + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) if err != nil { if maxErr, ok := extractMaxBytesError(err); ok { googleError(c, http.StatusRequestEntityTooLarge, buildBodyTooLargeMessage(maxErr.Limit)) @@ -268,8 +267,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { if apiKey.GroupID != nil { prefetchedGroupID = *apiKey.GroupID } - ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID) - ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, prefetchedGroupID) + ctx := service.WithPrefetchedStickySession(c.Request.Context(), sessionBoundAccountID, prefetchedGroupID, h.metadataBridgeEnabled()) c.Request = c.Request.WithContext(ctx) } } @@ -349,7 +347,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。 // 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。 if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) { - ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) + ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled()) c.Request = c.Request.WithContext(ctx) } @@ -363,7 +361,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { action := fs.HandleSelectionExhausted(c.Request.Context()) switch action { case FailoverContinue: - ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true) + ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled()) c.Request = c.Request.WithContext(ctx) continue case FailoverCanceled: @@ -456,7 +454,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { var result *service.ForwardResult requestCtx := c.Request.Context() if fs.SwitchCount > 0 { - requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount) + requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled()) } if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey { result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, modelName, action, stream, body, hasBoundSession) diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index b999180b..bbf4be4b 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -11,6 +11,7 @@ type AdminHandlers struct { Group *admin.GroupHandler Account *admin.AccountHandler Announcement *admin.AnnouncementHandler + DataManagement *admin.DataManagementHandler OAuth *admin.OAuthHandler OpenAIOAuth *admin.OpenAIOAuthHandler GeminiOAuth *admin.GeminiOAuthHandler @@ -40,6 +41,7 @@ type Handlers struct { Gateway *GatewayHandler OpenAIGateway *OpenAIGatewayHandler SoraGateway *SoraGatewayHandler + SoraClient *SoraClientHandler Setting *SettingHandler Totp *TotpHandler } diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 50af684d..4bbd17ba 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -5,17 +5,20 @@ import ( "encoding/json" "errors" "fmt" - "io" "net/http" + "runtime/debug" + "strconv" "strings" "time" "github.com/Wei-Shaw/sub2api/internal/config" + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" + coderws "github.com/coder/websocket" "github.com/gin-gonic/gin" "github.com/tidwall/gjson" "go.uber.org/zap" @@ -64,6 +67,11 @@ func NewOpenAIGatewayHandler( // Responses handles OpenAI Responses API endpoint // POST /openai/v1/responses func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { + // 局部兜底:确保该 handler 内部任何 panic 都不会击穿到进程级。 + streamStarted := false + defer h.recoverResponsesPanic(c, &streamStarted) + setOpenAIClientTransportHTTP(c) + requestStart := time.Now() // Get apiKey and user from context (set by ApiKeyAuth middleware) @@ -85,9 +93,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { zap.Int64("api_key_id", apiKey.ID), zap.Any("group_id", apiKey.GroupID), ) + if !h.ensureResponsesDependencies(c, reqLog) { + return + } // Read request body - body, err := io.ReadAll(c.Request.Body) + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) if err != nil { if maxErr, ok := extractMaxBytesError(err); ok { h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) @@ -125,43 +136,30 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { } reqStream := streamResult.Bool() reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream)) + previousResponseID := strings.TrimSpace(gjson.GetBytes(body, "previous_response_id").String()) + if previousResponseID != "" { + previousResponseIDKind := service.ClassifyOpenAIPreviousResponseIDKind(previousResponseID) + reqLog = reqLog.With( + zap.Bool("has_previous_response_id", true), + zap.String("previous_response_id_kind", previousResponseIDKind), + zap.Int("previous_response_id_len", len(previousResponseID)), + ) + if previousResponseIDKind == service.OpenAIPreviousResponseIDKindMessageID { + reqLog.Warn("openai.request_validation_failed", + zap.String("reason", "previous_response_id_looks_like_message_id"), + ) + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "previous_response_id must be a response.id (resp_*), not a message id") + return + } + } setOpsRequestContext(c, reqModel, reqStream, body) // 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。 - // 要求 previous_response_id,或 input 内存在带 call_id 的 tool_call/function_call, - // 或带 id 且与 call_id 匹配的 item_reference。 - // 此路径需要遍历 input 数组做 call_id 关联检查,保留 Unmarshal - if gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() { - var reqBody map[string]any - if err := json.Unmarshal(body, &reqBody); err == nil { - c.Set(service.OpenAIParsedRequestBodyKey, reqBody) - if service.HasFunctionCallOutput(reqBody) { - previousResponseID, _ := reqBody["previous_response_id"].(string) - if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) { - if service.HasFunctionCallOutputMissingCallID(reqBody) { - reqLog.Warn("openai.request_validation_failed", - zap.String("reason", "function_call_output_missing_call_id"), - ) - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id") - return - } - callIDs := service.FunctionCallOutputCallIDs(reqBody) - if !service.HasItemReferenceForCallIDs(reqBody, callIDs) { - reqLog.Warn("openai.request_validation_failed", - zap.String("reason", "function_call_output_missing_item_reference"), - ) - h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id") - return - } - } - } - } + if !h.validateFunctionCallOutputRequest(c, body, reqLog) { + return } - // Track if we've started streaming (for error handling) - streamStarted := false - // 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。 if h.errorPassthroughService != nil { service.BindErrorPassthroughService(c, h.errorPassthroughService) @@ -173,51 +171,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds()) routingStart := time.Now() - // 0. 先尝试直接抢占用户槽位(快速路径) - userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(c.Request.Context(), subject.UserID, subject.Concurrency) - if err != nil { - reqLog.Warn("openai.user_slot_acquire_failed", zap.Error(err)) - h.handleConcurrencyError(c, err, "user", streamStarted) + userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog) + if !acquired { return } - - waitCounted := false - if !userAcquired { - // 仅在抢槽失败时才进入等待队列,减少常态请求 Redis 写入。 - maxWait := service.CalculateMaxWait(subject.Concurrency) - canWait, waitErr := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) - if waitErr != nil { - reqLog.Warn("openai.user_wait_counter_increment_failed", zap.Error(waitErr)) - // 按现有降级语义:等待计数异常时放行后续抢槽流程 - } else if !canWait { - reqLog.Info("openai.user_wait_queue_full", zap.Int("max_wait", maxWait)) - h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") - return - } - if waitErr == nil && canWait { - waitCounted = true - } - defer func() { - if waitCounted { - h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) - } - }() - - userReleaseFunc, err = h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted) - if err != nil { - reqLog.Warn("openai.user_slot_acquire_failed_after_wait", zap.Error(err)) - h.handleConcurrencyError(c, err, "user", streamStarted) - return - } - } - - // 用户槽位已获取:退出等待队列计数。 - if waitCounted { - h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) - waitCounted = false - } // 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏 - userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc) if userReleaseFunc != nil { defer userReleaseFunc() } @@ -241,7 +199,15 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { for { // Select account supporting the requested model reqLog.Debug("openai.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs))) - selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs) + selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( + c.Request.Context(), + apiKey.GroupID, + previousResponseID, + sessionHash, + reqModel, + failedAccountIDs, + service.OpenAIUpstreamTransportAny, + ) if err != nil { reqLog.Warn("openai.account_select_failed", zap.Error(err), @@ -258,80 +224,30 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { } return } + if selection == nil || selection.Account == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) + return + } + if previousResponseID != "" && selection != nil && selection.Account != nil { + reqLog.Debug("openai.account_selected_with_previous_response_id", zap.Int64("account_id", selection.Account.ID)) + } + reqLog.Debug("openai.account_schedule_decision", + zap.String("layer", scheduleDecision.Layer), + zap.Bool("sticky_previous_hit", scheduleDecision.StickyPreviousHit), + zap.Bool("sticky_session_hit", scheduleDecision.StickySessionHit), + zap.Int("candidate_count", scheduleDecision.CandidateCount), + zap.Int("top_k", scheduleDecision.TopK), + zap.Int64("latency_ms", scheduleDecision.LatencyMs), + zap.Float64("load_skew", scheduleDecision.LoadSkew), + ) account := selection.Account reqLog.Debug("openai.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name)) setOpsSelectedAccount(c, account.ID, account.Platform) - // 3. Acquire account concurrency slot - accountReleaseFunc := selection.ReleaseFunc - if !selection.Acquired { - if selection.WaitPlan == nil { - h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) - return - } - - // 先快速尝试一次账号槽位,命中则跳过等待计数写入。 - fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot( - c.Request.Context(), - account.ID, - selection.WaitPlan.MaxConcurrency, - ) - if err != nil { - reqLog.Warn("openai.account_slot_quick_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) - h.handleConcurrencyError(c, err, "account", streamStarted) - return - } - if fastAcquired { - accountReleaseFunc = fastReleaseFunc - if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil { - reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) - } - } else { - accountWaitCounted := false - canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) - if err != nil { - reqLog.Warn("openai.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err)) - } else if !canWait { - reqLog.Info("openai.account_wait_queue_full", - zap.Int64("account_id", account.ID), - zap.Int("max_waiting", selection.WaitPlan.MaxWaiting), - ) - h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) - return - } - if err == nil && canWait { - accountWaitCounted = true - } - releaseWait := func() { - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - accountWaitCounted = false - } - } - - accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( - c, - account.ID, - selection.WaitPlan.MaxConcurrency, - selection.WaitPlan.Timeout, - reqStream, - &streamStarted, - ) - if err != nil { - reqLog.Warn("openai.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) - releaseWait() - h.handleConcurrencyError(c, err, "account", streamStarted) - return - } - // Slot acquired: no longer waiting in queue. - releaseWait() - if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil { - reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) - } - } + accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog) + if !acquired { + return } - // 账号槽位/等待计数需要在超时或断开时安全回收 - accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) // Forward request service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) @@ -353,6 +269,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + h.gatewayService.RecordOpenAIAccountSwitch() failedAccountIDs[account.ID] = struct{}{} lastFailoverErr = failoverErr if switchCount >= maxAccountSwitches { @@ -368,14 +286,25 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { ) continue } + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) wroteFallback := h.ensureForwardErrorResponse(c, streamStarted) - reqLog.Error("openai.forward_failed", + fields := []zap.Field{ zap.Int64("account_id", account.ID), zap.Bool("fallback_error_response_written", wroteFallback), zap.Error(err), - ) + } + if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) { + reqLog.Warn("openai.forward_failed", fields...) + return + } + reqLog.Error("openai.forward_failed", fields...) return } + if result != nil { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) + } else { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil) + } // 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context) userAgent := c.GetHeader("User-Agent") @@ -411,6 +340,525 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { } } +func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context, body []byte, reqLog *zap.Logger) bool { + if !gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() { + return true + } + + var reqBody map[string]any + if err := json.Unmarshal(body, &reqBody); err != nil { + // 保持原有容错语义:解析失败时跳过预校验,沿用后续上游校验结果。 + return true + } + + c.Set(service.OpenAIParsedRequestBodyKey, reqBody) + validation := service.ValidateFunctionCallOutputContext(reqBody) + if !validation.HasFunctionCallOutput { + return true + } + + previousResponseID, _ := reqBody["previous_response_id"].(string) + if strings.TrimSpace(previousResponseID) != "" || validation.HasToolCallContext { + return true + } + + if validation.HasFunctionCallOutputMissingCallID { + reqLog.Warn("openai.request_validation_failed", + zap.String("reason", "function_call_output_missing_call_id"), + ) + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id") + return false + } + if validation.HasItemReferenceForAllCallIDs { + return true + } + + reqLog.Warn("openai.request_validation_failed", + zap.String("reason", "function_call_output_missing_item_reference"), + ) + h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id") + return false +} + +func (h *OpenAIGatewayHandler) acquireResponsesUserSlot( + c *gin.Context, + userID int64, + userConcurrency int, + reqStream bool, + streamStarted *bool, + reqLog *zap.Logger, +) (func(), bool) { + ctx := c.Request.Context() + userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, userID, userConcurrency) + if err != nil { + reqLog.Warn("openai.user_slot_acquire_failed", zap.Error(err)) + h.handleConcurrencyError(c, err, "user", *streamStarted) + return nil, false + } + if userAcquired { + return wrapReleaseOnDone(ctx, userReleaseFunc), true + } + + maxWait := service.CalculateMaxWait(userConcurrency) + canWait, waitErr := h.concurrencyHelper.IncrementWaitCount(ctx, userID, maxWait) + if waitErr != nil { + reqLog.Warn("openai.user_wait_counter_increment_failed", zap.Error(waitErr)) + // 按现有降级语义:等待计数异常时放行后续抢槽流程 + } else if !canWait { + reqLog.Info("openai.user_wait_queue_full", zap.Int("max_wait", maxWait)) + h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") + return nil, false + } + + waitCounted := waitErr == nil && canWait + defer func() { + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(ctx, userID) + } + }() + + userReleaseFunc, err = h.concurrencyHelper.AcquireUserSlotWithWait(c, userID, userConcurrency, reqStream, streamStarted) + if err != nil { + reqLog.Warn("openai.user_slot_acquire_failed_after_wait", zap.Error(err)) + h.handleConcurrencyError(c, err, "user", *streamStarted) + return nil, false + } + + // 槽位获取成功后,立刻退出等待计数。 + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(ctx, userID) + waitCounted = false + } + return wrapReleaseOnDone(ctx, userReleaseFunc), true +} + +func (h *OpenAIGatewayHandler) acquireResponsesAccountSlot( + c *gin.Context, + groupID *int64, + sessionHash string, + selection *service.AccountSelectionResult, + reqStream bool, + streamStarted *bool, + reqLog *zap.Logger, +) (func(), bool) { + if selection == nil || selection.Account == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted) + return nil, false + } + + ctx := c.Request.Context() + account := selection.Account + if selection.Acquired { + return wrapReleaseOnDone(ctx, selection.ReleaseFunc), true + } + if selection.WaitPlan == nil { + h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted) + return nil, false + } + + fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot( + ctx, + account.ID, + selection.WaitPlan.MaxConcurrency, + ) + if err != nil { + reqLog.Warn("openai.account_slot_quick_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + h.handleConcurrencyError(c, err, "account", *streamStarted) + return nil, false + } + if fastAcquired { + if err := h.gatewayService.BindStickySession(ctx, groupID, sessionHash, account.ID); err != nil { + reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + } + return wrapReleaseOnDone(ctx, fastReleaseFunc), true + } + + canWait, waitErr := h.concurrencyHelper.IncrementAccountWaitCount(ctx, account.ID, selection.WaitPlan.MaxWaiting) + if waitErr != nil { + reqLog.Warn("openai.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(waitErr)) + } else if !canWait { + reqLog.Info("openai.account_wait_queue_full", + zap.Int64("account_id", account.ID), + zap.Int("max_waiting", selection.WaitPlan.MaxWaiting), + ) + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", *streamStarted) + return nil, false + } + + accountWaitCounted := waitErr == nil && canWait + releaseWait := func() { + if accountWaitCounted { + h.concurrencyHelper.DecrementAccountWaitCount(ctx, account.ID) + accountWaitCounted = false + } + } + defer releaseWait() + + accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( + c, + account.ID, + selection.WaitPlan.MaxConcurrency, + selection.WaitPlan.Timeout, + reqStream, + streamStarted, + ) + if err != nil { + reqLog.Warn("openai.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + h.handleConcurrencyError(c, err, "account", *streamStarted) + return nil, false + } + + // Slot acquired: no longer waiting in queue. + releaseWait() + if err := h.gatewayService.BindStickySession(ctx, groupID, sessionHash, account.ID); err != nil { + reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + } + return wrapReleaseOnDone(ctx, accountReleaseFunc), true +} + +// ResponsesWebSocket handles OpenAI Responses API WebSocket ingress endpoint +// GET /openai/v1/responses (Upgrade: websocket) +func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { + if !isOpenAIWSUpgradeRequest(c.Request) { + h.errorResponse(c, http.StatusUpgradeRequired, "invalid_request_error", "WebSocket upgrade required (Upgrade: websocket)") + return + } + setOpenAIClientTransportWS(c) + + apiKey, ok := middleware2.GetAPIKeyFromContext(c) + if !ok { + h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key") + return + } + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found") + return + } + + reqLog := requestLogger( + c, + "handler.openai_gateway.responses_ws", + zap.Int64("user_id", subject.UserID), + zap.Int64("api_key_id", apiKey.ID), + zap.Any("group_id", apiKey.GroupID), + zap.Bool("openai_ws_mode", true), + ) + if !h.ensureResponsesDependencies(c, reqLog) { + return + } + reqLog.Info("openai.websocket_ingress_started") + clientIP := ip.GetClientIP(c) + userAgent := strings.TrimSpace(c.GetHeader("User-Agent")) + + wsConn, err := coderws.Accept(c.Writer, c.Request, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + reqLog.Warn("openai.websocket_accept_failed", + zap.Error(err), + zap.String("client_ip", clientIP), + zap.String("request_user_agent", userAgent), + zap.String("upgrade_header", strings.TrimSpace(c.GetHeader("Upgrade"))), + zap.String("connection_header", strings.TrimSpace(c.GetHeader("Connection"))), + zap.String("sec_websocket_version", strings.TrimSpace(c.GetHeader("Sec-WebSocket-Version"))), + zap.Bool("has_sec_websocket_key", strings.TrimSpace(c.GetHeader("Sec-WebSocket-Key")) != ""), + ) + return + } + defer func() { + _ = wsConn.CloseNow() + }() + wsConn.SetReadLimit(16 * 1024 * 1024) + + ctx := c.Request.Context() + readCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + msgType, firstMessage, err := wsConn.Read(readCtx) + cancel() + if err != nil { + closeStatus, closeReason := summarizeWSCloseErrorForLog(err) + reqLog.Warn("openai.websocket_read_first_message_failed", + zap.Error(err), + zap.String("client_ip", clientIP), + zap.String("close_status", closeStatus), + zap.String("close_reason", closeReason), + zap.Duration("read_timeout", 30*time.Second), + ) + closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "missing first response.create message") + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "unsupported websocket message type") + return + } + if !gjson.ValidBytes(firstMessage) { + closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "invalid JSON payload") + return + } + + reqModel := strings.TrimSpace(gjson.GetBytes(firstMessage, "model").String()) + if reqModel == "" { + closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "model is required in first response.create payload") + return + } + previousResponseID := strings.TrimSpace(gjson.GetBytes(firstMessage, "previous_response_id").String()) + previousResponseIDKind := service.ClassifyOpenAIPreviousResponseIDKind(previousResponseID) + if previousResponseID != "" && previousResponseIDKind == service.OpenAIPreviousResponseIDKindMessageID { + closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "previous_response_id must be a response.id (resp_*), not a message id") + return + } + reqLog = reqLog.With( + zap.Bool("ws_ingress", true), + zap.String("model", reqModel), + zap.Bool("has_previous_response_id", previousResponseID != ""), + zap.String("previous_response_id_kind", previousResponseIDKind), + ) + setOpsRequestContext(c, reqModel, true, firstMessage) + + var currentUserRelease func() + var currentAccountRelease func() + releaseTurnSlots := func() { + if currentAccountRelease != nil { + currentAccountRelease() + currentAccountRelease = nil + } + if currentUserRelease != nil { + currentUserRelease() + currentUserRelease = nil + } + } + // 必须尽早注册,确保任何 early return 都能释放已获取的并发槽位。 + defer releaseTurnSlots() + + userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency) + if err != nil { + reqLog.Warn("openai.websocket_user_slot_acquire_failed", zap.Error(err)) + closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire user concurrency slot") + return + } + if !userAcquired { + closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "too many concurrent requests, please retry later") + return + } + currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc) + + subscription, _ := middleware2.GetSubscriptionFromContext(c) + if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription); err != nil { + reqLog.Info("openai.websocket_billing_eligibility_check_failed", zap.Error(err)) + closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "billing check failed") + return + } + + sessionHash := h.gatewayService.GenerateSessionHashWithFallback( + c, + firstMessage, + openAIWSIngressFallbackSessionSeed(subject.UserID, apiKey.ID, apiKey.GroupID), + ) + selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler( + ctx, + apiKey.GroupID, + previousResponseID, + sessionHash, + reqModel, + nil, + service.OpenAIUpstreamTransportResponsesWebsocketV2, + ) + if err != nil { + reqLog.Warn("openai.websocket_account_select_failed", zap.Error(err)) + closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account") + return + } + if selection == nil || selection.Account == nil { + closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account") + return + } + + account := selection.Account + accountMaxConcurrency := account.Concurrency + if selection.WaitPlan != nil && selection.WaitPlan.MaxConcurrency > 0 { + accountMaxConcurrency = selection.WaitPlan.MaxConcurrency + } + accountReleaseFunc := selection.ReleaseFunc + if !selection.Acquired { + if selection.WaitPlan == nil { + closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later") + return + } + fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot( + ctx, + account.ID, + selection.WaitPlan.MaxConcurrency, + ) + if err != nil { + reqLog.Warn("openai.websocket_account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire account concurrency slot") + return + } + if !fastAcquired { + closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later") + return + } + accountReleaseFunc = fastReleaseFunc + } + currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc) + if err := h.gatewayService.BindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil { + reqLog.Warn("openai.websocket_bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + } + + token, _, err := h.gatewayService.GetAccessToken(ctx, account) + if err != nil { + reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err)) + closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token") + return + } + + reqLog.Debug("openai.websocket_account_selected", + zap.Int64("account_id", account.ID), + zap.String("account_name", account.Name), + zap.String("schedule_layer", scheduleDecision.Layer), + zap.Int("candidate_count", scheduleDecision.CandidateCount), + ) + + hooks := &service.OpenAIWSIngressHooks{ + BeforeTurn: func(turn int) error { + if turn == 1 { + return nil + } + // 防御式清理:避免异常路径下旧槽位覆盖导致泄漏。 + releaseTurnSlots() + // 非首轮 turn 需要重新抢占并发槽位,避免长连接空闲占槽。 + userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency) + if err != nil { + return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire user concurrency slot", err) + } + if !userAcquired { + return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "too many concurrent requests, please retry later", nil) + } + accountReleaseFunc, accountAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(ctx, account.ID, accountMaxConcurrency) + if err != nil { + if userReleaseFunc != nil { + userReleaseFunc() + } + return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire account concurrency slot", err) + } + if !accountAcquired { + if userReleaseFunc != nil { + userReleaseFunc() + } + return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "account is busy, please retry later", nil) + } + currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc) + currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc) + return nil + }, + AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) { + releaseTurnSlots() + if turnErr != nil || result == nil { + return + } + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs) + h.submitUsageRecordTask(func(taskCtx context.Context) { + if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{ + Result: result, + APIKey: apiKey, + User: apiKey.User, + Account: account, + Subscription: subscription, + UserAgent: userAgent, + IPAddress: clientIP, + APIKeyService: h.apiKeyService, + }); err != nil { + reqLog.Error("openai.websocket_record_usage_failed", + zap.Int64("account_id", account.ID), + zap.String("request_id", result.RequestID), + zap.Error(err), + ) + } + }) + }, + } + + if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, firstMessage, hooks); err != nil { + h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil) + closeStatus, closeReason := summarizeWSCloseErrorForLog(err) + reqLog.Warn("openai.websocket_proxy_failed", + zap.Int64("account_id", account.ID), + zap.Error(err), + zap.String("close_status", closeStatus), + zap.String("close_reason", closeReason), + ) + var closeErr *service.OpenAIWSClientCloseError + if errors.As(err, &closeErr) { + closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason()) + return + } + closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed") + return + } + reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID)) +} + +func (h *OpenAIGatewayHandler) recoverResponsesPanic(c *gin.Context, streamStarted *bool) { + recovered := recover() + if recovered == nil { + return + } + + started := false + if streamStarted != nil { + started = *streamStarted + } + wroteFallback := h.ensureForwardErrorResponse(c, started) + requestLogger(c, "handler.openai_gateway.responses").Error( + "openai.responses_panic_recovered", + zap.Bool("fallback_error_response_written", wroteFallback), + zap.Any("panic", recovered), + zap.ByteString("stack", debug.Stack()), + ) +} + +func (h *OpenAIGatewayHandler) ensureResponsesDependencies(c *gin.Context, reqLog *zap.Logger) bool { + missing := h.missingResponsesDependencies() + if len(missing) == 0 { + return true + } + + if reqLog == nil { + reqLog = requestLogger(c, "handler.openai_gateway.responses") + } + reqLog.Error("openai.handler_dependencies_missing", zap.Strings("missing_dependencies", missing)) + + if c != nil && c.Writer != nil && !c.Writer.Written() { + c.JSON(http.StatusServiceUnavailable, gin.H{ + "error": gin.H{ + "type": "api_error", + "message": "Service temporarily unavailable", + }, + }) + } + return false +} + +func (h *OpenAIGatewayHandler) missingResponsesDependencies() []string { + missing := make([]string, 0, 5) + if h == nil { + return append(missing, "handler") + } + if h.gatewayService == nil { + missing = append(missing, "gatewayService") + } + if h.billingCacheService == nil { + missing = append(missing, "billingCacheService") + } + if h.apiKeyService == nil { + missing = append(missing, "apiKeyService") + } + if h.concurrencyHelper == nil || h.concurrencyHelper.concurrencyService == nil { + missing = append(missing, "concurrencyHelper") + } + return missing +} + func getContextInt64(c *gin.Context, key string) (int64, bool) { if c == nil || key == "" { return 0, false @@ -444,6 +892,14 @@ func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTas // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() + defer func() { + if recovered := recover(); recovered != nil { + logger.L().With( + zap.String("component", "handler.openai_gateway.responses"), + zap.Any("panic", recovered), + ).Error("openai.usage_record_task_panic_recovered") + } + }() task(ctx) } @@ -515,19 +971,8 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status // Stream already started, send error as SSE event then close flusher, ok := c.Writer.(http.Flusher) if ok { - // Send error event in OpenAI SSE format with proper JSON marshaling - errorData := map[string]any{ - "error": map[string]string{ - "type": errType, - "message": message, - }, - } - jsonBytes, err := json.Marshal(errorData) - if err != nil { - _ = c.Error(err) - return - } - errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes)) + // SSE 错误事件固定 schema,使用 Quote 直拼可避免额外 Marshal 分配。 + errorEvent := "event: error\ndata: " + `{"error":{"type":` + strconv.Quote(errType) + `,"message":` + strconv.Quote(message) + `}}` + "\n\n" if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil { _ = c.Error(err) } @@ -549,6 +994,16 @@ func (h *OpenAIGatewayHandler) ensureForwardErrorResponse(c *gin.Context, stream return true } +func shouldLogOpenAIForwardFailureAsWarn(c *gin.Context, wroteFallback bool) bool { + if wroteFallback { + return false + } + if c == nil || c.Writer == nil { + return false + } + return c.Writer.Written() +} + // errorResponse returns OpenAI API format error response func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) { c.JSON(status, gin.H{ @@ -558,3 +1013,61 @@ func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType }, }) } + +func setOpenAIClientTransportHTTP(c *gin.Context) { + service.SetOpenAIClientTransport(c, service.OpenAIClientTransportHTTP) +} + +func setOpenAIClientTransportWS(c *gin.Context) { + service.SetOpenAIClientTransport(c, service.OpenAIClientTransportWS) +} + +func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64) string { + gid := int64(0) + if groupID != nil { + gid = *groupID + } + return fmt.Sprintf("openai_ws_ingress:%d:%d:%d", gid, userID, apiKeyID) +} + +func isOpenAIWSUpgradeRequest(r *http.Request) bool { + if r == nil { + return false + } + if !strings.EqualFold(strings.TrimSpace(r.Header.Get("Upgrade")), "websocket") { + return false + } + return strings.Contains(strings.ToLower(strings.TrimSpace(r.Header.Get("Connection"))), "upgrade") +} + +func closeOpenAIClientWS(conn *coderws.Conn, status coderws.StatusCode, reason string) { + if conn == nil { + return + } + reason = strings.TrimSpace(reason) + if len(reason) > 120 { + reason = reason[:120] + } + _ = conn.Close(status, reason) + _ = conn.CloseNow() +} + +func summarizeWSCloseErrorForLog(err error) (string, string) { + if err == nil { + return "-", "-" + } + statusCode := coderws.CloseStatus(err) + if statusCode == -1 { + return "-", "-" + } + closeStatus := fmt.Sprintf("%d(%s)", int(statusCode), statusCode.String()) + closeReason := "-" + var closeErr coderws.CloseError + if errors.As(err, &closeErr) { + reason := strings.TrimSpace(closeErr.Reason) + if reason != "" { + closeReason = reason + } + } + return closeStatus, closeReason +} diff --git a/backend/internal/handler/openai_gateway_handler_test.go b/backend/internal/handler/openai_gateway_handler_test.go index 1ca52c2d..a26b3a0c 100644 --- a/backend/internal/handler/openai_gateway_handler_test.go +++ b/backend/internal/handler/openai_gateway_handler_test.go @@ -1,12 +1,19 @@ package handler import ( + "context" "encoding/json" + "errors" "net/http" "net/http/httptest" "strings" "testing" + "time" + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + coderws "github.com/coder/websocket" "github.com/gin-gonic/gin" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -105,6 +112,27 @@ func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) { assert.Equal(t, "test error", errorObj["message"]) } +func TestReadRequestBodyWithPrealloc(t *testing.T) { + payload := `{"model":"gpt-5","input":"hello"}` + req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(payload)) + req.ContentLength = int64(len(payload)) + + body, err := pkghttputil.ReadRequestBodyWithPrealloc(req) + require.NoError(t, err) + require.Equal(t, payload, string(body)) +} + +func TestReadRequestBodyWithPrealloc_MaxBytesError(t *testing.T) { + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(strings.Repeat("x", 8))) + req.Body = http.MaxBytesReader(rec, req.Body, 4) + + _, err := pkghttputil.ReadRequestBodyWithPrealloc(req) + require.Error(t, err) + var maxErr *http.MaxBytesError + require.ErrorAs(t, err, &maxErr) +} + func TestOpenAIEnsureForwardErrorResponse_WritesFallbackWhenNotWritten(t *testing.T) { gin.SetMode(gin.TestMode) w := httptest.NewRecorder() @@ -141,6 +169,387 @@ func TestOpenAIEnsureForwardErrorResponse_DoesNotOverrideWrittenResponse(t *test assert.Equal(t, "already written", w.Body.String()) } +func TestShouldLogOpenAIForwardFailureAsWarn(t *testing.T) { + gin.SetMode(gin.TestMode) + + t.Run("fallback_written_should_not_downgrade", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + require.False(t, shouldLogOpenAIForwardFailureAsWarn(c, true)) + }) + + t.Run("context_nil_should_not_downgrade", func(t *testing.T) { + require.False(t, shouldLogOpenAIForwardFailureAsWarn(nil, false)) + }) + + t.Run("response_not_written_should_not_downgrade", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + require.False(t, shouldLogOpenAIForwardFailureAsWarn(c, false)) + }) + + t.Run("response_already_written_should_downgrade", func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + c.String(http.StatusForbidden, "already written") + require.True(t, shouldLogOpenAIForwardFailureAsWarn(c, false)) + }) +} + +func TestOpenAIRecoverResponsesPanic_WritesFallbackResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + + h := &OpenAIGatewayHandler{} + streamStarted := false + require.NotPanics(t, func() { + func() { + defer h.recoverResponsesPanic(c, &streamStarted) + panic("test panic") + }() + }) + + require.Equal(t, http.StatusBadGateway, w.Code) + + var parsed map[string]any + err := json.Unmarshal(w.Body.Bytes(), &parsed) + require.NoError(t, err) + + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "upstream_error", errorObj["type"]) + assert.Equal(t, "Upstream request failed", errorObj["message"]) +} + +func TestOpenAIRecoverResponsesPanic_NoPanicNoWrite(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + + h := &OpenAIGatewayHandler{} + streamStarted := false + require.NotPanics(t, func() { + func() { + defer h.recoverResponsesPanic(c, &streamStarted) + }() + }) + + require.False(t, c.Writer.Written()) + assert.Equal(t, "", w.Body.String()) +} + +func TestOpenAIRecoverResponsesPanic_DoesNotOverrideWrittenResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + c.String(http.StatusTeapot, "already written") + + h := &OpenAIGatewayHandler{} + streamStarted := false + require.NotPanics(t, func() { + func() { + defer h.recoverResponsesPanic(c, &streamStarted) + panic("test panic") + }() + }) + + require.Equal(t, http.StatusTeapot, w.Code) + assert.Equal(t, "already written", w.Body.String()) +} + +func TestOpenAIMissingResponsesDependencies(t *testing.T) { + t.Run("nil_handler", func(t *testing.T) { + var h *OpenAIGatewayHandler + require.Equal(t, []string{"handler"}, h.missingResponsesDependencies()) + }) + + t.Run("all_dependencies_missing", func(t *testing.T) { + h := &OpenAIGatewayHandler{} + require.Equal(t, + []string{"gatewayService", "billingCacheService", "apiKeyService", "concurrencyHelper"}, + h.missingResponsesDependencies(), + ) + }) + + t.Run("all_dependencies_present", func(t *testing.T) { + h := &OpenAIGatewayHandler{ + gatewayService: &service.OpenAIGatewayService{}, + billingCacheService: &service.BillingCacheService{}, + apiKeyService: &service.APIKeyService{}, + concurrencyHelper: &ConcurrencyHelper{ + concurrencyService: &service.ConcurrencyService{}, + }, + } + require.Empty(t, h.missingResponsesDependencies()) + }) +} + +func TestOpenAIEnsureResponsesDependencies(t *testing.T) { + t.Run("missing_dependencies_returns_503", func(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + + h := &OpenAIGatewayHandler{} + ok := h.ensureResponsesDependencies(c, nil) + + require.False(t, ok) + require.Equal(t, http.StatusServiceUnavailable, w.Code) + var parsed map[string]any + err := json.Unmarshal(w.Body.Bytes(), &parsed) + require.NoError(t, err) + errorObj, exists := parsed["error"].(map[string]any) + require.True(t, exists) + assert.Equal(t, "api_error", errorObj["type"]) + assert.Equal(t, "Service temporarily unavailable", errorObj["message"]) + }) + + t.Run("already_written_response_not_overridden", func(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + c.String(http.StatusTeapot, "already written") + + h := &OpenAIGatewayHandler{} + ok := h.ensureResponsesDependencies(c, nil) + + require.False(t, ok) + require.Equal(t, http.StatusTeapot, w.Code) + assert.Equal(t, "already written", w.Body.String()) + }) + + t.Run("dependencies_ready_returns_true_and_no_write", func(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + + h := &OpenAIGatewayHandler{ + gatewayService: &service.OpenAIGatewayService{}, + billingCacheService: &service.BillingCacheService{}, + apiKeyService: &service.APIKeyService{}, + concurrencyHelper: &ConcurrencyHelper{ + concurrencyService: &service.ConcurrencyService{}, + }, + } + ok := h.ensureResponsesDependencies(c, nil) + + require.True(t, ok) + require.False(t, c.Writer.Written()) + assert.Equal(t, "", w.Body.String()) + }) +} + +func TestOpenAIResponses_MissingDependencies_ReturnsServiceUnavailable(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(`{"model":"gpt-5","stream":false}`)) + c.Request.Header.Set("Content-Type", "application/json") + + groupID := int64(2) + c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{ + ID: 10, + GroupID: &groupID, + }) + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ + UserID: 1, + Concurrency: 1, + }) + + // 故意使用未初始化依赖,验证快速失败而不是崩溃。 + h := &OpenAIGatewayHandler{} + require.NotPanics(t, func() { + h.Responses(c) + }) + + require.Equal(t, http.StatusServiceUnavailable, w.Code) + + var parsed map[string]any + err := json.Unmarshal(w.Body.Bytes(), &parsed) + require.NoError(t, err) + + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + assert.Equal(t, "api_error", errorObj["type"]) + assert.Equal(t, "Service temporarily unavailable", errorObj["message"]) +} + +func TestOpenAIResponses_SetsClientTransportHTTP(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader(`{"model":"gpt-5"}`)) + c.Request.Header.Set("Content-Type", "application/json") + + h := &OpenAIGatewayHandler{} + h.Responses(c) + + require.Equal(t, http.StatusUnauthorized, w.Code) + require.Equal(t, service.OpenAIClientTransportHTTP, service.GetOpenAIClientTransport(c)) +} + +func TestOpenAIResponses_RejectsMessageIDAsPreviousResponseID(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", strings.NewReader( + `{"model":"gpt-5.1","stream":false,"previous_response_id":"msg_123456","input":[{"type":"input_text","text":"hello"}]}`, + )) + c.Request.Header.Set("Content-Type", "application/json") + + groupID := int64(2) + c.Set(string(middleware.ContextKeyAPIKey), &service.APIKey{ + ID: 101, + GroupID: &groupID, + User: &service.User{ID: 1}, + }) + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ + UserID: 1, + Concurrency: 1, + }) + + h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil) + h.Responses(c) + + require.Equal(t, http.StatusBadRequest, w.Code) + require.Contains(t, w.Body.String(), "previous_response_id must be a response.id") +} + +func TestOpenAIResponsesWebSocket_SetsClientTransportWSWhenUpgradeValid(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/openai/v1/responses", nil) + c.Request.Header.Set("Upgrade", "websocket") + c.Request.Header.Set("Connection", "Upgrade") + + h := &OpenAIGatewayHandler{} + h.ResponsesWebSocket(c) + + require.Equal(t, http.StatusUnauthorized, w.Code) + require.Equal(t, service.OpenAIClientTransportWS, service.GetOpenAIClientTransport(c)) +} + +func TestOpenAIResponsesWebSocket_InvalidUpgradeDoesNotSetTransport(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/openai/v1/responses", nil) + + h := &OpenAIGatewayHandler{} + h.ResponsesWebSocket(c) + + require.Equal(t, http.StatusUpgradeRequired, w.Code) + require.Equal(t, service.OpenAIClientTransportUnknown, service.GetOpenAIClientTransport(c)) +} + +func TestOpenAIResponsesWebSocket_RejectsMessageIDAsPreviousResponseID(t *testing.T) { + gin.SetMode(gin.TestMode) + + h := newOpenAIHandlerForPreviousResponseIDValidation(t, nil) + wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1}) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte( + `{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"msg_abc123"}`, + )) + cancelWrite() + require.NoError(t, err) + + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + _, _, err = clientConn.Read(readCtx) + cancelRead() + require.Error(t, err) + var closeErr coderws.CloseError + require.ErrorAs(t, err, &closeErr) + require.Equal(t, coderws.StatusPolicyViolation, closeErr.Code) + require.Contains(t, strings.ToLower(closeErr.Reason), "previous_response_id") +} + +func TestOpenAIResponsesWebSocket_PreviousResponseIDKindLoggedBeforeAcquireFailure(t *testing.T) { + gin.SetMode(gin.TestMode) + + cache := &concurrencyCacheMock{ + acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { + return false, errors.New("user slot unavailable") + }, + } + h := newOpenAIHandlerForPreviousResponseIDValidation(t, cache) + wsServer := newOpenAIWSHandlerTestServer(t, h, middleware.AuthSubject{UserID: 1, Concurrency: 1}) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http")+"/openai/v1/responses", nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte( + `{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_123"}`, + )) + cancelWrite() + require.NoError(t, err) + + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + _, _, err = clientConn.Read(readCtx) + cancelRead() + require.Error(t, err) + var closeErr coderws.CloseError + require.ErrorAs(t, err, &closeErr) + require.Equal(t, coderws.StatusInternalError, closeErr.Code) + require.Contains(t, strings.ToLower(closeErr.Reason), "failed to acquire user concurrency slot") +} + +func TestSetOpenAIClientTransportHTTP(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + setOpenAIClientTransportHTTP(c) + require.Equal(t, service.OpenAIClientTransportHTTP, service.GetOpenAIClientTransport(c)) +} + +func TestSetOpenAIClientTransportWS(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + setOpenAIClientTransportWS(c) + require.Equal(t, service.OpenAIClientTransportWS, service.GetOpenAIClientTransport(c)) +} + // TestOpenAIHandler_GjsonExtraction 验证 gjson 从请求体中提取 model/stream 的正确性 func TestOpenAIHandler_GjsonExtraction(t *testing.T) { tests := []struct { @@ -228,3 +637,41 @@ func TestOpenAIHandler_InstructionsInjection(t *testing.T) { require.NoError(t, setErr) require.True(t, gjson.ValidBytes(result)) } + +func newOpenAIHandlerForPreviousResponseIDValidation(t *testing.T, cache *concurrencyCacheMock) *OpenAIGatewayHandler { + t.Helper() + if cache == nil { + cache = &concurrencyCacheMock{ + acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { + return true, nil + }, + acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + return true, nil + }, + } + } + return &OpenAIGatewayHandler{ + gatewayService: &service.OpenAIGatewayService{}, + billingCacheService: &service.BillingCacheService{}, + apiKeyService: &service.APIKeyService{}, + concurrencyHelper: NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second), + } +} + +func newOpenAIWSHandlerTestServer(t *testing.T, h *OpenAIGatewayHandler, subject middleware.AuthSubject) *httptest.Server { + t.Helper() + groupID := int64(2) + apiKey := &service.APIKey{ + ID: 101, + GroupID: &groupID, + User: &service.User{ID: subject.UserID}, + } + router := gin.New() + router.Use(func(c *gin.Context) { + c.Set(string(middleware.ContextKeyAPIKey), apiKey) + c.Set(string(middleware.ContextKeyUser), subject) + c.Next() + }) + router.GET("/openai/v1/responses", h.ResponsesWebSocket) + return httptest.NewServer(router) +} diff --git a/backend/internal/handler/ops_error_logger.go b/backend/internal/handler/ops_error_logger.go index ab9a2167..6fbf7952 100644 --- a/backend/internal/handler/ops_error_logger.go +++ b/backend/internal/handler/ops_error_logger.go @@ -311,6 +311,35 @@ type opsCaptureWriter struct { buf bytes.Buffer } +const opsCaptureWriterLimit = 64 * 1024 + +var opsCaptureWriterPool = sync.Pool{ + New: func() any { + return &opsCaptureWriter{limit: opsCaptureWriterLimit} + }, +} + +func acquireOpsCaptureWriter(rw gin.ResponseWriter) *opsCaptureWriter { + w, ok := opsCaptureWriterPool.Get().(*opsCaptureWriter) + if !ok || w == nil { + w = &opsCaptureWriter{} + } + w.ResponseWriter = rw + w.limit = opsCaptureWriterLimit + w.buf.Reset() + return w +} + +func releaseOpsCaptureWriter(w *opsCaptureWriter) { + if w == nil { + return + } + w.ResponseWriter = nil + w.limit = opsCaptureWriterLimit + w.buf.Reset() + opsCaptureWriterPool.Put(w) +} + func (w *opsCaptureWriter) Write(b []byte) (int, error) { if w.Status() >= 400 && w.limit > 0 && w.buf.Len() < w.limit { remaining := w.limit - w.buf.Len() @@ -342,7 +371,16 @@ func (w *opsCaptureWriter) WriteString(s string) (int, error) { // - Streaming errors after the response has started (SSE) may still need explicit logging. func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { return func(c *gin.Context) { - w := &opsCaptureWriter{ResponseWriter: c.Writer, limit: 64 * 1024} + originalWriter := c.Writer + w := acquireOpsCaptureWriter(originalWriter) + defer func() { + // Restore the original writer before returning so outer middlewares + // don't observe a pooled wrapper that has been released. + if c.Writer == w { + c.Writer = originalWriter + } + releaseOpsCaptureWriter(w) + }() c.Writer = w c.Next() diff --git a/backend/internal/handler/ops_error_logger_test.go b/backend/internal/handler/ops_error_logger_test.go index a11fa1f2..731b36ab 100644 --- a/backend/internal/handler/ops_error_logger_test.go +++ b/backend/internal/handler/ops_error_logger_test.go @@ -6,6 +6,7 @@ import ( "sync" "testing" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" @@ -173,3 +174,43 @@ func TestEnqueueOpsErrorLog_EarlyReturnBranches(t *testing.T) { enqueueOpsErrorLog(ops, entry) require.Equal(t, int64(0), OpsErrorLogEnqueuedTotal()) } + +func TestOpsCaptureWriterPool_ResetOnRelease(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodGet, "/test", nil) + + writer := acquireOpsCaptureWriter(c.Writer) + require.NotNil(t, writer) + _, err := writer.buf.WriteString("temp-error-body") + require.NoError(t, err) + + releaseOpsCaptureWriter(writer) + + reused := acquireOpsCaptureWriter(c.Writer) + defer releaseOpsCaptureWriter(reused) + + require.Zero(t, reused.buf.Len(), "writer should be reset before reuse") +} + +func TestOpsErrorLoggerMiddleware_DoesNotBreakOuterMiddlewares(t *testing.T) { + gin.SetMode(gin.TestMode) + + r := gin.New() + r.Use(middleware2.Recovery()) + r.Use(middleware2.RequestLogger()) + r.Use(middleware2.Logger()) + r.GET("/v1/messages", OpsErrorLoggerMiddleware(nil), func(c *gin.Context) { + c.Status(http.StatusNoContent) + }) + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/v1/messages", nil) + + require.NotPanics(t, func() { + r.ServeHTTP(rec, req) + }) + require.Equal(t, http.StatusNoContent, rec.Code) +} diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index 2029f116..2141a9ee 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -51,6 +51,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, + SoraClientEnabled: settings.SoraClientEnabled, Version: h.version, }) } diff --git a/backend/internal/handler/sora_client_handler.go b/backend/internal/handler/sora_client_handler.go new file mode 100644 index 00000000..80acc833 --- /dev/null +++ b/backend/internal/handler/sora_client_handler.go @@ -0,0 +1,979 @@ +package handler + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +const ( + // 上游模型缓存 TTL + modelCacheTTL = 1 * time.Hour // 上游获取成功 + modelCacheFailedTTL = 2 * time.Minute // 上游获取失败(降级到本地) +) + +// SoraClientHandler 处理 Sora 客户端 API 请求。 +type SoraClientHandler struct { + genService *service.SoraGenerationService + quotaService *service.SoraQuotaService + s3Storage *service.SoraS3Storage + soraGatewayService *service.SoraGatewayService + gatewayService *service.GatewayService + mediaStorage *service.SoraMediaStorage + apiKeyService *service.APIKeyService + + // 上游模型缓存 + modelCacheMu sync.RWMutex + cachedFamilies []service.SoraModelFamily + modelCacheTime time.Time + modelCacheUpstream bool // 是否来自上游(决定 TTL) +} + +// NewSoraClientHandler 创建 Sora 客户端 Handler。 +func NewSoraClientHandler( + genService *service.SoraGenerationService, + quotaService *service.SoraQuotaService, + s3Storage *service.SoraS3Storage, + soraGatewayService *service.SoraGatewayService, + gatewayService *service.GatewayService, + mediaStorage *service.SoraMediaStorage, + apiKeyService *service.APIKeyService, +) *SoraClientHandler { + return &SoraClientHandler{ + genService: genService, + quotaService: quotaService, + s3Storage: s3Storage, + soraGatewayService: soraGatewayService, + gatewayService: gatewayService, + mediaStorage: mediaStorage, + apiKeyService: apiKeyService, + } +} + +// GenerateRequest 生成请求。 +type GenerateRequest struct { + Model string `json:"model" binding:"required"` + Prompt string `json:"prompt" binding:"required"` + MediaType string `json:"media_type"` // video / image,默认 video + VideoCount int `json:"video_count,omitempty"` // 视频数量(1-3) + ImageInput string `json:"image_input,omitempty"` // 参考图(base64 或 URL) + APIKeyID *int64 `json:"api_key_id,omitempty"` // 前端传递的 API Key ID +} + +// Generate 异步生成 — 创建 pending 记录后立即返回。 +// POST /api/v1/sora/generate +func (h *SoraClientHandler) Generate(c *gin.Context) { + userID := getUserIDFromContext(c) + if userID == 0 { + response.Error(c, http.StatusUnauthorized, "未登录") + return + } + + var req GenerateRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.Error(c, http.StatusBadRequest, "参数错误: "+err.Error()) + return + } + + if req.MediaType == "" { + req.MediaType = "video" + } + req.VideoCount = normalizeVideoCount(req.MediaType, req.VideoCount) + + // 并发数检查(最多 3 个) + activeCount, err := h.genService.CountActiveByUser(c.Request.Context(), userID) + if err != nil { + response.ErrorFrom(c, err) + return + } + if activeCount >= 3 { + response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个") + return + } + + // 配额检查(粗略检查,实际文件大小在上传后才知道) + if h.quotaService != nil { + if err := h.quotaService.CheckQuota(c.Request.Context(), userID, 0); err != nil { + var quotaErr *service.QuotaExceededError + if errors.As(err, "aErr) { + response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间") + return + } + response.Error(c, http.StatusForbidden, err.Error()) + return + } + } + + // 获取 API Key ID 和 Group ID + var apiKeyID *int64 + var groupID *int64 + + if req.APIKeyID != nil && h.apiKeyService != nil { + // 前端传递了 api_key_id,需要校验 + apiKey, err := h.apiKeyService.GetByID(c.Request.Context(), *req.APIKeyID) + if err != nil { + response.Error(c, http.StatusBadRequest, "API Key 不存在") + return + } + if apiKey.UserID != userID { + response.Error(c, http.StatusForbidden, "API Key 不属于当前用户") + return + } + if apiKey.Status != service.StatusAPIKeyActive { + response.Error(c, http.StatusForbidden, "API Key 不可用") + return + } + apiKeyID = &apiKey.ID + groupID = apiKey.GroupID + } else if id, ok := c.Get("api_key_id"); ok { + // 兼容 API Key 认证路径(/sora/v1/ 网关路由) + if v, ok := id.(int64); ok { + apiKeyID = &v + } + } + + gen, err := h.genService.CreatePending(c.Request.Context(), userID, apiKeyID, req.Model, req.Prompt, req.MediaType) + if err != nil { + if errors.Is(err, service.ErrSoraGenerationConcurrencyLimit) { + response.Error(c, http.StatusTooManyRequests, "同时进行中的任务不能超过 3 个") + return + } + response.ErrorFrom(c, err) + return + } + + // 启动后台异步生成 goroutine + go h.processGeneration(gen.ID, userID, groupID, req.Model, req.Prompt, req.MediaType, req.ImageInput, req.VideoCount) + + response.Success(c, gin.H{ + "generation_id": gen.ID, + "status": gen.Status, + }) +} + +// processGeneration 后台异步执行 Sora 生成任务。 +// 流程:选择账号 → Forward → 提取媒体 URL → 三层降级存储(S3 → 本地 → 上游)→ 更新记录。 +func (h *SoraClientHandler) processGeneration(genID int64, userID int64, groupID *int64, model, prompt, mediaType, imageInput string, videoCount int) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() + + // 标记为生成中 + if err := h.genService.MarkGenerating(ctx, genID, ""); err != nil { + if errors.Is(err, service.ErrSoraGenerationStateConflict) { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务状态已变化,跳过生成 id=%d", genID) + return + } + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记生成中失败 id=%d err=%v", genID, err) + return + } + + logger.LegacyPrintf( + "handler.sora_client", + "[SoraClient] 开始生成 id=%d user=%d group=%d model=%s media_type=%s video_count=%d has_image=%v prompt_len=%d", + genID, + userID, + groupIDForLog(groupID), + model, + mediaType, + videoCount, + strings.TrimSpace(imageInput) != "", + len(strings.TrimSpace(prompt)), + ) + + // 有 groupID 时由分组决定平台,无 groupID 时用 ForcePlatform 兜底 + if groupID == nil { + ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora) + } + + if h.gatewayService == nil { + _ = h.genService.MarkFailed(ctx, genID, "内部错误: gatewayService 未初始化") + return + } + + // 选择 Sora 账号 + account, err := h.gatewayService.SelectAccountForModel(ctx, groupID, "", model) + if err != nil { + logger.LegacyPrintf( + "handler.sora_client", + "[SoraClient] 选择账号失败 id=%d user=%d group=%d model=%s err=%v", + genID, + userID, + groupIDForLog(groupID), + model, + err, + ) + _ = h.genService.MarkFailed(ctx, genID, "选择账号失败: "+err.Error()) + return + } + logger.LegacyPrintf( + "handler.sora_client", + "[SoraClient] 选中账号 id=%d user=%d group=%d model=%s account_id=%d account_name=%s platform=%s type=%s", + genID, + userID, + groupIDForLog(groupID), + model, + account.ID, + account.Name, + account.Platform, + account.Type, + ) + + // 构建 chat completions 请求体(非流式) + body := buildAsyncRequestBody(model, prompt, imageInput, normalizeVideoCount(mediaType, videoCount)) + + if h.soraGatewayService == nil { + _ = h.genService.MarkFailed(ctx, genID, "内部错误: soraGatewayService 未初始化") + return + } + + // 创建 mock gin 上下文用于 Forward(捕获响应以提取媒体 URL) + recorder := httptest.NewRecorder() + mockGinCtx, _ := gin.CreateTestContext(recorder) + mockGinCtx.Request, _ = http.NewRequest("POST", "/", nil) + + // 调用 Forward(非流式) + result, err := h.soraGatewayService.Forward(ctx, mockGinCtx, account, body, false) + if err != nil { + logger.LegacyPrintf( + "handler.sora_client", + "[SoraClient] Forward失败 id=%d account_id=%d model=%s status=%d body=%s err=%v", + genID, + account.ID, + model, + recorder.Code, + trimForLog(recorder.Body.String(), 400), + err, + ) + // 检查是否已取消 + gen, _ := h.genService.GetByID(ctx, genID, userID) + if gen != nil && gen.Status == service.SoraGenStatusCancelled { + return + } + _ = h.genService.MarkFailed(ctx, genID, "生成失败: "+err.Error()) + return + } + + // 提取媒体 URL(优先从 ForwardResult,其次从响应体解析) + mediaURL, mediaURLs := extractMediaURLsFromResult(result, recorder) + if mediaURL == "" { + logger.LegacyPrintf( + "handler.sora_client", + "[SoraClient] 未提取到媒体URL id=%d account_id=%d model=%s status=%d body=%s", + genID, + account.ID, + model, + recorder.Code, + trimForLog(recorder.Body.String(), 400), + ) + _ = h.genService.MarkFailed(ctx, genID, "未获取到媒体 URL") + return + } + + // 检查任务是否已被取消 + gen, _ := h.genService.GetByID(ctx, genID, userID) + if gen != nil && gen.Status == service.SoraGenStatusCancelled { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 任务已取消,跳过存储 id=%d", genID) + return + } + + // 三层降级存储:S3 → 本地 → 上游临时 URL + storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation(ctx, userID, mediaType, mediaURL, mediaURLs) + + usageAdded := false + if (storageType == service.SoraStorageTypeS3 || storageType == service.SoraStorageTypeLocal) && fileSize > 0 && h.quotaService != nil { + if err := h.quotaService.AddUsage(ctx, userID, fileSize); err != nil { + h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs) + var quotaErr *service.QuotaExceededError + if errors.As(err, "aErr) { + _ = h.genService.MarkFailed(ctx, genID, "存储配额已满,请删除不需要的作品释放空间") + return + } + _ = h.genService.MarkFailed(ctx, genID, "存储配额更新失败: "+err.Error()) + return + } + usageAdded = true + } + + // 存储完成后再做一次取消检查,防止取消被 completed 覆盖。 + gen, _ = h.genService.GetByID(ctx, genID, userID) + if gen != nil && gen.Status == service.SoraGenStatusCancelled { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 存储后检测到任务已取消,回滚存储 id=%d", genID) + h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs) + if usageAdded && h.quotaService != nil { + _ = h.quotaService.ReleaseUsage(ctx, userID, fileSize) + } + return + } + + // 标记完成 + if err := h.genService.MarkCompleted(ctx, genID, storedURL, storedURLs, storageType, s3Keys, fileSize); err != nil { + if errors.Is(err, service.ErrSoraGenerationStateConflict) { + h.cleanupStoredMedia(ctx, storageType, s3Keys, storedURLs) + if usageAdded && h.quotaService != nil { + _ = h.quotaService.ReleaseUsage(ctx, userID, fileSize) + } + return + } + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 标记完成失败 id=%d err=%v", genID, err) + return + } + + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成完成 id=%d storage=%s size=%d", genID, storageType, fileSize) +} + +// storeMediaWithDegradation 实现三层降级存储链:S3 → 本地 → 上游。 +func (h *SoraClientHandler) storeMediaWithDegradation( + ctx context.Context, userID int64, mediaType string, + mediaURL string, mediaURLs []string, +) (storedURL string, storedURLs []string, storageType string, s3Keys []string, fileSize int64) { + urls := mediaURLs + if len(urls) == 0 { + urls = []string{mediaURL} + } + + // 第一层:尝试 S3 + if h.s3Storage != nil && h.s3Storage.Enabled(ctx) { + keys := make([]string, 0, len(urls)) + var totalSize int64 + allOK := true + for _, u := range urls { + key, size, err := h.s3Storage.UploadFromURL(ctx, userID, u) + if err != nil { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] S3 上传失败 err=%v", err) + allOK = false + // 清理已上传的文件 + if len(keys) > 0 { + _ = h.s3Storage.DeleteObjects(ctx, keys) + } + break + } + keys = append(keys, key) + totalSize += size + } + if allOK && len(keys) > 0 { + accessURLs := make([]string, 0, len(keys)) + for _, key := range keys { + accessURL, err := h.s3Storage.GetAccessURL(ctx, key) + if err != nil { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 生成 S3 访问 URL 失败 err=%v", err) + _ = h.s3Storage.DeleteObjects(ctx, keys) + allOK = false + break + } + accessURLs = append(accessURLs, accessURL) + } + if allOK && len(accessURLs) > 0 { + return accessURLs[0], accessURLs, service.SoraStorageTypeS3, keys, totalSize + } + } + } + + // 第二层:尝试本地存储 + if h.mediaStorage != nil && h.mediaStorage.Enabled() { + storedPaths, err := h.mediaStorage.StoreFromURLs(ctx, mediaType, urls) + if err == nil && len(storedPaths) > 0 { + firstPath := storedPaths[0] + totalSize, sizeErr := h.mediaStorage.TotalSizeByRelativePaths(storedPaths) + if sizeErr != nil { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 统计本地文件大小失败 err=%v", sizeErr) + } + return firstPath, storedPaths, service.SoraStorageTypeLocal, nil, totalSize + } + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 本地存储失败 err=%v", err) + } + + // 第三层:保留上游临时 URL + return urls[0], urls, service.SoraStorageTypeUpstream, nil, 0 +} + +// buildAsyncRequestBody 构建 Sora 异步生成的 chat completions 请求体。 +func buildAsyncRequestBody(model, prompt, imageInput string, videoCount int) []byte { + body := map[string]any{ + "model": model, + "messages": []map[string]string{ + {"role": "user", "content": prompt}, + }, + "stream": false, + } + if imageInput != "" { + body["image_input"] = imageInput + } + if videoCount > 1 { + body["video_count"] = videoCount + } + b, _ := json.Marshal(body) + return b +} + +func normalizeVideoCount(mediaType string, videoCount int) int { + if mediaType != "video" { + return 1 + } + if videoCount <= 0 { + return 1 + } + if videoCount > 3 { + return 3 + } + return videoCount +} + +// extractMediaURLsFromResult 从 Forward 结果和响应体中提取媒体 URL。 +// OAuth 路径:ForwardResult.MediaURL 已填充。 +// APIKey 路径:需从响应体解析 media_url / media_urls 字段。 +func extractMediaURLsFromResult(result *service.ForwardResult, recorder *httptest.ResponseRecorder) (string, []string) { + // 优先从 ForwardResult 获取(OAuth 路径) + if result != nil && result.MediaURL != "" { + // 尝试从响应体获取完整 URL 列表 + if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 { + return urls[0], urls + } + return result.MediaURL, []string{result.MediaURL} + } + + // 从响应体解析(APIKey 路径) + if urls := parseMediaURLsFromBody(recorder.Body.Bytes()); len(urls) > 0 { + return urls[0], urls + } + + return "", nil +} + +// parseMediaURLsFromBody 从 JSON 响应体中解析 media_url / media_urls 字段。 +func parseMediaURLsFromBody(body []byte) []string { + if len(body) == 0 { + return nil + } + var resp map[string]any + if err := json.Unmarshal(body, &resp); err != nil { + return nil + } + + // 优先 media_urls(多图数组) + if rawURLs, ok := resp["media_urls"]; ok { + if arr, ok := rawURLs.([]any); ok && len(arr) > 0 { + urls := make([]string, 0, len(arr)) + for _, item := range arr { + if s, ok := item.(string); ok && s != "" { + urls = append(urls, s) + } + } + if len(urls) > 0 { + return urls + } + } + } + + // 回退到 media_url(单个 URL) + if url, ok := resp["media_url"].(string); ok && url != "" { + return []string{url} + } + + return nil +} + +// ListGenerations 查询生成记录列表。 +// GET /api/v1/sora/generations +func (h *SoraClientHandler) ListGenerations(c *gin.Context) { + userID := getUserIDFromContext(c) + if userID == 0 { + response.Error(c, http.StatusUnauthorized, "未登录") + return + } + + page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) + pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20")) + + params := service.SoraGenerationListParams{ + UserID: userID, + Status: c.Query("status"), + StorageType: c.Query("storage_type"), + MediaType: c.Query("media_type"), + Page: page, + PageSize: pageSize, + } + + gens, total, err := h.genService.List(c.Request.Context(), params) + if err != nil { + response.ErrorFrom(c, err) + return + } + + // 为 S3 记录动态生成预签名 URL + for _, gen := range gens { + _ = h.genService.ResolveMediaURLs(c.Request.Context(), gen) + } + + response.Success(c, gin.H{ + "data": gens, + "total": total, + "page": page, + }) +} + +// GetGeneration 查询生成记录详情。 +// GET /api/v1/sora/generations/:id +func (h *SoraClientHandler) GetGeneration(c *gin.Context) { + userID := getUserIDFromContext(c) + if userID == 0 { + response.Error(c, http.StatusUnauthorized, "未登录") + return + } + + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.Error(c, http.StatusBadRequest, "无效的 ID") + return + } + + gen, err := h.genService.GetByID(c.Request.Context(), id, userID) + if err != nil { + response.Error(c, http.StatusNotFound, err.Error()) + return + } + + _ = h.genService.ResolveMediaURLs(c.Request.Context(), gen) + response.Success(c, gen) +} + +// DeleteGeneration 删除生成记录。 +// DELETE /api/v1/sora/generations/:id +func (h *SoraClientHandler) DeleteGeneration(c *gin.Context) { + userID := getUserIDFromContext(c) + if userID == 0 { + response.Error(c, http.StatusUnauthorized, "未登录") + return + } + + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.Error(c, http.StatusBadRequest, "无效的 ID") + return + } + + gen, err := h.genService.GetByID(c.Request.Context(), id, userID) + if err != nil { + response.Error(c, http.StatusNotFound, err.Error()) + return + } + + // 先尝试清理本地文件,再删除记录(清理失败不阻塞删除)。 + if gen.StorageType == service.SoraStorageTypeLocal && h.mediaStorage != nil { + paths := gen.MediaURLs + if len(paths) == 0 && gen.MediaURL != "" { + paths = []string{gen.MediaURL} + } + if err := h.mediaStorage.DeleteByRelativePaths(paths); err != nil { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 删除本地文件失败 id=%d err=%v", id, err) + } + } + + if err := h.genService.Delete(c.Request.Context(), id, userID); err != nil { + response.Error(c, http.StatusNotFound, err.Error()) + return + } + + response.Success(c, gin.H{"message": "已删除"}) +} + +// GetQuota 查询用户存储配额。 +// GET /api/v1/sora/quota +func (h *SoraClientHandler) GetQuota(c *gin.Context) { + userID := getUserIDFromContext(c) + if userID == 0 { + response.Error(c, http.StatusUnauthorized, "未登录") + return + } + + if h.quotaService == nil { + response.Success(c, service.QuotaInfo{QuotaSource: "unlimited", Source: "unlimited"}) + return + } + + quota, err := h.quotaService.GetQuota(c.Request.Context(), userID) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, quota) +} + +// CancelGeneration 取消生成任务。 +// POST /api/v1/sora/generations/:id/cancel +func (h *SoraClientHandler) CancelGeneration(c *gin.Context) { + userID := getUserIDFromContext(c) + if userID == 0 { + response.Error(c, http.StatusUnauthorized, "未登录") + return + } + + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.Error(c, http.StatusBadRequest, "无效的 ID") + return + } + + // 权限校验 + gen, err := h.genService.GetByID(c.Request.Context(), id, userID) + if err != nil { + response.Error(c, http.StatusNotFound, err.Error()) + return + } + _ = gen + + if err := h.genService.MarkCancelled(c.Request.Context(), id); err != nil { + if errors.Is(err, service.ErrSoraGenerationNotActive) { + response.Error(c, http.StatusConflict, "任务已结束,无法取消") + return + } + response.Error(c, http.StatusBadRequest, err.Error()) + return + } + + response.Success(c, gin.H{"message": "已取消"}) +} + +// SaveToStorage 手动保存 upstream 记录到 S3。 +// POST /api/v1/sora/generations/:id/save +func (h *SoraClientHandler) SaveToStorage(c *gin.Context) { + userID := getUserIDFromContext(c) + if userID == 0 { + response.Error(c, http.StatusUnauthorized, "未登录") + return + } + + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.Error(c, http.StatusBadRequest, "无效的 ID") + return + } + + gen, err := h.genService.GetByID(c.Request.Context(), id, userID) + if err != nil { + response.Error(c, http.StatusNotFound, err.Error()) + return + } + + if gen.StorageType != service.SoraStorageTypeUpstream { + response.Error(c, http.StatusBadRequest, "仅 upstream 类型的记录可手动保存") + return + } + if gen.MediaURL == "" { + response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期") + return + } + + if h.s3Storage == nil || !h.s3Storage.Enabled(c.Request.Context()) { + response.Error(c, http.StatusServiceUnavailable, "云存储未配置,请联系管理员") + return + } + + sourceURLs := gen.MediaURLs + if len(sourceURLs) == 0 && gen.MediaURL != "" { + sourceURLs = []string{gen.MediaURL} + } + if len(sourceURLs) == 0 { + response.Error(c, http.StatusBadRequest, "媒体 URL 为空,可能已过期") + return + } + + uploadedKeys := make([]string, 0, len(sourceURLs)) + accessURLs := make([]string, 0, len(sourceURLs)) + var totalSize int64 + + for _, sourceURL := range sourceURLs { + objectKey, fileSize, uploadErr := h.s3Storage.UploadFromURL(c.Request.Context(), userID, sourceURL) + if uploadErr != nil { + if len(uploadedKeys) > 0 { + _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) + } + var upstreamErr *service.UpstreamDownloadError + if errors.As(uploadErr, &upstreamErr) && (upstreamErr.StatusCode == http.StatusForbidden || upstreamErr.StatusCode == http.StatusNotFound) { + response.Error(c, http.StatusGone, "媒体链接已过期,无法保存") + return + } + response.Error(c, http.StatusInternalServerError, "上传到 S3 失败: "+uploadErr.Error()) + return + } + accessURL, err := h.s3Storage.GetAccessURL(c.Request.Context(), objectKey) + if err != nil { + uploadedKeys = append(uploadedKeys, objectKey) + _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) + response.Error(c, http.StatusInternalServerError, "生成 S3 访问链接失败: "+err.Error()) + return + } + uploadedKeys = append(uploadedKeys, objectKey) + accessURLs = append(accessURLs, accessURL) + totalSize += fileSize + } + + usageAdded := false + if totalSize > 0 && h.quotaService != nil { + if err := h.quotaService.AddUsage(c.Request.Context(), userID, totalSize); err != nil { + _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) + var quotaErr *service.QuotaExceededError + if errors.As(err, "aErr) { + response.Error(c, http.StatusTooManyRequests, "存储配额已满,请删除不需要的作品释放空间") + return + } + response.Error(c, http.StatusInternalServerError, "配额更新失败: "+err.Error()) + return + } + usageAdded = true + } + + if err := h.genService.UpdateStorageForCompleted( + c.Request.Context(), + id, + accessURLs[0], + accessURLs, + service.SoraStorageTypeS3, + uploadedKeys, + totalSize, + ); err != nil { + _ = h.s3Storage.DeleteObjects(c.Request.Context(), uploadedKeys) + if usageAdded && h.quotaService != nil { + _ = h.quotaService.ReleaseUsage(c.Request.Context(), userID, totalSize) + } + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{ + "message": "已保存到 S3", + "object_key": uploadedKeys[0], + "object_keys": uploadedKeys, + }) +} + +// GetStorageStatus 返回存储状态。 +// GET /api/v1/sora/storage-status +func (h *SoraClientHandler) GetStorageStatus(c *gin.Context) { + s3Enabled := h.s3Storage != nil && h.s3Storage.Enabled(c.Request.Context()) + s3Healthy := false + if s3Enabled { + s3Healthy = h.s3Storage.IsHealthy(c.Request.Context()) + } + localEnabled := h.mediaStorage != nil && h.mediaStorage.Enabled() + response.Success(c, gin.H{ + "s3_enabled": s3Enabled, + "s3_healthy": s3Healthy, + "local_enabled": localEnabled, + }) +} + +func (h *SoraClientHandler) cleanupStoredMedia(ctx context.Context, storageType string, s3Keys []string, localPaths []string) { + switch storageType { + case service.SoraStorageTypeS3: + if h.s3Storage != nil && len(s3Keys) > 0 { + if err := h.s3Storage.DeleteObjects(ctx, s3Keys); err != nil { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理 S3 文件失败 keys=%v err=%v", s3Keys, err) + } + } + case service.SoraStorageTypeLocal: + if h.mediaStorage != nil && len(localPaths) > 0 { + if err := h.mediaStorage.DeleteByRelativePaths(localPaths); err != nil { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 清理本地文件失败 paths=%v err=%v", localPaths, err) + } + } + } +} + +// getUserIDFromContext 从 gin 上下文中提取用户 ID。 +func getUserIDFromContext(c *gin.Context) int64 { + if subject, ok := middleware2.GetAuthSubjectFromContext(c); ok && subject.UserID > 0 { + return subject.UserID + } + + if id, ok := c.Get("user_id"); ok { + switch v := id.(type) { + case int64: + return v + case float64: + return int64(v) + case string: + n, _ := strconv.ParseInt(v, 10, 64) + return n + } + } + // 尝试从 JWT claims 获取 + if id, ok := c.Get("userID"); ok { + if v, ok := id.(int64); ok { + return v + } + } + return 0 +} + +func groupIDForLog(groupID *int64) int64 { + if groupID == nil { + return 0 + } + return *groupID +} + +func trimForLog(raw string, maxLen int) string { + trimmed := strings.TrimSpace(raw) + if maxLen <= 0 || len(trimmed) <= maxLen { + return trimmed + } + return trimmed[:maxLen] + "...(truncated)" +} + +// GetModels 获取可用 Sora 模型家族列表。 +// 优先从上游 Sora API 同步模型列表,失败时降级到本地配置。 +// GET /api/v1/sora/models +func (h *SoraClientHandler) GetModels(c *gin.Context) { + families := h.getModelFamilies(c.Request.Context()) + response.Success(c, families) +} + +// getModelFamilies 获取模型家族列表(带缓存)。 +func (h *SoraClientHandler) getModelFamilies(ctx context.Context) []service.SoraModelFamily { + // 读锁检查缓存 + h.modelCacheMu.RLock() + ttl := modelCacheTTL + if !h.modelCacheUpstream { + ttl = modelCacheFailedTTL + } + if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl { + families := h.cachedFamilies + h.modelCacheMu.RUnlock() + return families + } + h.modelCacheMu.RUnlock() + + // 写锁更新缓存 + h.modelCacheMu.Lock() + defer h.modelCacheMu.Unlock() + + // double-check + ttl = modelCacheTTL + if !h.modelCacheUpstream { + ttl = modelCacheFailedTTL + } + if h.cachedFamilies != nil && time.Since(h.modelCacheTime) < ttl { + return h.cachedFamilies + } + + // 尝试从上游获取 + families, err := h.fetchUpstreamModels(ctx) + if err != nil { + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 上游模型获取失败,使用本地配置: %v", err) + families = service.BuildSoraModelFamilies() + h.cachedFamilies = families + h.modelCacheTime = time.Now() + h.modelCacheUpstream = false + return families + } + + logger.LegacyPrintf("handler.sora_client", "[SoraClient] 从上游同步到 %d 个模型家族", len(families)) + h.cachedFamilies = families + h.modelCacheTime = time.Now() + h.modelCacheUpstream = true + return families +} + +// fetchUpstreamModels 从上游 Sora API 获取模型列表。 +func (h *SoraClientHandler) fetchUpstreamModels(ctx context.Context) ([]service.SoraModelFamily, error) { + if h.gatewayService == nil { + return nil, fmt.Errorf("gatewayService 未初始化") + } + + // 设置 ForcePlatform 用于 Sora 账号选择 + ctx = context.WithValue(ctx, ctxkey.ForcePlatform, service.PlatformSora) + + // 选择一个 Sora 账号 + account, err := h.gatewayService.SelectAccountForModel(ctx, nil, "", "sora2-landscape-10s") + if err != nil { + return nil, fmt.Errorf("选择 Sora 账号失败: %w", err) + } + + // 仅支持 API Key 类型账号 + if account.Type != service.AccountTypeAPIKey { + return nil, fmt.Errorf("当前账号类型 %s 不支持模型同步", account.Type) + } + + apiKey := account.GetCredential("api_key") + if apiKey == "" { + return nil, fmt.Errorf("账号缺少 api_key") + } + + baseURL := account.GetBaseURL() + if baseURL == "" { + return nil, fmt.Errorf("账号缺少 base_url") + } + + // 构建上游模型列表请求 + modelsURL := strings.TrimRight(baseURL, "/") + "/sora/v1/models" + + reqCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(reqCtx, http.MethodGet, modelsURL, nil) + if err != nil { + return nil, fmt.Errorf("创建请求失败: %w", err) + } + req.Header.Set("Authorization", "Bearer "+apiKey) + + client := &http.Client{Timeout: 10 * time.Second} + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("请求上游失败: %w", err) + } + defer func() { + _ = resp.Body.Close() + }() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("上游返回状态码 %d", resp.StatusCode) + } + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1*1024*1024)) + if err != nil { + return nil, fmt.Errorf("读取响应失败: %w", err) + } + + // 解析 OpenAI 格式的模型列表 + var modelsResp struct { + Data []struct { + ID string `json:"id"` + } `json:"data"` + } + if err := json.Unmarshal(body, &modelsResp); err != nil { + return nil, fmt.Errorf("解析响应失败: %w", err) + } + + if len(modelsResp.Data) == 0 { + return nil, fmt.Errorf("上游返回空模型列表") + } + + // 提取模型 ID + modelIDs := make([]string, 0, len(modelsResp.Data)) + for _, m := range modelsResp.Data { + modelIDs = append(modelIDs, m.ID) + } + + // 转换为模型家族 + families := service.BuildSoraModelFamiliesFromIDs(modelIDs) + if len(families) == 0 { + return nil, fmt.Errorf("未能从上游模型列表中识别出有效的模型家族") + } + + return families, nil +} diff --git a/backend/internal/handler/sora_client_handler_test.go b/backend/internal/handler/sora_client_handler_test.go new file mode 100644 index 00000000..a639e504 --- /dev/null +++ b/backend/internal/handler/sora_client_handler_test.go @@ -0,0 +1,3135 @@ +//go:build unit + +package handler + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "os" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func init() { + gin.SetMode(gin.TestMode) +} + +// ==================== Stub: SoraGenerationRepository ==================== + +var _ service.SoraGenerationRepository = (*stubSoraGenRepo)(nil) + +type stubSoraGenRepo struct { + gens map[int64]*service.SoraGeneration + nextID int64 + createErr error + getErr error + updateErr error + deleteErr error + listErr error + countErr error + countValue int64 + + // 条件性 Update 失败:前 updateFailAfterN 次成功,之后失败 + updateCallCount *int32 + updateFailAfterN int32 + + // 条件性 GetByID 状态覆盖:前 getByIDOverrideAfterN 次正常返回,之后返回 overrideStatus + getByIDCallCount int32 + getByIDOverrideAfterN int32 // 0 = 不覆盖 + getByIDOverrideStatus string +} + +func newStubSoraGenRepo() *stubSoraGenRepo { + return &stubSoraGenRepo{gens: make(map[int64]*service.SoraGeneration), nextID: 1} +} + +func (r *stubSoraGenRepo) Create(_ context.Context, gen *service.SoraGeneration) error { + if r.createErr != nil { + return r.createErr + } + gen.ID = r.nextID + r.nextID++ + r.gens[gen.ID] = gen + return nil +} +func (r *stubSoraGenRepo) GetByID(_ context.Context, id int64) (*service.SoraGeneration, error) { + if r.getErr != nil { + return nil, r.getErr + } + gen, ok := r.gens[id] + if !ok { + return nil, fmt.Errorf("not found") + } + // 条件性状态覆盖:模拟外部取消等场景 + if r.getByIDOverrideAfterN > 0 { + n := atomic.AddInt32(&r.getByIDCallCount, 1) + if n > r.getByIDOverrideAfterN { + cp := *gen + cp.Status = r.getByIDOverrideStatus + return &cp, nil + } + } + return gen, nil +} +func (r *stubSoraGenRepo) Update(_ context.Context, gen *service.SoraGeneration) error { + // 条件性失败:前 N 次成功,之后失败 + if r.updateCallCount != nil { + n := atomic.AddInt32(r.updateCallCount, 1) + if n > r.updateFailAfterN { + return fmt.Errorf("conditional update error (call #%d)", n) + } + } + if r.updateErr != nil { + return r.updateErr + } + r.gens[gen.ID] = gen + return nil +} +func (r *stubSoraGenRepo) Delete(_ context.Context, id int64) error { + if r.deleteErr != nil { + return r.deleteErr + } + delete(r.gens, id) + return nil +} +func (r *stubSoraGenRepo) List(_ context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) { + if r.listErr != nil { + return nil, 0, r.listErr + } + var result []*service.SoraGeneration + for _, gen := range r.gens { + if gen.UserID != params.UserID { + continue + } + result = append(result, gen) + } + return result, int64(len(result)), nil +} +func (r *stubSoraGenRepo) CountByUserAndStatus(_ context.Context, _ int64, _ []string) (int64, error) { + if r.countErr != nil { + return 0, r.countErr + } + return r.countValue, nil +} + +// ==================== 辅助函数 ==================== + +func newTestSoraClientHandler(repo *stubSoraGenRepo) *SoraClientHandler { + genService := service.NewSoraGenerationService(repo, nil, nil) + return &SoraClientHandler{genService: genService} +} + +func makeGinContext(method, path, body string, userID int64) (*gin.Context, *httptest.ResponseRecorder) { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + if body != "" { + c.Request = httptest.NewRequest(method, path, strings.NewReader(body)) + c.Request.Header.Set("Content-Type", "application/json") + } else { + c.Request = httptest.NewRequest(method, path, nil) + } + if userID > 0 { + c.Set("user_id", userID) + } + return c, rec +} + +func parseResponse(t *testing.T, rec *httptest.ResponseRecorder) map[string]any { + t.Helper() + var resp map[string]any + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + return resp +} + +// ==================== 纯函数测试: buildAsyncRequestBody ==================== + +func TestBuildAsyncRequestBody(t *testing.T) { + body := buildAsyncRequestBody("sora2-landscape-10s", "一只猫在跳舞", "", 1) + var parsed map[string]any + require.NoError(t, json.Unmarshal(body, &parsed)) + require.Equal(t, "sora2-landscape-10s", parsed["model"]) + require.Equal(t, false, parsed["stream"]) + + msgs := parsed["messages"].([]any) + require.Len(t, msgs, 1) + msg := msgs[0].(map[string]any) + require.Equal(t, "user", msg["role"]) + require.Equal(t, "一只猫在跳舞", msg["content"]) +} + +func TestBuildAsyncRequestBody_EmptyPrompt(t *testing.T) { + body := buildAsyncRequestBody("gpt-image", "", "", 1) + var parsed map[string]any + require.NoError(t, json.Unmarshal(body, &parsed)) + require.Equal(t, "gpt-image", parsed["model"]) + msgs := parsed["messages"].([]any) + msg := msgs[0].(map[string]any) + require.Equal(t, "", msg["content"]) +} + +func TestBuildAsyncRequestBody_WithImageInput(t *testing.T) { + body := buildAsyncRequestBody("gpt-image", "一只猫", "https://example.com/ref.png", 1) + var parsed map[string]any + require.NoError(t, json.Unmarshal(body, &parsed)) + require.Equal(t, "https://example.com/ref.png", parsed["image_input"]) +} + +func TestBuildAsyncRequestBody_WithVideoCount(t *testing.T) { + body := buildAsyncRequestBody("sora2-landscape-10s", "一只猫在跳舞", "", 3) + var parsed map[string]any + require.NoError(t, json.Unmarshal(body, &parsed)) + require.Equal(t, float64(3), parsed["video_count"]) +} + +func TestNormalizeVideoCount(t *testing.T) { + require.Equal(t, 1, normalizeVideoCount("video", 0)) + require.Equal(t, 2, normalizeVideoCount("video", 2)) + require.Equal(t, 3, normalizeVideoCount("video", 5)) + require.Equal(t, 1, normalizeVideoCount("image", 3)) +} + +// ==================== 纯函数测试: parseMediaURLsFromBody ==================== + +func TestParseMediaURLsFromBody_MediaURLs(t *testing.T) { + urls := parseMediaURLsFromBody([]byte(`{"media_urls":["https://a.com/1.mp4","https://a.com/2.mp4"]}`)) + require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls) +} + +func TestParseMediaURLsFromBody_SingleMediaURL(t *testing.T) { + urls := parseMediaURLsFromBody([]byte(`{"media_url":"https://a.com/video.mp4"}`)) + require.Equal(t, []string{"https://a.com/video.mp4"}, urls) +} + +func TestParseMediaURLsFromBody_EmptyBody(t *testing.T) { + require.Nil(t, parseMediaURLsFromBody(nil)) + require.Nil(t, parseMediaURLsFromBody([]byte{})) +} + +func TestParseMediaURLsFromBody_InvalidJSON(t *testing.T) { + require.Nil(t, parseMediaURLsFromBody([]byte("not json"))) +} + +func TestParseMediaURLsFromBody_NoMediaFields(t *testing.T) { + require.Nil(t, parseMediaURLsFromBody([]byte(`{"data":"something"}`))) +} + +func TestParseMediaURLsFromBody_EmptyMediaURL(t *testing.T) { + require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_url":""}`))) +} + +func TestParseMediaURLsFromBody_EmptyMediaURLs(t *testing.T) { + require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":[]}`))) +} + +func TestParseMediaURLsFromBody_MediaURLsPriority(t *testing.T) { + body := `{"media_url":"https://single.com/1.mp4","media_urls":["https://multi.com/a.mp4","https://multi.com/b.mp4"]}` + urls := parseMediaURLsFromBody([]byte(body)) + require.Len(t, urls, 2) + require.Equal(t, "https://multi.com/a.mp4", urls[0]) +} + +func TestParseMediaURLsFromBody_FilterEmpty(t *testing.T) { + urls := parseMediaURLsFromBody([]byte(`{"media_urls":["https://a.com/1.mp4","","https://a.com/2.mp4"]}`)) + require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls) +} + +func TestParseMediaURLsFromBody_AllEmpty(t *testing.T) { + require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":["",""]}`))) +} + +func TestParseMediaURLsFromBody_NonStringArray(t *testing.T) { + // media_urls 不是 string 数组 + require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_urls":"not-array"}`))) +} + +func TestParseMediaURLsFromBody_MediaURLNotString(t *testing.T) { + require.Nil(t, parseMediaURLsFromBody([]byte(`{"media_url":123}`))) +} + +// ==================== 纯函数测试: extractMediaURLsFromResult ==================== + +func TestExtractMediaURLsFromResult_OAuthPath(t *testing.T) { + result := &service.ForwardResult{MediaURL: "https://oauth.com/video.mp4"} + recorder := httptest.NewRecorder() + url, urls := extractMediaURLsFromResult(result, recorder) + require.Equal(t, "https://oauth.com/video.mp4", url) + require.Equal(t, []string{"https://oauth.com/video.mp4"}, urls) +} + +func TestExtractMediaURLsFromResult_OAuthWithBody(t *testing.T) { + result := &service.ForwardResult{MediaURL: "https://oauth.com/video.mp4"} + recorder := httptest.NewRecorder() + _, _ = recorder.Write([]byte(`{"media_urls":["https://body.com/1.mp4","https://body.com/2.mp4"]}`)) + url, urls := extractMediaURLsFromResult(result, recorder) + require.Equal(t, "https://body.com/1.mp4", url) + require.Len(t, urls, 2) +} + +func TestExtractMediaURLsFromResult_APIKeyPath(t *testing.T) { + recorder := httptest.NewRecorder() + _, _ = recorder.Write([]byte(`{"media_url":"https://upstream.com/video.mp4"}`)) + url, urls := extractMediaURLsFromResult(nil, recorder) + require.Equal(t, "https://upstream.com/video.mp4", url) + require.Equal(t, []string{"https://upstream.com/video.mp4"}, urls) +} + +func TestExtractMediaURLsFromResult_NilResultEmptyBody(t *testing.T) { + recorder := httptest.NewRecorder() + url, urls := extractMediaURLsFromResult(nil, recorder) + require.Empty(t, url) + require.Nil(t, urls) +} + +func TestExtractMediaURLsFromResult_EmptyMediaURL(t *testing.T) { + result := &service.ForwardResult{MediaURL: ""} + recorder := httptest.NewRecorder() + url, urls := extractMediaURLsFromResult(result, recorder) + require.Empty(t, url) + require.Nil(t, urls) +} + +// ==================== getUserIDFromContext ==================== + +func TestGetUserIDFromContext_Int64(t *testing.T) { + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Set("user_id", int64(42)) + require.Equal(t, int64(42), getUserIDFromContext(c)) +} + +func TestGetUserIDFromContext_AuthSubject(t *testing.T) { + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 777}) + require.Equal(t, int64(777), getUserIDFromContext(c)) +} + +func TestGetUserIDFromContext_Float64(t *testing.T) { + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Set("user_id", float64(99)) + require.Equal(t, int64(99), getUserIDFromContext(c)) +} + +func TestGetUserIDFromContext_String(t *testing.T) { + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Set("user_id", "123") + require.Equal(t, int64(123), getUserIDFromContext(c)) +} + +func TestGetUserIDFromContext_UserIDFallback(t *testing.T) { + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Set("userID", int64(55)) + require.Equal(t, int64(55), getUserIDFromContext(c)) +} + +func TestGetUserIDFromContext_NoID(t *testing.T) { + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + require.Equal(t, int64(0), getUserIDFromContext(c)) +} + +func TestGetUserIDFromContext_InvalidString(t *testing.T) { + c, _ := gin.CreateTestContext(httptest.NewRecorder()) + c.Request = httptest.NewRequest("GET", "/", nil) + c.Set("user_id", "not-a-number") + require.Equal(t, int64(0), getUserIDFromContext(c)) +} + +// ==================== Handler: Generate ==================== + +func TestGenerate_Unauthorized(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 0) + h.Generate(c) + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestGenerate_BadRequest_MissingModel(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestGenerate_BadRequest_MissingPrompt(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestGenerate_BadRequest_InvalidJSON(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{invalid`, 1) + h.Generate(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestGenerate_TooManyRequests(t *testing.T) { + repo := newStubSoraGenRepo() + repo.countValue = 3 + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusTooManyRequests, rec.Code) +} + +func TestGenerate_CountError(t *testing.T) { + repo := newStubSoraGenRepo() + repo.countErr = fmt.Errorf("db error") + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +func TestGenerate_Success(t *testing.T) { + repo := newStubSoraGenRepo() + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"测试生成"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.NotZero(t, data["generation_id"]) + require.Equal(t, "pending", data["status"]) +} + +func TestGenerate_DefaultMediaType(t *testing.T) { + repo := newStubSoraGenRepo() + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "video", repo.gens[1].MediaType) +} + +func TestGenerate_ImageMediaType(t *testing.T) { + repo := newStubSoraGenRepo() + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"gpt-image","prompt":"test","media_type":"image"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "image", repo.gens[1].MediaType) +} + +func TestGenerate_CreatePendingError(t *testing.T) { + repo := newStubSoraGenRepo() + repo.createErr = fmt.Errorf("create failed") + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +func TestGenerate_NilQuotaServiceSkipsCheck(t *testing.T) { + repo := newStubSoraGenRepo() + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) +} + +func TestGenerate_APIKeyInContext(t *testing.T) { + repo := newStubSoraGenRepo() + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + c.Set("api_key_id", int64(42)) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + require.NotNil(t, repo.gens[1].APIKeyID) + require.Equal(t, int64(42), *repo.gens[1].APIKeyID) +} + +func TestGenerate_NoAPIKeyInContext(t *testing.T) { + repo := newStubSoraGenRepo() + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + require.Nil(t, repo.gens[1].APIKeyID) +} + +func TestGenerate_ConcurrencyBoundary(t *testing.T) { + // activeCount == 2 应该允许 + repo := newStubSoraGenRepo() + repo.countValue = 2 + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) +} + +// ==================== Handler: ListGenerations ==================== + +func TestListGenerations_Unauthorized(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 0) + h.ListGenerations(c) + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestListGenerations_Success(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Model: "sora2-landscape-10s", Status: "completed", StorageType: "upstream"} + repo.gens[2] = &service.SoraGeneration{ID: 2, UserID: 1, Model: "gpt-image", Status: "pending", StorageType: "none"} + repo.nextID = 3 + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("GET", "/api/v1/sora/generations?page=1&page_size=10", "", 1) + h.ListGenerations(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + items := data["data"].([]any) + require.Len(t, items, 2) + require.Equal(t, float64(2), data["total"]) +} + +func TestListGenerations_ListError(t *testing.T) { + repo := newStubSoraGenRepo() + repo.listErr = fmt.Errorf("db error") + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 1) + h.ListGenerations(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +func TestListGenerations_DefaultPagination(t *testing.T) { + repo := newStubSoraGenRepo() + h := newTestSoraClientHandler(repo) + // 不传分页参数,应默认 page=1 page_size=20 + c, rec := makeGinContext("GET", "/api/v1/sora/generations", "", 1) + h.ListGenerations(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Equal(t, float64(1), data["page"]) +} + +// ==================== Handler: GetGeneration ==================== + +func TestGetGeneration_Unauthorized(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 0) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.GetGeneration(c) + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestGetGeneration_InvalidID(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("GET", "/api/v1/sora/generations/abc", "", 1) + c.Params = gin.Params{{Key: "id", Value: "abc"}} + h.GetGeneration(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestGetGeneration_NotFound(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("GET", "/api/v1/sora/generations/999", "", 1) + c.Params = gin.Params{{Key: "id", Value: "999"}} + h.GetGeneration(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestGetGeneration_WrongUser(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.GetGeneration(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestGetGeneration_Success(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Model: "sora2-landscape-10s", Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("GET", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.GetGeneration(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Equal(t, float64(1), data["id"]) +} + +// ==================== Handler: DeleteGeneration ==================== + +func TestDeleteGeneration_Unauthorized(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 0) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestDeleteGeneration_InvalidID(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/abc", "", 1) + c.Params = gin.Params{{Key: "id", Value: "abc"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestDeleteGeneration_NotFound(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/999", "", 1) + c.Params = gin.Params{{Key: "id", Value: "999"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestDeleteGeneration_WrongUser(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestDeleteGeneration_Success(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusOK, rec.Code) + _, exists := repo.gens[1] + require.False(t, exists) +} + +// ==================== Handler: CancelGeneration ==================== + +func TestCancelGeneration_Unauthorized(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 0) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestCancelGeneration_InvalidID(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/abc/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "abc"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestCancelGeneration_NotFound(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/999/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "999"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestCancelGeneration_WrongUser(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "pending"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestCancelGeneration_Pending(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "cancelled", repo.gens[1].Status) +} + +func TestCancelGeneration_Generating(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "generating"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "cancelled", repo.gens[1].Status) +} + +func TestCancelGeneration_Completed(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusConflict, rec.Code) +} + +func TestCancelGeneration_Failed(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "failed"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusConflict, rec.Code) +} + +func TestCancelGeneration_Cancelled(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "cancelled"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusConflict, rec.Code) +} + +// ==================== Handler: GetQuota ==================== + +func TestGetQuota_Unauthorized(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 0) + h.GetQuota(c) + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestGetQuota_NilQuotaService(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 1) + h.GetQuota(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Equal(t, "unlimited", data["source"]) +} + +// ==================== Handler: GetModels ==================== + +func TestGetModels(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("GET", "/api/v1/sora/models", "", 0) + h.GetModels(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].([]any) + require.Len(t, data, 4) + // 验证类型分布 + videoCount, imageCount := 0, 0 + for _, item := range data { + m := item.(map[string]any) + if m["type"] == "video" { + videoCount++ + } else if m["type"] == "image" { + imageCount++ + } + } + require.Equal(t, 3, videoCount) + require.Equal(t, 1, imageCount) +} + +// ==================== Handler: GetStorageStatus ==================== + +func TestGetStorageStatus_NilS3(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) + h.GetStorageStatus(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Equal(t, false, data["s3_enabled"]) + require.Equal(t, false, data["s3_healthy"]) + require.Equal(t, false, data["local_enabled"]) +} + +func TestGetStorageStatus_LocalEnabled(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "sora-storage-status-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + }, + }, + } + mediaStorage := service.NewSoraMediaStorage(cfg) + h := &SoraClientHandler{mediaStorage: mediaStorage} + + c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) + h.GetStorageStatus(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Equal(t, false, data["s3_enabled"]) + require.Equal(t, false, data["s3_healthy"]) + require.Equal(t, true, data["local_enabled"]) +} + +// ==================== Handler: SaveToStorage ==================== + +func TestSaveToStorage_Unauthorized(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 0) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestSaveToStorage_InvalidID(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/abc/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "abc"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestSaveToStorage_NotFound(t *testing.T) { + h := newTestSoraClientHandler(newStubSoraGenRepo()) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/999/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "999"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestSaveToStorage_NotUpstream(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "s3", MediaURL: "https://example.com/v.mp4"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestSaveToStorage_EmptyMediaURL(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: ""} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestSaveToStorage_S3Nil(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusServiceUnavailable, rec.Code) + resp := parseResponse(t, rec) + require.Contains(t, fmt.Sprint(resp["message"]), "云存储") +} + +func TestSaveToStorage_WrongUser(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 2, Status: "completed", StorageType: "upstream", MediaURL: "https://example.com/video.mp4"} + h := newTestSoraClientHandler(repo) + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +// ==================== storeMediaWithDegradation — nil guard 路径 ==================== + +func TestStoreMediaWithDegradation_NilS3NilMedia(t *testing.T) { + h := &SoraClientHandler{} + url, urls, storageType, keys, size := h.storeMediaWithDegradation( + context.Background(), 1, "video", "https://upstream.com/v.mp4", nil, + ) + require.Equal(t, service.SoraStorageTypeUpstream, storageType) + require.Equal(t, "https://upstream.com/v.mp4", url) + require.Equal(t, []string{"https://upstream.com/v.mp4"}, urls) + require.Nil(t, keys) + require.Equal(t, int64(0), size) +} + +func TestStoreMediaWithDegradation_NilGuardsMultiURL(t *testing.T) { + h := &SoraClientHandler{} + url, urls, storageType, keys, size := h.storeMediaWithDegradation( + context.Background(), 1, "video", "https://upstream.com/v.mp4", []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, + ) + require.Equal(t, service.SoraStorageTypeUpstream, storageType) + require.Equal(t, "https://a.com/1.mp4", url) + require.Equal(t, []string{"https://a.com/1.mp4", "https://a.com/2.mp4"}, urls) + require.Nil(t, keys) + require.Equal(t, int64(0), size) +} + +func TestStoreMediaWithDegradation_EmptyMediaURLsFallback(t *testing.T) { + h := &SoraClientHandler{} + url, _, storageType, _, _ := h.storeMediaWithDegradation( + context.Background(), 1, "video", "https://upstream.com/v.mp4", []string{}, + ) + require.Equal(t, service.SoraStorageTypeUpstream, storageType) + require.Equal(t, "https://upstream.com/v.mp4", url) +} + +// ==================== Stub: UserRepository (用于 SoraQuotaService) ==================== + +var _ service.UserRepository = (*stubUserRepoForHandler)(nil) + +type stubUserRepoForHandler struct { + users map[int64]*service.User + updateErr error +} + +func newStubUserRepoForHandler() *stubUserRepoForHandler { + return &stubUserRepoForHandler{users: make(map[int64]*service.User)} +} + +func (r *stubUserRepoForHandler) GetByID(_ context.Context, id int64) (*service.User, error) { + if u, ok := r.users[id]; ok { + return u, nil + } + return nil, fmt.Errorf("user not found") +} +func (r *stubUserRepoForHandler) Update(_ context.Context, user *service.User) error { + if r.updateErr != nil { + return r.updateErr + } + r.users[user.ID] = user + return nil +} +func (r *stubUserRepoForHandler) Create(context.Context, *service.User) error { return nil } +func (r *stubUserRepoForHandler) GetByEmail(context.Context, string) (*service.User, error) { + return nil, nil +} +func (r *stubUserRepoForHandler) GetFirstAdmin(context.Context) (*service.User, error) { + return nil, nil +} +func (r *stubUserRepoForHandler) Delete(context.Context, int64) error { return nil } +func (r *stubUserRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.User, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubUserRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, service.UserListFilters) ([]service.User, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubUserRepoForHandler) UpdateBalance(context.Context, int64, float64) error { return nil } +func (r *stubUserRepoForHandler) DeductBalance(context.Context, int64, float64) error { return nil } +func (r *stubUserRepoForHandler) UpdateConcurrency(context.Context, int64, int) error { return nil } +func (r *stubUserRepoForHandler) ExistsByEmail(context.Context, string) (bool, error) { + return false, nil +} +func (r *stubUserRepoForHandler) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { + return 0, nil +} +func (r *stubUserRepoForHandler) UpdateTotpSecret(context.Context, int64, *string) error { return nil } +func (r *stubUserRepoForHandler) EnableTotp(context.Context, int64) error { return nil } +func (r *stubUserRepoForHandler) DisableTotp(context.Context, int64) error { return nil } + +// ==================== NewSoraClientHandler ==================== + +func TestNewSoraClientHandler(t *testing.T) { + h := NewSoraClientHandler(nil, nil, nil, nil, nil, nil, nil) + require.NotNil(t, h) +} + +func TestNewSoraClientHandler_WithAPIKeyService(t *testing.T) { + h := NewSoraClientHandler(nil, nil, nil, nil, nil, nil, nil) + require.NotNil(t, h) + require.Nil(t, h.apiKeyService) +} + +// ==================== Stub: APIKeyRepository (用于 API Key 校验测试) ==================== + +var _ service.APIKeyRepository = (*stubAPIKeyRepoForHandler)(nil) + +type stubAPIKeyRepoForHandler struct { + keys map[int64]*service.APIKey + getErr error +} + +func newStubAPIKeyRepoForHandler() *stubAPIKeyRepoForHandler { + return &stubAPIKeyRepoForHandler{keys: make(map[int64]*service.APIKey)} +} + +func (r *stubAPIKeyRepoForHandler) GetByID(_ context.Context, id int64) (*service.APIKey, error) { + if r.getErr != nil { + return nil, r.getErr + } + if k, ok := r.keys[id]; ok { + return k, nil + } + return nil, fmt.Errorf("api key not found: %d", id) +} +func (r *stubAPIKeyRepoForHandler) Create(context.Context, *service.APIKey) error { return nil } +func (r *stubAPIKeyRepoForHandler) GetKeyAndOwnerID(_ context.Context, _ int64) (string, int64, error) { + return "", 0, nil +} +func (r *stubAPIKeyRepoForHandler) GetByKey(context.Context, string) (*service.APIKey, error) { + return nil, nil +} +func (r *stubAPIKeyRepoForHandler) GetByKeyForAuth(context.Context, string) (*service.APIKey, error) { + return nil, nil +} +func (r *stubAPIKeyRepoForHandler) Update(context.Context, *service.APIKey) error { return nil } +func (r *stubAPIKeyRepoForHandler) Delete(context.Context, int64) error { return nil } +func (r *stubAPIKeyRepoForHandler) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubAPIKeyRepoForHandler) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) { + return nil, nil +} +func (r *stubAPIKeyRepoForHandler) CountByUserID(context.Context, int64) (int64, error) { + return 0, nil +} +func (r *stubAPIKeyRepoForHandler) ExistsByKey(context.Context, string) (bool, error) { + return false, nil +} +func (r *stubAPIKeyRepoForHandler) ListByGroupID(_ context.Context, _ int64, _ pagination.PaginationParams) ([]service.APIKey, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubAPIKeyRepoForHandler) SearchAPIKeys(context.Context, int64, string, int) ([]service.APIKey, error) { + return nil, nil +} +func (r *stubAPIKeyRepoForHandler) ClearGroupIDByGroupID(context.Context, int64) (int64, error) { + return 0, nil +} +func (r *stubAPIKeyRepoForHandler) CountByGroupID(context.Context, int64) (int64, error) { + return 0, nil +} +func (r *stubAPIKeyRepoForHandler) ListKeysByUserID(context.Context, int64) ([]string, error) { + return nil, nil +} +func (r *stubAPIKeyRepoForHandler) ListKeysByGroupID(context.Context, int64) ([]string, error) { + return nil, nil +} +func (r *stubAPIKeyRepoForHandler) IncrementQuotaUsed(_ context.Context, _ int64, _ float64) (float64, error) { + return 0, nil +} +func (r *stubAPIKeyRepoForHandler) UpdateLastUsed(context.Context, int64, time.Time) error { + return nil +} + +// newTestAPIKeyService 创建测试用的 APIKeyService +func newTestAPIKeyService(repo *stubAPIKeyRepoForHandler) *service.APIKeyService { + return service.NewAPIKeyService(repo, nil, nil, nil, nil, nil, &config.Config{}) +} + +// ==================== Generate: API Key 校验(前端传递 api_key_id)==================== + +func TestGenerate_WithAPIKeyID_Success(t *testing.T) { + // 前端传递 api_key_id,校验通过 → 成功生成,记录关联 api_key_id + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + groupID := int64(5) + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyRepo.keys[42] = &service.APIKey{ + ID: 42, + UserID: 1, + Status: service.StatusAPIKeyActive, + GroupID: &groupID, + } + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.NotZero(t, data["generation_id"]) + + // 验证 api_key_id 已关联到生成记录 + gen := repo.gens[1] + require.NotNil(t, gen.APIKeyID) + require.Equal(t, int64(42), *gen.APIKeyID) +} + +func TestGenerate_WithAPIKeyID_NotFound(t *testing.T) { + // 前端传递不存在的 api_key_id → 400 + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":999}`, 1) + h.Generate(c) + require.Equal(t, http.StatusBadRequest, rec.Code) + resp := parseResponse(t, rec) + require.Contains(t, fmt.Sprint(resp["message"]), "不存在") +} + +func TestGenerate_WithAPIKeyID_WrongUser(t *testing.T) { + // 前端传递别人的 api_key_id → 403 + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyRepo.keys[42] = &service.APIKey{ + ID: 42, + UserID: 999, // 属于 user 999 + Status: service.StatusAPIKeyActive, + } + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) + h.Generate(c) + require.Equal(t, http.StatusForbidden, rec.Code) + resp := parseResponse(t, rec) + require.Contains(t, fmt.Sprint(resp["message"]), "不属于") +} + +func TestGenerate_WithAPIKeyID_Disabled(t *testing.T) { + // 前端传递已禁用的 api_key_id → 403 + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyRepo.keys[42] = &service.APIKey{ + ID: 42, + UserID: 1, + Status: service.StatusAPIKeyDisabled, + } + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) + h.Generate(c) + require.Equal(t, http.StatusForbidden, rec.Code) + resp := parseResponse(t, rec) + require.Contains(t, fmt.Sprint(resp["message"]), "不可用") +} + +func TestGenerate_WithAPIKeyID_QuotaExhausted(t *testing.T) { + // 前端传递配额耗尽的 api_key_id → 403 + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyRepo.keys[42] = &service.APIKey{ + ID: 42, + UserID: 1, + Status: service.StatusAPIKeyQuotaExhausted, + } + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) + h.Generate(c) + require.Equal(t, http.StatusForbidden, rec.Code) +} + +func TestGenerate_WithAPIKeyID_Expired(t *testing.T) { + // 前端传递已过期的 api_key_id → 403 + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyRepo.keys[42] = &service.APIKey{ + ID: 42, + UserID: 1, + Status: service.StatusAPIKeyExpired, + } + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) + h.Generate(c) + require.Equal(t, http.StatusForbidden, rec.Code) +} + +func TestGenerate_WithAPIKeyID_NilAPIKeyService(t *testing.T) { + // apiKeyService 为 nil 时忽略 api_key_id → 正常生成但不记录 api_key_id + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + h := &SoraClientHandler{genService: genService} // apiKeyService = nil + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + // apiKeyService 为 nil → 跳过校验 → api_key_id 不记录 + require.Nil(t, repo.gens[1].APIKeyID) +} + +func TestGenerate_WithAPIKeyID_NilGroupID(t *testing.T) { + // api_key 有效但 GroupID 为 nil → 成功,groupID 为 nil + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyRepo.keys[42] = &service.APIKey{ + ID: 42, + UserID: 1, + Status: service.StatusAPIKeyActive, + GroupID: nil, // 无分组 + } + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + require.NotNil(t, repo.gens[1].APIKeyID) + require.Equal(t, int64(42), *repo.gens[1].APIKeyID) +} + +func TestGenerate_NoAPIKeyID_NoContext_NilResult(t *testing.T) { + // 既无 api_key_id 字段也无 context 中的 api_key_id → api_key_id 为 nil + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + require.Nil(t, repo.gens[1].APIKeyID) +} + +func TestGenerate_WithAPIKeyIDInBody_OverridesContext(t *testing.T) { + // 同时有 body api_key_id 和 context api_key_id → 优先使用 body 的 + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + groupID := int64(10) + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyRepo.keys[42] = &service.APIKey{ + ID: 42, + UserID: 1, + Status: service.StatusAPIKeyActive, + GroupID: &groupID, + } + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":42}`, 1) + c.Set("api_key_id", int64(99)) // context 中有另一个 api_key_id + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + // 应使用 body 中的 api_key_id=42,而不是 context 中的 99 + require.NotNil(t, repo.gens[1].APIKeyID) + require.Equal(t, int64(42), *repo.gens[1].APIKeyID) +} + +func TestGenerate_WithContextAPIKeyID_FallbackPath(t *testing.T) { + // 无 body api_key_id,但 context 有 → 使用 context 中的(兼容网关路由) + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + c.Set("api_key_id", int64(99)) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) + // 应使用 context 中的 api_key_id=99 + require.NotNil(t, repo.gens[1].APIKeyID) + require.Equal(t, int64(99), *repo.gens[1].APIKeyID) +} + +func TestGenerate_APIKeyID_Zero_IgnoredInJSON(t *testing.T) { + // JSON 中 api_key_id=0 被视为 omitempty → 仍然为指针值 0,需要传 nil 检查 + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + apiKeyRepo := newStubAPIKeyRepoForHandler() + apiKeyService := newTestAPIKeyService(apiKeyRepo) + + h := &SoraClientHandler{genService: genService, apiKeyService: apiKeyService} + // JSON 中传了 api_key_id: 0 → 解析后 *int64(0),会触发校验 + // api_key_id=0 不存在 → 400 + c, rec := makeGinContext("POST", "/api/v1/sora/generate", + `{"model":"sora2-landscape-10s","prompt":"test","api_key_id":0}`, 1) + h.Generate(c) + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +// ==================== processGeneration: groupID 传递与 ForcePlatform ==================== + +func TestProcessGeneration_WithGroupID_NoForcePlatform(t *testing.T) { + // groupID 不为 nil → 不设置 ForcePlatform + // gatewayService 为 nil → MarkFailed → 检查错误消息不包含 ForcePlatform 相关 + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService} + + gid := int64(5) + h.processGeneration(1, 1, &gid, "sora2-landscape-10s", "test", "video", "", 1) + require.Equal(t, "failed", repo.gens[1].Status) + require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService") +} + +func TestProcessGeneration_NilGroupID_SetsForcePlatform(t *testing.T) { + // groupID 为 nil → 设置 ForcePlatform → gatewayService 为 nil → MarkFailed + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService} + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + require.Equal(t, "failed", repo.gens[1].Status) + require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService") +} + +func TestProcessGeneration_MarkGeneratingStateConflict(t *testing.T) { + // 任务状态已变化(如已取消)→ MarkGenerating 返回 ErrSoraGenerationStateConflict → 跳过 + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "cancelled"} + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService} + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + // 状态为 cancelled 时 MarkGenerating 不符合状态转换规则 → 应保持 cancelled + require.Equal(t, "cancelled", repo.gens[1].Status) +} + +// ==================== GenerateRequest JSON 解析 ==================== + +func TestGenerateRequest_WithAPIKeyID_JSONParsing(t *testing.T) { + // 验证 api_key_id 在 JSON 中正确解析为 *int64 + var req GenerateRequest + err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test","api_key_id":42}`), &req) + require.NoError(t, err) + require.NotNil(t, req.APIKeyID) + require.Equal(t, int64(42), *req.APIKeyID) +} + +func TestGenerateRequest_WithoutAPIKeyID_JSONParsing(t *testing.T) { + // 不传 api_key_id → 解析后为 nil + var req GenerateRequest + err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test"}`), &req) + require.NoError(t, err) + require.Nil(t, req.APIKeyID) +} + +func TestGenerateRequest_NullAPIKeyID_JSONParsing(t *testing.T) { + // api_key_id: null → 解析后为 nil + var req GenerateRequest + err := json.Unmarshal([]byte(`{"model":"sora2","prompt":"test","api_key_id":null}`), &req) + require.NoError(t, err) + require.Nil(t, req.APIKeyID) +} + +func TestGenerateRequest_FullFields_JSONParsing(t *testing.T) { + // 全字段解析 + var req GenerateRequest + err := json.Unmarshal([]byte(`{ + "model":"sora2-landscape-10s", + "prompt":"test prompt", + "media_type":"video", + "video_count":2, + "image_input":"data:image/png;base64,abc", + "api_key_id":100 + }`), &req) + require.NoError(t, err) + require.Equal(t, "sora2-landscape-10s", req.Model) + require.Equal(t, "test prompt", req.Prompt) + require.Equal(t, "video", req.MediaType) + require.Equal(t, 2, req.VideoCount) + require.Equal(t, "data:image/png;base64,abc", req.ImageInput) + require.NotNil(t, req.APIKeyID) + require.Equal(t, int64(100), *req.APIKeyID) +} + +func TestGenerateRequest_JSONSerialize_OmitsNilAPIKeyID(t *testing.T) { + // api_key_id 为 nil 时 JSON 序列化应省略 + req := GenerateRequest{Model: "sora2", Prompt: "test"} + b, err := json.Marshal(req) + require.NoError(t, err) + var parsed map[string]any + require.NoError(t, json.Unmarshal(b, &parsed)) + _, hasAPIKeyID := parsed["api_key_id"] + require.False(t, hasAPIKeyID, "api_key_id 为 nil 时应省略") +} + +func TestGenerateRequest_JSONSerialize_IncludesAPIKeyID(t *testing.T) { + // api_key_id 不为 nil 时 JSON 序列化应包含 + id := int64(42) + req := GenerateRequest{Model: "sora2", Prompt: "test", APIKeyID: &id} + b, err := json.Marshal(req) + require.NoError(t, err) + var parsed map[string]any + require.NoError(t, json.Unmarshal(b, &parsed)) + require.Equal(t, float64(42), parsed["api_key_id"]) +} + +// ==================== GetQuota: 有配额服务 ==================== + +func TestGetQuota_WithQuotaService_Success(t *testing.T) { + userRepo := newStubUserRepoForHandler() + userRepo.users[1] = &service.User{ + ID: 1, + SoraStorageQuotaBytes: 10 * 1024 * 1024, + SoraStorageUsedBytes: 3 * 1024 * 1024, + } + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{ + genService: genService, + quotaService: quotaService, + } + + c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 1) + h.GetQuota(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Equal(t, "user", data["source"]) + require.Equal(t, float64(10*1024*1024), data["quota_bytes"]) + require.Equal(t, float64(3*1024*1024), data["used_bytes"]) +} + +func TestGetQuota_WithQuotaService_Error(t *testing.T) { + // 用户不存在时 GetQuota 返回错误 + userRepo := newStubUserRepoForHandler() + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{ + genService: genService, + quotaService: quotaService, + } + + c, rec := makeGinContext("GET", "/api/v1/sora/quota", "", 999) + h.GetQuota(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +// ==================== Generate: 配额检查 ==================== + +func TestGenerate_QuotaCheckFailed(t *testing.T) { + // 配额超限时返回 429 + userRepo := newStubUserRepoForHandler() + userRepo.users[1] = &service.User{ + ID: 1, + SoraStorageQuotaBytes: 1024, + SoraStorageUsedBytes: 1025, // 已超限 + } + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{ + genService: genService, + quotaService: quotaService, + } + + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusTooManyRequests, rec.Code) +} + +func TestGenerate_QuotaCheckPassed(t *testing.T) { + // 配额充足时允许生成 + userRepo := newStubUserRepoForHandler() + userRepo.users[1] = &service.User{ + ID: 1, + SoraStorageQuotaBytes: 10 * 1024 * 1024, + SoraStorageUsedBytes: 0, + } + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{ + genService: genService, + quotaService: quotaService, + } + + c, rec := makeGinContext("POST", "/api/v1/sora/generate", `{"model":"sora2-landscape-10s","prompt":"test"}`, 1) + h.Generate(c) + require.Equal(t, http.StatusOK, rec.Code) +} + +// ==================== Stub: SettingRepository (用于 S3 存储测试) ==================== + +var _ service.SettingRepository = (*stubSettingRepoForHandler)(nil) + +type stubSettingRepoForHandler struct { + values map[string]string +} + +func newStubSettingRepoForHandler(values map[string]string) *stubSettingRepoForHandler { + if values == nil { + values = make(map[string]string) + } + return &stubSettingRepoForHandler{values: values} +} + +func (r *stubSettingRepoForHandler) Get(_ context.Context, key string) (*service.Setting, error) { + if v, ok := r.values[key]; ok { + return &service.Setting{Key: key, Value: v}, nil + } + return nil, service.ErrSettingNotFound +} +func (r *stubSettingRepoForHandler) GetValue(_ context.Context, key string) (string, error) { + if v, ok := r.values[key]; ok { + return v, nil + } + return "", service.ErrSettingNotFound +} +func (r *stubSettingRepoForHandler) Set(_ context.Context, key, value string) error { + r.values[key] = value + return nil +} +func (r *stubSettingRepoForHandler) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + result := make(map[string]string) + for _, k := range keys { + if v, ok := r.values[k]; ok { + result[k] = v + } + } + return result, nil +} +func (r *stubSettingRepoForHandler) SetMultiple(_ context.Context, settings map[string]string) error { + for k, v := range settings { + r.values[k] = v + } + return nil +} +func (r *stubSettingRepoForHandler) GetAll(_ context.Context) (map[string]string, error) { + return r.values, nil +} +func (r *stubSettingRepoForHandler) Delete(_ context.Context, key string) error { + delete(r.values, key) + return nil +} + +// ==================== S3 / MediaStorage 辅助函数 ==================== + +// newS3StorageForHandler 创建指向指定 endpoint 的 S3Storage(用于测试)。 +func newS3StorageForHandler(endpoint string) *service.SoraS3Storage { + settingRepo := newStubSettingRepoForHandler(map[string]string{ + "sora_s3_enabled": "true", + "sora_s3_endpoint": endpoint, + "sora_s3_region": "us-east-1", + "sora_s3_bucket": "test-bucket", + "sora_s3_access_key_id": "AKIATEST", + "sora_s3_secret_access_key": "test-secret", + "sora_s3_prefix": "sora", + "sora_s3_force_path_style": "true", + }) + settingService := service.NewSettingService(settingRepo, &config.Config{}) + return service.NewSoraS3Storage(settingService) +} + +// newFakeSourceServer 创建返回固定内容的 HTTP 服务器(模拟上游媒体文件)。 +func newFakeSourceServer() *httptest.Server { + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "video/mp4") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("fake video data for test")) + })) +} + +// newFakeS3Server 创建模拟 S3 的 HTTP 服务器。 +// mode: "ok" 接受所有请求,"fail" 返回 403,"fail-second" 第一次成功第二次失败。 +func newFakeS3Server(mode string) *httptest.Server { + var counter atomic.Int32 + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.Copy(io.Discard, r.Body) + _ = r.Body.Close() + + switch mode { + case "ok": + w.Header().Set("ETag", `"test-etag"`) + w.WriteHeader(http.StatusOK) + case "fail": + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`AccessDenied`)) + case "fail-second": + n := counter.Add(1) + if n <= 1 { + w.Header().Set("ETag", `"test-etag"`) + w.WriteHeader(http.StatusOK) + } else { + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`AccessDenied`)) + } + } + })) +} + +// ==================== processGeneration 直接调用测试 ==================== + +func TestProcessGeneration_MarkGeneratingFails(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + repo.updateErr = fmt.Errorf("db error") + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService} + + // 直接调用(非 goroutine),MarkGenerating 失败 → 早退 + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + // MarkGenerating 在调用 repo.Update 前已修改内存对象为 "generating" + // repo.Update 返回错误 → processGeneration 早退,不会继续到 MarkFailed + // 因此 ErrorMessage 为空(证明未调用 MarkFailed) + require.Equal(t, "generating", repo.gens[1].Status) + require.Empty(t, repo.gens[1].ErrorMessage) +} + +func TestProcessGeneration_GatewayServiceNil(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService} + // gatewayService 未设置 → MarkFailed + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + require.Equal(t, "failed", repo.gens[1].Status) + require.Contains(t, repo.gens[1].ErrorMessage, "gatewayService") +} + +// ==================== storeMediaWithDegradation: S3 路径 ==================== + +func TestStoreMediaWithDegradation_S3SuccessSingleURL(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation( + context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, + ) + require.Equal(t, service.SoraStorageTypeS3, storageType) + require.Len(t, s3Keys, 1) + require.NotEmpty(t, s3Keys[0]) + require.Len(t, storedURLs, 1) + require.Equal(t, storedURL, storedURLs[0]) + require.Contains(t, storedURL, fakeS3.URL) + require.Contains(t, storedURL, "/test-bucket/") + require.Greater(t, fileSize, int64(0)) +} + +func TestStoreMediaWithDegradation_S3SuccessMultiURL(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"} + storedURL, storedURLs, storageType, s3Keys, fileSize := h.storeMediaWithDegradation( + context.Background(), 1, "video", sourceServer.URL+"/a.mp4", urls, + ) + require.Equal(t, service.SoraStorageTypeS3, storageType) + require.Len(t, s3Keys, 2) + require.Len(t, storedURLs, 2) + require.Equal(t, storedURL, storedURLs[0]) + require.Contains(t, storedURLs[0], fakeS3.URL) + require.Contains(t, storedURLs[1], fakeS3.URL) + require.Greater(t, fileSize, int64(0)) +} + +func TestStoreMediaWithDegradation_S3DownloadFails(t *testing.T) { + // 上游返回 404 → 下载失败 → S3 上传不会开始 + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + badSource := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusNotFound) + })) + defer badSource.Close() + + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + _, _, storageType, _, _ := h.storeMediaWithDegradation( + context.Background(), 1, "video", badSource.URL+"/missing.mp4", nil, + ) + require.Equal(t, service.SoraStorageTypeUpstream, storageType) +} + +func TestStoreMediaWithDegradation_S3FailsSingleURL(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("fail") + defer fakeS3.Close() + + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation( + context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, + ) + // S3 失败,降级到 upstream + require.Equal(t, service.SoraStorageTypeUpstream, storageType) + require.Nil(t, s3Keys) +} + +func TestStoreMediaWithDegradation_S3PartialFailureCleanup(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("fail-second") + defer fakeS3.Close() + + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + urls := []string{sourceServer.URL + "/a.mp4", sourceServer.URL + "/b.mp4"} + _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation( + context.Background(), 1, "video", sourceServer.URL+"/a.mp4", urls, + ) + // 第二个 URL 上传失败 → 清理已上传 → 降级到 upstream + require.Equal(t, service.SoraStorageTypeUpstream, storageType) + require.Nil(t, s3Keys) +} + +// ==================== storeMediaWithDegradation: 本地存储路径 ==================== + +func TestStoreMediaWithDegradation_LocalStorageFails(t *testing.T) { + // 使用无效路径,EnsureLocalDirs 失败 → StoreFromURLs 返回 error + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: "/dev/null/invalid_dir", + }, + }, + } + mediaStorage := service.NewSoraMediaStorage(cfg) + h := &SoraClientHandler{mediaStorage: mediaStorage} + + _, _, storageType, _, _ := h.storeMediaWithDegradation( + context.Background(), 1, "video", "https://upstream.com/v.mp4", nil, + ) + // 本地存储失败,降级到 upstream + require.Equal(t, service.SoraStorageTypeUpstream, storageType) +} + +func TestStoreMediaWithDegradation_LocalStorageSuccess(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "sora-handler-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + DownloadTimeoutSeconds: 5, + MaxDownloadBytes: 10 * 1024 * 1024, + }, + }, + } + mediaStorage := service.NewSoraMediaStorage(cfg) + h := &SoraClientHandler{mediaStorage: mediaStorage} + + _, _, storageType, s3Keys, _ := h.storeMediaWithDegradation( + context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, + ) + require.Equal(t, service.SoraStorageTypeLocal, storageType) + require.Nil(t, s3Keys) // 本地存储不返回 S3 keys +} + +func TestStoreMediaWithDegradation_S3FailsFallbackToLocal(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "sora-handler-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("fail") + defer fakeS3.Close() + + s3Storage := newS3StorageForHandler(fakeS3.URL) + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + DownloadTimeoutSeconds: 5, + MaxDownloadBytes: 10 * 1024 * 1024, + }, + }, + } + mediaStorage := service.NewSoraMediaStorage(cfg) + h := &SoraClientHandler{ + s3Storage: s3Storage, + mediaStorage: mediaStorage, + } + + _, _, storageType, _, _ := h.storeMediaWithDegradation( + context.Background(), 1, "video", sourceServer.URL+"/v.mp4", nil, + ) + // S3 失败 → 本地存储成功 + require.Equal(t, service.SoraStorageTypeLocal, storageType) +} + +// ==================== SaveToStorage: S3 路径 ==================== + +func TestSaveToStorage_S3EnabledButUploadFails(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("fail") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v.mp4", + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) + resp := parseResponse(t, rec) + require.Contains(t, resp["message"], "S3") +} + +func TestSaveToStorage_UpstreamURLExpired(t *testing.T) { + expiredServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusForbidden) + })) + defer expiredServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: expiredServer.URL + "/v.mp4", + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusGone, rec.Code) + resp := parseResponse(t, rec) + require.Contains(t, fmt.Sprint(resp["message"]), "过期") +} + +func TestSaveToStorage_S3EnabledUploadSuccess(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v.mp4", + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Contains(t, data["message"], "S3") + require.NotEmpty(t, data["object_key"]) + // 验证记录已更新为 S3 存储 + require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType) +} + +func TestSaveToStorage_S3EnabledUploadSuccess_MultiMediaURLs(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v1.mp4", + MediaURLs: []string{ + sourceServer.URL + "/v1.mp4", + sourceServer.URL + "/v2.mp4", + }, + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Len(t, data["object_keys"].([]any), 2) + require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType) + require.Len(t, repo.gens[1].S3ObjectKeys, 2) + require.Len(t, repo.gens[1].MediaURLs, 2) +} + +func TestSaveToStorage_S3EnabledUploadSuccessWithQuota(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v.mp4", + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + + userRepo := newStubUserRepoForHandler() + userRepo.users[1] = &service.User{ + ID: 1, + SoraStorageQuotaBytes: 100 * 1024 * 1024, + SoraStorageUsedBytes: 0, + } + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusOK, rec.Code) + // 验证配额已累加 + require.Greater(t, userRepo.users[1].SoraStorageUsedBytes, int64(0)) +} + +func TestSaveToStorage_S3UploadSuccessMarkCompletedFails(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v.mp4", + } + // S3 上传成功后,MarkCompleted 会调用 repo.Update → 失败 + repo.updateErr = fmt.Errorf("db error") + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +// ==================== GetStorageStatus: S3 路径 ==================== + +func TestGetStorageStatus_S3EnabledNotHealthy(t *testing.T) { + // S3 启用但 TestConnection 失败(fake 端点不响应 HeadBucket) + fakeS3 := newFakeS3Server("fail") + defer fakeS3.Close() + + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) + h.GetStorageStatus(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Equal(t, true, data["s3_enabled"]) + require.Equal(t, false, data["s3_healthy"]) +} + +func TestGetStorageStatus_S3EnabledHealthy(t *testing.T) { + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + c, rec := makeGinContext("GET", "/api/v1/sora/storage-status", "", 0) + h.GetStorageStatus(c) + require.Equal(t, http.StatusOK, rec.Code) + resp := parseResponse(t, rec) + data := resp["data"].(map[string]any) + require.Equal(t, true, data["s3_enabled"]) + require.Equal(t, true, data["s3_healthy"]) +} + +// ==================== Stub: AccountRepository (用于 GatewayService) ==================== + +var _ service.AccountRepository = (*stubAccountRepoForHandler)(nil) + +type stubAccountRepoForHandler struct { + accounts []service.Account +} + +func (r *stubAccountRepoForHandler) Create(context.Context, *service.Account) error { return nil } +func (r *stubAccountRepoForHandler) GetByID(_ context.Context, id int64) (*service.Account, error) { + for i := range r.accounts { + if r.accounts[i].ID == id { + return &r.accounts[i], nil + } + } + return nil, fmt.Errorf("account not found") +} +func (r *stubAccountRepoForHandler) GetByIDs(context.Context, []int64) ([]*service.Account, error) { + return nil, nil +} +func (r *stubAccountRepoForHandler) ExistsByID(context.Context, int64) (bool, error) { + return false, nil +} +func (r *stubAccountRepoForHandler) GetByCRSAccountID(context.Context, string) (*service.Account, error) { + return nil, nil +} +func (r *stubAccountRepoForHandler) FindByExtraField(context.Context, string, any) ([]service.Account, error) { + return nil, nil +} +func (r *stubAccountRepoForHandler) ListCRSAccountIDs(context.Context) (map[string]int64, error) { + return nil, nil +} +func (r *stubAccountRepoForHandler) Update(context.Context, *service.Account) error { return nil } +func (r *stubAccountRepoForHandler) Delete(context.Context, int64) error { return nil } +func (r *stubAccountRepoForHandler) List(context.Context, pagination.PaginationParams) ([]service.Account, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubAccountRepoForHandler) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, string, int64) ([]service.Account, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubAccountRepoForHandler) ListByGroup(context.Context, int64) ([]service.Account, error) { + return nil, nil +} +func (r *stubAccountRepoForHandler) ListActive(context.Context) ([]service.Account, error) { + return nil, nil +} +func (r *stubAccountRepoForHandler) ListByPlatform(context.Context, string) ([]service.Account, error) { + return nil, nil +} +func (r *stubAccountRepoForHandler) UpdateLastUsed(context.Context, int64) error { return nil } +func (r *stubAccountRepoForHandler) BatchUpdateLastUsed(context.Context, map[int64]time.Time) error { + return nil +} +func (r *stubAccountRepoForHandler) SetError(context.Context, int64, string) error { return nil } +func (r *stubAccountRepoForHandler) ClearError(context.Context, int64) error { return nil } +func (r *stubAccountRepoForHandler) SetSchedulable(context.Context, int64, bool) error { + return nil +} +func (r *stubAccountRepoForHandler) AutoPauseExpiredAccounts(context.Context, time.Time) (int64, error) { + return 0, nil +} +func (r *stubAccountRepoForHandler) BindGroups(context.Context, int64, []int64) error { return nil } +func (r *stubAccountRepoForHandler) ListSchedulable(context.Context) ([]service.Account, error) { + return r.accounts, nil +} +func (r *stubAccountRepoForHandler) ListSchedulableByGroupID(context.Context, int64) ([]service.Account, error) { + return r.accounts, nil +} +func (r *stubAccountRepoForHandler) ListSchedulableByPlatform(_ context.Context, _ string) ([]service.Account, error) { + return r.accounts, nil +} +func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatform(context.Context, int64, string) ([]service.Account, error) { + return r.accounts, nil +} +func (r *stubAccountRepoForHandler) ListSchedulableByPlatforms(context.Context, []string) ([]service.Account, error) { + return r.accounts, nil +} +func (r *stubAccountRepoForHandler) ListSchedulableByGroupIDAndPlatforms(context.Context, int64, []string) ([]service.Account, error) { + return r.accounts, nil +} +func (r *stubAccountRepoForHandler) SetRateLimited(context.Context, int64, time.Time) error { + return nil +} +func (r *stubAccountRepoForHandler) SetModelRateLimit(context.Context, int64, string, time.Time) error { + return nil +} +func (r *stubAccountRepoForHandler) SetOverloaded(context.Context, int64, time.Time) error { + return nil +} +func (r *stubAccountRepoForHandler) SetTempUnschedulable(context.Context, int64, time.Time, string) error { + return nil +} +func (r *stubAccountRepoForHandler) ClearTempUnschedulable(context.Context, int64) error { return nil } +func (r *stubAccountRepoForHandler) ClearRateLimit(context.Context, int64) error { return nil } +func (r *stubAccountRepoForHandler) ClearAntigravityQuotaScopes(context.Context, int64) error { + return nil +} +func (r *stubAccountRepoForHandler) ClearModelRateLimits(context.Context, int64) error { return nil } +func (r *stubAccountRepoForHandler) UpdateSessionWindow(context.Context, int64, *time.Time, *time.Time, string) error { + return nil +} +func (r *stubAccountRepoForHandler) UpdateExtra(context.Context, int64, map[string]any) error { + return nil +} +func (r *stubAccountRepoForHandler) BulkUpdate(context.Context, []int64, service.AccountBulkUpdate) (int64, error) { + return 0, nil +} + +// ==================== Stub: SoraClient (用于 SoraGatewayService) ==================== + +var _ service.SoraClient = (*stubSoraClientForHandler)(nil) + +type stubSoraClientForHandler struct { + videoStatus *service.SoraVideoTaskStatus +} + +func (s *stubSoraClientForHandler) Enabled() bool { return true } +func (s *stubSoraClientForHandler) UploadImage(context.Context, *service.Account, []byte, string) (string, error) { + return "", nil +} +func (s *stubSoraClientForHandler) CreateImageTask(context.Context, *service.Account, service.SoraImageRequest) (string, error) { + return "task-image", nil +} +func (s *stubSoraClientForHandler) CreateVideoTask(context.Context, *service.Account, service.SoraVideoRequest) (string, error) { + return "task-video", nil +} +func (s *stubSoraClientForHandler) CreateStoryboardTask(context.Context, *service.Account, service.SoraStoryboardRequest) (string, error) { + return "task-video", nil +} +func (s *stubSoraClientForHandler) UploadCharacterVideo(context.Context, *service.Account, []byte) (string, error) { + return "", nil +} +func (s *stubSoraClientForHandler) GetCameoStatus(context.Context, *service.Account, string) (*service.SoraCameoStatus, error) { + return nil, nil +} +func (s *stubSoraClientForHandler) DownloadCharacterImage(context.Context, *service.Account, string) ([]byte, error) { + return nil, nil +} +func (s *stubSoraClientForHandler) UploadCharacterImage(context.Context, *service.Account, []byte) (string, error) { + return "", nil +} +func (s *stubSoraClientForHandler) FinalizeCharacter(context.Context, *service.Account, service.SoraCharacterFinalizeRequest) (string, error) { + return "", nil +} +func (s *stubSoraClientForHandler) SetCharacterPublic(context.Context, *service.Account, string) error { + return nil +} +func (s *stubSoraClientForHandler) DeleteCharacter(context.Context, *service.Account, string) error { + return nil +} +func (s *stubSoraClientForHandler) PostVideoForWatermarkFree(context.Context, *service.Account, string) (string, error) { + return "", nil +} +func (s *stubSoraClientForHandler) DeletePost(context.Context, *service.Account, string) error { + return nil +} +func (s *stubSoraClientForHandler) GetWatermarkFreeURLCustom(context.Context, *service.Account, string, string, string) (string, error) { + return "", nil +} +func (s *stubSoraClientForHandler) EnhancePrompt(context.Context, *service.Account, string, string, int) (string, error) { + return "", nil +} +func (s *stubSoraClientForHandler) GetImageTask(context.Context, *service.Account, string) (*service.SoraImageTaskStatus, error) { + return nil, nil +} +func (s *stubSoraClientForHandler) GetVideoTask(_ context.Context, _ *service.Account, _ string) (*service.SoraVideoTaskStatus, error) { + return s.videoStatus, nil +} + +// ==================== 辅助:创建最小 GatewayService 和 SoraGatewayService ==================== + +// newMinimalGatewayService 创建仅包含 accountRepo 的最小 GatewayService(用于测试 SelectAccountForModel)。 +func newMinimalGatewayService(accountRepo service.AccountRepository) *service.GatewayService { + return service.NewGatewayService( + accountRepo, nil, nil, nil, nil, nil, nil, nil, + nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, + ) +} + +// newMinimalSoraGatewayService 创建最小 SoraGatewayService(用于测试 Forward)。 +func newMinimalSoraGatewayService(soraClient service.SoraClient) *service.SoraGatewayService { + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + return service.NewSoraGatewayService(soraClient, nil, nil, cfg) +} + +// ==================== processGeneration: 更多路径测试 ==================== + +func TestProcessGeneration_SelectAccountError(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + // accountRepo 返回空列表 → SelectAccountForModel 返回 "no available accounts" + accountRepo := &stubAccountRepoForHandler{accounts: nil} + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{genService: genService, gatewayService: gatewayService} + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + require.Equal(t, "failed", repo.gens[1].Status) + require.Contains(t, repo.gens[1].ErrorMessage, "选择账号失败") +} + +func TestProcessGeneration_SoraGatewayServiceNil(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + // 提供可用账号使 SelectAccountForModel 成功 + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + // soraGatewayService 为 nil + h := &SoraClientHandler{genService: genService, gatewayService: gatewayService} + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + require.Equal(t, "failed", repo.gens[1].Status) + require.Contains(t, repo.gens[1].ErrorMessage, "soraGatewayService") +} + +func TestProcessGeneration_ForwardError(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + // SoraClient 返回视频任务失败 + soraClient := &stubSoraClientForHandler{ + videoStatus: &service.SoraVideoTaskStatus{ + Status: "failed", + ErrorMsg: "content policy violation", + }, + } + soraGatewayService := newMinimalSoraGatewayService(soraClient) + h := &SoraClientHandler{ + genService: genService, + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + } + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1) + require.Equal(t, "failed", repo.gens[1].Status) + require.Contains(t, repo.gens[1].ErrorMessage, "生成失败") +} + +func TestProcessGeneration_ForwardErrorCancelled(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + // MarkGenerating 内部调用 GetByID(第 1 次),Forward 失败后 processGeneration + // 调用 GetByID(第 2 次)。模拟外部在 Forward 期间取消了任务。 + repo.getByIDOverrideAfterN = 1 + repo.getByIDOverrideStatus = "cancelled" + genService := service.NewSoraGenerationService(repo, nil, nil) + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + soraClient := &stubSoraClientForHandler{ + videoStatus: &service.SoraVideoTaskStatus{Status: "failed", ErrorMsg: "reject"}, + } + soraGatewayService := newMinimalSoraGatewayService(soraClient) + h := &SoraClientHandler{ + genService: genService, + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + } + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + // Forward 失败后检测到外部取消,不应调用 MarkFailed(状态保持 generating) + require.Equal(t, "generating", repo.gens[1].Status) +} + +func TestProcessGeneration_ForwardSuccessNoMediaURL(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + // SoraClient 返回 completed 但无 URL + soraClient := &stubSoraClientForHandler{ + videoStatus: &service.SoraVideoTaskStatus{ + Status: "completed", + URLs: nil, // 无 URL + }, + } + soraGatewayService := newMinimalSoraGatewayService(soraClient) + h := &SoraClientHandler{ + genService: genService, + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + } + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + require.Equal(t, "failed", repo.gens[1].Status) + require.Contains(t, repo.gens[1].ErrorMessage, "未获取到媒体 URL") +} + +func TestProcessGeneration_ForwardSuccessCancelledBeforeStore(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + // MarkGenerating 调用 GetByID(第 1 次),之后 processGeneration 行 176 调用 GetByID(第 2 次) + // 第 2 次返回 "cancelled" 状态,模拟外部取消 + repo.getByIDOverrideAfterN = 1 + repo.getByIDOverrideStatus = "cancelled" + genService := service.NewSoraGenerationService(repo, nil, nil) + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + soraClient := &stubSoraClientForHandler{ + videoStatus: &service.SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/video.mp4"}, + }, + } + soraGatewayService := newMinimalSoraGatewayService(soraClient) + h := &SoraClientHandler{ + genService: genService, + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + } + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + // Forward 成功后检测到外部取消,不应调用存储和 MarkCompleted(状态保持 generating) + require.Equal(t, "generating", repo.gens[1].Status) +} + +func TestProcessGeneration_FullSuccessUpstream(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + soraClient := &stubSoraClientForHandler{ + videoStatus: &service.SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/video.mp4"}, + }, + } + soraGatewayService := newMinimalSoraGatewayService(soraClient) + // 无 S3 和本地存储,降级到 upstream + h := &SoraClientHandler{ + genService: genService, + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + } + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1) + require.Equal(t, "completed", repo.gens[1].Status) + require.Equal(t, service.SoraStorageTypeUpstream, repo.gens[1].StorageType) + require.NotEmpty(t, repo.gens[1].MediaURL) +} + +func TestProcessGeneration_FullSuccessWithS3(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + soraClient := &stubSoraClientForHandler{ + videoStatus: &service.SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{sourceServer.URL + "/video.mp4"}, + }, + } + soraGatewayService := newMinimalSoraGatewayService(soraClient) + s3Storage := newS3StorageForHandler(fakeS3.URL) + + userRepo := newStubUserRepoForHandler() + userRepo.users[1] = &service.User{ + ID: 1, SoraStorageQuotaBytes: 100 * 1024 * 1024, + } + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + + h := &SoraClientHandler{ + genService: genService, + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + s3Storage: s3Storage, + quotaService: quotaService, + } + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1) + require.Equal(t, "completed", repo.gens[1].Status) + require.Equal(t, service.SoraStorageTypeS3, repo.gens[1].StorageType) + require.NotEmpty(t, repo.gens[1].S3ObjectKeys) + require.Greater(t, repo.gens[1].FileSizeBytes, int64(0)) + // 验证配额已累加 + require.Greater(t, userRepo.users[1].SoraStorageUsedBytes, int64(0)) +} + +func TestProcessGeneration_MarkCompletedFails(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora processGeneration 集成测试,待流程稳定后恢复") + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + // 第 1 次 Update(MarkGenerating)成功,第 2 次(MarkCompleted)失败 + repo.updateCallCount = new(int32) + repo.updateFailAfterN = 1 + genService := service.NewSoraGenerationService(repo, nil, nil) + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + soraClient := &stubSoraClientForHandler{ + videoStatus: &service.SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/video.mp4"}, + }, + } + soraGatewayService := newMinimalSoraGatewayService(soraClient) + h := &SoraClientHandler{ + genService: genService, + gatewayService: gatewayService, + soraGatewayService: soraGatewayService, + } + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test prompt", "video", "", 1) + // MarkCompleted 内部先修改内存对象状态为 completed,然后 Update 失败。 + // 由于 stub 存储的是指针,内存中的状态已被修改为 completed。 + // 此测试验证 processGeneration 在 MarkCompleted 失败后提前返回(不调用 AddUsage)。 + require.Equal(t, "completed", repo.gens[1].Status) +} + +// ==================== cleanupStoredMedia 直接测试 ==================== + +func TestCleanupStoredMedia_S3Path(t *testing.T) { + // S3 清理路径:s3Storage 为 nil 时不 panic + h := &SoraClientHandler{} + // 不应 panic + h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil) +} + +func TestCleanupStoredMedia_LocalPath(t *testing.T) { + // 本地清理路径:mediaStorage 为 nil 时不 panic + h := &SoraClientHandler{} + h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, []string{"/tmp/test.mp4"}) +} + +func TestCleanupStoredMedia_UpstreamPath(t *testing.T) { + // upstream 类型不清理 + h := &SoraClientHandler{} + h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeUpstream, nil, nil) +} + +func TestCleanupStoredMedia_EmptyKeys(t *testing.T) { + // 空 keys 不触发清理 + h := &SoraClientHandler{} + h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, nil, nil) + h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, nil) +} + +// ==================== DeleteGeneration: 本地存储清理路径 ==================== + +func TestDeleteGeneration_LocalStorageCleanup(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "sora-delete-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + }, + }, + } + mediaStorage := service.NewSoraMediaStorage(cfg) + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, + UserID: 1, + Status: "completed", + StorageType: service.SoraStorageTypeLocal, + MediaURL: "video/test.mp4", + MediaURLs: []string{"video/test.mp4"}, + } + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage} + + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusOK, rec.Code) + _, exists := repo.gens[1] + require.False(t, exists) +} + +func TestDeleteGeneration_LocalStorageCleanup_MediaURLFallback(t *testing.T) { + // MediaURLs 为空,使用 MediaURL 作为清理路径 + tmpDir, err := os.MkdirTemp("", "sora-delete-fallback-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + }, + }, + } + mediaStorage := service.NewSoraMediaStorage(cfg) + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, + UserID: 1, + Status: "completed", + StorageType: service.SoraStorageTypeLocal, + MediaURL: "video/test.mp4", + MediaURLs: nil, // 空 + } + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage} + + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusOK, rec.Code) +} + +func TestDeleteGeneration_NonLocalStorage_SkipCleanup(t *testing.T) { + // 非本地存储类型 → 跳过清理 + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, + UserID: 1, + Status: "completed", + StorageType: service.SoraStorageTypeUpstream, + MediaURL: "https://upstream.com/v.mp4", + } + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService} + + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusOK, rec.Code) +} + +func TestDeleteGeneration_DeleteError(t *testing.T) { + // repo.Delete 出错 + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed", StorageType: "upstream"} + repo.deleteErr = fmt.Errorf("delete failed") + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService} + + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusNotFound, rec.Code) +} + +// ==================== fetchUpstreamModels 测试 ==================== + +func TestFetchUpstreamModels_NilGateway(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") + h := &SoraClientHandler{} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "gatewayService 未初始化") +} + +func TestFetchUpstreamModels_NoAccounts(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") + accountRepo := &stubAccountRepoForHandler{accounts: nil} + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "选择 Sora 账号失败") +} + +func TestFetchUpstreamModels_NonAPIKeyAccount(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: "oauth", Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "不支持模型同步") +} + +func TestFetchUpstreamModels_MissingAPIKey(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, + Credentials: map[string]any{"base_url": "https://sora.test"}}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "api_key") +} + +func TestFetchUpstreamModels_MissingBaseURL_FallsBackToDefault(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") + // GetBaseURL() 在缺少 base_url 时返回默认值 "https://api.anthropic.com" + // 因此不会触发 "账号缺少 base_url" 错误,而是会尝试请求默认 URL 并失败 + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, + Credentials: map[string]any{"api_key": "sk-test"}}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) +} + +func TestFetchUpstreamModels_UpstreamReturns500(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + })) + defer ts.Close() + + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, + Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "状态码 500") +} + +func TestFetchUpstreamModels_UpstreamReturnsInvalidJSON(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("not json")) + })) + defer ts.Close() + + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, + Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "解析响应失败") +} + +func TestFetchUpstreamModels_UpstreamReturnsEmptyList(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"data":[]}`)) + })) + defer ts.Close() + + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, + Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "空模型列表") +} + +func TestFetchUpstreamModels_Success(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // 验证请求头 + require.Equal(t, "Bearer sk-test", r.Header.Get("Authorization")) + require.True(t, strings.HasSuffix(r.URL.Path, "/sora/v1/models")) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"sora2-portrait-10s"},{"id":"sora2-landscape-15s"},{"id":"gpt-image"}]}`)) + })) + defer ts.Close() + + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, + Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + families, err := h.fetchUpstreamModels(context.Background()) + require.NoError(t, err) + require.NotEmpty(t, families) +} + +func TestFetchUpstreamModels_UnrecognizedModels(t *testing.T) { + t.Skip("TODO: 临时屏蔽 Sora 上游模型同步相关测试,待账号选择逻辑稳定后恢复") + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"data":[{"id":"unknown-model-1"},{"id":"unknown-model-2"}]}`)) + })) + defer ts.Close() + + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, + Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + _, err := h.fetchUpstreamModels(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "未能从上游模型列表中识别") +} + +// ==================== getModelFamilies 缓存测试 ==================== + +func TestGetModelFamilies_CachesLocalConfig(t *testing.T) { + // gatewayService 为 nil → fetchUpstreamModels 失败 → 降级到本地配置 + h := &SoraClientHandler{} + families := h.getModelFamilies(context.Background()) + require.NotEmpty(t, families) + + // 第二次调用应命中缓存(modelCacheUpstream=false → 使用短 TTL) + families2 := h.getModelFamilies(context.Background()) + require.Equal(t, families, families2) + require.False(t, h.modelCacheUpstream) +} + +func TestGetModelFamilies_CachesUpstreamResult(t *testing.T) { + t.Skip("TODO: 临时屏蔽依赖 Sora 上游模型同步的缓存测试,待账号选择逻辑稳定后恢复") + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"data":[{"id":"sora2-landscape-10s"},{"id":"gpt-image"}]}`)) + })) + defer ts.Close() + + accountRepo := &stubAccountRepoForHandler{ + accounts: []service.Account{ + {ID: 1, Type: service.AccountTypeAPIKey, Platform: service.PlatformSora, Status: service.StatusActive, Schedulable: true, + Credentials: map[string]any{"api_key": "sk-test", "base_url": ts.URL}}, + }, + } + gatewayService := newMinimalGatewayService(accountRepo) + h := &SoraClientHandler{gatewayService: gatewayService} + + families := h.getModelFamilies(context.Background()) + require.NotEmpty(t, families) + require.True(t, h.modelCacheUpstream) + + // 第二次调用命中缓存 + families2 := h.getModelFamilies(context.Background()) + require.Equal(t, families, families2) +} + +func TestGetModelFamilies_ExpiredCacheRefreshes(t *testing.T) { + // 预设过期的缓存(modelCacheUpstream=false → 短 TTL) + h := &SoraClientHandler{ + cachedFamilies: []service.SoraModelFamily{{ID: "old"}}, + modelCacheTime: time.Now().Add(-10 * time.Minute), // 已过期 + modelCacheUpstream: false, + } + // gatewayService 为 nil → fetchUpstreamModels 失败 → 使用本地配置刷新缓存 + families := h.getModelFamilies(context.Background()) + require.NotEmpty(t, families) + // 缓存已刷新,不再是 "old" + found := false + for _, f := range families { + if f.ID == "old" { + found = true + } + } + require.False(t, found, "过期缓存应被刷新") +} + +// ==================== processGeneration: groupID 与 ForcePlatform ==================== + +func TestProcessGeneration_NilGroupID_WithGateway_SelectAccountFails(t *testing.T) { + // groupID 为 nil → 设置 ForcePlatform=sora → 无可用 sora 账号 → MarkFailed + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "pending"} + genService := service.NewSoraGenerationService(repo, nil, nil) + + // 空账号列表 → SelectAccountForModel 失败 + accountRepo := &stubAccountRepoForHandler{accounts: nil} + gatewayService := newMinimalGatewayService(accountRepo) + + h := &SoraClientHandler{ + genService: genService, + gatewayService: gatewayService, + } + + h.processGeneration(1, 1, nil, "sora2-landscape-10s", "test", "video", "", 1) + require.Equal(t, "failed", repo.gens[1].Status) + require.Contains(t, repo.gens[1].ErrorMessage, "选择账号失败") +} + +// ==================== Generate: 配额检查非 QuotaExceeded 错误 ==================== + +func TestGenerate_CheckQuotaNonQuotaError(t *testing.T) { + // quotaService.CheckQuota 返回非 QuotaExceededError → 返回 403 + repo := newStubSoraGenRepo() + genService := service.NewSoraGenerationService(repo, nil, nil) + + // 用户不存在 → GetByID 失败 → CheckQuota 返回普通 error + userRepo := newStubUserRepoForHandler() + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + + h := NewSoraClientHandler(genService, quotaService, nil, nil, nil, nil, nil) + + body := `{"model":"sora2-landscape-10s","prompt":"test"}` + c, rec := makeGinContext("POST", "/api/v1/sora/generate", body, 1) + h.Generate(c) + require.Equal(t, http.StatusForbidden, rec.Code) +} + +// ==================== Generate: CreatePending 并发限制错误 ==================== + +// stubSoraGenRepoWithAtomicCreate 实现 soraGenerationRepoAtomicCreator 接口 +type stubSoraGenRepoWithAtomicCreate struct { + stubSoraGenRepo + limitErr error +} + +func (r *stubSoraGenRepoWithAtomicCreate) CreatePendingWithLimit(_ context.Context, gen *service.SoraGeneration, _ []string, _ int64) error { + if r.limitErr != nil { + return r.limitErr + } + return r.stubSoraGenRepo.Create(context.Background(), gen) +} + +func TestGenerate_CreatePendingConcurrencyLimit(t *testing.T) { + repo := &stubSoraGenRepoWithAtomicCreate{ + stubSoraGenRepo: *newStubSoraGenRepo(), + limitErr: service.ErrSoraGenerationConcurrencyLimit, + } + genService := service.NewSoraGenerationService(repo, nil, nil) + h := NewSoraClientHandler(genService, nil, nil, nil, nil, nil, nil) + + body := `{"model":"sora2-landscape-10s","prompt":"test"}` + c, rec := makeGinContext("POST", "/api/v1/sora/generate", body, 1) + h.Generate(c) + require.Equal(t, http.StatusTooManyRequests, rec.Code) + resp := parseResponse(t, rec) + require.Contains(t, resp["message"], "3") +} + +// ==================== SaveToStorage: 配额超限 ==================== + +func TestSaveToStorage_QuotaExceeded(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v.mp4", + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + + // 用户配额已满 + userRepo := newStubUserRepoForHandler() + userRepo.users[1] = &service.User{ + ID: 1, + SoraStorageQuotaBytes: 10, + SoraStorageUsedBytes: 10, + } + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusTooManyRequests, rec.Code) +} + +// ==================== SaveToStorage: 配额非 QuotaExceeded 错误 ==================== + +func TestSaveToStorage_QuotaNonQuotaError(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v.mp4", + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + + // 用户不存在 → GetByID 失败 → AddUsage 返回普通 error + userRepo := newStubUserRepoForHandler() + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +// ==================== SaveToStorage: MediaURLs 全为空 ==================== + +func TestSaveToStorage_EmptyMediaURLs(t *testing.T) { + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: "", + MediaURLs: []string{}, + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusBadRequest, rec.Code) + resp := parseResponse(t, rec) + require.Contains(t, resp["message"], "已过期") +} + +// ==================== SaveToStorage: S3 上传失败时已有已上传文件需清理 ==================== + +func TestSaveToStorage_MultiURL_SecondUploadFails(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("fail-second") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v1.mp4", + MediaURLs: []string{sourceServer.URL + "/v1.mp4", sourceServer.URL + "/v2.mp4"}, + } + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +// ==================== SaveToStorage: UpdateStorageForCompleted 失败(含配额回滚) ==================== + +func TestSaveToStorage_MarkCompletedFailsWithQuotaRollback(t *testing.T) { + sourceServer := newFakeSourceServer() + defer sourceServer.Close() + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: "upstream", + MediaURL: sourceServer.URL + "/v.mp4", + } + repo.updateErr = fmt.Errorf("db error") + s3Storage := newS3StorageForHandler(fakeS3.URL) + genService := service.NewSoraGenerationService(repo, nil, nil) + + userRepo := newStubUserRepoForHandler() + userRepo.users[1] = &service.User{ + ID: 1, + SoraStorageQuotaBytes: 100 * 1024 * 1024, + SoraStorageUsedBytes: 0, + } + quotaService := service.NewSoraQuotaService(userRepo, nil, nil) + h := &SoraClientHandler{genService: genService, s3Storage: s3Storage, quotaService: quotaService} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/save", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.SaveToStorage(c) + require.Equal(t, http.StatusInternalServerError, rec.Code) +} + +// ==================== cleanupStoredMedia: 实际 S3 删除路径 ==================== + +func TestCleanupStoredMedia_WithS3Storage_ActualDelete(t *testing.T) { + fakeS3 := newFakeS3Server("ok") + defer fakeS3.Close() + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1", "key2"}, nil) +} + +func TestCleanupStoredMedia_S3DeleteFails_LogOnly(t *testing.T) { + fakeS3 := newFakeS3Server("fail") + defer fakeS3.Close() + s3Storage := newS3StorageForHandler(fakeS3.URL) + h := &SoraClientHandler{s3Storage: s3Storage} + + h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeS3, []string{"key1"}, nil) +} + +func TestCleanupStoredMedia_LocalDeleteFails_LogOnly(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "sora-cleanup-fail-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + }, + }, + } + mediaStorage := service.NewSoraMediaStorage(cfg) + h := &SoraClientHandler{mediaStorage: mediaStorage} + + h.cleanupStoredMedia(context.Background(), service.SoraStorageTypeLocal, nil, []string{"nonexistent/file.mp4"}) +} + +// ==================== DeleteGeneration: 本地文件删除失败(仅日志) ==================== + +func TestDeleteGeneration_LocalStorageDeleteFails_LogOnly(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "sora-del-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + cfg := &config.Config{ + Sora: config.SoraConfig{ + Storage: config.SoraStorageConfig{ + Type: "local", + LocalPath: tmpDir, + }, + }, + } + mediaStorage := service.NewSoraMediaStorage(cfg) + + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ + ID: 1, UserID: 1, Status: "completed", + StorageType: service.SoraStorageTypeLocal, + MediaURL: "nonexistent/video.mp4", + MediaURLs: []string{"nonexistent/video.mp4"}, + } + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService, mediaStorage: mediaStorage} + + c, rec := makeGinContext("DELETE", "/api/v1/sora/generations/1", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.DeleteGeneration(c) + require.Equal(t, http.StatusOK, rec.Code) +} + +// ==================== CancelGeneration: 任务已结束冲突 ==================== + +func TestCancelGeneration_AlreadyCompleted(t *testing.T) { + repo := newStubSoraGenRepo() + repo.gens[1] = &service.SoraGeneration{ID: 1, UserID: 1, Status: "completed"} + genService := service.NewSoraGenerationService(repo, nil, nil) + h := &SoraClientHandler{genService: genService} + + c, rec := makeGinContext("POST", "/api/v1/sora/generations/1/cancel", "", 1) + c.Params = gin.Params{{Key: "id", Value: "1"}} + h.CancelGeneration(c) + require.Equal(t, http.StatusConflict, rec.Code) +} diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go index ab3a3f14..48c1e451 100644 --- a/backend/internal/handler/sora_gateway_handler.go +++ b/backend/internal/handler/sora_gateway_handler.go @@ -7,7 +7,6 @@ import ( "encoding/json" "errors" "fmt" - "io" "net/http" "os" "path" @@ -17,6 +16,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" "github.com/Wei-Shaw/sub2api/internal/pkg/ip" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" @@ -107,7 +107,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) { zap.Any("group_id", apiKey.GroupID), ) - body, err := io.ReadAll(c.Request.Body) + body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request) if err != nil { if maxErr, ok := extractMaxBytesError(err); ok { h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit)) @@ -461,6 +461,14 @@ func (h *SoraGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) // 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() + defer func() { + if recovered := recover(); recovered != nil { + logger.L().With( + zap.String("component", "handler.sora_gateway.chat_completions"), + zap.Any("panic", recovered), + ).Error("sora.usage_record_task_panic_recovered") + } + }() task(ctx) } diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index cc792350..01c684ca 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -314,10 +314,10 @@ func (s *stubUsageLogRepo) GetAccountTodayStats(ctx context.Context, accountID i func (s *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) { return nil, nil } -func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) { +func (s *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) { return nil, nil } -func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) { +func (s *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) { return nil, nil } func (s *stubUsageLogRepo) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) { diff --git a/backend/internal/handler/usage_handler.go b/backend/internal/handler/usage_handler.go index b8182dad..2bd0e0d7 100644 --- a/backend/internal/handler/usage_handler.go +++ b/backend/internal/handler/usage_handler.go @@ -2,6 +2,7 @@ package handler import ( "strconv" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/handler/dto" @@ -65,8 +66,17 @@ func (h *UsageHandler) List(c *gin.Context) { // Parse additional filters model := c.Query("model") + var requestType *int16 var stream *bool - if streamStr := c.Query("stream"); streamStr != "" { + if requestTypeStr := strings.TrimSpace(c.Query("request_type")); requestTypeStr != "" { + parsed, err := service.ParseUsageRequestType(requestTypeStr) + if err != nil { + response.BadRequest(c, err.Error()) + return + } + value := int16(parsed) + requestType = &value + } else if streamStr := c.Query("stream"); streamStr != "" { val, err := strconv.ParseBool(streamStr) if err != nil { response.BadRequest(c, "Invalid stream value, use true or false") @@ -114,6 +124,7 @@ func (h *UsageHandler) List(c *gin.Context) { UserID: subject.UserID, // Always filter by current user for security APIKeyID: apiKeyID, Model: model, + RequestType: requestType, Stream: stream, BillingType: billingType, StartTime: startTime, diff --git a/backend/internal/handler/usage_handler_request_type_test.go b/backend/internal/handler/usage_handler_request_type_test.go new file mode 100644 index 00000000..7c4c7913 --- /dev/null +++ b/backend/internal/handler/usage_handler_request_type_test.go @@ -0,0 +1,80 @@ +package handler + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type userUsageRepoCapture struct { + service.UsageLogRepository + listFilters usagestats.UsageLogFilters +} + +func (s *userUsageRepoCapture) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters usagestats.UsageLogFilters) ([]service.UsageLog, *pagination.PaginationResult, error) { + s.listFilters = filters + return []service.UsageLog{}, &pagination.PaginationResult{ + Total: 0, + Page: params.Page, + PageSize: params.PageSize, + Pages: 0, + }, nil +} + +func newUserUsageRequestTypeTestRouter(repo *userUsageRepoCapture) *gin.Engine { + gin.SetMode(gin.TestMode) + usageSvc := service.NewUsageService(repo, nil, nil, nil) + handler := NewUsageHandler(usageSvc, nil) + router := gin.New() + router.Use(func(c *gin.Context) { + c.Set(string(middleware2.ContextKeyUser), middleware2.AuthSubject{UserID: 42}) + c.Next() + }) + router.GET("/usage", handler.List) + return router +} + +func TestUserUsageListRequestTypePriority(t *testing.T) { + repo := &userUsageRepoCapture{} + router := newUserUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/usage?request_type=ws_v2&stream=bad", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, int64(42), repo.listFilters.UserID) + require.NotNil(t, repo.listFilters.RequestType) + require.Equal(t, int16(service.RequestTypeWSV2), *repo.listFilters.RequestType) + require.Nil(t, repo.listFilters.Stream) +} + +func TestUserUsageListInvalidRequestType(t *testing.T) { + repo := &userUsageRepoCapture{} + router := newUserUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/usage?request_type=invalid", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} + +func TestUserUsageListInvalidStream(t *testing.T) { + repo := &userUsageRepoCapture{} + router := newUserUsageRequestTypeTestRouter(repo) + + req := httptest.NewRequest(http.MethodGet, "/usage?stream=invalid", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusBadRequest, rec.Code) +} diff --git a/backend/internal/handler/usage_record_submit_task_test.go b/backend/internal/handler/usage_record_submit_task_test.go index df759f44..c7c48e14 100644 --- a/backend/internal/handler/usage_record_submit_task_test.go +++ b/backend/internal/handler/usage_record_submit_task_test.go @@ -61,6 +61,22 @@ func TestGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) { }) } +func TestGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) { + h := &GatewayHandler{} + var called atomic.Bool + + require.NotPanics(t, func() { + h.submitUsageRecordTask(func(ctx context.Context) { + panic("usage task panic") + }) + }) + + h.submitUsageRecordTask(func(ctx context.Context) { + called.Store(true) + }) + require.True(t, called.Load(), "panic 后后续任务应仍可执行") +} + func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) { pool := newUsageRecordTestPool(t) h := &OpenAIGatewayHandler{usageRecordWorkerPool: pool} @@ -98,6 +114,22 @@ func TestOpenAIGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) { }) } +func TestOpenAIGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) { + h := &OpenAIGatewayHandler{} + var called atomic.Bool + + require.NotPanics(t, func() { + h.submitUsageRecordTask(func(ctx context.Context) { + panic("usage task panic") + }) + }) + + h.submitUsageRecordTask(func(ctx context.Context) { + called.Store(true) + }) + require.True(t, called.Load(), "panic 后后续任务应仍可执行") +} + func TestSoraGatewayHandlerSubmitUsageRecordTask_WithPool(t *testing.T) { pool := newUsageRecordTestPool(t) h := &SoraGatewayHandler{usageRecordWorkerPool: pool} @@ -134,3 +166,19 @@ func TestSoraGatewayHandlerSubmitUsageRecordTask_NilTask(t *testing.T) { h.submitUsageRecordTask(nil) }) } + +func TestSoraGatewayHandlerSubmitUsageRecordTask_WithoutPool_TaskPanicRecovered(t *testing.T) { + h := &SoraGatewayHandler{} + var called atomic.Bool + + require.NotPanics(t, func() { + h.submitUsageRecordTask(func(ctx context.Context) { + panic("usage task panic") + }) + }) + + h.submitUsageRecordTask(func(ctx context.Context) { + called.Store(true) + }) + require.True(t, called.Load(), "panic 后后续任务应仍可执行") +} diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index 79d583fd..f1a21119 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -14,6 +14,7 @@ func ProvideAdminHandlers( groupHandler *admin.GroupHandler, accountHandler *admin.AccountHandler, announcementHandler *admin.AnnouncementHandler, + dataManagementHandler *admin.DataManagementHandler, oauthHandler *admin.OAuthHandler, openaiOAuthHandler *admin.OpenAIOAuthHandler, geminiOAuthHandler *admin.GeminiOAuthHandler, @@ -35,6 +36,7 @@ func ProvideAdminHandlers( Group: groupHandler, Account: accountHandler, Announcement: announcementHandler, + DataManagement: dataManagementHandler, OAuth: oauthHandler, OpenAIOAuth: openaiOAuthHandler, GeminiOAuth: geminiOAuthHandler, @@ -75,6 +77,7 @@ func ProvideHandlers( gatewayHandler *GatewayHandler, openaiGatewayHandler *OpenAIGatewayHandler, soraGatewayHandler *SoraGatewayHandler, + soraClientHandler *SoraClientHandler, settingHandler *SettingHandler, totpHandler *TotpHandler, _ *service.IdempotencyCoordinator, @@ -92,6 +95,7 @@ func ProvideHandlers( Gateway: gatewayHandler, OpenAIGateway: openaiGatewayHandler, SoraGateway: soraGatewayHandler, + SoraClient: soraClientHandler, Setting: settingHandler, Totp: totpHandler, } @@ -119,6 +123,7 @@ var ProviderSet = wire.NewSet( admin.NewGroupHandler, admin.NewAccountHandler, admin.NewAnnouncementHandler, + admin.NewDataManagementHandler, admin.NewOAuthHandler, admin.NewOpenAIOAuthHandler, admin.NewGeminiOAuthHandler, diff --git a/backend/internal/pkg/antigravity/claude_types.go b/backend/internal/pkg/antigravity/claude_types.go index 1b94dad5..7cc68060 100644 --- a/backend/internal/pkg/antigravity/claude_types.go +++ b/backend/internal/pkg/antigravity/claude_types.go @@ -152,6 +152,7 @@ var claudeModels = []modelDef{ {ID: "claude-sonnet-4-5", DisplayName: "Claude Sonnet 4.5", CreatedAt: "2025-09-29T00:00:00Z"}, {ID: "claude-sonnet-4-5-thinking", DisplayName: "Claude Sonnet 4.5 Thinking", CreatedAt: "2025-09-29T00:00:00Z"}, {ID: "claude-opus-4-6", DisplayName: "Claude Opus 4.6", CreatedAt: "2026-02-05T00:00:00Z"}, + {ID: "claude-opus-4-6-thinking", DisplayName: "Claude Opus 4.6 Thinking", CreatedAt: "2026-02-05T00:00:00Z"}, {ID: "claude-sonnet-4-6", DisplayName: "Claude Sonnet 4.6", CreatedAt: "2026-02-17T00:00:00Z"}, } @@ -165,6 +166,8 @@ var geminiModels = []modelDef{ {ID: "gemini-3-pro-high", DisplayName: "Gemini 3 Pro High", CreatedAt: "2025-06-01T00:00:00Z"}, {ID: "gemini-3.1-pro-low", DisplayName: "Gemini 3.1 Pro Low", CreatedAt: "2026-02-19T00:00:00Z"}, {ID: "gemini-3.1-pro-high", DisplayName: "Gemini 3.1 Pro High", CreatedAt: "2026-02-19T00:00:00Z"}, + {ID: "gemini-3.1-flash-image", DisplayName: "Gemini 3.1 Flash Image", CreatedAt: "2026-02-19T00:00:00Z"}, + {ID: "gemini-3.1-flash-image-preview", DisplayName: "Gemini 3.1 Flash Image Preview", CreatedAt: "2026-02-19T00:00:00Z"}, {ID: "gemini-3-pro-preview", DisplayName: "Gemini 3 Pro Preview", CreatedAt: "2025-06-01T00:00:00Z"}, {ID: "gemini-3-pro-image", DisplayName: "Gemini 3 Pro Image", CreatedAt: "2025-06-01T00:00:00Z"}, } diff --git a/backend/internal/pkg/antigravity/claude_types_test.go b/backend/internal/pkg/antigravity/claude_types_test.go new file mode 100644 index 00000000..f7cb0a24 --- /dev/null +++ b/backend/internal/pkg/antigravity/claude_types_test.go @@ -0,0 +1,26 @@ +package antigravity + +import "testing" + +func TestDefaultModels_ContainsNewAndLegacyImageModels(t *testing.T) { + t.Parallel() + + models := DefaultModels() + byID := make(map[string]ClaudeModel, len(models)) + for _, m := range models { + byID[m.ID] = m + } + + requiredIDs := []string{ + "claude-opus-4-6-thinking", + "gemini-3.1-flash-image", + "gemini-3.1-flash-image-preview", + "gemini-3-pro-image", // legacy compatibility + } + + for _, id := range requiredIDs { + if _, ok := byID[id]; !ok { + t.Fatalf("expected model %q to be exposed in DefaultModels", id) + } + } +} diff --git a/backend/internal/pkg/antigravity/gemini_types.go b/backend/internal/pkg/antigravity/gemini_types.go index 32495827..0ff24a1f 100644 --- a/backend/internal/pkg/antigravity/gemini_types.go +++ b/backend/internal/pkg/antigravity/gemini_types.go @@ -70,7 +70,7 @@ type GeminiGenerationConfig struct { ImageConfig *GeminiImageConfig `json:"imageConfig,omitempty"` } -// GeminiImageConfig Gemini 图片生成配置(仅 gemini-3-pro-image 支持) +// GeminiImageConfig Gemini 图片生成配置(gemini-3-pro-image / gemini-3.1-flash-image 等图片模型支持) type GeminiImageConfig struct { AspectRatio string `json:"aspectRatio,omitempty"` // "1:1", "16:9", "9:16", "4:3", "3:4" ImageSize string `json:"imageSize,omitempty"` // "1K", "2K", "4K" diff --git a/backend/internal/pkg/antigravity/oauth.go b/backend/internal/pkg/antigravity/oauth.go index afffe9b1..18310655 100644 --- a/backend/internal/pkg/antigravity/oauth.go +++ b/backend/internal/pkg/antigravity/oauth.go @@ -53,7 +53,8 @@ const ( var defaultUserAgentVersion = "1.19.6" // defaultClientSecret 可通过环境变量 ANTIGRAVITY_OAUTH_CLIENT_SECRET 配置 -var defaultClientSecret = "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" +// 默认值使用占位符,生产环境请通过环境变量注入真实值。 +var defaultClientSecret = "GOCSPX-your-client-secret" func init() { // 从环境变量读取版本号,未设置则使用默认值 diff --git a/backend/internal/pkg/antigravity/oauth_test.go b/backend/internal/pkg/antigravity/oauth_test.go index 8417416a..2a2a52e9 100644 --- a/backend/internal/pkg/antigravity/oauth_test.go +++ b/backend/internal/pkg/antigravity/oauth_test.go @@ -612,14 +612,14 @@ func TestBuildAuthorizationURL_参数验证(t *testing.T) { expectedParams := map[string]string{ "client_id": ClientID, - "redirect_uri": RedirectURI, - "response_type": "code", - "scope": Scopes, - "state": state, - "code_challenge": codeChallenge, - "code_challenge_method": "S256", - "access_type": "offline", - "prompt": "consent", + "redirect_uri": RedirectURI, + "response_type": "code", + "scope": Scopes, + "state": state, + "code_challenge": codeChallenge, + "code_challenge_method": "S256", + "access_type": "offline", + "prompt": "consent", "include_granted_scopes": "true", } @@ -684,7 +684,7 @@ func TestConstants_值正确(t *testing.T) { if err != nil { t.Fatalf("getClientSecret 应返回默认值,但报错: %v", err) } - if secret != "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" { + if secret != "GOCSPX-your-client-secret" { t.Errorf("默认 client_secret 不匹配: got %s", secret) } if RedirectURI != "http://localhost:8085/callback" { diff --git a/backend/internal/pkg/errors/errors_test.go b/backend/internal/pkg/errors/errors_test.go index 1a1c842e..25e62907 100644 --- a/backend/internal/pkg/errors/errors_test.go +++ b/backend/internal/pkg/errors/errors_test.go @@ -166,3 +166,18 @@ func TestToHTTP(t *testing.T) { }) } } + +func TestToHTTP_MetadataDeepCopy(t *testing.T) { + md := map[string]string{"k": "v"} + appErr := BadRequest("BAD_REQUEST", "invalid").WithMetadata(md) + + code, body := ToHTTP(appErr) + require.Equal(t, http.StatusBadRequest, code) + require.Equal(t, "v", body.Metadata["k"]) + + md["k"] = "changed" + require.Equal(t, "v", body.Metadata["k"]) + + appErr.Metadata["k"] = "changed-again" + require.Equal(t, "v", body.Metadata["k"]) +} diff --git a/backend/internal/pkg/errors/http.go b/backend/internal/pkg/errors/http.go index 7b5560e3..420c69a3 100644 --- a/backend/internal/pkg/errors/http.go +++ b/backend/internal/pkg/errors/http.go @@ -16,6 +16,16 @@ func ToHTTP(err error) (statusCode int, body Status) { return http.StatusOK, Status{Code: int32(http.StatusOK)} } - cloned := Clone(appErr) - return int(cloned.Code), cloned.Status + body = Status{ + Code: appErr.Code, + Reason: appErr.Reason, + Message: appErr.Message, + } + if appErr.Metadata != nil { + body.Metadata = make(map[string]string, len(appErr.Metadata)) + for k, v := range appErr.Metadata { + body.Metadata[k] = v + } + } + return int(appErr.Code), body } diff --git a/backend/internal/pkg/geminicli/constants.go b/backend/internal/pkg/geminicli/constants.go index 97234ffd..f5ee5735 100644 --- a/backend/internal/pkg/geminicli/constants.go +++ b/backend/internal/pkg/geminicli/constants.go @@ -39,7 +39,7 @@ const ( // They enable the "login without creating your own OAuth client" experience, but Google may // restrict which scopes are allowed for this client. GeminiCLIOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" - GeminiCLIOAuthClientSecret = "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" + GeminiCLIOAuthClientSecret = "GOCSPX-your-client-secret" // GeminiCLIOAuthClientSecretEnv is the environment variable name for the built-in client secret. GeminiCLIOAuthClientSecretEnv = "GEMINI_CLI_OAUTH_CLIENT_SECRET" diff --git a/backend/internal/pkg/httpclient/pool.go b/backend/internal/pkg/httpclient/pool.go index 76b7aa91..6ef3d714 100644 --- a/backend/internal/pkg/httpclient/pool.go +++ b/backend/internal/pkg/httpclient/pool.go @@ -32,6 +32,7 @@ const ( defaultMaxIdleConns = 100 // 最大空闲连接数 defaultMaxIdleConnsPerHost = 10 // 每个主机最大空闲连接数 defaultIdleConnTimeout = 90 * time.Second // 空闲连接超时时间(建议小于上游 LB 超时) + validatedHostTTL = 30 * time.Second // DNS Rebinding 校验缓存 TTL ) // Options 定义共享 HTTP 客户端的构建参数 @@ -53,6 +54,9 @@ type Options struct { // sharedClients 存储按配置参数缓存的 http.Client 实例 var sharedClients sync.Map +// 允许测试替换校验函数,生产默认指向真实实现。 +var validateResolvedIP = urlvalidator.ValidateResolvedIP + // GetClient 返回共享的 HTTP 客户端实例 // 性能优化:相同配置复用同一客户端,避免重复创建 Transport // 安全说明:代理配置失败时直接返回错误,不会回退到直连,避免 IP 关联风险 @@ -84,7 +88,7 @@ func buildClient(opts Options) (*http.Client, error) { var rt http.RoundTripper = transport if opts.ValidateResolvedIP && !opts.AllowPrivateHosts { - rt = &validatedTransport{base: transport} + rt = newValidatedTransport(transport) } return &http.Client{ Transport: rt, @@ -149,17 +153,56 @@ func buildClientKey(opts Options) string { } type validatedTransport struct { - base http.RoundTripper + base http.RoundTripper + validatedHosts sync.Map // map[string]time.Time, value 为过期时间 + now func() time.Time +} + +func newValidatedTransport(base http.RoundTripper) *validatedTransport { + return &validatedTransport{ + base: base, + now: time.Now, + } +} + +func (t *validatedTransport) isValidatedHost(host string, now time.Time) bool { + if t == nil { + return false + } + raw, ok := t.validatedHosts.Load(host) + if !ok { + return false + } + expireAt, ok := raw.(time.Time) + if !ok { + t.validatedHosts.Delete(host) + return false + } + if now.Before(expireAt) { + return true + } + t.validatedHosts.Delete(host) + return false } func (t *validatedTransport) RoundTrip(req *http.Request) (*http.Response, error) { if req != nil && req.URL != nil { - host := strings.TrimSpace(req.URL.Hostname()) + host := strings.ToLower(strings.TrimSpace(req.URL.Hostname())) if host != "" { - if err := urlvalidator.ValidateResolvedIP(host); err != nil { - return nil, err + now := time.Now() + if t != nil && t.now != nil { + now = t.now() + } + if !t.isValidatedHost(host, now) { + if err := validateResolvedIP(host); err != nil { + return nil, err + } + t.validatedHosts.Store(host, now.Add(validatedHostTTL)) } } } + if t == nil || t.base == nil { + return nil, fmt.Errorf("validated transport base is nil") + } return t.base.RoundTrip(req) } diff --git a/backend/internal/pkg/httpclient/pool_test.go b/backend/internal/pkg/httpclient/pool_test.go new file mode 100644 index 00000000..f945758a --- /dev/null +++ b/backend/internal/pkg/httpclient/pool_test.go @@ -0,0 +1,115 @@ +package httpclient + +import ( + "errors" + "io" + "net/http" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func TestValidatedTransport_CacheHostValidation(t *testing.T) { + originalValidate := validateResolvedIP + defer func() { validateResolvedIP = originalValidate }() + + var validateCalls int32 + validateResolvedIP = func(host string) error { + atomic.AddInt32(&validateCalls, 1) + require.Equal(t, "api.openai.com", host) + return nil + } + + var baseCalls int32 + base := roundTripFunc(func(_ *http.Request) (*http.Response, error) { + atomic.AddInt32(&baseCalls, 1) + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{}`)), + Header: make(http.Header), + }, nil + }) + + now := time.Unix(1730000000, 0) + transport := newValidatedTransport(base) + transport.now = func() time.Time { return now } + + req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil) + require.NoError(t, err) + + _, err = transport.RoundTrip(req) + require.NoError(t, err) + _, err = transport.RoundTrip(req) + require.NoError(t, err) + + require.Equal(t, int32(1), atomic.LoadInt32(&validateCalls)) + require.Equal(t, int32(2), atomic.LoadInt32(&baseCalls)) +} + +func TestValidatedTransport_ExpiredCacheTriggersRevalidation(t *testing.T) { + originalValidate := validateResolvedIP + defer func() { validateResolvedIP = originalValidate }() + + var validateCalls int32 + validateResolvedIP = func(_ string) error { + atomic.AddInt32(&validateCalls, 1) + return nil + } + + base := roundTripFunc(func(_ *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{}`)), + Header: make(http.Header), + }, nil + }) + + now := time.Unix(1730001000, 0) + transport := newValidatedTransport(base) + transport.now = func() time.Time { return now } + + req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil) + require.NoError(t, err) + + _, err = transport.RoundTrip(req) + require.NoError(t, err) + + now = now.Add(validatedHostTTL + time.Second) + _, err = transport.RoundTrip(req) + require.NoError(t, err) + + require.Equal(t, int32(2), atomic.LoadInt32(&validateCalls)) +} + +func TestValidatedTransport_ValidationErrorStopsRoundTrip(t *testing.T) { + originalValidate := validateResolvedIP + defer func() { validateResolvedIP = originalValidate }() + + expectedErr := errors.New("dns rebinding rejected") + validateResolvedIP = func(_ string) error { + return expectedErr + } + + var baseCalls int32 + base := roundTripFunc(func(_ *http.Request) (*http.Response, error) { + atomic.AddInt32(&baseCalls, 1) + return &http.Response{StatusCode: http.StatusOK, Body: io.NopCloser(strings.NewReader(`{}`))}, nil + }) + + transport := newValidatedTransport(base) + req, err := http.NewRequest(http.MethodGet, "https://api.openai.com/v1/responses", nil) + require.NoError(t, err) + + _, err = transport.RoundTrip(req) + require.ErrorIs(t, err, expectedErr) + require.Equal(t, int32(0), atomic.LoadInt32(&baseCalls)) +} diff --git a/backend/internal/pkg/httputil/body.go b/backend/internal/pkg/httputil/body.go new file mode 100644 index 00000000..69e99dc5 --- /dev/null +++ b/backend/internal/pkg/httputil/body.go @@ -0,0 +1,37 @@ +package httputil + +import ( + "bytes" + "io" + "net/http" +) + +const ( + requestBodyReadInitCap = 512 + requestBodyReadMaxInitCap = 1 << 20 +) + +// ReadRequestBodyWithPrealloc reads request body with preallocated buffer based on content length. +func ReadRequestBodyWithPrealloc(req *http.Request) ([]byte, error) { + if req == nil || req.Body == nil { + return nil, nil + } + + capHint := requestBodyReadInitCap + if req.ContentLength > 0 { + switch { + case req.ContentLength < int64(requestBodyReadInitCap): + capHint = requestBodyReadInitCap + case req.ContentLength > int64(requestBodyReadMaxInitCap): + capHint = requestBodyReadMaxInitCap + default: + capHint = int(req.ContentLength) + } + } + + buf := bytes.NewBuffer(make([]byte, 0, capHint)) + if _, err := io.Copy(buf, req.Body); err != nil { + return nil, err + } + return buf.Bytes(), nil +} diff --git a/backend/internal/pkg/ip/ip.go b/backend/internal/pkg/ip/ip.go index 3f05ac41..f6f77c86 100644 --- a/backend/internal/pkg/ip/ip.go +++ b/backend/internal/pkg/ip/ip.go @@ -67,6 +67,14 @@ func normalizeIP(ip string) string { // privateNets 预编译私有 IP CIDR 块,避免每次调用 isPrivateIP 时重复解析 var privateNets []*net.IPNet +// CompiledIPRules 表示预编译的 IP 匹配规则。 +// PatternCount 记录原始规则数量,用于保留“规则存在但全无效”时的行为语义。 +type CompiledIPRules struct { + CIDRs []*net.IPNet + IPs []net.IP + PatternCount int +} + func init() { for _, cidr := range []string{ "10.0.0.0/8", @@ -84,6 +92,53 @@ func init() { } } +// CompileIPRules 将 IP/CIDR 字符串规则预编译为可复用结构。 +// 非法规则会被忽略,但 PatternCount 会保留原始规则条数。 +func CompileIPRules(patterns []string) *CompiledIPRules { + compiled := &CompiledIPRules{ + CIDRs: make([]*net.IPNet, 0, len(patterns)), + IPs: make([]net.IP, 0, len(patterns)), + PatternCount: len(patterns), + } + for _, pattern := range patterns { + normalized := strings.TrimSpace(pattern) + if normalized == "" { + continue + } + if strings.Contains(normalized, "/") { + _, cidr, err := net.ParseCIDR(normalized) + if err != nil || cidr == nil { + continue + } + compiled.CIDRs = append(compiled.CIDRs, cidr) + continue + } + parsedIP := net.ParseIP(normalized) + if parsedIP == nil { + continue + } + compiled.IPs = append(compiled.IPs, parsedIP) + } + return compiled +} + +func matchesCompiledRules(parsedIP net.IP, rules *CompiledIPRules) bool { + if parsedIP == nil || rules == nil { + return false + } + for _, cidr := range rules.CIDRs { + if cidr.Contains(parsedIP) { + return true + } + } + for _, ruleIP := range rules.IPs { + if parsedIP.Equal(ruleIP) { + return true + } + } + return false +} + // isPrivateIP 检查 IP 是否为私有地址。 func isPrivateIP(ipStr string) bool { ip := net.ParseIP(ipStr) @@ -142,19 +197,32 @@ func MatchesAnyPattern(clientIP string, patterns []string) bool { // 2. 如果白名单不为空,IP 必须在白名单中 // 3. 如果白名单为空,允许访问(除非被黑名单拒绝) func CheckIPRestriction(clientIP string, whitelist, blacklist []string) (bool, string) { + return CheckIPRestrictionWithCompiledRules( + clientIP, + CompileIPRules(whitelist), + CompileIPRules(blacklist), + ) +} + +// CheckIPRestrictionWithCompiledRules 使用预编译规则检查 IP 是否允许访问。 +func CheckIPRestrictionWithCompiledRules(clientIP string, whitelist, blacklist *CompiledIPRules) (bool, string) { // 规范化 IP clientIP = normalizeIP(clientIP) if clientIP == "" { return false, "access denied" } + parsedIP := net.ParseIP(clientIP) + if parsedIP == nil { + return false, "access denied" + } // 1. 检查黑名单 - if len(blacklist) > 0 && MatchesAnyPattern(clientIP, blacklist) { + if blacklist != nil && blacklist.PatternCount > 0 && matchesCompiledRules(parsedIP, blacklist) { return false, "access denied" } // 2. 检查白名单(如果设置了白名单,IP 必须在其中) - if len(whitelist) > 0 && !MatchesAnyPattern(clientIP, whitelist) { + if whitelist != nil && whitelist.PatternCount > 0 && !matchesCompiledRules(parsedIP, whitelist) { return false, "access denied" } diff --git a/backend/internal/pkg/ip/ip_test.go b/backend/internal/pkg/ip/ip_test.go index 3839403c..403b2d59 100644 --- a/backend/internal/pkg/ip/ip_test.go +++ b/backend/internal/pkg/ip/ip_test.go @@ -73,3 +73,24 @@ func TestGetTrustedClientIPUsesGinClientIP(t *testing.T) { require.Equal(t, 200, w.Code) require.Equal(t, "9.9.9.9", w.Body.String()) } + +func TestCheckIPRestrictionWithCompiledRules(t *testing.T) { + whitelist := CompileIPRules([]string{"10.0.0.0/8", "192.168.1.2"}) + blacklist := CompileIPRules([]string{"10.1.1.1"}) + + allowed, reason := CheckIPRestrictionWithCompiledRules("10.2.3.4", whitelist, blacklist) + require.True(t, allowed) + require.Equal(t, "", reason) + + allowed, reason = CheckIPRestrictionWithCompiledRules("10.1.1.1", whitelist, blacklist) + require.False(t, allowed) + require.Equal(t, "access denied", reason) +} + +func TestCheckIPRestrictionWithCompiledRules_InvalidWhitelistStillDenies(t *testing.T) { + // 与旧实现保持一致:白名单有配置但全无效时,最终应拒绝访问。 + invalidWhitelist := CompileIPRules([]string{"not-a-valid-pattern"}) + allowed, reason := CheckIPRestrictionWithCompiledRules("8.8.8.8", invalidWhitelist, nil) + require.False(t, allowed) + require.Equal(t, "access denied", reason) +} diff --git a/backend/internal/pkg/logger/logger.go b/backend/internal/pkg/logger/logger.go index 80d92517..3fca706e 100644 --- a/backend/internal/pkg/logger/logger.go +++ b/backend/internal/pkg/logger/logger.go @@ -10,6 +10,7 @@ import ( "path/filepath" "strings" "sync" + "sync/atomic" "time" "go.uber.org/zap" @@ -42,15 +43,19 @@ type LogEvent struct { var ( mu sync.RWMutex - global *zap.Logger - sugar *zap.SugaredLogger + global atomic.Pointer[zap.Logger] + sugar atomic.Pointer[zap.SugaredLogger] atomicLevel zap.AtomicLevel initOptions InitOptions - currentSink Sink + currentSink atomic.Value // sinkState stdLogUndo func() bootstrapOnce sync.Once ) +type sinkState struct { + sink Sink +} + func InitBootstrap() { bootstrapOnce.Do(func() { if err := Init(bootstrapOptions()); err != nil { @@ -72,9 +77,9 @@ func initLocked(options InitOptions) error { return err } - prev := global - global = zl - sugar = zl.Sugar() + prev := global.Load() + global.Store(zl) + sugar.Store(zl.Sugar()) atomicLevel = al initOptions = normalized @@ -115,24 +120,32 @@ func SetLevel(level string) error { func CurrentLevel() string { mu.RLock() defer mu.RUnlock() - if global == nil { + if global.Load() == nil { return "info" } return atomicLevel.Level().String() } func SetSink(sink Sink) { - mu.Lock() - defer mu.Unlock() - currentSink = sink + currentSink.Store(sinkState{sink: sink}) +} + +func loadSink() Sink { + v := currentSink.Load() + if v == nil { + return nil + } + state, ok := v.(sinkState) + if !ok { + return nil + } + return state.sink } // WriteSinkEvent 直接写入日志 sink,不经过全局日志级别门控。 // 用于需要“可观测性入库”与“业务输出级别”解耦的场景(例如 ops 系统日志索引)。 func WriteSinkEvent(level, component, message string, fields map[string]any) { - mu.RLock() - sink := currentSink - mu.RUnlock() + sink := loadSink() if sink == nil { return } @@ -168,19 +181,15 @@ func WriteSinkEvent(level, component, message string, fields map[string]any) { } func L() *zap.Logger { - mu.RLock() - defer mu.RUnlock() - if global != nil { - return global + if l := global.Load(); l != nil { + return l } return zap.NewNop() } func S() *zap.SugaredLogger { - mu.RLock() - defer mu.RUnlock() - if sugar != nil { - return sugar + if s := sugar.Load(); s != nil { + return s } return zap.NewNop().Sugar() } @@ -190,9 +199,7 @@ func With(fields ...zap.Field) *zap.Logger { } func Sync() { - mu.RLock() - l := global - mu.RUnlock() + l := global.Load() if l != nil { _ = l.Sync() } @@ -210,7 +217,11 @@ func bridgeStdLogLocked() { log.SetFlags(0) log.SetPrefix("") - log.SetOutput(newStdLogBridge(global.Named("stdlog"))) + base := global.Load() + if base == nil { + base = zap.NewNop() + } + log.SetOutput(newStdLogBridge(base.Named("stdlog"))) stdLogUndo = func() { log.SetOutput(prevWriter) @@ -220,7 +231,11 @@ func bridgeStdLogLocked() { } func bridgeSlogLocked() { - slog.SetDefault(slog.New(newSlogZapHandler(global.Named("slog")))) + base := global.Load() + if base == nil { + base = zap.NewNop() + } + slog.SetDefault(slog.New(newSlogZapHandler(base.Named("slog")))) } func buildLogger(options InitOptions) (*zap.Logger, zap.AtomicLevel, error) { @@ -363,9 +378,7 @@ func (s *sinkCore) Check(entry zapcore.Entry, ce *zapcore.CheckedEntry) *zapcore func (s *sinkCore) Write(entry zapcore.Entry, fields []zapcore.Field) error { // Only handle sink forwarding — the inner cores write via their own // Write methods (added to CheckedEntry by s.core.Check above). - mu.RLock() - sink := currentSink - mu.RUnlock() + sink := loadSink() if sink == nil { return nil } @@ -454,7 +467,7 @@ func inferStdLogLevel(msg string) Level { if strings.Contains(lower, " failed") || strings.Contains(lower, "error") || strings.Contains(lower, "panic") || strings.Contains(lower, "fatal") { return LevelError } - if strings.Contains(lower, "warning") || strings.Contains(lower, "warn") || strings.Contains(lower, " retry") || strings.Contains(lower, " queue full") || strings.Contains(lower, "fallback") { + if strings.Contains(lower, "warning") || strings.Contains(lower, "warn") || strings.Contains(lower, " queue full") || strings.Contains(lower, "fallback") { return LevelWarn } return LevelInfo @@ -467,9 +480,7 @@ func LegacyPrintf(component, format string, args ...any) { return } - mu.RLock() - initialized := global != nil - mu.RUnlock() + initialized := global.Load() != nil if !initialized { // 在日志系统未初始化前,回退到标准库 log,避免测试/工具链丢日志。 log.Print(msg) diff --git a/backend/internal/pkg/logger/slog_handler.go b/backend/internal/pkg/logger/slog_handler.go index 562b8341..602ca1e0 100644 --- a/backend/internal/pkg/logger/slog_handler.go +++ b/backend/internal/pkg/logger/slog_handler.go @@ -48,16 +48,15 @@ func (h *slogZapHandler) Handle(_ context.Context, record slog.Record) error { return true }) - entry := h.logger.With(fields...) switch { case record.Level >= slog.LevelError: - entry.Error(record.Message) + h.logger.Error(record.Message, fields...) case record.Level >= slog.LevelWarn: - entry.Warn(record.Message) + h.logger.Warn(record.Message, fields...) case record.Level <= slog.LevelDebug: - entry.Debug(record.Message) + h.logger.Debug(record.Message, fields...) default: - entry.Info(record.Message) + h.logger.Info(record.Message, fields...) } return nil } diff --git a/backend/internal/pkg/logger/stdlog_bridge_test.go b/backend/internal/pkg/logger/stdlog_bridge_test.go index a3f76fd7..4482a2ec 100644 --- a/backend/internal/pkg/logger/stdlog_bridge_test.go +++ b/backend/internal/pkg/logger/stdlog_bridge_test.go @@ -16,6 +16,7 @@ func TestInferStdLogLevel(t *testing.T) { {msg: "Warning: queue full", want: LevelWarn}, {msg: "Forward request failed: timeout", want: LevelError}, {msg: "[ERROR] upstream unavailable", want: LevelError}, + {msg: "[OpenAI WS Mode] reconnect_retry account_id=22 retry=1 max_retries=5", want: LevelInfo}, {msg: "service started", want: LevelInfo}, {msg: "debug: cache miss", want: LevelDebug}, } diff --git a/backend/internal/pkg/openai/oauth.go b/backend/internal/pkg/openai/oauth.go index e3b931be..8bdcbe16 100644 --- a/backend/internal/pkg/openai/oauth.go +++ b/backend/internal/pkg/openai/oauth.go @@ -36,10 +36,18 @@ const ( SessionTTL = 30 * time.Minute ) +const ( + // OAuthPlatformOpenAI uses OpenAI Codex-compatible OAuth client. + OAuthPlatformOpenAI = "openai" + // OAuthPlatformSora uses Sora OAuth client. + OAuthPlatformSora = "sora" +) + // OAuthSession stores OAuth flow state for OpenAI type OAuthSession struct { State string `json:"state"` CodeVerifier string `json:"code_verifier"` + ClientID string `json:"client_id,omitempty"` ProxyURL string `json:"proxy_url,omitempty"` RedirectURI string `json:"redirect_uri"` CreatedAt time.Time `json:"created_at"` @@ -174,13 +182,20 @@ func base64URLEncode(data []byte) string { // BuildAuthorizationURL builds the OpenAI OAuth authorization URL func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string { + return BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, OAuthPlatformOpenAI) +} + +// BuildAuthorizationURLForPlatform builds authorization URL by platform. +func BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, platform string) string { if redirectURI == "" { redirectURI = DefaultRedirectURI } + clientID, codexFlow := OAuthClientConfigByPlatform(platform) + params := url.Values{} params.Set("response_type", "code") - params.Set("client_id", ClientID) + params.Set("client_id", clientID) params.Set("redirect_uri", redirectURI) params.Set("scope", DefaultScopes) params.Set("state", state) @@ -188,11 +203,25 @@ func BuildAuthorizationURL(state, codeChallenge, redirectURI string) string { params.Set("code_challenge_method", "S256") // OpenAI specific parameters params.Set("id_token_add_organizations", "true") - params.Set("codex_cli_simplified_flow", "true") + if codexFlow { + params.Set("codex_cli_simplified_flow", "true") + } return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode()) } +// OAuthClientConfigByPlatform returns oauth client_id and whether codex simplified flow should be enabled. +// Sora 授权流程复用 Codex CLI 的 client_id(支持 localhost redirect_uri), +// 但不启用 codex_cli_simplified_flow;拿到的 access_token 绑定同一 OpenAI 账号,对 Sora API 同样可用。 +func OAuthClientConfigByPlatform(platform string) (clientID string, codexFlow bool) { + switch strings.ToLower(strings.TrimSpace(platform)) { + case OAuthPlatformSora: + return ClientID, false + default: + return ClientID, true + } +} + // TokenRequest represents the token exchange request body type TokenRequest struct { GrantType string `json:"grant_type"` @@ -296,9 +325,11 @@ func (r *RefreshTokenRequest) ToFormData() string { return params.Encode() } -// ParseIDToken parses the ID Token JWT and extracts claims -// Note: This does NOT verify the signature - it only decodes the payload -// For production, you should verify the token signature using OpenAI's public keys +// ParseIDToken parses the ID Token JWT and extracts claims. +// 注意:当前仅解码 payload 并校验 exp,未验证 JWT 签名。 +// 生产环境如需用 ID Token 做授权决策,应通过 OpenAI 的 JWKS 端点验证签名: +// +// https://auth.openai.com/.well-known/jwks.json func ParseIDToken(idToken string) (*IDTokenClaims, error) { parts := strings.Split(idToken, ".") if len(parts) != 3 { @@ -329,6 +360,13 @@ func ParseIDToken(idToken string) (*IDTokenClaims, error) { return nil, fmt.Errorf("failed to parse JWT claims: %w", err) } + // 校验 ID Token 是否已过期(允许 2 分钟时钟偏差,防止因服务器时钟略有差异误判刚颁发的令牌) + const clockSkewTolerance = 120 // 秒 + now := time.Now().Unix() + if claims.Exp > 0 && now > claims.Exp+clockSkewTolerance { + return nil, fmt.Errorf("id_token has expired (exp: %d, now: %d, skew_tolerance: %ds)", claims.Exp, now, clockSkewTolerance) + } + return &claims, nil } diff --git a/backend/internal/pkg/openai/oauth_test.go b/backend/internal/pkg/openai/oauth_test.go index f1d616a6..2970addf 100644 --- a/backend/internal/pkg/openai/oauth_test.go +++ b/backend/internal/pkg/openai/oauth_test.go @@ -1,6 +1,7 @@ package openai import ( + "net/url" "sync" "testing" "time" @@ -41,3 +42,41 @@ func TestSessionStore_Stop_Concurrent(t *testing.T) { t.Fatal("stopCh 未关闭") } } + +func TestBuildAuthorizationURLForPlatform_OpenAI(t *testing.T) { + authURL := BuildAuthorizationURLForPlatform("state-1", "challenge-1", DefaultRedirectURI, OAuthPlatformOpenAI) + parsed, err := url.Parse(authURL) + if err != nil { + t.Fatalf("Parse URL failed: %v", err) + } + q := parsed.Query() + if got := q.Get("client_id"); got != ClientID { + t.Fatalf("client_id mismatch: got=%q want=%q", got, ClientID) + } + if got := q.Get("codex_cli_simplified_flow"); got != "true" { + t.Fatalf("codex flow mismatch: got=%q want=true", got) + } + if got := q.Get("id_token_add_organizations"); got != "true" { + t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got) + } +} + +// TestBuildAuthorizationURLForPlatform_Sora 验证 Sora 平台复用 Codex CLI 的 client_id, +// 但不启用 codex_cli_simplified_flow。 +func TestBuildAuthorizationURLForPlatform_Sora(t *testing.T) { + authURL := BuildAuthorizationURLForPlatform("state-2", "challenge-2", DefaultRedirectURI, OAuthPlatformSora) + parsed, err := url.Parse(authURL) + if err != nil { + t.Fatalf("Parse URL failed: %v", err) + } + q := parsed.Query() + if got := q.Get("client_id"); got != ClientID { + t.Fatalf("client_id mismatch: got=%q want=%q (Sora should reuse Codex CLI client_id)", got, ClientID) + } + if got := q.Get("codex_cli_simplified_flow"); got != "" { + t.Fatalf("codex flow should be empty for sora, got=%q", got) + } + if got := q.Get("id_token_add_organizations"); got != "true" { + t.Fatalf("id_token_add_organizations mismatch: got=%q want=true", got) + } +} diff --git a/backend/internal/pkg/response/response_test.go b/backend/internal/pkg/response/response_test.go index 3c12f5f4..0debce5f 100644 --- a/backend/internal/pkg/response/response_test.go +++ b/backend/internal/pkg/response/response_test.go @@ -29,10 +29,10 @@ func parsePaginatedBody(t *testing.T, w *httptest.ResponseRecorder) (Response, P t.Helper() // 先用 raw json 解析,因为 Data 是 any 类型 var raw struct { - Code int `json:"code"` - Message string `json:"message"` - Reason string `json:"reason,omitempty"` - Data json.RawMessage `json:"data,omitempty"` + Code int `json:"code"` + Message string `json:"message"` + Reason string `json:"reason,omitempty"` + Data json.RawMessage `json:"data,omitempty"` } require.NoError(t, json.Unmarshal(w.Body.Bytes(), &raw)) diff --git a/backend/internal/pkg/tlsfingerprint/dialer.go b/backend/internal/pkg/tlsfingerprint/dialer.go index 992f8b0a..4f25a34a 100644 --- a/backend/internal/pkg/tlsfingerprint/dialer.go +++ b/backend/internal/pkg/tlsfingerprint/dialer.go @@ -268,8 +268,8 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st "cipher_suites", len(spec.CipherSuites), "extensions", len(spec.Extensions), "compression_methods", spec.CompressionMethods, - "tls_vers_max", fmt.Sprintf("0x%04x", spec.TLSVersMax), - "tls_vers_min", fmt.Sprintf("0x%04x", spec.TLSVersMin)) + "tls_vers_max", spec.TLSVersMax, + "tls_vers_min", spec.TLSVersMin) if d.profile != nil { slog.Debug("tls_fingerprint_socks5_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE) @@ -294,8 +294,8 @@ func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr st state := tlsConn.ConnectionState() slog.Debug("tls_fingerprint_socks5_handshake_success", - "version", fmt.Sprintf("0x%04x", state.Version), - "cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite), + "version", state.Version, + "cipher_suite", state.CipherSuite, "alpn", state.NegotiatedProtocol) return tlsConn, nil @@ -404,8 +404,8 @@ func (d *HTTPProxyDialer) DialTLSContext(ctx context.Context, network, addr stri state := tlsConn.ConnectionState() slog.Debug("tls_fingerprint_http_proxy_handshake_success", - "version", fmt.Sprintf("0x%04x", state.Version), - "cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite), + "version", state.Version, + "cipher_suite", state.CipherSuite, "alpn", state.NegotiatedProtocol) return tlsConn, nil @@ -470,8 +470,8 @@ func (d *Dialer) DialTLSContext(ctx context.Context, network, addr string) (net. // Log successful handshake details state := tlsConn.ConnectionState() slog.Debug("tls_fingerprint_handshake_success", - "version", fmt.Sprintf("0x%04x", state.Version), - "cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite), + "version", state.Version, + "cipher_suite", state.CipherSuite, "alpn", state.NegotiatedProtocol) return tlsConn, nil diff --git a/backend/internal/pkg/usagestats/usage_log_types.go b/backend/internal/pkg/usagestats/usage_log_types.go index 2f6c7fe0..5f4e13f5 100644 --- a/backend/internal/pkg/usagestats/usage_log_types.go +++ b/backend/internal/pkg/usagestats/usage_log_types.go @@ -139,6 +139,7 @@ type UsageLogFilters struct { AccountID int64 GroupID int64 Model string + RequestType *int16 Stream *bool BillingType *int8 StartTime *time.Time diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index 3f77a57e..4aa74928 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -50,11 +50,6 @@ type accountRepository struct { schedulerCache service.SchedulerCache } -type tempUnschedSnapshot struct { - until *time.Time - reason string -} - // NewAccountRepository 创建账户仓储实例。 // 这是对外暴露的构造函数,返回接口类型以便于依赖注入。 func NewAccountRepository(client *dbent.Client, sqlDB *sql.DB, schedulerCache service.SchedulerCache) service.AccountRepository { @@ -189,11 +184,6 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi accountIDs = append(accountIDs, acc.ID) } - tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs) - if err != nil { - return nil, err - } - groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs) if err != nil { return nil, err @@ -220,10 +210,6 @@ func (r *accountRepository) GetByIDs(ctx context.Context, ids []int64) ([]*servi if ags, ok := accountGroupsByAccount[entAcc.ID]; ok { out.AccountGroups = ags } - if snap, ok := tempUnschedMap[entAcc.ID]; ok { - out.TempUnschedulableUntil = snap.until - out.TempUnschedulableReason = snap.reason - } outByID[entAcc.ID] = out } @@ -611,6 +597,43 @@ func (r *accountRepository) syncSchedulerAccountSnapshot(ctx context.Context, ac } } +func (r *accountRepository) syncSchedulerAccountSnapshots(ctx context.Context, accountIDs []int64) { + if r == nil || r.schedulerCache == nil || len(accountIDs) == 0 { + return + } + + uniqueIDs := make([]int64, 0, len(accountIDs)) + seen := make(map[int64]struct{}, len(accountIDs)) + for _, id := range accountIDs { + if id <= 0 { + continue + } + if _, exists := seen[id]; exists { + continue + } + seen[id] = struct{}{} + uniqueIDs = append(uniqueIDs, id) + } + if len(uniqueIDs) == 0 { + return + } + + accounts, err := r.GetByIDs(ctx, uniqueIDs) + if err != nil { + logger.LegacyPrintf("repository.account", "[Scheduler] batch sync account snapshot read failed: count=%d err=%v", len(uniqueIDs), err) + return + } + + for _, account := range accounts { + if account == nil { + continue + } + if err := r.schedulerCache.SetAccount(ctx, account); err != nil { + logger.LegacyPrintf("repository.account", "[Scheduler] batch sync account snapshot write failed: id=%d err=%v", account.ID, err) + } + } +} + func (r *accountRepository) ClearError(ctx context.Context, id int64) error { _, err := r.client.Account.Update(). Where(dbaccount.IDEQ(id)). @@ -1197,9 +1220,7 @@ func (r *accountRepository) BulkUpdate(ctx context.Context, ids []int64, updates shouldSync = true } if shouldSync { - for _, id := range ids { - r.syncSchedulerAccountSnapshot(ctx, id) - } + r.syncSchedulerAccountSnapshots(ctx, ids) } } return rows, nil @@ -1291,10 +1312,6 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d if err != nil { return nil, err } - tempUnschedMap, err := r.loadTempUnschedStates(ctx, accountIDs) - if err != nil { - return nil, err - } groupsByAccount, groupIDsByAccount, accountGroupsByAccount, err := r.loadAccountGroups(ctx, accountIDs) if err != nil { return nil, err @@ -1320,10 +1337,6 @@ func (r *accountRepository) accountsToService(ctx context.Context, accounts []*d if ags, ok := accountGroupsByAccount[acc.ID]; ok { out.AccountGroups = ags } - if snap, ok := tempUnschedMap[acc.ID]; ok { - out.TempUnschedulableUntil = snap.until - out.TempUnschedulableReason = snap.reason - } outAccounts = append(outAccounts, *out) } @@ -1348,48 +1361,6 @@ func notExpiredPredicate(now time.Time) dbpredicate.Account { ) } -func (r *accountRepository) loadTempUnschedStates(ctx context.Context, accountIDs []int64) (map[int64]tempUnschedSnapshot, error) { - out := make(map[int64]tempUnschedSnapshot) - if len(accountIDs) == 0 { - return out, nil - } - - rows, err := r.sql.QueryContext(ctx, ` - SELECT id, temp_unschedulable_until, temp_unschedulable_reason - FROM accounts - WHERE id = ANY($1) - `, pq.Array(accountIDs)) - if err != nil { - return nil, err - } - defer func() { _ = rows.Close() }() - - for rows.Next() { - var id int64 - var until sql.NullTime - var reason sql.NullString - if err := rows.Scan(&id, &until, &reason); err != nil { - return nil, err - } - var untilPtr *time.Time - if until.Valid { - tmp := until.Time - untilPtr = &tmp - } - if reason.Valid { - out[id] = tempUnschedSnapshot{until: untilPtr, reason: reason.String} - } else { - out[id] = tempUnschedSnapshot{until: untilPtr, reason: ""} - } - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return out, nil -} - func (r *accountRepository) loadProxies(ctx context.Context, proxyIDs []int64) (map[int64]*service.Proxy, error) { proxyMap := make(map[int64]*service.Proxy) if len(proxyIDs) == 0 { @@ -1500,31 +1471,33 @@ func accountEntityToService(m *dbent.Account) *service.Account { rateMultiplier := m.RateMultiplier return &service.Account{ - ID: m.ID, - Name: m.Name, - Notes: m.Notes, - Platform: m.Platform, - Type: m.Type, - Credentials: copyJSONMap(m.Credentials), - Extra: copyJSONMap(m.Extra), - ProxyID: m.ProxyID, - Concurrency: m.Concurrency, - Priority: m.Priority, - RateMultiplier: &rateMultiplier, - Status: m.Status, - ErrorMessage: derefString(m.ErrorMessage), - LastUsedAt: m.LastUsedAt, - ExpiresAt: m.ExpiresAt, - AutoPauseOnExpired: m.AutoPauseOnExpired, - CreatedAt: m.CreatedAt, - UpdatedAt: m.UpdatedAt, - Schedulable: m.Schedulable, - RateLimitedAt: m.RateLimitedAt, - RateLimitResetAt: m.RateLimitResetAt, - OverloadUntil: m.OverloadUntil, - SessionWindowStart: m.SessionWindowStart, - SessionWindowEnd: m.SessionWindowEnd, - SessionWindowStatus: derefString(m.SessionWindowStatus), + ID: m.ID, + Name: m.Name, + Notes: m.Notes, + Platform: m.Platform, + Type: m.Type, + Credentials: copyJSONMap(m.Credentials), + Extra: copyJSONMap(m.Extra), + ProxyID: m.ProxyID, + Concurrency: m.Concurrency, + Priority: m.Priority, + RateMultiplier: &rateMultiplier, + Status: m.Status, + ErrorMessage: derefString(m.ErrorMessage), + LastUsedAt: m.LastUsedAt, + ExpiresAt: m.ExpiresAt, + AutoPauseOnExpired: m.AutoPauseOnExpired, + CreatedAt: m.CreatedAt, + UpdatedAt: m.UpdatedAt, + Schedulable: m.Schedulable, + RateLimitedAt: m.RateLimitedAt, + RateLimitResetAt: m.RateLimitResetAt, + OverloadUntil: m.OverloadUntil, + TempUnschedulableUntil: m.TempUnschedulableUntil, + TempUnschedulableReason: derefString(m.TempUnschedulableReason), + SessionWindowStart: m.SessionWindowStart, + SessionWindowEnd: m.SessionWindowEnd, + SessionWindowStatus: derefString(m.SessionWindowStatus), } } diff --git a/backend/internal/repository/account_repo_integration_test.go b/backend/internal/repository/account_repo_integration_test.go index 4f9d0152..fd48a5d4 100644 --- a/backend/internal/repository/account_repo_integration_test.go +++ b/backend/internal/repository/account_repo_integration_test.go @@ -500,6 +500,38 @@ func (s *AccountRepoSuite) TestClearRateLimit() { s.Require().Nil(got.OverloadUntil) } +func (s *AccountRepoSuite) TestTempUnschedulableFieldsLoadedByGetByIDAndGetByIDs() { + acc1 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-temp-1"}) + acc2 := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-temp-2"}) + + until := time.Now().Add(15 * time.Minute).UTC().Truncate(time.Second) + reason := `{"rule":"429","matched_keyword":"too many requests"}` + s.Require().NoError(s.repo.SetTempUnschedulable(s.ctx, acc1.ID, until, reason)) + + gotByID, err := s.repo.GetByID(s.ctx, acc1.ID) + s.Require().NoError(err) + s.Require().NotNil(gotByID.TempUnschedulableUntil) + s.Require().WithinDuration(until, *gotByID.TempUnschedulableUntil, time.Second) + s.Require().Equal(reason, gotByID.TempUnschedulableReason) + + gotByIDs, err := s.repo.GetByIDs(s.ctx, []int64{acc2.ID, acc1.ID}) + s.Require().NoError(err) + s.Require().Len(gotByIDs, 2) + s.Require().Equal(acc2.ID, gotByIDs[0].ID) + s.Require().Nil(gotByIDs[0].TempUnschedulableUntil) + s.Require().Equal("", gotByIDs[0].TempUnschedulableReason) + s.Require().Equal(acc1.ID, gotByIDs[1].ID) + s.Require().NotNil(gotByIDs[1].TempUnschedulableUntil) + s.Require().WithinDuration(until, *gotByIDs[1].TempUnschedulableUntil, time.Second) + s.Require().Equal(reason, gotByIDs[1].TempUnschedulableReason) + + s.Require().NoError(s.repo.ClearTempUnschedulable(s.ctx, acc1.ID)) + cleared, err := s.repo.GetByID(s.ctx, acc1.ID) + s.Require().NoError(err) + s.Require().Nil(cleared.TempUnschedulableUntil) + s.Require().Equal("", cleared.TempUnschedulableReason) +} + // --- UpdateLastUsed --- func (s *AccountRepoSuite) TestUpdateLastUsed() { diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index cdccd4fc..a9faf388 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -445,20 +445,22 @@ func userEntityToService(u *dbent.User) *service.User { return nil } return &service.User{ - ID: u.ID, - Email: u.Email, - Username: u.Username, - Notes: u.Notes, - PasswordHash: u.PasswordHash, - Role: u.Role, - Balance: u.Balance, - Concurrency: u.Concurrency, - Status: u.Status, - TotpSecretEncrypted: u.TotpSecretEncrypted, - TotpEnabled: u.TotpEnabled, - TotpEnabledAt: u.TotpEnabledAt, - CreatedAt: u.CreatedAt, - UpdatedAt: u.UpdatedAt, + ID: u.ID, + Email: u.Email, + Username: u.Username, + Notes: u.Notes, + PasswordHash: u.PasswordHash, + Role: u.Role, + Balance: u.Balance, + Concurrency: u.Concurrency, + Status: u.Status, + SoraStorageQuotaBytes: u.SoraStorageQuotaBytes, + SoraStorageUsedBytes: u.SoraStorageUsedBytes, + TotpSecretEncrypted: u.TotpSecretEncrypted, + TotpEnabled: u.TotpEnabled, + TotpEnabledAt: u.TotpEnabledAt, + CreatedAt: u.CreatedAt, + UpdatedAt: u.UpdatedAt, } } @@ -486,6 +488,7 @@ func groupEntityToService(g *dbent.Group) *service.Group { SoraImagePrice540: g.SoraImagePrice540, SoraVideoPricePerRequest: g.SoraVideoPricePerRequest, SoraVideoPricePerRequestHD: g.SoraVideoPricePerRequestHd, + SoraStorageQuotaBytes: g.SoraStorageQuotaBytes, DefaultValidityDays: g.DefaultValidityDays, ClaudeCodeOnly: g.ClaudeCodeOnly, FallbackGroupID: g.FallbackGroupID, diff --git a/backend/internal/repository/concurrency_cache.go b/backend/internal/repository/concurrency_cache.go index e047bff0..a2552715 100644 --- a/backend/internal/repository/concurrency_cache.go +++ b/backend/internal/repository/concurrency_cache.go @@ -227,6 +227,43 @@ func (c *concurrencyCache) GetAccountConcurrency(ctx context.Context, accountID return result, nil } +func (c *concurrencyCache) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { + if len(accountIDs) == 0 { + return map[int64]int{}, nil + } + + now, err := c.rdb.Time(ctx).Result() + if err != nil { + return nil, fmt.Errorf("redis TIME: %w", err) + } + cutoffTime := now.Unix() - int64(c.slotTTLSeconds) + + pipe := c.rdb.Pipeline() + type accountCmd struct { + accountID int64 + zcardCmd *redis.IntCmd + } + cmds := make([]accountCmd, 0, len(accountIDs)) + for _, accountID := range accountIDs { + slotKey := accountSlotKeyPrefix + strconv.FormatInt(accountID, 10) + pipe.ZRemRangeByScore(ctx, slotKey, "-inf", strconv.FormatInt(cutoffTime, 10)) + cmds = append(cmds, accountCmd{ + accountID: accountID, + zcardCmd: pipe.ZCard(ctx, slotKey), + }) + } + + if _, err := pipe.Exec(ctx); err != nil && !errors.Is(err, redis.Nil) { + return nil, fmt.Errorf("pipeline exec: %w", err) + } + + result := make(map[int64]int, len(accountIDs)) + for _, cmd := range cmds { + result[cmd.accountID] = int(cmd.zcardCmd.Val()) + } + return result, nil +} + // User slot operations func (c *concurrencyCache) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { diff --git a/backend/internal/repository/gateway_cache_integration_test.go b/backend/internal/repository/gateway_cache_integration_test.go index 2fdaa3d1..0eebc33f 100644 --- a/backend/internal/repository/gateway_cache_integration_test.go +++ b/backend/internal/repository/gateway_cache_integration_test.go @@ -104,7 +104,6 @@ func (s *GatewayCacheSuite) TestGetSessionAccountID_CorruptedValue() { require.False(s.T(), errors.Is(err, redis.Nil), "expected parsing error, not redis.Nil") } - func TestGatewayCacheSuite(t *testing.T) { suite.Run(t, new(GatewayCacheSuite)) } diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index fd239996..e9b4902a 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -4,6 +4,8 @@ import ( "context" "database/sql" "errors" + "fmt" + "strings" dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/apikey" @@ -56,7 +58,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetNillableFallbackGroupID(groupIn.FallbackGroupID). SetNillableFallbackGroupIDOnInvalidRequest(groupIn.FallbackGroupIDOnInvalidRequest). SetModelRoutingEnabled(groupIn.ModelRoutingEnabled). - SetMcpXMLInject(groupIn.MCPXMLInject) + SetMcpXMLInject(groupIn.MCPXMLInject). + SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes) // 设置模型路由配置 if groupIn.ModelRouting != nil { @@ -121,7 +124,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er SetDefaultValidityDays(groupIn.DefaultValidityDays). SetClaudeCodeOnly(groupIn.ClaudeCodeOnly). SetModelRoutingEnabled(groupIn.ModelRoutingEnabled). - SetMcpXMLInject(groupIn.MCPXMLInject) + SetMcpXMLInject(groupIn.MCPXMLInject). + SetSoraStorageQuotaBytes(groupIn.SoraStorageQuotaBytes) // 处理 FallbackGroupID:nil 时清除,否则设置 if groupIn.FallbackGroupID != nil { @@ -281,6 +285,54 @@ func (r *groupRepository) ExistsByName(ctx context.Context, name string) (bool, return r.client.Group.Query().Where(group.NameEQ(name)).Exist(ctx) } +// ExistsByIDs 批量检查分组是否存在(仅检查未软删除记录)。 +// 返回结构:map[groupID]exists。 +func (r *groupRepository) ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error) { + result := make(map[int64]bool, len(ids)) + if len(ids) == 0 { + return result, nil + } + + uniqueIDs := make([]int64, 0, len(ids)) + seen := make(map[int64]struct{}, len(ids)) + for _, id := range ids { + if id <= 0 { + continue + } + if _, ok := seen[id]; ok { + continue + } + seen[id] = struct{}{} + uniqueIDs = append(uniqueIDs, id) + result[id] = false + } + if len(uniqueIDs) == 0 { + return result, nil + } + + rows, err := r.sql.QueryContext(ctx, ` + SELECT id + FROM groups + WHERE id = ANY($1) AND deleted_at IS NULL + `, pq.Array(uniqueIDs)) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + for rows.Next() { + var id int64 + if err := rows.Scan(&id); err != nil { + return nil, err + } + result[id] = true + } + if err := rows.Err(); err != nil { + return nil, err + } + return result, nil +} + func (r *groupRepository) GetAccountCount(ctx context.Context, groupID int64) (int64, error) { var count int64 if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM account_groups WHERE group_id = $1", []any{groupID}, &count); err != nil { @@ -512,22 +564,72 @@ func (r *groupRepository) UpdateSortOrders(ctx context.Context, updates []servic return nil } - // 使用事务批量更新 - tx, err := r.client.Tx(ctx) + // 去重后保留最后一次排序值,避免重复 ID 造成 CASE 分支冲突。 + sortOrderByID := make(map[int64]int, len(updates)) + groupIDs := make([]int64, 0, len(updates)) + for _, u := range updates { + if u.ID <= 0 { + continue + } + if _, exists := sortOrderByID[u.ID]; !exists { + groupIDs = append(groupIDs, u.ID) + } + sortOrderByID[u.ID] = u.SortOrder + } + if len(groupIDs) == 0 { + return nil + } + + // 与旧实现保持一致:任何不存在/已删除的分组都返回 not found,且不执行更新。 + var existingCount int + if err := scanSingleRow( + ctx, + r.sql, + `SELECT COUNT(*) FROM groups WHERE deleted_at IS NULL AND id = ANY($1)`, + []any{pq.Array(groupIDs)}, + &existingCount, + ); err != nil { + return err + } + if existingCount != len(groupIDs) { + return service.ErrGroupNotFound + } + + args := make([]any, 0, len(groupIDs)*2+1) + caseClauses := make([]string, 0, len(groupIDs)) + placeholder := 1 + for _, id := range groupIDs { + caseClauses = append(caseClauses, fmt.Sprintf("WHEN $%d THEN $%d", placeholder, placeholder+1)) + args = append(args, id, sortOrderByID[id]) + placeholder += 2 + } + args = append(args, pq.Array(groupIDs)) + + query := fmt.Sprintf(` + UPDATE groups + SET sort_order = CASE id + %s + ELSE sort_order + END + WHERE deleted_at IS NULL AND id = ANY($%d) + `, strings.Join(caseClauses, "\n\t\t\t"), placeholder) + + result, err := r.sql.ExecContext(ctx, query, args...) if err != nil { return err } - defer func() { _ = tx.Rollback() }() - - for _, u := range updates { - if _, err := tx.Group.UpdateOneID(u.ID).SetSortOrder(u.SortOrder).Save(ctx); err != nil { - return translatePersistenceError(err, service.ErrGroupNotFound, nil) - } - } - - if err := tx.Commit(); err != nil { + affected, err := result.RowsAffected() + if err != nil { return err } + if affected != int64(len(groupIDs)) { + return service.ErrGroupNotFound + } + for _, id := range groupIDs { + if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventGroupChanged, nil, &id, nil); err != nil { + logger.LegacyPrintf("repository.group", "[SchedulerOutbox] enqueue group sort update failed: group=%d err=%v", id, err) + } + } return nil } diff --git a/backend/internal/repository/group_repo_integration_test.go b/backend/internal/repository/group_repo_integration_test.go index c31a9ec4..4a849a46 100644 --- a/backend/internal/repository/group_repo_integration_test.go +++ b/backend/internal/repository/group_repo_integration_test.go @@ -352,6 +352,81 @@ func (s *GroupRepoSuite) TestListWithFilters_Search() { }) } +func (s *GroupRepoSuite) TestUpdateSortOrders_BatchCaseWhen() { + g1 := &service.Group{ + Name: "sort-g1", + Platform: service.PlatformAnthropic, + RateMultiplier: 1.0, + IsExclusive: false, + Status: service.StatusActive, + SubscriptionType: service.SubscriptionTypeStandard, + } + g2 := &service.Group{ + Name: "sort-g2", + Platform: service.PlatformAnthropic, + RateMultiplier: 1.0, + IsExclusive: false, + Status: service.StatusActive, + SubscriptionType: service.SubscriptionTypeStandard, + } + g3 := &service.Group{ + Name: "sort-g3", + Platform: service.PlatformAnthropic, + RateMultiplier: 1.0, + IsExclusive: false, + Status: service.StatusActive, + SubscriptionType: service.SubscriptionTypeStandard, + } + s.Require().NoError(s.repo.Create(s.ctx, g1)) + s.Require().NoError(s.repo.Create(s.ctx, g2)) + s.Require().NoError(s.repo.Create(s.ctx, g3)) + + err := s.repo.UpdateSortOrders(s.ctx, []service.GroupSortOrderUpdate{ + {ID: g1.ID, SortOrder: 30}, + {ID: g2.ID, SortOrder: 10}, + {ID: g3.ID, SortOrder: 20}, + {ID: g2.ID, SortOrder: 15}, // 重复 ID 应以最后一次为准 + }) + s.Require().NoError(err) + + got1, err := s.repo.GetByID(s.ctx, g1.ID) + s.Require().NoError(err) + got2, err := s.repo.GetByID(s.ctx, g2.ID) + s.Require().NoError(err) + got3, err := s.repo.GetByID(s.ctx, g3.ID) + s.Require().NoError(err) + s.Require().Equal(30, got1.SortOrder) + s.Require().Equal(15, got2.SortOrder) + s.Require().Equal(20, got3.SortOrder) +} + +func (s *GroupRepoSuite) TestUpdateSortOrders_MissingGroupNoPartialUpdate() { + g1 := &service.Group{ + Name: "sort-no-partial", + Platform: service.PlatformAnthropic, + RateMultiplier: 1.0, + IsExclusive: false, + Status: service.StatusActive, + SubscriptionType: service.SubscriptionTypeStandard, + } + s.Require().NoError(s.repo.Create(s.ctx, g1)) + + before, err := s.repo.GetByID(s.ctx, g1.ID) + s.Require().NoError(err) + beforeSort := before.SortOrder + + err = s.repo.UpdateSortOrders(s.ctx, []service.GroupSortOrderUpdate{ + {ID: g1.ID, SortOrder: 99}, + {ID: 99999999, SortOrder: 1}, + }) + s.Require().Error(err) + s.Require().ErrorIs(err, service.ErrGroupNotFound) + + after, err := s.repo.GetByID(s.ctx, g1.ID) + s.Require().NoError(err) + s.Require().Equal(beforeSort, after.SortOrder) +} + func (s *GroupRepoSuite) TestListWithFilters_AccountCount() { g1 := &service.Group{ Name: "g1", diff --git a/backend/internal/repository/idempotency_repo_integration_test.go b/backend/internal/repository/idempotency_repo_integration_test.go index 23b52726..f163c2f0 100644 --- a/backend/internal/repository/idempotency_repo_integration_test.go +++ b/backend/internal/repository/idempotency_repo_integration_test.go @@ -147,4 +147,3 @@ func TestIdempotencyRepo_StatusTransition_ToSucceeded(t *testing.T) { require.Equal(t, `{"ok":true}`, *got.ResponseBody) require.Nil(t, got.LockedUntil) } - diff --git a/backend/internal/repository/migrations_runner.go b/backend/internal/repository/migrations_runner.go index 5912e50f..a60ba294 100644 --- a/backend/internal/repository/migrations_runner.go +++ b/backend/internal/repository/migrations_runner.go @@ -50,6 +50,23 @@ CREATE TABLE IF NOT EXISTS atlas_schema_revisions ( // 任何稳定的 int64 值都可以,只要不与同一数据库中的其他锁冲突即可。 const migrationsAdvisoryLockID int64 = 694208311321144027 const migrationsLockRetryInterval = 500 * time.Millisecond +const nonTransactionalMigrationSuffix = "_notx.sql" + +type migrationChecksumCompatibilityRule struct { + fileChecksum string + acceptedDBChecksum map[string]struct{} +} + +// migrationChecksumCompatibilityRules 仅用于兼容历史上误修改过的迁移文件 checksum。 +// 规则必须同时匹配「迁移名 + 当前文件 checksum + 历史库 checksum」才会放行,避免放宽全局校验。 +var migrationChecksumCompatibilityRules = map[string]migrationChecksumCompatibilityRule{ + "054_drop_legacy_cache_columns.sql": { + fileChecksum: "82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d", + acceptedDBChecksum: map[string]struct{}{ + "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4": {}, + }, + }, +} // ApplyMigrations 将嵌入的 SQL 迁移文件应用到指定的数据库。 // @@ -147,6 +164,10 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error { if rowErr == nil { // 迁移已应用,验证校验和是否匹配 if existing != checksum { + // 兼容特定历史误改场景(仅白名单规则),其余仍保持严格不可变约束。 + if isMigrationChecksumCompatible(name, existing, checksum) { + continue + } // 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。 // 正确的做法是创建新的迁移文件来进行变更。 return fmt.Errorf( @@ -165,8 +186,34 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error { return fmt.Errorf("check migration %s: %w", name, rowErr) } - // 迁移未应用,在事务中执行。 - // 使用事务确保迁移的原子性:要么完全成功,要么完全回滚。 + nonTx, err := validateMigrationExecutionMode(name, content) + if err != nil { + return fmt.Errorf("validate migration %s: %w", name, err) + } + + if nonTx { + // *_notx.sql:用于 CREATE/DROP INDEX CONCURRENTLY 场景,必须非事务执行。 + // 逐条语句执行,避免将多条 CONCURRENTLY 语句放入同一个隐式事务块。 + statements := splitSQLStatements(content) + for i, stmt := range statements { + trimmed := strings.TrimSpace(stmt) + if trimmed == "" { + continue + } + if stripSQLLineComment(trimmed) == "" { + continue + } + if _, err := db.ExecContext(ctx, trimmed); err != nil { + return fmt.Errorf("apply migration %s (non-tx statement %d): %w", name, i+1, err) + } + } + if _, err := db.ExecContext(ctx, "INSERT INTO schema_migrations (filename, checksum) VALUES ($1, $2)", name, checksum); err != nil { + return fmt.Errorf("record migration %s (non-tx): %w", name, err) + } + continue + } + + // 默认迁移在事务中执行,确保原子性:要么完全成功,要么完全回滚。 tx, err := db.BeginTx(ctx, nil) if err != nil { return fmt.Errorf("begin migration %s: %w", name, err) @@ -268,6 +315,84 @@ func latestMigrationBaseline(fsys fs.FS) (string, string, string, error) { return version, version, hash, nil } +func isMigrationChecksumCompatible(name, dbChecksum, fileChecksum string) bool { + rule, ok := migrationChecksumCompatibilityRules[name] + if !ok { + return false + } + if rule.fileChecksum != fileChecksum { + return false + } + _, ok = rule.acceptedDBChecksum[dbChecksum] + return ok +} + +func validateMigrationExecutionMode(name, content string) (bool, error) { + normalizedName := strings.ToLower(strings.TrimSpace(name)) + upperContent := strings.ToUpper(content) + nonTx := strings.HasSuffix(normalizedName, nonTransactionalMigrationSuffix) + + if !nonTx { + if strings.Contains(upperContent, "CONCURRENTLY") { + return false, errors.New("CONCURRENTLY statements must be placed in *_notx.sql migrations") + } + return false, nil + } + + if strings.Contains(upperContent, "BEGIN") || strings.Contains(upperContent, "COMMIT") || strings.Contains(upperContent, "ROLLBACK") { + return false, errors.New("*_notx.sql must not contain transaction control statements (BEGIN/COMMIT/ROLLBACK)") + } + + statements := splitSQLStatements(content) + for _, stmt := range statements { + normalizedStmt := strings.ToUpper(stripSQLLineComment(strings.TrimSpace(stmt))) + if normalizedStmt == "" { + continue + } + + if strings.Contains(normalizedStmt, "CONCURRENTLY") { + isCreateIndex := strings.Contains(normalizedStmt, "CREATE") && strings.Contains(normalizedStmt, "INDEX") + isDropIndex := strings.Contains(normalizedStmt, "DROP") && strings.Contains(normalizedStmt, "INDEX") + if !isCreateIndex && !isDropIndex { + return false, errors.New("*_notx.sql currently only supports CREATE/DROP INDEX CONCURRENTLY statements") + } + if isCreateIndex && !strings.Contains(normalizedStmt, "IF NOT EXISTS") { + return false, errors.New("CREATE INDEX CONCURRENTLY in *_notx.sql must include IF NOT EXISTS for idempotency") + } + if isDropIndex && !strings.Contains(normalizedStmt, "IF EXISTS") { + return false, errors.New("DROP INDEX CONCURRENTLY in *_notx.sql must include IF EXISTS for idempotency") + } + continue + } + + return false, errors.New("*_notx.sql must not mix non-CONCURRENTLY SQL statements") + } + + return true, nil +} + +func splitSQLStatements(content string) []string { + parts := strings.Split(content, ";") + out := make([]string, 0, len(parts)) + for _, part := range parts { + if strings.TrimSpace(part) == "" { + continue + } + out = append(out, part) + } + return out +} + +func stripSQLLineComment(s string) string { + lines := strings.Split(s, "\n") + for i, line := range lines { + if idx := strings.Index(line, "--"); idx >= 0 { + lines[i] = line[:idx] + } + } + return strings.TrimSpace(strings.Join(lines, "\n")) +} + // pgAdvisoryLock 获取 PostgreSQL Advisory Lock。 // Advisory Lock 是一种轻量级的锁机制,不与任何特定的数据库对象关联。 // 它非常适合用于应用层面的分布式锁场景,如迁移序列化。 diff --git a/backend/internal/repository/migrations_runner_checksum_test.go b/backend/internal/repository/migrations_runner_checksum_test.go new file mode 100644 index 00000000..54f5b0ec --- /dev/null +++ b/backend/internal/repository/migrations_runner_checksum_test.go @@ -0,0 +1,36 @@ +package repository + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsMigrationChecksumCompatible(t *testing.T) { + t.Run("054历史checksum可兼容", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "054_drop_legacy_cache_columns.sql", + "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4", + "82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d", + ) + require.True(t, ok) + }) + + t.Run("054在未知文件checksum下不兼容", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "054_drop_legacy_cache_columns.sql", + "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4", + "0000000000000000000000000000000000000000000000000000000000000000", + ) + require.False(t, ok) + }) + + t.Run("非白名单迁移不兼容", func(t *testing.T) { + ok := isMigrationChecksumCompatible( + "001_init.sql", + "182c193f3359946cf094090cd9e57d5c3fd9abaffbc1e8fc378646b8a6fa12b4", + "82de761156e03876653e7a6a4eee883cd927847036f779b0b9f34c42a8af7a7d", + ) + require.False(t, ok) + }) +} diff --git a/backend/internal/repository/migrations_runner_extra_test.go b/backend/internal/repository/migrations_runner_extra_test.go new file mode 100644 index 00000000..9f8a94c6 --- /dev/null +++ b/backend/internal/repository/migrations_runner_extra_test.go @@ -0,0 +1,368 @@ +package repository + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "errors" + "io/fs" + "strings" + "testing" + "testing/fstest" + "time" + + sqlmock "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/require" +) + +func TestApplyMigrations_NilDB(t *testing.T) { + err := ApplyMigrations(context.Background(), nil) + require.Error(t, err) + require.Contains(t, err.Error(), "nil sql db") +} + +func TestApplyMigrations_DelegatesToApplyMigrationsFS(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnError(errors.New("lock failed")) + + err = ApplyMigrations(context.Background(), db) + require.Error(t, err) + require.Contains(t, err.Error(), "acquire migrations lock") + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestLatestMigrationBaseline(t *testing.T) { + t.Run("empty_fs_returns_baseline", func(t *testing.T) { + version, description, hash, err := latestMigrationBaseline(fstest.MapFS{}) + require.NoError(t, err) + require.Equal(t, "baseline", version) + require.Equal(t, "baseline", description) + require.Equal(t, "", hash) + }) + + t.Run("uses_latest_sorted_sql_file", func(t *testing.T) { + fsys := fstest.MapFS{ + "001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t1(id int);")}, + "010_final.sql": &fstest.MapFile{ + Data: []byte("CREATE TABLE t2(id int);"), + }, + } + version, description, hash, err := latestMigrationBaseline(fsys) + require.NoError(t, err) + require.Equal(t, "010_final", version) + require.Equal(t, "010_final", description) + require.Len(t, hash, 64) + }) + + t.Run("read_file_error", func(t *testing.T) { + fsys := fstest.MapFS{ + "010_bad.sql": &fstest.MapFile{Mode: fs.ModeDir}, + } + _, _, _, err := latestMigrationBaseline(fsys) + require.Error(t, err) + }) +} + +func TestIsMigrationChecksumCompatible_AdditionalCases(t *testing.T) { + require.False(t, isMigrationChecksumCompatible("unknown.sql", "db", "file")) + + var ( + name string + rule migrationChecksumCompatibilityRule + ) + for n, r := range migrationChecksumCompatibilityRules { + name = n + rule = r + break + } + require.NotEmpty(t, name) + + require.False(t, isMigrationChecksumCompatible(name, "db-not-accepted", "file-not-match")) + require.False(t, isMigrationChecksumCompatible(name, "db-not-accepted", rule.fileChecksum)) + + var accepted string + for checksum := range rule.acceptedDBChecksum { + accepted = checksum + break + } + require.NotEmpty(t, accepted) + require.True(t, isMigrationChecksumCompatible(name, accepted, rule.fileChecksum)) +} + +func TestEnsureAtlasBaselineAligned(t *testing.T) { + t.Run("skip_when_no_legacy_table", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("schema_migrations"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false)) + + err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{}) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("create_atlas_and_insert_baseline_when_empty", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("schema_migrations"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("atlas_schema_revisions"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false)) + mock.ExpectExec("CREATE TABLE IF NOT EXISTS atlas_schema_revisions"). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) + mock.ExpectExec("INSERT INTO atlas_schema_revisions"). + WithArgs("002_next", "002_next", 1, sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + + fsys := fstest.MapFS{ + "001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t1(id int);")}, + "002_next.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t2(id int);")}, + } + err = ensureAtlasBaselineAligned(context.Background(), db, fsys) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("error_when_checking_legacy_table", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("schema_migrations"). + WillReturnError(errors.New("exists failed")) + + err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{}) + require.Error(t, err) + require.Contains(t, err.Error(), "check schema_migrations") + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("error_when_counting_atlas_rows", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("schema_migrations"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("atlas_schema_revisions"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions"). + WillReturnError(errors.New("count failed")) + + err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{}) + require.Error(t, err) + require.Contains(t, err.Error(), "count atlas_schema_revisions") + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("error_when_creating_atlas_table", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("schema_migrations"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("atlas_schema_revisions"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(false)) + mock.ExpectExec("CREATE TABLE IF NOT EXISTS atlas_schema_revisions"). + WillReturnError(errors.New("create failed")) + + err = ensureAtlasBaselineAligned(context.Background(), db, fstest.MapFS{}) + require.Error(t, err) + require.Contains(t, err.Error(), "create atlas_schema_revisions") + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("error_when_inserting_baseline", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("schema_migrations"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("atlas_schema_revisions"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) + mock.ExpectExec("INSERT INTO atlas_schema_revisions"). + WithArgs("001_init", "001_init", 1, sqlmock.AnyArg()). + WillReturnError(errors.New("insert failed")) + + fsys := fstest.MapFS{ + "001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t(id int);")}, + } + err = ensureAtlasBaselineAligned(context.Background(), db, fsys) + require.Error(t, err) + require.Contains(t, err.Error(), "insert atlas baseline") + require.NoError(t, mock.ExpectationsWereMet()) + }) +} + +func TestApplyMigrationsFS_ChecksumMismatchRejected(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1"). + WithArgs("001_init.sql"). + WillReturnRows(sqlmock.NewRows([]string{"checksum"}).AddRow("mismatched-checksum")) + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "001_init.sql": &fstest.MapFile{Data: []byte("CREATE TABLE t(id int);")}, + } + err = applyMigrationsFS(context.Background(), db, fsys) + require.Error(t, err) + require.Contains(t, err.Error(), "checksum mismatch") + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestApplyMigrationsFS_CheckMigrationQueryError(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1"). + WithArgs("001_err.sql"). + WillReturnError(errors.New("query failed")) + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "001_err.sql": &fstest.MapFile{Data: []byte("SELECT 1;")}, + } + err = applyMigrationsFS(context.Background(), db, fsys) + require.Error(t, err) + require.Contains(t, err.Error(), "check migration 001_err.sql") + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestApplyMigrationsFS_SkipEmptyAndAlreadyApplied(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + + alreadySQL := "CREATE TABLE t(id int);" + checksum := migrationChecksum(alreadySQL) + mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1"). + WithArgs("001_already.sql"). + WillReturnRows(sqlmock.NewRows([]string{"checksum"}).AddRow(checksum)) + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "000_empty.sql": &fstest.MapFile{Data: []byte(" \n\t ")}, + "001_already.sql": &fstest.MapFile{Data: []byte(alreadySQL)}, + } + err = applyMigrationsFS(context.Background(), db, fsys) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestApplyMigrationsFS_ReadMigrationError(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "001_bad.sql": &fstest.MapFile{Mode: fs.ModeDir}, + } + err = applyMigrationsFS(context.Background(), db, fsys) + require.Error(t, err) + require.Contains(t, err.Error(), "read migration 001_bad.sql") + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestPgAdvisoryLockAndUnlock_ErrorBranches(t *testing.T) { + t.Run("context_cancelled_while_not_locked", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(false)) + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Millisecond) + defer cancel() + err = pgAdvisoryLock(ctx, db) + require.Error(t, err) + require.Contains(t, err.Error(), "acquire migrations lock") + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("unlock_exec_error", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnError(errors.New("unlock failed")) + + err = pgAdvisoryUnlock(context.Background(), db) + require.Error(t, err) + require.Contains(t, err.Error(), "release migrations lock") + require.NoError(t, mock.ExpectationsWereMet()) + }) + + t.Run("acquire_lock_after_retry", func(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(false)) + mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(true)) + + ctx, cancel := context.WithTimeout(context.Background(), migrationsLockRetryInterval*3) + defer cancel() + start := time.Now() + err = pgAdvisoryLock(ctx, db) + require.NoError(t, err) + require.GreaterOrEqual(t, time.Since(start), migrationsLockRetryInterval) + require.NoError(t, mock.ExpectationsWereMet()) + }) +} + +func migrationChecksum(content string) string { + sum := sha256.Sum256([]byte(strings.TrimSpace(content))) + return hex.EncodeToString(sum[:]) +} diff --git a/backend/internal/repository/migrations_runner_notx_test.go b/backend/internal/repository/migrations_runner_notx_test.go new file mode 100644 index 00000000..db1183cd --- /dev/null +++ b/backend/internal/repository/migrations_runner_notx_test.go @@ -0,0 +1,164 @@ +package repository + +import ( + "context" + "database/sql" + "testing" + "testing/fstest" + + sqlmock "github.com/DATA-DOG/go-sqlmock" + "github.com/stretchr/testify/require" +) + +func TestValidateMigrationExecutionMode(t *testing.T) { + t.Run("事务迁移包含CONCURRENTLY会被拒绝", func(t *testing.T) { + nonTx, err := validateMigrationExecutionMode("001_add_idx.sql", "CREATE INDEX CONCURRENTLY idx_a ON t(a);") + require.False(t, nonTx) + require.Error(t, err) + }) + + t.Run("notx迁移要求CREATE使用IF NOT EXISTS", func(t *testing.T) { + nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", "CREATE INDEX CONCURRENTLY idx_a ON t(a);") + require.False(t, nonTx) + require.Error(t, err) + }) + + t.Run("notx迁移要求DROP使用IF EXISTS", func(t *testing.T) { + nonTx, err := validateMigrationExecutionMode("001_drop_idx_notx.sql", "DROP INDEX CONCURRENTLY idx_a;") + require.False(t, nonTx) + require.Error(t, err) + }) + + t.Run("notx迁移禁止事务控制语句", func(t *testing.T) { + nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", "BEGIN; CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a); COMMIT;") + require.False(t, nonTx) + require.Error(t, err) + }) + + t.Run("notx迁移禁止混用非CONCURRENTLY语句", func(t *testing.T) { + nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", "CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a); UPDATE t SET a = 1;") + require.False(t, nonTx) + require.Error(t, err) + }) + + t.Run("notx迁移允许幂等并发索引语句", func(t *testing.T) { + nonTx, err := validateMigrationExecutionMode("001_add_idx_notx.sql", ` +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_a ON t(a); +DROP INDEX CONCURRENTLY IF EXISTS idx_b; +`) + require.True(t, nonTx) + require.NoError(t, err) + }) +} + +func TestApplyMigrationsFS_NonTransactionalMigration(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1"). + WithArgs("001_add_idx_notx.sql"). + WillReturnError(sql.ErrNoRows) + mock.ExpectExec("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t\\(a\\)"). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)"). + WithArgs("001_add_idx_notx.sql", sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "001_add_idx_notx.sql": &fstest.MapFile{ + Data: []byte("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t(a);"), + }, + } + + err = applyMigrationsFS(context.Background(), db, fsys) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestApplyMigrationsFS_NonTransactionalMigration_MultiStatements(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1"). + WithArgs("001_add_multi_idx_notx.sql"). + WillReturnError(sql.ErrNoRows) + mock.ExpectExec("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t\\(a\\)"). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec("CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t\\(b\\)"). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)"). + WithArgs("001_add_multi_idx_notx.sql", sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "001_add_multi_idx_notx.sql": &fstest.MapFile{ + Data: []byte(` +-- first +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_a ON t(a); +-- second +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_t_b ON t(b); +`), + }, + } + + err = applyMigrationsFS(context.Background(), db, fsys) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestApplyMigrationsFS_TransactionalMigration(t *testing.T) { + db, mock, err := sqlmock.New() + require.NoError(t, err) + defer func() { _ = db.Close() }() + + prepareMigrationsBootstrapExpectations(mock) + mock.ExpectQuery("SELECT checksum FROM schema_migrations WHERE filename = \\$1"). + WithArgs("001_add_col.sql"). + WillReturnError(sql.ErrNoRows) + mock.ExpectBegin() + mock.ExpectExec("ALTER TABLE t ADD COLUMN name TEXT"). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec("INSERT INTO schema_migrations \\(filename, checksum\\) VALUES \\(\\$1, \\$2\\)"). + WithArgs("001_add_col.sql", sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + mock.ExpectCommit() + mock.ExpectExec("SELECT pg_advisory_unlock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnResult(sqlmock.NewResult(0, 1)) + + fsys := fstest.MapFS{ + "001_add_col.sql": &fstest.MapFile{ + Data: []byte("ALTER TABLE t ADD COLUMN name TEXT;"), + }, + } + + err = applyMigrationsFS(context.Background(), db, fsys) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func prepareMigrationsBootstrapExpectations(mock sqlmock.Sqlmock) { + mock.ExpectQuery("SELECT pg_try_advisory_lock\\(\\$1\\)"). + WithArgs(migrationsAdvisoryLockID). + WillReturnRows(sqlmock.NewRows([]string{"pg_try_advisory_lock"}).AddRow(true)) + mock.ExpectExec("CREATE TABLE IF NOT EXISTS schema_migrations"). + WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("schema_migrations"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT EXISTS \\("). + WithArgs("atlas_schema_revisions"). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(true)) + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM atlas_schema_revisions"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) +} diff --git a/backend/internal/repository/migrations_schema_integration_test.go b/backend/internal/repository/migrations_schema_integration_test.go index f50d2b26..72422d18 100644 --- a/backend/internal/repository/migrations_schema_integration_test.go +++ b/backend/internal/repository/migrations_schema_integration_test.go @@ -42,6 +42,8 @@ func TestMigrationsRunner_IsIdempotent_AndSchemaIsUpToDate(t *testing.T) { // usage_logs: billing_type used by filters/stats requireColumn(t, tx, "usage_logs", "billing_type", "smallint", 0, false) + requireColumn(t, tx, "usage_logs", "request_type", "smallint", 0, false) + requireColumn(t, tx, "usage_logs", "openai_ws_mode", "boolean", 0, false) // settings table should exist var settingsRegclass sql.NullString diff --git a/backend/internal/repository/openai_oauth_service.go b/backend/internal/repository/openai_oauth_service.go index 088e7d7f..3e155971 100644 --- a/backend/internal/repository/openai_oauth_service.go +++ b/backend/internal/repository/openai_oauth_service.go @@ -22,16 +22,20 @@ type openaiOAuthService struct { tokenURL string } -func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) { +func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) { client := createOpenAIReqClient(proxyURL) if redirectURI == "" { redirectURI = openai.DefaultRedirectURI } + clientID = strings.TrimSpace(clientID) + if clientID == "" { + clientID = openai.ClientID + } formData := url.Values{} formData.Set("grant_type", "authorization_code") - formData.Set("client_id", openai.ClientID) + formData.Set("client_id", clientID) formData.Set("code", code) formData.Set("redirect_uri", redirectURI) formData.Set("code_verifier", codeVerifier) @@ -61,36 +65,12 @@ func (s *openaiOAuthService) RefreshToken(ctx context.Context, refreshToken, pro } func (s *openaiOAuthService) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) { - if strings.TrimSpace(clientID) != "" { - return s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, strings.TrimSpace(clientID)) + // 调用方应始终传入正确的 client_id;为兼容旧数据,未指定时默认使用 OpenAI ClientID + clientID = strings.TrimSpace(clientID) + if clientID == "" { + clientID = openai.ClientID } - - clientIDs := []string{ - openai.ClientID, - openai.SoraClientID, - } - seen := make(map[string]struct{}, len(clientIDs)) - var lastErr error - for _, clientID := range clientIDs { - clientID = strings.TrimSpace(clientID) - if clientID == "" { - continue - } - if _, ok := seen[clientID]; ok { - continue - } - seen[clientID] = struct{}{} - - tokenResp, err := s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID) - if err == nil { - return tokenResp, nil - } - lastErr = err - } - if lastErr != nil { - return nil, lastErr - } - return nil, infraerrors.New(http.StatusBadGateway, "OPENAI_OAUTH_TOKEN_REFRESH_FAILED", "token refresh failed") + return s.refreshTokenWithClientID(ctx, refreshToken, proxyURL, clientID) } func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL, clientID string) (*openai.TokenResponse, error) { diff --git a/backend/internal/repository/openai_oauth_service_test.go b/backend/internal/repository/openai_oauth_service_test.go index 5938272a..44fa291b 100644 --- a/backend/internal/repository/openai_oauth_service_test.go +++ b/backend/internal/repository/openai_oauth_service_test.go @@ -81,7 +81,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_DefaultRedirectURI() { _, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`) })) - resp, err := s.svc.ExchangeCode(s.ctx, "code", "ver", "", "") + resp, err := s.svc.ExchangeCode(s.ctx, "code", "ver", "", "", "") require.NoError(s.T(), err, "ExchangeCode") select { case msg := <-errCh: @@ -136,7 +136,9 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FormFields() { require.Equal(s.T(), "rt2", resp.RefreshToken) } -func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() { +// TestRefreshToken_DefaultsToOpenAIClientID 验证未指定 client_id 时默认使用 OpenAI ClientID, +// 且只发送一次请求(不再盲猜多个 client_id)。 +func (s *OpenAIOAuthServiceSuite) TestRefreshToken_DefaultsToOpenAIClientID() { var seenClientIDs []string s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if err := r.ParseForm(); err != nil { @@ -145,11 +147,27 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() { } clientID := r.PostForm.Get("client_id") seenClientIDs = append(seenClientIDs, clientID) - if clientID == openai.ClientID { + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"access_token":"at","refresh_token":"rt","token_type":"bearer","expires_in":3600}`) + })) + + resp, err := s.svc.RefreshToken(s.ctx, "rt", "") + require.NoError(s.T(), err, "RefreshToken") + require.Equal(s.T(), "at", resp.AccessToken) + // 只发送了一次请求,使用默认的 OpenAI ClientID + require.Equal(s.T(), []string{openai.ClientID}, seenClientIDs) +} + +// TestRefreshToken_UseSoraClientID 验证显式传入 Sora ClientID 时直接使用,不回退。 +func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseSoraClientID() { + var seenClientIDs []string + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := r.ParseForm(); err != nil { w.WriteHeader(http.StatusBadRequest) - _, _ = io.WriteString(w, "invalid_grant") return } + clientID := r.PostForm.Get("client_id") + seenClientIDs = append(seenClientIDs, clientID) if clientID == openai.SoraClientID { w.Header().Set("Content-Type", "application/json") _, _ = io.WriteString(w, `{"access_token":"at-sora","refresh_token":"rt-sora","token_type":"bearer","expires_in":3600}`) @@ -158,11 +176,10 @@ func (s *OpenAIOAuthServiceSuite) TestRefreshToken_FallbackToSoraClientID() { w.WriteHeader(http.StatusBadRequest) })) - resp, err := s.svc.RefreshToken(s.ctx, "rt", "") - require.NoError(s.T(), err, "RefreshToken") + resp, err := s.svc.RefreshTokenWithClientID(s.ctx, "rt", "", openai.SoraClientID) + require.NoError(s.T(), err, "RefreshTokenWithClientID") require.Equal(s.T(), "at-sora", resp.AccessToken) - require.Equal(s.T(), "rt-sora", resp.RefreshToken) - require.Equal(s.T(), []string{openai.ClientID, openai.SoraClientID}, seenClientIDs) + require.Equal(s.T(), []string{openai.SoraClientID}, seenClientIDs) } func (s *OpenAIOAuthServiceSuite) TestRefreshToken_UseProvidedClientID() { @@ -196,7 +213,7 @@ func (s *OpenAIOAuthServiceSuite) TestNonSuccessStatus_IncludesBody() { _, _ = io.WriteString(w, "bad") })) - _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "") + _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "") require.Error(s.T(), err) require.ErrorContains(s.T(), err, "status 400") require.ErrorContains(s.T(), err, "bad") @@ -206,7 +223,7 @@ func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() { s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) s.srv.Close() - _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "") + _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "") require.Error(s.T(), err) require.ErrorContains(s.T(), err, "request failed") } @@ -223,7 +240,7 @@ func (s *OpenAIOAuthServiceSuite) TestContextCancel() { done := make(chan error, 1) go func() { - _, err := s.svc.ExchangeCode(ctx, "code", "ver", openai.DefaultRedirectURI, "") + _, err := s.svc.ExchangeCode(ctx, "code", "ver", openai.DefaultRedirectURI, "", "") done <- err }() @@ -249,7 +266,30 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UsesProvidedRedirectURI() { _, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`) })) - _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", want, "") + _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", want, "", "") + require.NoError(s.T(), err, "ExchangeCode") + select { + case msg := <-errCh: + require.Fail(s.T(), msg) + default: + } +} + +func (s *OpenAIOAuthServiceSuite) TestExchangeCode_UseProvidedClientID() { + wantClientID := openai.SoraClientID + errCh := make(chan string, 1) + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _ = r.ParseForm() + if got := r.PostForm.Get("client_id"); got != wantClientID { + errCh <- "client_id mismatch" + w.WriteHeader(http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", "application/json") + _, _ = io.WriteString(w, `{"access_token":"at","token_type":"bearer","expires_in":1}`) + })) + + _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", wantClientID) require.NoError(s.T(), err, "ExchangeCode") select { case msg := <-errCh: @@ -267,7 +307,7 @@ func (s *OpenAIOAuthServiceSuite) TestTokenURL_CanBeOverriddenWithQuery() { })) s.svc.tokenURL = s.srv.URL + "?x=1" - _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "") + _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "") require.NoError(s.T(), err, "ExchangeCode") select { case <-s.received: @@ -283,7 +323,7 @@ func (s *OpenAIOAuthServiceSuite) TestExchangeCode_SuccessButInvalidJSON() { _, _ = io.WriteString(w, "not-valid-json") })) - _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "") + _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "") require.Error(s.T(), err, "expected error for invalid JSON response") } diff --git a/backend/internal/repository/ops_repo_dashboard.go b/backend/internal/repository/ops_repo_dashboard.go index 85791a9a..b43d6706 100644 --- a/backend/internal/repository/ops_repo_dashboard.go +++ b/backend/internal/repository/ops_repo_dashboard.go @@ -12,6 +12,11 @@ import ( "github.com/Wei-Shaw/sub2api/internal/service" ) +const ( + opsRawLatencyQueryTimeout = 2 * time.Second + opsRawPeakQueryTimeout = 1500 * time.Millisecond +) + func (r *opsRepository) GetDashboardOverview(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsDashboardOverview, error) { if r == nil || r.db == nil { return nil, fmt.Errorf("nil ops repository") @@ -45,15 +50,24 @@ func (r *opsRepository) GetDashboardOverview(ctx context.Context, filter *servic func (r *opsRepository) getDashboardOverviewRaw(ctx context.Context, filter *service.OpsDashboardFilter) (*service.OpsDashboardOverview, error) { start := filter.StartTime.UTC() end := filter.EndTime.UTC() + degraded := false successCount, tokenConsumed, err := r.queryUsageCounts(ctx, filter, start, end) if err != nil { return nil, err } - duration, ttft, err := r.queryUsageLatency(ctx, filter, start, end) + latencyCtx, cancelLatency := context.WithTimeout(ctx, opsRawLatencyQueryTimeout) + duration, ttft, err := r.queryUsageLatency(latencyCtx, filter, start, end) + cancelLatency() if err != nil { - return nil, err + if isQueryTimeoutErr(err) { + degraded = true + duration = service.OpsPercentiles{} + ttft = service.OpsPercentiles{} + } else { + return nil, err + } } errorTotal, businessLimited, errorCountSLA, upstreamExcl, upstream429, upstream529, err := r.queryErrorCounts(ctx, filter, start, end) @@ -75,20 +89,40 @@ func (r *opsRepository) getDashboardOverviewRaw(ctx context.Context, filter *ser qpsCurrent, tpsCurrent, err := r.queryCurrentRates(ctx, filter, end) if err != nil { - return nil, err + if isQueryTimeoutErr(err) { + degraded = true + } else { + return nil, err + } } - qpsPeak, err := r.queryPeakQPS(ctx, filter, start, end) + peakCtx, cancelPeak := context.WithTimeout(ctx, opsRawPeakQueryTimeout) + qpsPeak, tpsPeak, err := r.queryPeakRates(peakCtx, filter, start, end) + cancelPeak() if err != nil { - return nil, err - } - tpsPeak, err := r.queryPeakTPS(ctx, filter, start, end) - if err != nil { - return nil, err + if isQueryTimeoutErr(err) { + degraded = true + } else { + return nil, err + } } qpsAvg := roundTo1DP(float64(requestCountTotal) / windowSeconds) tpsAvg := roundTo1DP(float64(tokenConsumed) / windowSeconds) + if degraded { + if qpsCurrent <= 0 { + qpsCurrent = qpsAvg + } + if tpsCurrent <= 0 { + tpsCurrent = tpsAvg + } + if qpsPeak <= 0 { + qpsPeak = roundTo1DP(math.Max(qpsCurrent, qpsAvg)) + } + if tpsPeak <= 0 { + tpsPeak = roundTo1DP(math.Max(tpsCurrent, tpsAvg)) + } + } return &service.OpsDashboardOverview{ StartTime: start, @@ -230,26 +264,45 @@ func (r *opsRepository) getDashboardOverviewPreaggregated(ctx context.Context, f sla := safeDivideFloat64(float64(successCount), float64(requestCountSLA)) errorRate := safeDivideFloat64(float64(errorCountSLA), float64(requestCountSLA)) upstreamErrorRate := safeDivideFloat64(float64(upstreamExcl), float64(requestCountSLA)) + degraded := false // Keep "current" rates as raw, to preserve realtime semantics. qpsCurrent, tpsCurrent, err := r.queryCurrentRates(ctx, filter, end) if err != nil { - return nil, err + if isQueryTimeoutErr(err) { + degraded = true + } else { + return nil, err + } } - // NOTE: peak still uses raw logs (minute granularity). This is typically cheaper than percentile_cont - // and keeps semantics consistent across modes. - qpsPeak, err := r.queryPeakQPS(ctx, filter, start, end) + peakCtx, cancelPeak := context.WithTimeout(ctx, opsRawPeakQueryTimeout) + qpsPeak, tpsPeak, err := r.queryPeakRates(peakCtx, filter, start, end) + cancelPeak() if err != nil { - return nil, err - } - tpsPeak, err := r.queryPeakTPS(ctx, filter, start, end) - if err != nil { - return nil, err + if isQueryTimeoutErr(err) { + degraded = true + } else { + return nil, err + } } qpsAvg := roundTo1DP(float64(requestCountTotal) / windowSeconds) tpsAvg := roundTo1DP(float64(tokenConsumed) / windowSeconds) + if degraded { + if qpsCurrent <= 0 { + qpsCurrent = qpsAvg + } + if tpsCurrent <= 0 { + tpsCurrent = tpsAvg + } + if qpsPeak <= 0 { + qpsPeak = roundTo1DP(math.Max(qpsCurrent, qpsAvg)) + } + if tpsPeak <= 0 { + tpsPeak = roundTo1DP(math.Max(tpsCurrent, tpsAvg)) + } + } return &service.OpsDashboardOverview{ StartTime: start, @@ -577,9 +630,16 @@ func (r *opsRepository) queryRawPartial(ctx context.Context, filter *service.Ops return nil, err } - duration, ttft, err := r.queryUsageLatency(ctx, filter, start, end) + latencyCtx, cancelLatency := context.WithTimeout(ctx, opsRawLatencyQueryTimeout) + duration, ttft, err := r.queryUsageLatency(latencyCtx, filter, start, end) + cancelLatency() if err != nil { - return nil, err + if isQueryTimeoutErr(err) { + duration = service.OpsPercentiles{} + ttft = service.OpsPercentiles{} + } else { + return nil, err + } } errorTotal, businessLimited, errorCountSLA, upstreamExcl, upstream429, upstream529, err := r.queryErrorCounts(ctx, filter, start, end) @@ -735,68 +795,56 @@ FROM usage_logs ul } func (r *opsRepository) queryUsageLatency(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (duration service.OpsPercentiles, ttft service.OpsPercentiles, err error) { - { - join, where, args, _ := buildUsageWhere(filter, start, end, 1) - q := ` + join, where, args, _ := buildUsageWhere(filter, start, end, 1) + q := ` SELECT - percentile_cont(0.50) WITHIN GROUP (ORDER BY duration_ms) AS p50, - percentile_cont(0.90) WITHIN GROUP (ORDER BY duration_ms) AS p90, - percentile_cont(0.95) WITHIN GROUP (ORDER BY duration_ms) AS p95, - percentile_cont(0.99) WITHIN GROUP (ORDER BY duration_ms) AS p99, - AVG(duration_ms) AS avg_ms, - MAX(duration_ms) AS max_ms + percentile_cont(0.50) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p50, + percentile_cont(0.90) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p90, + percentile_cont(0.95) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p95, + percentile_cont(0.99) WITHIN GROUP (ORDER BY duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_p99, + AVG(duration_ms) FILTER (WHERE duration_ms IS NOT NULL) AS duration_avg, + MAX(duration_ms) AS duration_max, + percentile_cont(0.50) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p50, + percentile_cont(0.90) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p90, + percentile_cont(0.95) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p95, + percentile_cont(0.99) WITHIN GROUP (ORDER BY first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_p99, + AVG(first_token_ms) FILTER (WHERE first_token_ms IS NOT NULL) AS ttft_avg, + MAX(first_token_ms) AS ttft_max FROM usage_logs ul ` + join + ` -` + where + ` -AND duration_ms IS NOT NULL` +` + where - var p50, p90, p95, p99 sql.NullFloat64 - var avg sql.NullFloat64 - var max sql.NullInt64 - if err := r.db.QueryRowContext(ctx, q, args...).Scan(&p50, &p90, &p95, &p99, &avg, &max); err != nil { - return service.OpsPercentiles{}, service.OpsPercentiles{}, err - } - duration.P50 = floatToIntPtr(p50) - duration.P90 = floatToIntPtr(p90) - duration.P95 = floatToIntPtr(p95) - duration.P99 = floatToIntPtr(p99) - duration.Avg = floatToIntPtr(avg) - if max.Valid { - v := int(max.Int64) - duration.Max = &v - } + var dP50, dP90, dP95, dP99 sql.NullFloat64 + var dAvg sql.NullFloat64 + var dMax sql.NullInt64 + var tP50, tP90, tP95, tP99 sql.NullFloat64 + var tAvg sql.NullFloat64 + var tMax sql.NullInt64 + if err := r.db.QueryRowContext(ctx, q, args...).Scan( + &dP50, &dP90, &dP95, &dP99, &dAvg, &dMax, + &tP50, &tP90, &tP95, &tP99, &tAvg, &tMax, + ); err != nil { + return service.OpsPercentiles{}, service.OpsPercentiles{}, err } - { - join, where, args, _ := buildUsageWhere(filter, start, end, 1) - q := ` -SELECT - percentile_cont(0.50) WITHIN GROUP (ORDER BY first_token_ms) AS p50, - percentile_cont(0.90) WITHIN GROUP (ORDER BY first_token_ms) AS p90, - percentile_cont(0.95) WITHIN GROUP (ORDER BY first_token_ms) AS p95, - percentile_cont(0.99) WITHIN GROUP (ORDER BY first_token_ms) AS p99, - AVG(first_token_ms) AS avg_ms, - MAX(first_token_ms) AS max_ms -FROM usage_logs ul -` + join + ` -` + where + ` -AND first_token_ms IS NOT NULL` + duration.P50 = floatToIntPtr(dP50) + duration.P90 = floatToIntPtr(dP90) + duration.P95 = floatToIntPtr(dP95) + duration.P99 = floatToIntPtr(dP99) + duration.Avg = floatToIntPtr(dAvg) + if dMax.Valid { + v := int(dMax.Int64) + duration.Max = &v + } - var p50, p90, p95, p99 sql.NullFloat64 - var avg sql.NullFloat64 - var max sql.NullInt64 - if err := r.db.QueryRowContext(ctx, q, args...).Scan(&p50, &p90, &p95, &p99, &avg, &max); err != nil { - return service.OpsPercentiles{}, service.OpsPercentiles{}, err - } - ttft.P50 = floatToIntPtr(p50) - ttft.P90 = floatToIntPtr(p90) - ttft.P95 = floatToIntPtr(p95) - ttft.P99 = floatToIntPtr(p99) - ttft.Avg = floatToIntPtr(avg) - if max.Valid { - v := int(max.Int64) - ttft.Max = &v - } + ttft.P50 = floatToIntPtr(tP50) + ttft.P90 = floatToIntPtr(tP90) + ttft.P95 = floatToIntPtr(tP95) + ttft.P99 = floatToIntPtr(tP99) + ttft.Avg = floatToIntPtr(tAvg) + if tMax.Valid { + v := int(tMax.Int64) + ttft.Max = &v } return duration, ttft, nil @@ -854,20 +902,23 @@ func (r *opsRepository) queryCurrentRates(ctx context.Context, filter *service.O return qpsCurrent, tpsCurrent, nil } -func (r *opsRepository) queryPeakQPS(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (float64, error) { +func (r *opsRepository) queryPeakRates(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (qpsPeak float64, tpsPeak float64, err error) { usageJoin, usageWhere, usageArgs, next := buildUsageWhere(filter, start, end, 1) errorWhere, errorArgs, _ := buildErrorWhere(filter, start, end, next) q := ` WITH usage_buckets AS ( - SELECT date_trunc('minute', ul.created_at) AS bucket, COUNT(*) AS cnt + SELECT + date_trunc('minute', ul.created_at) AS bucket, + COUNT(*) AS req_cnt, + COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS token_cnt FROM usage_logs ul ` + usageJoin + ` ` + usageWhere + ` GROUP BY 1 ), error_buckets AS ( - SELECT date_trunc('minute', created_at) AS bucket, COUNT(*) AS cnt + SELECT date_trunc('minute', created_at) AS bucket, COUNT(*) AS err_cnt FROM ops_error_logs ` + errorWhere + ` AND COALESCE(status_code, 0) >= 400 @@ -875,47 +926,33 @@ error_buckets AS ( ), combined AS ( SELECT COALESCE(u.bucket, e.bucket) AS bucket, - COALESCE(u.cnt, 0) + COALESCE(e.cnt, 0) AS total + COALESCE(u.req_cnt, 0) + COALESCE(e.err_cnt, 0) AS total_req, + COALESCE(u.token_cnt, 0) AS total_tokens FROM usage_buckets u FULL OUTER JOIN error_buckets e ON u.bucket = e.bucket ) -SELECT COALESCE(MAX(total), 0) FROM combined` +SELECT + COALESCE(MAX(total_req), 0) AS max_req_per_min, + COALESCE(MAX(total_tokens), 0) AS max_tokens_per_min +FROM combined` args := append(usageArgs, errorArgs...) - var maxPerMinute sql.NullInt64 - if err := r.db.QueryRowContext(ctx, q, args...).Scan(&maxPerMinute); err != nil { - return 0, err + var maxReqPerMinute, maxTokensPerMinute sql.NullInt64 + if err := r.db.QueryRowContext(ctx, q, args...).Scan(&maxReqPerMinute, &maxTokensPerMinute); err != nil { + return 0, 0, err } - if !maxPerMinute.Valid || maxPerMinute.Int64 <= 0 { - return 0, nil + if maxReqPerMinute.Valid && maxReqPerMinute.Int64 > 0 { + qpsPeak = roundTo1DP(float64(maxReqPerMinute.Int64) / 60.0) } - return roundTo1DP(float64(maxPerMinute.Int64) / 60.0), nil + if maxTokensPerMinute.Valid && maxTokensPerMinute.Int64 > 0 { + tpsPeak = roundTo1DP(float64(maxTokensPerMinute.Int64) / 60.0) + } + return qpsPeak, tpsPeak, nil } -func (r *opsRepository) queryPeakTPS(ctx context.Context, filter *service.OpsDashboardFilter, start, end time.Time) (float64, error) { - join, where, args, _ := buildUsageWhere(filter, start, end, 1) - - q := ` -SELECT COALESCE(MAX(tokens_per_min), 0) -FROM ( - SELECT - date_trunc('minute', ul.created_at) AS bucket, - COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) AS tokens_per_min - FROM usage_logs ul - ` + join + ` - ` + where + ` - GROUP BY 1 -) t` - - var maxPerMinute sql.NullInt64 - if err := r.db.QueryRowContext(ctx, q, args...).Scan(&maxPerMinute); err != nil { - return 0, err - } - if !maxPerMinute.Valid || maxPerMinute.Int64 <= 0 { - return 0, nil - } - return roundTo1DP(float64(maxPerMinute.Int64) / 60.0), nil +func isQueryTimeoutErr(err error) bool { + return errors.Is(err, context.DeadlineExceeded) } func buildUsageWhere(filter *service.OpsDashboardFilter, start, end time.Time, startIndex int) (join string, where string, args []any, nextIndex int) { diff --git a/backend/internal/repository/ops_repo_dashboard_timeout_test.go b/backend/internal/repository/ops_repo_dashboard_timeout_test.go new file mode 100644 index 00000000..76332ca0 --- /dev/null +++ b/backend/internal/repository/ops_repo_dashboard_timeout_test.go @@ -0,0 +1,22 @@ +package repository + +import ( + "context" + "fmt" + "testing" +) + +func TestIsQueryTimeoutErr(t *testing.T) { + if !isQueryTimeoutErr(context.DeadlineExceeded) { + t.Fatalf("context.DeadlineExceeded should be treated as query timeout") + } + if !isQueryTimeoutErr(fmt.Errorf("wrapped: %w", context.DeadlineExceeded)) { + t.Fatalf("wrapped context.DeadlineExceeded should be treated as query timeout") + } + if isQueryTimeoutErr(context.Canceled) { + t.Fatalf("context.Canceled should not be treated as query timeout") + } + if isQueryTimeoutErr(fmt.Errorf("wrapped: %w", context.Canceled)) { + t.Fatalf("wrapped context.Canceled should not be treated as query timeout") + } +} diff --git a/backend/internal/repository/sora_generation_repo.go b/backend/internal/repository/sora_generation_repo.go new file mode 100644 index 00000000..aaf3cb2f --- /dev/null +++ b/backend/internal/repository/sora_generation_repo.go @@ -0,0 +1,419 @@ +package repository + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +// soraGenerationRepository 实现 service.SoraGenerationRepository 接口。 +// 使用原生 SQL 操作 sora_generations 表。 +type soraGenerationRepository struct { + sql *sql.DB +} + +// NewSoraGenerationRepository 创建 Sora 生成记录仓储实例。 +func NewSoraGenerationRepository(sqlDB *sql.DB) service.SoraGenerationRepository { + return &soraGenerationRepository{sql: sqlDB} +} + +func (r *soraGenerationRepository) Create(ctx context.Context, gen *service.SoraGeneration) error { + mediaURLsJSON, _ := json.Marshal(gen.MediaURLs) + s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys) + + err := r.sql.QueryRowContext(ctx, ` + INSERT INTO sora_generations ( + user_id, api_key_id, model, prompt, media_type, + status, media_url, media_urls, file_size_bytes, + storage_type, s3_object_keys, upstream_task_id, error_message + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) + RETURNING id, created_at + `, + gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType, + gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes, + gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage, + ).Scan(&gen.ID, &gen.CreatedAt) + return err +} + +// CreatePendingWithLimit 在单事务内执行“并发上限检查 + 创建”,避免 count+create 竞态。 +func (r *soraGenerationRepository) CreatePendingWithLimit( + ctx context.Context, + gen *service.SoraGeneration, + activeStatuses []string, + maxActive int64, +) error { + if gen == nil { + return fmt.Errorf("generation is nil") + } + if maxActive <= 0 { + return r.Create(ctx, gen) + } + if len(activeStatuses) == 0 { + activeStatuses = []string{service.SoraGenStatusPending, service.SoraGenStatusGenerating} + } + + tx, err := r.sql.BeginTx(ctx, nil) + if err != nil { + return err + } + defer func() { _ = tx.Rollback() }() + + // 使用用户级 advisory lock 串行化并发创建,避免超限竞态。 + if _, err := tx.ExecContext(ctx, `SELECT pg_advisory_xact_lock($1)`, gen.UserID); err != nil { + return err + } + + placeholders := make([]string, len(activeStatuses)) + args := make([]any, 0, 1+len(activeStatuses)) + args = append(args, gen.UserID) + for i, s := range activeStatuses { + placeholders[i] = fmt.Sprintf("$%d", i+2) + args = append(args, s) + } + countQuery := fmt.Sprintf( + `SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)`, + strings.Join(placeholders, ","), + ) + var activeCount int64 + if err := tx.QueryRowContext(ctx, countQuery, args...).Scan(&activeCount); err != nil { + return err + } + if activeCount >= maxActive { + return service.ErrSoraGenerationConcurrencyLimit + } + + mediaURLsJSON, _ := json.Marshal(gen.MediaURLs) + s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys) + if err := tx.QueryRowContext(ctx, ` + INSERT INTO sora_generations ( + user_id, api_key_id, model, prompt, media_type, + status, media_url, media_urls, file_size_bytes, + storage_type, s3_object_keys, upstream_task_id, error_message + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13) + RETURNING id, created_at + `, + gen.UserID, gen.APIKeyID, gen.Model, gen.Prompt, gen.MediaType, + gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes, + gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, gen.ErrorMessage, + ).Scan(&gen.ID, &gen.CreatedAt); err != nil { + return err + } + + return tx.Commit() +} + +func (r *soraGenerationRepository) GetByID(ctx context.Context, id int64) (*service.SoraGeneration, error) { + gen := &service.SoraGeneration{} + var mediaURLsJSON, s3KeysJSON []byte + var completedAt sql.NullTime + var apiKeyID sql.NullInt64 + + err := r.sql.QueryRowContext(ctx, ` + SELECT id, user_id, api_key_id, model, prompt, media_type, + status, media_url, media_urls, file_size_bytes, + storage_type, s3_object_keys, upstream_task_id, error_message, + created_at, completed_at + FROM sora_generations WHERE id = $1 + `, id).Scan( + &gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType, + &gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes, + &gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage, + &gen.CreatedAt, &completedAt, + ) + if err != nil { + if err == sql.ErrNoRows { + return nil, fmt.Errorf("生成记录不存在") + } + return nil, err + } + + if apiKeyID.Valid { + gen.APIKeyID = &apiKeyID.Int64 + } + if completedAt.Valid { + gen.CompletedAt = &completedAt.Time + } + _ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs) + _ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys) + return gen, nil +} + +func (r *soraGenerationRepository) Update(ctx context.Context, gen *service.SoraGeneration) error { + mediaURLsJSON, _ := json.Marshal(gen.MediaURLs) + s3KeysJSON, _ := json.Marshal(gen.S3ObjectKeys) + + var completedAt *time.Time + if gen.CompletedAt != nil { + completedAt = gen.CompletedAt + } + + _, err := r.sql.ExecContext(ctx, ` + UPDATE sora_generations SET + status = $2, media_url = $3, media_urls = $4, file_size_bytes = $5, + storage_type = $6, s3_object_keys = $7, upstream_task_id = $8, + error_message = $9, completed_at = $10 + WHERE id = $1 + `, + gen.ID, gen.Status, gen.MediaURL, mediaURLsJSON, gen.FileSizeBytes, + gen.StorageType, s3KeysJSON, gen.UpstreamTaskID, + gen.ErrorMessage, completedAt, + ) + return err +} + +// UpdateGeneratingIfPending 仅当状态为 pending 时更新为 generating。 +func (r *soraGenerationRepository) UpdateGeneratingIfPending(ctx context.Context, id int64, upstreamTaskID string) (bool, error) { + result, err := r.sql.ExecContext(ctx, ` + UPDATE sora_generations + SET status = $2, upstream_task_id = $3 + WHERE id = $1 AND status = $4 + `, + id, service.SoraGenStatusGenerating, upstreamTaskID, service.SoraGenStatusPending, + ) + if err != nil { + return false, err + } + affected, err := result.RowsAffected() + if err != nil { + return false, err + } + return affected > 0, nil +} + +// UpdateCompletedIfActive 仅当状态为 pending/generating 时更新为 completed。 +func (r *soraGenerationRepository) UpdateCompletedIfActive( + ctx context.Context, + id int64, + mediaURL string, + mediaURLs []string, + storageType string, + s3Keys []string, + fileSizeBytes int64, + completedAt time.Time, +) (bool, error) { + mediaURLsJSON, _ := json.Marshal(mediaURLs) + s3KeysJSON, _ := json.Marshal(s3Keys) + result, err := r.sql.ExecContext(ctx, ` + UPDATE sora_generations + SET status = $2, + media_url = $3, + media_urls = $4, + file_size_bytes = $5, + storage_type = $6, + s3_object_keys = $7, + error_message = '', + completed_at = $8 + WHERE id = $1 AND status IN ($9, $10) + `, + id, service.SoraGenStatusCompleted, mediaURL, mediaURLsJSON, fileSizeBytes, + storageType, s3KeysJSON, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating, + ) + if err != nil { + return false, err + } + affected, err := result.RowsAffected() + if err != nil { + return false, err + } + return affected > 0, nil +} + +// UpdateFailedIfActive 仅当状态为 pending/generating 时更新为 failed。 +func (r *soraGenerationRepository) UpdateFailedIfActive( + ctx context.Context, + id int64, + errMsg string, + completedAt time.Time, +) (bool, error) { + result, err := r.sql.ExecContext(ctx, ` + UPDATE sora_generations + SET status = $2, + error_message = $3, + completed_at = $4 + WHERE id = $1 AND status IN ($5, $6) + `, + id, service.SoraGenStatusFailed, errMsg, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating, + ) + if err != nil { + return false, err + } + affected, err := result.RowsAffected() + if err != nil { + return false, err + } + return affected > 0, nil +} + +// UpdateCancelledIfActive 仅当状态为 pending/generating 时更新为 cancelled。 +func (r *soraGenerationRepository) UpdateCancelledIfActive(ctx context.Context, id int64, completedAt time.Time) (bool, error) { + result, err := r.sql.ExecContext(ctx, ` + UPDATE sora_generations + SET status = $2, completed_at = $3 + WHERE id = $1 AND status IN ($4, $5) + `, + id, service.SoraGenStatusCancelled, completedAt, service.SoraGenStatusPending, service.SoraGenStatusGenerating, + ) + if err != nil { + return false, err + } + affected, err := result.RowsAffected() + if err != nil { + return false, err + } + return affected > 0, nil +} + +// UpdateStorageIfCompleted 更新已完成记录的存储信息(用于手动保存,不重置 completed_at)。 +func (r *soraGenerationRepository) UpdateStorageIfCompleted( + ctx context.Context, + id int64, + mediaURL string, + mediaURLs []string, + storageType string, + s3Keys []string, + fileSizeBytes int64, +) (bool, error) { + mediaURLsJSON, _ := json.Marshal(mediaURLs) + s3KeysJSON, _ := json.Marshal(s3Keys) + result, err := r.sql.ExecContext(ctx, ` + UPDATE sora_generations + SET media_url = $2, + media_urls = $3, + file_size_bytes = $4, + storage_type = $5, + s3_object_keys = $6 + WHERE id = $1 AND status = $7 + `, + id, mediaURL, mediaURLsJSON, fileSizeBytes, storageType, s3KeysJSON, service.SoraGenStatusCompleted, + ) + if err != nil { + return false, err + } + affected, err := result.RowsAffected() + if err != nil { + return false, err + } + return affected > 0, nil +} + +func (r *soraGenerationRepository) Delete(ctx context.Context, id int64) error { + _, err := r.sql.ExecContext(ctx, `DELETE FROM sora_generations WHERE id = $1`, id) + return err +} + +func (r *soraGenerationRepository) List(ctx context.Context, params service.SoraGenerationListParams) ([]*service.SoraGeneration, int64, error) { + // 构建 WHERE 条件 + conditions := []string{"user_id = $1"} + args := []any{params.UserID} + argIdx := 2 + + if params.Status != "" { + // 支持逗号分隔的多状态 + statuses := strings.Split(params.Status, ",") + placeholders := make([]string, len(statuses)) + for i, s := range statuses { + placeholders[i] = fmt.Sprintf("$%d", argIdx) + args = append(args, strings.TrimSpace(s)) + argIdx++ + } + conditions = append(conditions, fmt.Sprintf("status IN (%s)", strings.Join(placeholders, ","))) + } + if params.StorageType != "" { + storageTypes := strings.Split(params.StorageType, ",") + placeholders := make([]string, len(storageTypes)) + for i, s := range storageTypes { + placeholders[i] = fmt.Sprintf("$%d", argIdx) + args = append(args, strings.TrimSpace(s)) + argIdx++ + } + conditions = append(conditions, fmt.Sprintf("storage_type IN (%s)", strings.Join(placeholders, ","))) + } + if params.MediaType != "" { + conditions = append(conditions, fmt.Sprintf("media_type = $%d", argIdx)) + args = append(args, params.MediaType) + argIdx++ + } + + whereClause := "WHERE " + strings.Join(conditions, " AND ") + + // 计数 + var total int64 + countQuery := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations %s", whereClause) + if err := r.sql.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil { + return nil, 0, err + } + + // 分页查询 + offset := (params.Page - 1) * params.PageSize + listQuery := fmt.Sprintf(` + SELECT id, user_id, api_key_id, model, prompt, media_type, + status, media_url, media_urls, file_size_bytes, + storage_type, s3_object_keys, upstream_task_id, error_message, + created_at, completed_at + FROM sora_generations %s + ORDER BY created_at DESC + LIMIT $%d OFFSET $%d + `, whereClause, argIdx, argIdx+1) + args = append(args, params.PageSize, offset) + + rows, err := r.sql.QueryContext(ctx, listQuery, args...) + if err != nil { + return nil, 0, err + } + defer func() { + _ = rows.Close() + }() + + var results []*service.SoraGeneration + for rows.Next() { + gen := &service.SoraGeneration{} + var mediaURLsJSON, s3KeysJSON []byte + var completedAt sql.NullTime + var apiKeyID sql.NullInt64 + + if err := rows.Scan( + &gen.ID, &gen.UserID, &apiKeyID, &gen.Model, &gen.Prompt, &gen.MediaType, + &gen.Status, &gen.MediaURL, &mediaURLsJSON, &gen.FileSizeBytes, + &gen.StorageType, &s3KeysJSON, &gen.UpstreamTaskID, &gen.ErrorMessage, + &gen.CreatedAt, &completedAt, + ); err != nil { + return nil, 0, err + } + + if apiKeyID.Valid { + gen.APIKeyID = &apiKeyID.Int64 + } + if completedAt.Valid { + gen.CompletedAt = &completedAt.Time + } + _ = json.Unmarshal(mediaURLsJSON, &gen.MediaURLs) + _ = json.Unmarshal(s3KeysJSON, &gen.S3ObjectKeys) + results = append(results, gen) + } + + return results, total, rows.Err() +} + +func (r *soraGenerationRepository) CountByUserAndStatus(ctx context.Context, userID int64, statuses []string) (int64, error) { + if len(statuses) == 0 { + return 0, nil + } + + placeholders := make([]string, len(statuses)) + args := []any{userID} + for i, s := range statuses { + placeholders[i] = fmt.Sprintf("$%d", i+2) + args = append(args, s) + } + + var count int64 + query := fmt.Sprintf("SELECT COUNT(*) FROM sora_generations WHERE user_id = $1 AND status IN (%s)", strings.Join(placeholders, ",")) + err := r.sql.QueryRowContext(ctx, query, args...).Scan(&count) + return count, err +} diff --git a/backend/internal/repository/usage_cleanup_repo.go b/backend/internal/repository/usage_cleanup_repo.go index 9c021357..1a25696e 100644 --- a/backend/internal/repository/usage_cleanup_repo.go +++ b/backend/internal/repository/usage_cleanup_repo.go @@ -362,7 +362,12 @@ func buildUsageCleanupWhere(filters service.UsageCleanupFilters) (string, []any) idx++ } } - if filters.Stream != nil { + if filters.RequestType != nil { + condition, conditionArgs := buildRequestTypeFilterCondition(idx, *filters.RequestType) + conditions = append(conditions, condition) + args = append(args, conditionArgs...) + idx += len(conditionArgs) + } else if filters.Stream != nil { conditions = append(conditions, fmt.Sprintf("stream = $%d", idx)) args = append(args, *filters.Stream) idx++ diff --git a/backend/internal/repository/usage_cleanup_repo_test.go b/backend/internal/repository/usage_cleanup_repo_test.go index 0ca30ec7..1ac7cca5 100644 --- a/backend/internal/repository/usage_cleanup_repo_test.go +++ b/backend/internal/repository/usage_cleanup_repo_test.go @@ -466,6 +466,38 @@ func TestBuildUsageCleanupWhere(t *testing.T) { require.Equal(t, []any{start, end, userID, apiKeyID, accountID, groupID, "gpt-4", stream, billingType}, args) } +func TestBuildUsageCleanupWhereRequestTypePriority(t *testing.T) { + start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + requestType := int16(service.RequestTypeWSV2) + stream := false + + where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{ + StartTime: start, + EndTime: end, + RequestType: &requestType, + Stream: &stream, + }) + + require.Equal(t, "created_at >= $1 AND created_at <= $2 AND (request_type = $3 OR (request_type = 0 AND openai_ws_mode = TRUE))", where) + require.Equal(t, []any{start, end, requestType}, args) +} + +func TestBuildUsageCleanupWhereRequestTypeLegacyFallback(t *testing.T) { + start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + requestType := int16(service.RequestTypeStream) + + where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{ + StartTime: start, + EndTime: end, + RequestType: &requestType, + }) + + require.Equal(t, "created_at >= $1 AND created_at <= $2 AND (request_type = $3 OR (request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE))", where) + require.Equal(t, []any{start, end, requestType}, args) +} + func TestBuildUsageCleanupWhereModelEmpty(t *testing.T) { start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) end := start.Add(24 * time.Hour) diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index ce67ba4d..4a08904d 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -22,7 +22,7 @@ import ( "github.com/lib/pq" ) -const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, stream, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at" +const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, media_type, reasoning_effort, cache_ttl_overridden, created_at" // dateFormatWhitelist 将 granularity 参数映射为 PostgreSQL TO_CHAR 格式字符串,防止外部输入直接拼入 SQL var dateFormatWhitelist = map[string]string{ @@ -98,6 +98,8 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) log.RequestID = requestID rateMultiplier := log.RateMultiplier + log.SyncRequestTypeAndLegacyFields() + requestType := int16(log.RequestType) query := ` INSERT INTO usage_logs ( @@ -123,7 +125,9 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) rate_multiplier, account_rate_multiplier, billing_type, + request_type, stream, + openai_ws_mode, duration_ms, first_token_ms, user_agent, @@ -140,7 +144,7 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, - $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33 + $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35 ) ON CONFLICT (request_id, api_key_id) DO NOTHING RETURNING id, created_at @@ -184,7 +188,9 @@ func (r *usageLogRepository) Create(ctx context.Context, log *service.UsageLog) rateMultiplier, log.AccountRateMultiplier, log.BillingType, + requestType, log.Stream, + log.OpenAIWSMode, duration, firstToken, userAgent, @@ -492,25 +498,46 @@ func (r *usageLogRepository) fillDashboardUsageStatsAggregated(ctx context.Conte } func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Context, stats *DashboardStats, startUTC, endUTC, todayUTC, now time.Time) error { - totalStatsQuery := ` + todayEnd := todayUTC.Add(24 * time.Hour) + combinedStatsQuery := ` + WITH scoped AS ( + SELECT + created_at, + input_tokens, + output_tokens, + cache_creation_tokens, + cache_read_tokens, + total_cost, + actual_cost, + COALESCE(duration_ms, 0) AS duration_ms + FROM usage_logs + WHERE created_at >= LEAST($1::timestamptz, $3::timestamptz) + AND created_at < GREATEST($2::timestamptz, $4::timestamptz) + ) SELECT - COUNT(*) as total_requests, - COALESCE(SUM(input_tokens), 0) as total_input_tokens, - COALESCE(SUM(output_tokens), 0) as total_output_tokens, - COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens, - COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens, - COALESCE(SUM(total_cost), 0) as total_cost, - COALESCE(SUM(actual_cost), 0) as total_actual_cost, - COALESCE(SUM(COALESCE(duration_ms, 0)), 0) as total_duration_ms - FROM usage_logs - WHERE created_at >= $1 AND created_at < $2 + COUNT(*) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz) AS total_requests, + COALESCE(SUM(input_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_input_tokens, + COALESCE(SUM(output_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_output_tokens, + COALESCE(SUM(cache_creation_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cache_creation_tokens, + COALESCE(SUM(cache_read_tokens) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cache_read_tokens, + COALESCE(SUM(total_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_cost, + COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_actual_cost, + COALESCE(SUM(duration_ms) FILTER (WHERE created_at >= $1::timestamptz AND created_at < $2::timestamptz), 0) AS total_duration_ms, + COUNT(*) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz) AS today_requests, + COALESCE(SUM(input_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_input_tokens, + COALESCE(SUM(output_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_output_tokens, + COALESCE(SUM(cache_creation_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cache_creation_tokens, + COALESCE(SUM(cache_read_tokens) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cache_read_tokens, + COALESCE(SUM(total_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_cost, + COALESCE(SUM(actual_cost) FILTER (WHERE created_at >= $3::timestamptz AND created_at < $4::timestamptz), 0) AS today_actual_cost + FROM scoped ` var totalDurationMs int64 if err := scanSingleRow( ctx, r.sql, - totalStatsQuery, - []any{startUTC, endUTC}, + combinedStatsQuery, + []any{startUTC, endUTC, todayUTC, todayEnd}, &stats.TotalRequests, &stats.TotalInputTokens, &stats.TotalOutputTokens, @@ -519,32 +546,6 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co &stats.TotalCost, &stats.TotalActualCost, &totalDurationMs, - ); err != nil { - return err - } - stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens - if stats.TotalRequests > 0 { - stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests) - } - - todayEnd := todayUTC.Add(24 * time.Hour) - todayStatsQuery := ` - SELECT - COUNT(*) as today_requests, - COALESCE(SUM(input_tokens), 0) as today_input_tokens, - COALESCE(SUM(output_tokens), 0) as today_output_tokens, - COALESCE(SUM(cache_creation_tokens), 0) as today_cache_creation_tokens, - COALESCE(SUM(cache_read_tokens), 0) as today_cache_read_tokens, - COALESCE(SUM(total_cost), 0) as today_cost, - COALESCE(SUM(actual_cost), 0) as today_actual_cost - FROM usage_logs - WHERE created_at >= $1 AND created_at < $2 - ` - if err := scanSingleRow( - ctx, - r.sql, - todayStatsQuery, - []any{todayUTC, todayEnd}, &stats.TodayRequests, &stats.TodayInputTokens, &stats.TodayOutputTokens, @@ -555,25 +556,28 @@ func (r *usageLogRepository) fillDashboardUsageStatsFromUsageLogs(ctx context.Co ); err != nil { return err } - stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens - - activeUsersQuery := ` - SELECT COUNT(DISTINCT user_id) as active_users - FROM usage_logs - WHERE created_at >= $1 AND created_at < $2 - ` - if err := scanSingleRow(ctx, r.sql, activeUsersQuery, []any{todayUTC, todayEnd}, &stats.ActiveUsers); err != nil { - return err + stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens + if stats.TotalRequests > 0 { + stats.AverageDurationMs = float64(totalDurationMs) / float64(stats.TotalRequests) } + stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens + hourStart := now.UTC().Truncate(time.Hour) hourEnd := hourStart.Add(time.Hour) - hourlyActiveQuery := ` - SELECT COUNT(DISTINCT user_id) as active_users - FROM usage_logs - WHERE created_at >= $1 AND created_at < $2 + activeUsersQuery := ` + WITH scoped AS ( + SELECT user_id, created_at + FROM usage_logs + WHERE created_at >= LEAST($1::timestamptz, $3::timestamptz) + AND created_at < GREATEST($2::timestamptz, $4::timestamptz) + ) + SELECT + COUNT(DISTINCT CASE WHEN created_at >= $1::timestamptz AND created_at < $2::timestamptz THEN user_id END) AS active_users, + COUNT(DISTINCT CASE WHEN created_at >= $3::timestamptz AND created_at < $4::timestamptz THEN user_id END) AS hourly_active_users + FROM scoped ` - if err := scanSingleRow(ctx, r.sql, hourlyActiveQuery, []any{hourStart, hourEnd}, &stats.HourlyActiveUsers); err != nil { + if err := scanSingleRow(ctx, r.sql, activeUsersQuery, []any{todayUTC, todayEnd, hourStart, hourEnd}, &stats.ActiveUsers, &stats.HourlyActiveUsers); err != nil { return err } @@ -968,6 +972,61 @@ func (r *usageLogRepository) GetAccountWindowStatsBatch(ctx context.Context, acc return result, nil } +// GetGeminiUsageTotalsBatch 批量聚合 Gemini 账号在窗口内的 Pro/Flash 请求与用量。 +// 模型分类规则与 service.geminiModelClassFromName 一致:model 包含 flash/lite 视为 flash,其余视为 pro。 +func (r *usageLogRepository) GetGeminiUsageTotalsBatch(ctx context.Context, accountIDs []int64, startTime, endTime time.Time) (map[int64]service.GeminiUsageTotals, error) { + result := make(map[int64]service.GeminiUsageTotals, len(accountIDs)) + if len(accountIDs) == 0 { + return result, nil + } + + query := ` + SELECT + account_id, + COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 1 ELSE 0 END), 0) AS flash_requests, + COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 0 ELSE 1 END), 0) AS pro_requests, + COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN (input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) ELSE 0 END), 0) AS flash_tokens, + COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 0 ELSE (input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens) END), 0) AS pro_tokens, + COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN actual_cost ELSE 0 END), 0) AS flash_cost, + COALESCE(SUM(CASE WHEN LOWER(COALESCE(model, '')) LIKE '%flash%' OR LOWER(COALESCE(model, '')) LIKE '%lite%' THEN 0 ELSE actual_cost END), 0) AS pro_cost + FROM usage_logs + WHERE account_id = ANY($1) AND created_at >= $2 AND created_at < $3 + GROUP BY account_id + ` + rows, err := r.sql.QueryContext(ctx, query, pq.Array(accountIDs), startTime, endTime) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + for rows.Next() { + var accountID int64 + var totals service.GeminiUsageTotals + if err := rows.Scan( + &accountID, + &totals.FlashRequests, + &totals.ProRequests, + &totals.FlashTokens, + &totals.ProTokens, + &totals.FlashCost, + &totals.ProCost, + ); err != nil { + return nil, err + } + result[accountID] = totals + } + if err := rows.Err(); err != nil { + return nil, err + } + + for _, accountID := range accountIDs { + if _, ok := result[accountID]; !ok { + result[accountID] = service.GeminiUsageTotals{} + } + } + return result, nil +} + // TrendDataPoint represents a single point in trend data type TrendDataPoint = usagestats.TrendDataPoint @@ -1399,10 +1458,7 @@ func (r *usageLogRepository) ListWithFilters(ctx context.Context, params paginat conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1)) args = append(args, filters.Model) } - if filters.Stream != nil { - conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1)) - args = append(args, *filters.Stream) - } + conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream) if filters.BillingType != nil { conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1)) args = append(args, int16(*filters.BillingType)) @@ -1598,7 +1654,7 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe } // GetUsageTrendWithFilters returns usage trend data with optional filters -func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) (results []TrendDataPoint, err error) { +func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []TrendDataPoint, err error) { dateFormat := safeDateFormat(granularity) query := fmt.Sprintf(` @@ -1636,10 +1692,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start query += fmt.Sprintf(" AND model = $%d", len(args)+1) args = append(args, model) } - if stream != nil { - query += fmt.Sprintf(" AND stream = $%d", len(args)+1) - args = append(args, *stream) - } + query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream) if billingType != nil { query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) args = append(args, int16(*billingType)) @@ -1667,7 +1720,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start } // GetModelStatsWithFilters returns model statistics with optional filters -func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) (results []ModelStat, err error) { +func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) (results []ModelStat, err error) { actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost" // 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。 if accountID > 0 && userID == 0 && apiKeyID == 0 { @@ -1704,10 +1757,7 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start query += fmt.Sprintf(" AND group_id = $%d", len(args)+1) args = append(args, groupID) } - if stream != nil { - query += fmt.Sprintf(" AND stream = $%d", len(args)+1) - args = append(args, *stream) - } + query, args = appendRequestTypeOrStreamQueryFilter(query, args, requestType, stream) if billingType != nil { query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) args = append(args, int16(*billingType)) @@ -1794,10 +1844,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us conditions = append(conditions, fmt.Sprintf("model = $%d", len(args)+1)) args = append(args, filters.Model) } - if filters.Stream != nil { - conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1)) - args = append(args, *filters.Stream) - } + conditions, args = appendRequestTypeOrStreamWhereCondition(conditions, args, filters.RequestType, filters.Stream) if filters.BillingType != nil { conditions = append(conditions, fmt.Sprintf("billing_type = $%d", len(args)+1)) args = append(args, int16(*filters.BillingType)) @@ -2017,7 +2064,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID } } - models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil, nil) + models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil, nil, nil) if err != nil { models = []ModelStat{} } @@ -2267,7 +2314,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e rateMultiplier float64 accountRateMultiplier sql.NullFloat64 billingType int16 + requestTypeRaw int16 stream bool + openaiWSMode bool durationMs sql.NullInt64 firstTokenMs sql.NullInt64 userAgent sql.NullString @@ -2304,7 +2353,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e &rateMultiplier, &accountRateMultiplier, &billingType, + &requestTypeRaw, &stream, + &openaiWSMode, &durationMs, &firstTokenMs, &userAgent, @@ -2340,11 +2391,16 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e RateMultiplier: rateMultiplier, AccountRateMultiplier: nullFloat64Ptr(accountRateMultiplier), BillingType: int8(billingType), - Stream: stream, + RequestType: service.RequestTypeFromInt16(requestTypeRaw), ImageCount: imageCount, CacheTTLOverridden: cacheTTLOverridden, CreatedAt: createdAt, } + // 先回填 legacy 字段,再基于 legacy + request_type 计算最终请求类型,保证历史数据兼容。 + log.Stream = stream + log.OpenAIWSMode = openaiWSMode + log.RequestType = log.EffectiveRequestType() + log.Stream, log.OpenAIWSMode = service.ApplyLegacyRequestFields(log.RequestType, stream, openaiWSMode) if requestID.Valid { log.RequestID = requestID.String @@ -2438,6 +2494,50 @@ func buildWhere(conditions []string) string { return "WHERE " + strings.Join(conditions, " AND ") } +func appendRequestTypeOrStreamWhereCondition(conditions []string, args []any, requestType *int16, stream *bool) ([]string, []any) { + if requestType != nil { + condition, conditionArgs := buildRequestTypeFilterCondition(len(args)+1, *requestType) + conditions = append(conditions, condition) + args = append(args, conditionArgs...) + return conditions, args + } + if stream != nil { + conditions = append(conditions, fmt.Sprintf("stream = $%d", len(args)+1)) + args = append(args, *stream) + } + return conditions, args +} + +func appendRequestTypeOrStreamQueryFilter(query string, args []any, requestType *int16, stream *bool) (string, []any) { + if requestType != nil { + condition, conditionArgs := buildRequestTypeFilterCondition(len(args)+1, *requestType) + query += " AND " + condition + args = append(args, conditionArgs...) + return query, args + } + if stream != nil { + query += fmt.Sprintf(" AND stream = $%d", len(args)+1) + args = append(args, *stream) + } + return query, args +} + +// buildRequestTypeFilterCondition 在 request_type 过滤时兼容 legacy 字段,避免历史数据漏查。 +func buildRequestTypeFilterCondition(startArgIndex int, requestType int16) (string, []any) { + normalized := service.RequestTypeFromInt16(requestType) + requestTypeArg := int16(normalized) + switch normalized { + case service.RequestTypeSync: + return fmt.Sprintf("(request_type = $%d OR (request_type = %d AND stream = FALSE AND openai_ws_mode = FALSE))", startArgIndex, int16(service.RequestTypeUnknown)), []any{requestTypeArg} + case service.RequestTypeStream: + return fmt.Sprintf("(request_type = $%d OR (request_type = %d AND stream = TRUE AND openai_ws_mode = FALSE))", startArgIndex, int16(service.RequestTypeUnknown)), []any{requestTypeArg} + case service.RequestTypeWSV2: + return fmt.Sprintf("(request_type = $%d OR (request_type = %d AND openai_ws_mode = TRUE))", startArgIndex, int16(service.RequestTypeUnknown)), []any{requestTypeArg} + default: + return fmt.Sprintf("request_type = $%d", startArgIndex), []any{requestTypeArg} + } +} + func nullInt64(v *int64) sql.NullInt64 { if v == nil { return sql.NullInt64{} diff --git a/backend/internal/repository/usage_log_repo_integration_test.go b/backend/internal/repository/usage_log_repo_integration_test.go index 8cb3aab1..4d50f7de 100644 --- a/backend/internal/repository/usage_log_repo_integration_test.go +++ b/backend/internal/repository/usage_log_repo_integration_test.go @@ -130,6 +130,62 @@ func (s *UsageLogRepoSuite) TestGetByID_ReturnsAccountRateMultiplier() { s.Require().InEpsilon(0.5, *got.AccountRateMultiplier, 0.0001) } +func (s *UsageLogRepoSuite) TestGetByID_ReturnsOpenAIWSMode() { + user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid-ws@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid-ws", Name: "k"}) + account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid-ws"}) + + log := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.New().String(), + Model: "gpt-5.3-codex", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 1.0, + ActualCost: 1.0, + OpenAIWSMode: true, + CreatedAt: timezone.Today().Add(3 * time.Hour), + } + _, err := s.repo.Create(s.ctx, log) + s.Require().NoError(err) + + got, err := s.repo.GetByID(s.ctx, log.ID) + s.Require().NoError(err) + s.Require().True(got.OpenAIWSMode) +} + +func (s *UsageLogRepoSuite) TestGetByID_ReturnsRequestTypeAndLegacyFallback() { + user := mustCreateUser(s.T(), s.client, &service.User{Email: "getbyid-request-type@test.com"}) + apiKey := mustCreateApiKey(s.T(), s.client, &service.APIKey{UserID: user.ID, Key: "sk-getbyid-request-type", Name: "k"}) + account := mustCreateAccount(s.T(), s.client, &service.Account{Name: "acc-getbyid-request-type"}) + + log := &service.UsageLog{ + UserID: user.ID, + APIKeyID: apiKey.ID, + AccountID: account.ID, + RequestID: uuid.New().String(), + Model: "gpt-5.3-codex", + RequestType: service.RequestTypeWSV2, + Stream: true, + OpenAIWSMode: false, + InputTokens: 10, + OutputTokens: 20, + TotalCost: 1.0, + ActualCost: 1.0, + CreatedAt: timezone.Today().Add(4 * time.Hour), + } + _, err := s.repo.Create(s.ctx, log) + s.Require().NoError(err) + + got, err := s.repo.GetByID(s.ctx, log.ID) + s.Require().NoError(err) + s.Require().Equal(service.RequestTypeWSV2, got.RequestType) + s.Require().True(got.Stream) + s.Require().True(got.OpenAIWSMode) +} + // --- Delete --- func (s *UsageLogRepoSuite) TestDelete() { @@ -944,17 +1000,17 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() { endTime := base.Add(48 * time.Hour) // Test with user filter - trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil, nil) + trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil, nil, nil) s.Require().NoError(err, "GetUsageTrendWithFilters user filter") s.Require().Len(trend, 2) // Test with apiKey filter - trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil, nil) + trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil, nil, nil) s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter") s.Require().Len(trend, 2) // Test with both filters - trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil, nil) + trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil, nil, nil) s.Require().NoError(err, "GetUsageTrendWithFilters both filters") s.Require().Len(trend, 2) } @@ -971,7 +1027,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() { startTime := base.Add(-1 * time.Hour) endTime := base.Add(3 * time.Hour) - trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil, nil) + trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil, nil, nil) s.Require().NoError(err, "GetUsageTrendWithFilters hourly") s.Require().Len(trend, 2) } @@ -1017,17 +1073,17 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() { endTime := base.Add(2 * time.Hour) // Test with user filter - stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil, nil) + stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil, nil, nil) s.Require().NoError(err, "GetModelStatsWithFilters user filter") s.Require().Len(stats, 2) // Test with apiKey filter - stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil, nil) + stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil, nil, nil) s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter") s.Require().Len(stats, 2) // Test with account filter - stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil, nil) + stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil, nil, nil) s.Require().NoError(err, "GetModelStatsWithFilters account filter") s.Require().Len(stats, 2) } diff --git a/backend/internal/repository/usage_log_repo_request_type_test.go b/backend/internal/repository/usage_log_repo_request_type_test.go new file mode 100644 index 00000000..95cf2a2d --- /dev/null +++ b/backend/internal/repository/usage_log_repo_request_type_test.go @@ -0,0 +1,327 @@ +package repository + +import ( + "context" + "database/sql" + "fmt" + "reflect" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageLogRepository{sql: db} + + createdAt := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) + log := &service.UsageLog{ + UserID: 1, + APIKeyID: 2, + AccountID: 3, + RequestID: "req-1", + Model: "gpt-5", + InputTokens: 10, + OutputTokens: 20, + TotalCost: 1, + ActualCost: 1, + BillingType: service.BillingTypeBalance, + RequestType: service.RequestTypeWSV2, + Stream: false, + OpenAIWSMode: false, + CreatedAt: createdAt, + } + + mock.ExpectQuery("INSERT INTO usage_logs"). + WithArgs( + log.UserID, + log.APIKeyID, + log.AccountID, + log.RequestID, + log.Model, + sqlmock.AnyArg(), // group_id + sqlmock.AnyArg(), // subscription_id + log.InputTokens, + log.OutputTokens, + log.CacheCreationTokens, + log.CacheReadTokens, + log.CacheCreation5mTokens, + log.CacheCreation1hTokens, + log.InputCost, + log.OutputCost, + log.CacheCreationCost, + log.CacheReadCost, + log.TotalCost, + log.ActualCost, + log.RateMultiplier, + log.AccountRateMultiplier, + log.BillingType, + int16(service.RequestTypeWSV2), + true, + true, + sqlmock.AnyArg(), // duration_ms + sqlmock.AnyArg(), // first_token_ms + sqlmock.AnyArg(), // user_agent + sqlmock.AnyArg(), // ip_address + log.ImageCount, + sqlmock.AnyArg(), // image_size + sqlmock.AnyArg(), // media_type + sqlmock.AnyArg(), // reasoning_effort + log.CacheTTLOverridden, + createdAt, + ). + WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt)) + + inserted, err := repo.Create(context.Background(), log) + require.NoError(t, err) + require.True(t, inserted) + require.Equal(t, int64(99), log.ID) + require.Equal(t, service.RequestTypeWSV2, log.RequestType) + require.True(t, log.Stream) + require.True(t, log.OpenAIWSMode) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageLogRepositoryListWithFiltersRequestTypePriority(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageLogRepository{sql: db} + + requestType := int16(service.RequestTypeWSV2) + stream := false + filters := usagestats.UsageLogFilters{ + RequestType: &requestType, + Stream: &stream, + } + + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_logs WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)"). + WithArgs(requestType). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(0))) + mock.ExpectQuery("SELECT .* FROM usage_logs WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\) ORDER BY id DESC LIMIT \\$2 OFFSET \\$3"). + WithArgs(requestType, 20, 0). + WillReturnRows(sqlmock.NewRows([]string{"id"})) + + logs, page, err := repo.ListWithFilters(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20}, filters) + require.NoError(t, err) + require.Empty(t, logs) + require.NotNil(t, page) + require.Equal(t, int64(0), page.Total) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageLogRepositoryGetUsageTrendWithFiltersRequestTypePriority(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageLogRepository{sql: db} + + start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + requestType := int16(service.RequestTypeStream) + stream := true + + mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE\\)\\)"). + WithArgs(start, end, requestType). + WillReturnRows(sqlmock.NewRows([]string{"date", "requests", "input_tokens", "output_tokens", "cache_tokens", "total_tokens", "cost", "actual_cost"})) + + trend, err := repo.GetUsageTrendWithFilters(context.Background(), start, end, "day", 0, 0, 0, 0, "", &requestType, &stream, nil) + require.NoError(t, err) + require.Empty(t, trend) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageLogRepositoryGetModelStatsWithFiltersRequestTypePriority(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageLogRepository{sql: db} + + start := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + requestType := int16(service.RequestTypeWSV2) + stream := false + + mock.ExpectQuery("AND \\(request_type = \\$3 OR \\(request_type = 0 AND openai_ws_mode = TRUE\\)\\)"). + WithArgs(start, end, requestType). + WillReturnRows(sqlmock.NewRows([]string{"model", "requests", "input_tokens", "output_tokens", "total_tokens", "cost", "actual_cost"})) + + stats, err := repo.GetModelStatsWithFilters(context.Background(), start, end, 0, 0, 0, 0, &requestType, &stream, nil) + require.NoError(t, err) + require.Empty(t, stats) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageLogRepositoryGetStatsWithFiltersRequestTypePriority(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageLogRepository{sql: db} + + requestType := int16(service.RequestTypeSync) + stream := true + filters := usagestats.UsageLogFilters{ + RequestType: &requestType, + Stream: &stream, + } + + mock.ExpectQuery("FROM usage_logs\\s+WHERE \\(request_type = \\$1 OR \\(request_type = 0 AND stream = FALSE AND openai_ws_mode = FALSE\\)\\)"). + WithArgs(requestType). + WillReturnRows(sqlmock.NewRows([]string{ + "total_requests", + "total_input_tokens", + "total_output_tokens", + "total_cache_tokens", + "total_cost", + "total_actual_cost", + "total_account_cost", + "avg_duration_ms", + }).AddRow(int64(1), int64(2), int64(3), int64(4), 1.2, 1.0, 1.2, 20.0)) + + stats, err := repo.GetStatsWithFilters(context.Background(), filters) + require.NoError(t, err) + require.Equal(t, int64(1), stats.TotalRequests) + require.Equal(t, int64(9), stats.TotalTokens) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestBuildRequestTypeFilterConditionLegacyFallback(t *testing.T) { + tests := []struct { + name string + request int16 + wantWhere string + wantArg int16 + }{ + { + name: "sync_with_legacy_fallback", + request: int16(service.RequestTypeSync), + wantWhere: "(request_type = $3 OR (request_type = 0 AND stream = FALSE AND openai_ws_mode = FALSE))", + wantArg: int16(service.RequestTypeSync), + }, + { + name: "stream_with_legacy_fallback", + request: int16(service.RequestTypeStream), + wantWhere: "(request_type = $3 OR (request_type = 0 AND stream = TRUE AND openai_ws_mode = FALSE))", + wantArg: int16(service.RequestTypeStream), + }, + { + name: "ws_v2_with_legacy_fallback", + request: int16(service.RequestTypeWSV2), + wantWhere: "(request_type = $3 OR (request_type = 0 AND openai_ws_mode = TRUE))", + wantArg: int16(service.RequestTypeWSV2), + }, + { + name: "invalid_request_type_normalized_to_unknown", + request: int16(99), + wantWhere: "request_type = $3", + wantArg: int16(service.RequestTypeUnknown), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + where, args := buildRequestTypeFilterCondition(3, tt.request) + require.Equal(t, tt.wantWhere, where) + require.Equal(t, []any{tt.wantArg}, args) + }) + } +} + +type usageLogScannerStub struct { + values []any +} + +func (s usageLogScannerStub) Scan(dest ...any) error { + if len(dest) != len(s.values) { + return fmt.Errorf("scan arg count mismatch: got %d want %d", len(dest), len(s.values)) + } + for i := range dest { + dv := reflect.ValueOf(dest[i]) + if dv.Kind() != reflect.Ptr { + return fmt.Errorf("dest[%d] is not pointer", i) + } + dv.Elem().Set(reflect.ValueOf(s.values[i])) + } + return nil +} + +func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { + t.Run("request_type_ws_v2_overrides_legacy", func(t *testing.T) { + now := time.Now().UTC() + log, err := scanUsageLog(usageLogScannerStub{values: []any{ + int64(1), // id + int64(10), // user_id + int64(20), // api_key_id + int64(30), // account_id + sql.NullString{Valid: true, String: "req-1"}, + "gpt-5", // model + sql.NullInt64{}, // group_id + sql.NullInt64{}, // subscription_id + 1, // input_tokens + 2, // output_tokens + 3, // cache_creation_tokens + 4, // cache_read_tokens + 5, // cache_creation_5m_tokens + 6, // cache_creation_1h_tokens + 0.1, // input_cost + 0.2, // output_cost + 0.3, // cache_creation_cost + 0.4, // cache_read_cost + 1.0, // total_cost + 0.9, // actual_cost + 1.0, // rate_multiplier + sql.NullFloat64{}, // account_rate_multiplier + int16(service.BillingTypeBalance), + int16(service.RequestTypeWSV2), + false, // legacy stream + false, // legacy openai ws + sql.NullInt64{}, + sql.NullInt64{}, + sql.NullString{}, + sql.NullString{}, + 0, + sql.NullString{}, + sql.NullString{}, + sql.NullString{}, + false, + now, + }}) + require.NoError(t, err) + require.Equal(t, service.RequestTypeWSV2, log.RequestType) + require.True(t, log.Stream) + require.True(t, log.OpenAIWSMode) + }) + + t.Run("request_type_unknown_falls_back_to_legacy", func(t *testing.T) { + now := time.Now().UTC() + log, err := scanUsageLog(usageLogScannerStub{values: []any{ + int64(2), + int64(11), + int64(21), + int64(31), + sql.NullString{Valid: true, String: "req-2"}, + "gpt-5", + sql.NullInt64{}, + sql.NullInt64{}, + 1, 2, 3, 4, 5, 6, + 0.1, 0.2, 0.3, 0.4, 1.0, 0.9, + 1.0, + sql.NullFloat64{}, + int16(service.BillingTypeBalance), + int16(service.RequestTypeUnknown), + true, + false, + sql.NullInt64{}, + sql.NullInt64{}, + sql.NullString{}, + sql.NullString{}, + 0, + sql.NullString{}, + sql.NullString{}, + sql.NullString{}, + false, + now, + }}) + require.NoError(t, err) + require.Equal(t, service.RequestTypeStream, log.RequestType) + require.True(t, log.Stream) + require.False(t, log.OpenAIWSMode) + }) +} diff --git a/backend/internal/repository/user_group_rate_repo.go b/backend/internal/repository/user_group_rate_repo.go index eb65403b..e3b11096 100644 --- a/backend/internal/repository/user_group_rate_repo.go +++ b/backend/internal/repository/user_group_rate_repo.go @@ -6,6 +6,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/lib/pq" ) type userGroupRateRepository struct { @@ -41,6 +42,59 @@ func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) return result, nil } +// GetByUserIDs 批量获取多个用户的专属分组倍率。 +// 返回结构:map[userID]map[groupID]rate +func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) { + result := make(map[int64]map[int64]float64, len(userIDs)) + if len(userIDs) == 0 { + return result, nil + } + + uniqueIDs := make([]int64, 0, len(userIDs)) + seen := make(map[int64]struct{}, len(userIDs)) + for _, userID := range userIDs { + if userID <= 0 { + continue + } + if _, exists := seen[userID]; exists { + continue + } + seen[userID] = struct{}{} + uniqueIDs = append(uniqueIDs, userID) + result[userID] = make(map[int64]float64) + } + if len(uniqueIDs) == 0 { + return result, nil + } + + rows, err := r.sql.QueryContext(ctx, ` + SELECT user_id, group_id, rate_multiplier + FROM user_group_rate_multipliers + WHERE user_id = ANY($1) + `, pq.Array(uniqueIDs)) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + for rows.Next() { + var userID int64 + var groupID int64 + var rate float64 + if err := rows.Scan(&userID, &groupID, &rate); err != nil { + return nil, err + } + if _, ok := result[userID]; !ok { + result[userID] = make(map[int64]float64) + } + result[userID][groupID] = rate + } + if err := rows.Err(); err != nil { + return nil, err + } + return result, nil +} + // GetByUserAndGroup 获取用户在特定分组的专属倍率 func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) { query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2` @@ -65,33 +119,43 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID // 分离需要删除和需要 upsert 的记录 var toDelete []int64 - toUpsert := make(map[int64]float64) + upsertGroupIDs := make([]int64, 0, len(rates)) + upsertRates := make([]float64, 0, len(rates)) for groupID, rate := range rates { if rate == nil { toDelete = append(toDelete, groupID) } else { - toUpsert[groupID] = *rate + upsertGroupIDs = append(upsertGroupIDs, groupID) + upsertRates = append(upsertRates, *rate) } } // 删除指定的记录 - for _, groupID := range toDelete { - _, err := r.sql.ExecContext(ctx, - `DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`, - userID, groupID) - if err != nil { + if len(toDelete) > 0 { + if _, err := r.sql.ExecContext(ctx, + `DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2)`, + userID, pq.Array(toDelete)); err != nil { return err } } // Upsert 记录 now := time.Now() - for groupID, rate := range toUpsert { + if len(upsertGroupIDs) > 0 { _, err := r.sql.ExecContext(ctx, ` INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at) - VALUES ($1, $2, $3, $4, $4) - ON CONFLICT (user_id, group_id) DO UPDATE SET rate_multiplier = $3, updated_at = $4 - `, userID, groupID, rate, now) + SELECT + $1::bigint, + data.group_id, + data.rate_multiplier, + $2::timestamptz, + $2::timestamptz + FROM unnest($3::bigint[], $4::double precision[]) AS data(group_id, rate_multiplier) + ON CONFLICT (user_id, group_id) + DO UPDATE SET + rate_multiplier = EXCLUDED.rate_multiplier, + updated_at = EXCLUDED.updated_at + `, userID, now, pq.Array(upsertGroupIDs), pq.Array(upsertRates)) if err != nil { return err } diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index 17674291..bc00e64d 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -61,6 +61,7 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error SetBalance(userIn.Balance). SetConcurrency(userIn.Concurrency). SetStatus(userIn.Status). + SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes). Save(ctx) if err != nil { return translatePersistenceError(err, nil, service.ErrEmailExists) @@ -143,6 +144,8 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error SetBalance(userIn.Balance). SetConcurrency(userIn.Concurrency). SetStatus(userIn.Status). + SetSoraStorageQuotaBytes(userIn.SoraStorageQuotaBytes). + SetSoraStorageUsedBytes(userIn.SoraStorageUsedBytes). Save(ctx) if err != nil { return translatePersistenceError(err, service.ErrUserNotFound, service.ErrEmailExists) @@ -363,6 +366,65 @@ func (r *userRepository) UpdateConcurrency(ctx context.Context, id int64, amount return nil } +// AddSoraStorageUsageWithQuota 原子累加 Sora 存储用量,并在有配额时校验不超额。 +func (r *userRepository) AddSoraStorageUsageWithQuota(ctx context.Context, userID int64, deltaBytes int64, effectiveQuota int64) (int64, error) { + if deltaBytes <= 0 { + user, err := r.GetByID(ctx, userID) + if err != nil { + return 0, err + } + return user.SoraStorageUsedBytes, nil + } + var newUsed int64 + err := scanSingleRow(ctx, r.sql, ` + UPDATE users + SET sora_storage_used_bytes = sora_storage_used_bytes + $2 + WHERE id = $1 + AND ($3 = 0 OR sora_storage_used_bytes + $2 <= $3) + RETURNING sora_storage_used_bytes + `, []any{userID, deltaBytes, effectiveQuota}, &newUsed) + if err == nil { + return newUsed, nil + } + if errors.Is(err, sql.ErrNoRows) { + // 区分用户不存在和配额冲突 + exists, existsErr := r.client.User.Query().Where(dbuser.IDEQ(userID)).Exist(ctx) + if existsErr != nil { + return 0, existsErr + } + if !exists { + return 0, service.ErrUserNotFound + } + return 0, service.ErrSoraStorageQuotaExceeded + } + return 0, err +} + +// ReleaseSoraStorageUsageAtomic 原子释放 Sora 存储用量,并保证不低于 0。 +func (r *userRepository) ReleaseSoraStorageUsageAtomic(ctx context.Context, userID int64, deltaBytes int64) (int64, error) { + if deltaBytes <= 0 { + user, err := r.GetByID(ctx, userID) + if err != nil { + return 0, err + } + return user.SoraStorageUsedBytes, nil + } + var newUsed int64 + err := scanSingleRow(ctx, r.sql, ` + UPDATE users + SET sora_storage_used_bytes = GREATEST(sora_storage_used_bytes - $2, 0) + WHERE id = $1 + RETURNING sora_storage_used_bytes + `, []any{userID, deltaBytes}, &newUsed) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return 0, service.ErrUserNotFound + } + return 0, err + } + return newUsed, nil +} + func (r *userRepository) ExistsByEmail(ctx context.Context, email string) (bool, error) { return r.client.User.Query().Where(dbuser.EmailEQ(email)).Exist(ctx) } diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 76897bc1..c98086e0 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -186,11 +186,12 @@ func TestAPIContracts(t *testing.T) { "image_price_1k": null, "image_price_2k": null, "image_price_4k": null, - "sora_image_price_360": null, - "sora_image_price_540": null, - "sora_video_price_per_request": null, - "sora_video_price_per_request_hd": null, - "claude_code_only": false, + "sora_image_price_360": null, + "sora_image_price_540": null, + "sora_storage_quota_bytes": 0, + "sora_video_price_per_request": null, + "sora_video_price_per_request_hd": null, + "claude_code_only": false, "fallback_group_id": null, "fallback_group_id_on_invalid_request": null, "created_at": "2025-01-02T03:04:05Z", @@ -384,10 +385,12 @@ func TestAPIContracts(t *testing.T) { "user_id": 1, "api_key_id": 100, "account_id": 200, - "request_id": "req_123", - "model": "claude-3", - "group_id": null, - "subscription_id": null, + "request_id": "req_123", + "model": "claude-3", + "request_type": "stream", + "openai_ws_mode": false, + "group_id": null, + "subscription_id": null, "input_tokens": 10, "output_tokens": 20, "cache_creation_tokens": 1, @@ -500,11 +503,12 @@ func TestAPIContracts(t *testing.T) { "fallback_model_anthropic": "claude-3-5-sonnet-20241022", "fallback_model_antigravity": "gemini-2.5-pro", "fallback_model_gemini": "gemini-2.5-pro", - "fallback_model_openai": "gpt-4o", - "enable_identity_patch": true, - "identity_patch_prompt": "", - "invitation_code_enabled": false, - "home_content": "", + "fallback_model_openai": "gpt-4o", + "enable_identity_patch": true, + "identity_patch_prompt": "", + "sora_client_enabled": false, + "invitation_code_enabled": false, + "home_content": "", "hide_ccs_import_button": false, "purchase_subscription_enabled": false, "purchase_subscription_url": "" @@ -619,7 +623,7 @@ func newContractDeps(t *testing.T) *contractDeps { authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) - adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil) + adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil, nil) adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) jwtAuth := func(c *gin.Context) { @@ -1555,11 +1559,11 @@ func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.D return nil, errors.New("not implemented") } -func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) { +func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) { return nil, errors.New("not implemented") } -func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) { +func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) { return nil, errors.New("not implemented") } diff --git a/backend/internal/server/middleware/api_key_auth.go b/backend/internal/server/middleware/api_key_auth.go index 8fa3517a..19f97239 100644 --- a/backend/internal/server/middleware/api_key_auth.go +++ b/backend/internal/server/middleware/api_key_auth.go @@ -97,7 +97,7 @@ func apiKeyAuthWithSubscription(apiKeyService *service.APIKeyService, subscripti // 注意:错误信息故意模糊,避免暴露具体的 IP 限制机制 if len(apiKey.IPWhitelist) > 0 || len(apiKey.IPBlacklist) > 0 { clientIP := ip.GetTrustedClientIP(c) - allowed, _ := ip.CheckIPRestriction(clientIP, apiKey.IPWhitelist, apiKey.IPBlacklist) + allowed, _ := ip.CheckIPRestrictionWithCompiledRules(clientIP, apiKey.CompiledIPWhitelist, apiKey.CompiledIPBlacklist) if !allowed { AbortWithError(c, 403, "ACCESS_DENIED", "Access denied") return diff --git a/backend/internal/server/middleware/api_key_auth_google.go b/backend/internal/server/middleware/api_key_auth_google.go index 9da1b1c6..84d93edc 100644 --- a/backend/internal/server/middleware/api_key_auth_google.go +++ b/backend/internal/server/middleware/api_key_auth_google.go @@ -80,17 +80,25 @@ func APIKeyAuthWithSubscriptionGoogle(apiKeyService *service.APIKeyService, subs abortWithGoogleError(c, 403, "No active subscription found for this group") return } - if err := subscriptionService.ValidateSubscription(c.Request.Context(), subscription); err != nil { - abortWithGoogleError(c, 403, err.Error()) - return - } - _ = subscriptionService.CheckAndActivateWindow(c.Request.Context(), subscription) - _ = subscriptionService.CheckAndResetWindows(c.Request.Context(), subscription) - if err := subscriptionService.CheckUsageLimits(c.Request.Context(), subscription, apiKey.Group, 0); err != nil { - abortWithGoogleError(c, 429, err.Error()) + + needsMaintenance, err := subscriptionService.ValidateAndCheckLimits(subscription, apiKey.Group) + if err != nil { + status := 403 + if errors.Is(err, service.ErrDailyLimitExceeded) || + errors.Is(err, service.ErrWeeklyLimitExceeded) || + errors.Is(err, service.ErrMonthlyLimitExceeded) { + status = 429 + } + abortWithGoogleError(c, status, err.Error()) return } + c.Set(string(ContextKeySubscription), subscription) + + if needsMaintenance { + maintenanceCopy := *subscription + subscriptionService.DoWindowMaintenance(&maintenanceCopy) + } } else { if apiKey.User.Balance <= 0 { abortWithGoogleError(c, 403, "Insufficient account balance") diff --git a/backend/internal/server/middleware/api_key_auth_google_test.go b/backend/internal/server/middleware/api_key_auth_google_test.go index e4e0e253..2124c86c 100644 --- a/backend/internal/server/middleware/api_key_auth_google_test.go +++ b/backend/internal/server/middleware/api_key_auth_google_test.go @@ -23,6 +23,15 @@ type fakeAPIKeyRepo struct { updateLastUsed func(ctx context.Context, id int64, usedAt time.Time) error } +type fakeGoogleSubscriptionRepo struct { + getActive func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) + updateStatus func(ctx context.Context, subscriptionID int64, status string) error + activateWindow func(ctx context.Context, id int64, start time.Time) error + resetDaily func(ctx context.Context, id int64, start time.Time) error + resetWeekly func(ctx context.Context, id int64, start time.Time) error + resetMonthly func(ctx context.Context, id int64, start time.Time) error +} + func (f fakeAPIKeyRepo) Create(ctx context.Context, key *service.APIKey) error { return errors.New("not implemented") } @@ -87,6 +96,85 @@ func (f fakeAPIKeyRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt tim return nil } +func (f fakeGoogleSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error { + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) GetByID(ctx context.Context, id int64) (*service.UserSubscription, error) { + return nil, errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) GetByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { + return nil, errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) GetActiveByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { + if f.getActive != nil { + return f.getActive(ctx, userID, groupID) + } + return nil, errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) Update(ctx context.Context, sub *service.UserSubscription) error { + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) Delete(ctx context.Context, id int64) error { + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) { + return nil, errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) { + return nil, errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) List(ctx context.Context, params pagination.PaginationParams, userID, groupID *int64, status, sortBy, sortOrder string) ([]service.UserSubscription, *pagination.PaginationResult, error) { + return nil, nil, errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ExistsByUserIDAndGroupID(ctx context.Context, userID, groupID int64) (bool, error) { + return false, errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ExtendExpiry(ctx context.Context, subscriptionID int64, newExpiresAt time.Time) error { + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) UpdateStatus(ctx context.Context, subscriptionID int64, status string) error { + if f.updateStatus != nil { + return f.updateStatus(ctx, subscriptionID, status) + } + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) UpdateNotes(ctx context.Context, subscriptionID int64, notes string) error { + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ActivateWindows(ctx context.Context, id int64, start time.Time) error { + if f.activateWindow != nil { + return f.activateWindow(ctx, id, start) + } + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ResetDailyUsage(ctx context.Context, id int64, start time.Time) error { + if f.resetDaily != nil { + return f.resetDaily(ctx, id, start) + } + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ResetWeeklyUsage(ctx context.Context, id int64, start time.Time) error { + if f.resetWeekly != nil { + return f.resetWeekly(ctx, id, start) + } + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) ResetMonthlyUsage(ctx context.Context, id int64, start time.Time) error { + if f.resetMonthly != nil { + return f.resetMonthly(ctx, id, start) + } + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) IncrementUsage(ctx context.Context, id int64, costUSD float64) error { + return errors.New("not implemented") +} +func (f fakeGoogleSubscriptionRepo) BatchUpdateExpiredStatus(ctx context.Context) (int64, error) { + return 0, errors.New("not implemented") +} + type googleErrorResponse struct { Error struct { Code int `json:"code"` @@ -505,3 +593,85 @@ func TestApiKeyAuthWithSubscriptionGoogle_TouchesLastUsedInStandardMode(t *testi require.Equal(t, http.StatusOK, rec.Code) require.Equal(t, 1, touchCalls) } + +func TestApiKeyAuthWithSubscriptionGoogle_SubscriptionLimitExceededReturns429(t *testing.T) { + gin.SetMode(gin.TestMode) + + limit := 1.0 + group := &service.Group{ + ID: 77, + Name: "gemini-sub", + Status: service.StatusActive, + Platform: service.PlatformGemini, + Hydrated: true, + SubscriptionType: service.SubscriptionTypeSubscription, + DailyLimitUSD: &limit, + } + user := &service.User{ + ID: 999, + Role: service.RoleUser, + Status: service.StatusActive, + Balance: 10, + Concurrency: 3, + } + apiKey := &service.APIKey{ + ID: 501, + UserID: user.ID, + Key: "google-sub-limit", + Status: service.StatusActive, + User: user, + Group: group, + } + apiKey.GroupID = &group.ID + + apiKeyService := newTestAPIKeyService(fakeAPIKeyRepo{ + getByKey: func(ctx context.Context, key string) (*service.APIKey, error) { + if key != apiKey.Key { + return nil, service.ErrAPIKeyNotFound + } + clone := *apiKey + return &clone, nil + }, + }) + + now := time.Now() + sub := &service.UserSubscription{ + ID: 601, + UserID: user.ID, + GroupID: group.ID, + Status: service.SubscriptionStatusActive, + ExpiresAt: now.Add(24 * time.Hour), + DailyWindowStart: &now, + DailyUsageUSD: 10, + } + subscriptionService := service.NewSubscriptionService(nil, fakeGoogleSubscriptionRepo{ + getActive: func(ctx context.Context, userID, groupID int64) (*service.UserSubscription, error) { + if userID != user.ID || groupID != group.ID { + return nil, service.ErrSubscriptionNotFound + } + clone := *sub + return &clone, nil + }, + updateStatus: func(ctx context.Context, subscriptionID int64, status string) error { return nil }, + activateWindow: func(ctx context.Context, id int64, start time.Time) error { return nil }, + resetDaily: func(ctx context.Context, id int64, start time.Time) error { return nil }, + resetWeekly: func(ctx context.Context, id int64, start time.Time) error { return nil }, + resetMonthly: func(ctx context.Context, id int64, start time.Time) error { return nil }, + }, nil, nil, &config.Config{RunMode: config.RunModeStandard}) + + r := gin.New() + r.Use(APIKeyAuthWithSubscriptionGoogle(apiKeyService, subscriptionService, &config.Config{RunMode: config.RunModeStandard})) + r.GET("/v1beta/test", func(c *gin.Context) { c.JSON(200, gin.H{"ok": true}) }) + + req := httptest.NewRequest(http.MethodGet, "/v1beta/test", nil) + req.Header.Set("x-goog-api-key", apiKey.Key) + rec := httptest.NewRecorder() + r.ServeHTTP(rec, req) + + require.Equal(t, http.StatusTooManyRequests, rec.Code) + var resp googleErrorResponse + require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) + require.Equal(t, http.StatusTooManyRequests, resp.Error.Code) + require.Equal(t, "RESOURCE_EXHAUSTED", resp.Error.Status) + require.Contains(t, resp.Error.Message, "daily usage limit exceeded") +} diff --git a/backend/internal/server/middleware/security_headers.go b/backend/internal/server/middleware/security_headers.go index 67b19c09..f061db90 100644 --- a/backend/internal/server/middleware/security_headers.go +++ b/backend/internal/server/middleware/security_headers.go @@ -54,6 +54,10 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc { c.Header("X-Content-Type-Options", "nosniff") c.Header("X-Frame-Options", "DENY") c.Header("Referrer-Policy", "strict-origin-when-cross-origin") + if isAPIRoutePath(c) { + c.Next() + return + } if cfg.Enabled { // Generate nonce for this request @@ -73,6 +77,18 @@ func SecurityHeaders(cfg config.CSPConfig) gin.HandlerFunc { } } +func isAPIRoutePath(c *gin.Context) bool { + if c == nil || c.Request == nil || c.Request.URL == nil { + return false + } + path := c.Request.URL.Path + return strings.HasPrefix(path, "/v1/") || + strings.HasPrefix(path, "/v1beta/") || + strings.HasPrefix(path, "/antigravity/") || + strings.HasPrefix(path, "/sora/") || + strings.HasPrefix(path, "/responses") +} + // enhanceCSPPolicy ensures the CSP policy includes nonce support and Cloudflare Insights domain. // This allows the application to work correctly even if the config file has an older CSP policy. func enhanceCSPPolicy(policy string) string { diff --git a/backend/internal/server/middleware/security_headers_test.go b/backend/internal/server/middleware/security_headers_test.go index 43462b82..5a779825 100644 --- a/backend/internal/server/middleware/security_headers_test.go +++ b/backend/internal/server/middleware/security_headers_test.go @@ -131,6 +131,26 @@ func TestSecurityHeaders(t *testing.T) { assert.Contains(t, csp, CloudflareInsightsDomain) }) + t.Run("api_route_skips_csp_nonce_generation", func(t *testing.T) { + cfg := config.CSPConfig{ + Enabled: true, + Policy: "default-src 'self'; script-src 'self' __CSP_NONCE__", + } + middleware := SecurityHeaders(cfg) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + + middleware(c) + + assert.Equal(t, "nosniff", w.Header().Get("X-Content-Type-Options")) + assert.Equal(t, "DENY", w.Header().Get("X-Frame-Options")) + assert.Equal(t, "strict-origin-when-cross-origin", w.Header().Get("Referrer-Policy")) + assert.Empty(t, w.Header().Get("Content-Security-Policy")) + assert.Empty(t, GetNonceFromContext(c)) + }) + t.Run("csp_enabled_with_nonce_placeholder", func(t *testing.T) { cfg := config.CSPConfig{ Enabled: true, diff --git a/backend/internal/server/router.go b/backend/internal/server/router.go index fb91bc0e..07b51f23 100644 --- a/backend/internal/server/router.go +++ b/backend/internal/server/router.go @@ -75,6 +75,7 @@ func registerRoutes( // 注册各模块路由 routes.RegisterAuthRoutes(v1, h, jwtAuth, redisClient) routes.RegisterUserRoutes(v1, h, jwtAuth) + routes.RegisterSoraClientRoutes(v1, h, jwtAuth) routes.RegisterAdminRoutes(v1, h, adminAuth) routes.RegisterGatewayRoutes(r, h, apiKeyAuth, apiKeyService, subscriptionService, opsService, cfg) } diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 36efacc8..ffc20473 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -55,6 +55,9 @@ func RegisterAdminRoutes( // 系统设置 registerSettingsRoutes(admin, h) + // 数据管理 + registerDataManagementRoutes(admin, h) + // 运维监控(Ops) registerOpsRoutes(admin, h) @@ -231,6 +234,7 @@ func registerAccountRoutes(admin *gin.RouterGroup, h *handler.Handlers) { accounts.POST("/:id/clear-error", h.Admin.Account.ClearError) accounts.GET("/:id/usage", h.Admin.Account.GetUsage) accounts.GET("/:id/today-stats", h.Admin.Account.GetTodayStats) + accounts.POST("/today-stats/batch", h.Admin.Account.GetBatchTodayStats) accounts.POST("/:id/clear-rate-limit", h.Admin.Account.ClearRateLimit) accounts.GET("/:id/temp-unschedulable", h.Admin.Account.GetTempUnschedulable) accounts.DELETE("/:id/temp-unschedulable", h.Admin.Account.ClearTempUnschedulable) @@ -370,6 +374,38 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) { // 流超时处理配置 adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings) adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings) + // Sora S3 存储配置 + adminSettings.GET("/sora-s3", h.Admin.Setting.GetSoraS3Settings) + adminSettings.PUT("/sora-s3", h.Admin.Setting.UpdateSoraS3Settings) + adminSettings.POST("/sora-s3/test", h.Admin.Setting.TestSoraS3Connection) + adminSettings.GET("/sora-s3/profiles", h.Admin.Setting.ListSoraS3Profiles) + adminSettings.POST("/sora-s3/profiles", h.Admin.Setting.CreateSoraS3Profile) + adminSettings.PUT("/sora-s3/profiles/:profile_id", h.Admin.Setting.UpdateSoraS3Profile) + adminSettings.DELETE("/sora-s3/profiles/:profile_id", h.Admin.Setting.DeleteSoraS3Profile) + adminSettings.POST("/sora-s3/profiles/:profile_id/activate", h.Admin.Setting.SetActiveSoraS3Profile) + } +} + +func registerDataManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) { + dataManagement := admin.Group("/data-management") + { + dataManagement.GET("/agent/health", h.Admin.DataManagement.GetAgentHealth) + dataManagement.GET("/config", h.Admin.DataManagement.GetConfig) + dataManagement.PUT("/config", h.Admin.DataManagement.UpdateConfig) + dataManagement.GET("/sources/:source_type/profiles", h.Admin.DataManagement.ListSourceProfiles) + dataManagement.POST("/sources/:source_type/profiles", h.Admin.DataManagement.CreateSourceProfile) + dataManagement.PUT("/sources/:source_type/profiles/:profile_id", h.Admin.DataManagement.UpdateSourceProfile) + dataManagement.DELETE("/sources/:source_type/profiles/:profile_id", h.Admin.DataManagement.DeleteSourceProfile) + dataManagement.POST("/sources/:source_type/profiles/:profile_id/activate", h.Admin.DataManagement.SetActiveSourceProfile) + dataManagement.POST("/s3/test", h.Admin.DataManagement.TestS3) + dataManagement.GET("/s3/profiles", h.Admin.DataManagement.ListS3Profiles) + dataManagement.POST("/s3/profiles", h.Admin.DataManagement.CreateS3Profile) + dataManagement.PUT("/s3/profiles/:profile_id", h.Admin.DataManagement.UpdateS3Profile) + dataManagement.DELETE("/s3/profiles/:profile_id", h.Admin.DataManagement.DeleteS3Profile) + dataManagement.POST("/s3/profiles/:profile_id/activate", h.Admin.DataManagement.SetActiveS3Profile) + dataManagement.POST("/backups", h.Admin.DataManagement.CreateBackupJob) + dataManagement.GET("/backups", h.Admin.DataManagement.ListBackupJobs) + dataManagement.GET("/backups/:job_id", h.Admin.DataManagement.GetBackupJob) } } diff --git a/backend/internal/server/routes/gateway.go b/backend/internal/server/routes/gateway.go index 930c8b9e..6bd91b85 100644 --- a/backend/internal/server/routes/gateway.go +++ b/backend/internal/server/routes/gateway.go @@ -43,6 +43,7 @@ func RegisterGatewayRoutes( gateway.GET("/usage", h.Gateway.Usage) // OpenAI Responses API gateway.POST("/responses", h.OpenAIGateway.Responses) + gateway.GET("/responses", h.OpenAIGateway.ResponsesWebSocket) // 明确阻止旧协议入口:OpenAI 仅支持 Responses API,避免客户端误解为会自动路由到其它平台。 gateway.POST("/chat/completions", func(c *gin.Context) { c.JSON(http.StatusBadRequest, gin.H{ @@ -69,6 +70,7 @@ func RegisterGatewayRoutes( // OpenAI Responses API(不带v1前缀的别名) r.POST("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.Responses) + r.GET("/responses", bodyLimit, clientRequestID, opsErrorLogger, gin.HandlerFunc(apiKeyAuth), h.OpenAIGateway.ResponsesWebSocket) // Antigravity 模型列表 r.GET("/antigravity/models", gin.HandlerFunc(apiKeyAuth), h.Gateway.AntigravityModels) diff --git a/backend/internal/server/routes/sora_client.go b/backend/internal/server/routes/sora_client.go new file mode 100644 index 00000000..40ae0436 --- /dev/null +++ b/backend/internal/server/routes/sora_client.go @@ -0,0 +1,33 @@ +package routes + +import ( + "github.com/Wei-Shaw/sub2api/internal/handler" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + + "github.com/gin-gonic/gin" +) + +// RegisterSoraClientRoutes 注册 Sora 客户端 API 路由(需要用户认证)。 +func RegisterSoraClientRoutes( + v1 *gin.RouterGroup, + h *handler.Handlers, + jwtAuth middleware.JWTAuthMiddleware, +) { + if h.SoraClient == nil { + return + } + + authenticated := v1.Group("/sora") + authenticated.Use(gin.HandlerFunc(jwtAuth)) + { + authenticated.POST("/generate", h.SoraClient.Generate) + authenticated.GET("/generations", h.SoraClient.ListGenerations) + authenticated.GET("/generations/:id", h.SoraClient.GetGeneration) + authenticated.DELETE("/generations/:id", h.SoraClient.DeleteGeneration) + authenticated.POST("/generations/:id/cancel", h.SoraClient.CancelGeneration) + authenticated.POST("/generations/:id/save", h.SoraClient.SaveToStorage) + authenticated.GET("/quota", h.SoraClient.GetQuota) + authenticated.GET("/models", h.SoraClient.GetModels) + authenticated.GET("/storage-status", h.SoraClient.GetStorageStatus) + } +} diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 50fdac88..1864eb54 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -3,6 +3,8 @@ package service import ( "encoding/json" + "hash/fnv" + "reflect" "sort" "strconv" "strings" @@ -50,6 +52,14 @@ type Account struct { AccountGroups []AccountGroup GroupIDs []int64 Groups []*Group + + // model_mapping 热路径缓存(非持久化字段) + modelMappingCache map[string]string + modelMappingCacheReady bool + modelMappingCacheCredentialsPtr uintptr + modelMappingCacheRawPtr uintptr + modelMappingCacheRawLen int + modelMappingCacheRawSig uint64 } type TempUnschedulableRule struct { @@ -349,6 +359,39 @@ func parseTempUnschedInt(value any) int { } func (a *Account) GetModelMapping() map[string]string { + credentialsPtr := mapPtr(a.Credentials) + rawMapping, _ := a.Credentials["model_mapping"].(map[string]any) + rawPtr := mapPtr(rawMapping) + rawLen := len(rawMapping) + rawSig := uint64(0) + rawSigReady := false + + if a.modelMappingCacheReady && + a.modelMappingCacheCredentialsPtr == credentialsPtr && + a.modelMappingCacheRawPtr == rawPtr && + a.modelMappingCacheRawLen == rawLen { + rawSig = modelMappingSignature(rawMapping) + rawSigReady = true + if a.modelMappingCacheRawSig == rawSig { + return a.modelMappingCache + } + } + + mapping := a.resolveModelMapping(rawMapping) + if !rawSigReady { + rawSig = modelMappingSignature(rawMapping) + } + + a.modelMappingCache = mapping + a.modelMappingCacheReady = true + a.modelMappingCacheCredentialsPtr = credentialsPtr + a.modelMappingCacheRawPtr = rawPtr + a.modelMappingCacheRawLen = rawLen + a.modelMappingCacheRawSig = rawSig + return mapping +} + +func (a *Account) resolveModelMapping(rawMapping map[string]any) map[string]string { if a.Credentials == nil { // Antigravity 平台使用默认映射 if a.Platform == domain.PlatformAntigravity { @@ -356,32 +399,31 @@ func (a *Account) GetModelMapping() map[string]string { } return nil } - raw, ok := a.Credentials["model_mapping"] - if !ok || raw == nil { + if len(rawMapping) == 0 { // Antigravity 平台使用默认映射 if a.Platform == domain.PlatformAntigravity { return domain.DefaultAntigravityModelMapping } return nil } - if m, ok := raw.(map[string]any); ok { - result := make(map[string]string) - for k, v := range m { - if s, ok := v.(string); ok { - result[k] = s - } - } - if len(result) > 0 { - if a.Platform == domain.PlatformAntigravity { - ensureAntigravityDefaultPassthroughs(result, []string{ - "gemini-3-flash", - "gemini-3.1-pro-high", - "gemini-3.1-pro-low", - }) - } - return result + + result := make(map[string]string) + for k, v := range rawMapping { + if s, ok := v.(string); ok { + result[k] = s } } + if len(result) > 0 { + if a.Platform == domain.PlatformAntigravity { + ensureAntigravityDefaultPassthroughs(result, []string{ + "gemini-3-flash", + "gemini-3.1-pro-high", + "gemini-3.1-pro-low", + }) + } + return result + } + // Antigravity 平台使用默认映射 if a.Platform == domain.PlatformAntigravity { return domain.DefaultAntigravityModelMapping @@ -389,6 +431,37 @@ func (a *Account) GetModelMapping() map[string]string { return nil } +func mapPtr(m map[string]any) uintptr { + if m == nil { + return 0 + } + return reflect.ValueOf(m).Pointer() +} + +func modelMappingSignature(rawMapping map[string]any) uint64 { + if len(rawMapping) == 0 { + return 0 + } + keys := make([]string, 0, len(rawMapping)) + for k := range rawMapping { + keys = append(keys, k) + } + sort.Strings(keys) + + h := fnv.New64a() + for _, k := range keys { + _, _ = h.Write([]byte(k)) + _, _ = h.Write([]byte{0}) + if v, ok := rawMapping[k].(string); ok { + _, _ = h.Write([]byte(v)) + } else { + _, _ = h.Write([]byte{1}) + } + _, _ = h.Write([]byte{0xff}) + } + return h.Sum64() +} + func ensureAntigravityDefaultPassthrough(mapping map[string]string, model string) { if mapping == nil || model == "" { return @@ -742,6 +815,159 @@ func (a *Account) IsOpenAIPassthroughEnabled() bool { return false } +// IsOpenAIResponsesWebSocketV2Enabled 返回 OpenAI 账号是否开启 Responses WebSocket v2。 +// +// 分类型新字段: +// - OAuth 账号:accounts.extra.openai_oauth_responses_websockets_v2_enabled +// - API Key 账号:accounts.extra.openai_apikey_responses_websockets_v2_enabled +// +// 兼容字段: +// - accounts.extra.responses_websockets_v2_enabled +// - accounts.extra.openai_ws_enabled(历史开关) +// +// 优先级: +// 1. 按账号类型读取分类型字段 +// 2. 分类型字段缺失时,回退兼容字段 +func (a *Account) IsOpenAIResponsesWebSocketV2Enabled() bool { + if a == nil || !a.IsOpenAI() || a.Extra == nil { + return false + } + if a.IsOpenAIOAuth() { + if enabled, ok := a.Extra["openai_oauth_responses_websockets_v2_enabled"].(bool); ok { + return enabled + } + } + if a.IsOpenAIApiKey() { + if enabled, ok := a.Extra["openai_apikey_responses_websockets_v2_enabled"].(bool); ok { + return enabled + } + } + if enabled, ok := a.Extra["responses_websockets_v2_enabled"].(bool); ok { + return enabled + } + if enabled, ok := a.Extra["openai_ws_enabled"].(bool); ok { + return enabled + } + return false +} + +const ( + OpenAIWSIngressModeOff = "off" + OpenAIWSIngressModeShared = "shared" + OpenAIWSIngressModeDedicated = "dedicated" +) + +func normalizeOpenAIWSIngressMode(mode string) string { + switch strings.ToLower(strings.TrimSpace(mode)) { + case OpenAIWSIngressModeOff: + return OpenAIWSIngressModeOff + case OpenAIWSIngressModeShared: + return OpenAIWSIngressModeShared + case OpenAIWSIngressModeDedicated: + return OpenAIWSIngressModeDedicated + default: + return "" + } +} + +func normalizeOpenAIWSIngressDefaultMode(mode string) string { + if normalized := normalizeOpenAIWSIngressMode(mode); normalized != "" { + return normalized + } + return OpenAIWSIngressModeShared +} + +// ResolveOpenAIResponsesWebSocketV2Mode 返回账号在 WSv2 ingress 下的有效模式(off/shared/dedicated)。 +// +// 优先级: +// 1. 分类型 mode 新字段(string) +// 2. 分类型 enabled 旧字段(bool) +// 3. 兼容 enabled 旧字段(bool) +// 4. defaultMode(非法时回退 shared) +func (a *Account) ResolveOpenAIResponsesWebSocketV2Mode(defaultMode string) string { + resolvedDefault := normalizeOpenAIWSIngressDefaultMode(defaultMode) + if a == nil || !a.IsOpenAI() { + return OpenAIWSIngressModeOff + } + if a.Extra == nil { + return resolvedDefault + } + + resolveModeString := func(key string) (string, bool) { + raw, ok := a.Extra[key] + if !ok { + return "", false + } + mode, ok := raw.(string) + if !ok { + return "", false + } + normalized := normalizeOpenAIWSIngressMode(mode) + if normalized == "" { + return "", false + } + return normalized, true + } + resolveBoolMode := func(key string) (string, bool) { + raw, ok := a.Extra[key] + if !ok { + return "", false + } + enabled, ok := raw.(bool) + if !ok { + return "", false + } + if enabled { + return OpenAIWSIngressModeShared, true + } + return OpenAIWSIngressModeOff, true + } + + if a.IsOpenAIOAuth() { + if mode, ok := resolveModeString("openai_oauth_responses_websockets_v2_mode"); ok { + return mode + } + if mode, ok := resolveBoolMode("openai_oauth_responses_websockets_v2_enabled"); ok { + return mode + } + } + if a.IsOpenAIApiKey() { + if mode, ok := resolveModeString("openai_apikey_responses_websockets_v2_mode"); ok { + return mode + } + if mode, ok := resolveBoolMode("openai_apikey_responses_websockets_v2_enabled"); ok { + return mode + } + } + if mode, ok := resolveBoolMode("responses_websockets_v2_enabled"); ok { + return mode + } + if mode, ok := resolveBoolMode("openai_ws_enabled"); ok { + return mode + } + return resolvedDefault +} + +// IsOpenAIWSForceHTTPEnabled 返回账号级“强制 HTTP”开关。 +// 字段:accounts.extra.openai_ws_force_http。 +func (a *Account) IsOpenAIWSForceHTTPEnabled() bool { + if a == nil || !a.IsOpenAI() || a.Extra == nil { + return false + } + enabled, ok := a.Extra["openai_ws_force_http"].(bool) + return ok && enabled +} + +// IsOpenAIWSAllowStoreRecoveryEnabled 返回账号级 store 恢复开关。 +// 字段:accounts.extra.openai_ws_allow_store_recovery。 +func (a *Account) IsOpenAIWSAllowStoreRecoveryEnabled() bool { + if a == nil || !a.IsOpenAI() || a.Extra == nil { + return false + } + enabled, ok := a.Extra["openai_ws_allow_store_recovery"].(bool) + return ok && enabled +} + // IsOpenAIOAuthPassthroughEnabled 兼容旧接口,等价于 OAuth 账号的 IsOpenAIPassthroughEnabled。 func (a *Account) IsOpenAIOAuthPassthroughEnabled() bool { return a != nil && a.IsOpenAIOAuth() && a.IsOpenAIPassthroughEnabled() diff --git a/backend/internal/service/account_openai_passthrough_test.go b/backend/internal/service/account_openai_passthrough_test.go index 59f8cd8c..a85c68ec 100644 --- a/backend/internal/service/account_openai_passthrough_test.go +++ b/backend/internal/service/account_openai_passthrough_test.go @@ -134,3 +134,161 @@ func TestAccount_IsCodexCLIOnlyEnabled(t *testing.T) { require.False(t, otherPlatform.IsCodexCLIOnlyEnabled()) }) } + +func TestAccount_IsOpenAIResponsesWebSocketV2Enabled(t *testing.T) { + t.Run("OAuth使用OAuth专用开关", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_enabled": true, + }, + } + require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled()) + }) + + t.Run("API Key使用API Key专用开关", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled()) + }) + + t.Run("OAuth账号不会读取API Key专用开关", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled()) + }) + + t.Run("分类型新键优先于兼容键", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_enabled": false, + "responses_websockets_v2_enabled": true, + "openai_ws_enabled": true, + }, + } + require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled()) + }) + + t.Run("分类型键缺失时回退兼容键", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled()) + }) + + t.Run("非OpenAI账号默认关闭", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled()) + }) +} + +func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) { + t.Run("default fallback to shared", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{}, + } + require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("")) + require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid")) + }) + + t.Run("oauth mode field has highest priority", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated, + "openai_oauth_responses_websockets_v2_enabled": false, + "responses_websockets_v2_enabled": false, + }, + } + require.Equal(t, OpenAIWSIngressModeDedicated, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared)) + }) + + t.Run("legacy enabled maps to shared", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + require.Equal(t, OpenAIWSIngressModeShared, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff)) + }) + + t.Run("legacy disabled maps to off", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": false, + "responses_websockets_v2_enabled": true, + }, + } + require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeShared)) + }) + + t.Run("non openai always off", func(t *testing.T) { + account := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated, + }, + } + require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeDedicated)) + }) +} + +func TestAccount_OpenAIWSExtraFlags(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_ws_force_http": true, + "openai_ws_allow_store_recovery": true, + }, + } + require.True(t, account.IsOpenAIWSForceHTTPEnabled()) + require.True(t, account.IsOpenAIWSAllowStoreRecoveryEnabled()) + + off := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{}} + require.False(t, off.IsOpenAIWSForceHTTPEnabled()) + require.False(t, off.IsOpenAIWSAllowStoreRecoveryEnabled()) + + var nilAccount *Account + require.False(t, nilAccount.IsOpenAIWSAllowStoreRecoveryEnabled()) + + nonOpenAI := &Account{ + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_ws_allow_store_recovery": true, + }, + } + require.False(t, nonOpenAI.IsOpenAIWSAllowStoreRecoveryEnabled()) +} diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index b301049f..a3707184 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -119,6 +119,10 @@ type AccountService struct { groupRepo GroupRepository } +type groupExistenceBatchChecker interface { + ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error) +} + // NewAccountService 创建账号服务实例 func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository) *AccountService { return &AccountService{ @@ -131,11 +135,8 @@ func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository) func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*Account, error) { // 验证分组是否存在(如果指定了分组) if len(req.GroupIDs) > 0 { - for _, groupID := range req.GroupIDs { - _, err := s.groupRepo.GetByID(ctx, groupID) - if err != nil { - return nil, fmt.Errorf("get group: %w", err) - } + if err := s.validateGroupIDsExist(ctx, req.GroupIDs); err != nil { + return nil, err } } @@ -256,11 +257,8 @@ func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccount // 先验证分组是否存在(在任何写操作之前) if req.GroupIDs != nil { - for _, groupID := range *req.GroupIDs { - _, err := s.groupRepo.GetByID(ctx, groupID) - if err != nil { - return nil, fmt.Errorf("get group: %w", err) - } + if err := s.validateGroupIDsExist(ctx, *req.GroupIDs); err != nil { + return nil, err } } @@ -300,6 +298,39 @@ func (s *AccountService) Delete(ctx context.Context, id int64) error { return nil } +func (s *AccountService) validateGroupIDsExist(ctx context.Context, groupIDs []int64) error { + if len(groupIDs) == 0 { + return nil + } + if s.groupRepo == nil { + return fmt.Errorf("group repository not configured") + } + + if batchChecker, ok := s.groupRepo.(groupExistenceBatchChecker); ok { + existsByID, err := batchChecker.ExistsByIDs(ctx, groupIDs) + if err != nil { + return fmt.Errorf("check groups exists: %w", err) + } + for _, groupID := range groupIDs { + if groupID <= 0 { + return fmt.Errorf("get group: %w", ErrGroupNotFound) + } + if !existsByID[groupID] { + return fmt.Errorf("get group: %w", ErrGroupNotFound) + } + } + return nil + } + + for _, groupID := range groupIDs { + _, err := s.groupRepo.GetByID(ctx, groupID) + if err != nil { + return fmt.Errorf("get group: %w", err) + } + } + return nil +} + // UpdateStatus 更新账号状态 func (s *AccountService) UpdateStatus(ctx context.Context, id int64, status string, errorMessage string) error { account, err := s.accountRepo.GetByID(ctx, id) diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index a507efb4..c55e418d 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -598,9 +598,102 @@ func ceilSeconds(d time.Duration) int { return sec } +// testSoraAPIKeyAccountConnection 测试 Sora apikey 类型账号的连通性。 +// 向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性和 API Key 有效性。 +func (s *AccountTestService) testSoraAPIKeyAccountConnection(c *gin.Context, account *Account) error { + ctx := c.Request.Context() + + apiKey := account.GetCredential("api_key") + if apiKey == "" { + return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 api_key 凭证") + } + + baseURL := account.GetBaseURL() + if baseURL == "" { + return s.sendErrorAndEnd(c, "Sora apikey 账号缺少 base_url") + } + + // 验证 base_url 格式 + normalizedBaseURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("base_url 无效: %s", err.Error())) + } + upstreamURL := strings.TrimSuffix(normalizedBaseURL, "/") + "/sora/v1/chat/completions" + + // 设置 SSE 头 + c.Writer.Header().Set("Content-Type", "text/event-stream") + c.Writer.Header().Set("Cache-Control", "no-cache") + c.Writer.Header().Set("Connection", "keep-alive") + c.Writer.Header().Set("X-Accel-Buffering", "no") + c.Writer.Flush() + + if wait, ok := s.acquireSoraTestPermit(account.ID); !ok { + msg := fmt.Sprintf("Sora 账号测试过于频繁,请 %d 秒后重试", ceilSeconds(wait)) + return s.sendErrorAndEnd(c, msg) + } + + s.sendEvent(c, TestEvent{Type: "test_start", Model: "sora-upstream"}) + + // 构建轻量级 prompt-enhance 请求作为连通性测试 + testPayload := map[string]any{ + "model": "prompt-enhance-short-10s", + "messages": []map[string]string{{"role": "user", "content": "test"}}, + "stream": false, + } + payloadBytes, _ := json.Marshal(testPayload) + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(payloadBytes)) + if err != nil { + return s.sendErrorAndEnd(c, "构建测试请求失败") + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+apiKey) + + // 获取代理 URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("上游连接失败: %s", err.Error())) + } + defer func() { _ = resp.Body.Close() }() + + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024)) + + if resp.StatusCode == http.StatusOK { + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)}) + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效 (HTTP %d)", resp.StatusCode)}) + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil + } + + if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden { + return s.sendErrorAndEnd(c, fmt.Sprintf("上游认证失败 (HTTP %d),请检查 API Key 是否正确", resp.StatusCode)) + } + + // 其他错误但能连通(如 400 参数错误)也算连通性测试通过 + if resp.StatusCode == http.StatusBadRequest { + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("上游连接成功 (%s)", upstreamURL)}) + s.sendEvent(c, TestEvent{Type: "content", Text: fmt.Sprintf("API Key 有效(上游返回 %d,参数校验错误属正常)", resp.StatusCode)}) + s.sendEvent(c, TestEvent{Type: "test_complete", Success: true}) + return nil + } + + return s.sendErrorAndEnd(c, fmt.Sprintf("上游返回异常 HTTP %d: %s", resp.StatusCode, truncateSoraErrorBody(respBody, 256))) +} + // testSoraAccountConnection 测试 Sora 账号的连接 -// 调用 /backend/me 接口验证 access_token 有效性(不需要 Sentinel Token) +// OAuth 类型:调用 /backend/me 接口验证 access_token 有效性 +// APIKey 类型:向上游 base_url 发送轻量级 prompt-enhance 请求验证连通性 func (s *AccountTestService) testSoraAccountConnection(c *gin.Context, account *Account) error { + // apikey 类型走独立测试流程 + if account.Type == AccountTypeAPIKey { + return s.testSoraAPIKeyAccountConnection(c, account) + } + ctx := c.Request.Context() recorder := &soraProbeRecorder{} diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index a363a790..13a13856 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -9,7 +9,9 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "golang.org/x/sync/errgroup" ) type UsageLogRepository interface { @@ -33,8 +35,8 @@ type UsageLogRepository interface { // Admin dashboard stats GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) - GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) - GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) + GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) + GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) @@ -62,6 +64,10 @@ type UsageLogRepository interface { GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error) } +type accountWindowStatsBatchReader interface { + GetAccountWindowStatsBatch(ctx context.Context, accountIDs []int64, startTime time.Time) (map[int64]*usagestats.AccountStats, error) +} + // apiUsageCache 缓存从 Anthropic API 获取的使用率数据(utilization, resets_at) type apiUsageCache struct { response *ClaudeUsageResponse @@ -297,7 +303,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou } dayStart := geminiDailyWindowStart(now) - stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil, nil) + stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil, nil, nil) if err != nil { return nil, fmt.Errorf("get gemini usage stats failed: %w", err) } @@ -319,7 +325,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou // Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m) minuteStart := now.Truncate(time.Minute) minuteResetAt := minuteStart.Add(time.Minute) - minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil, nil) + minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil, nil, nil) if err != nil { return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err) } @@ -440,6 +446,78 @@ func (s *AccountUsageService) GetTodayStats(ctx context.Context, accountID int64 }, nil } +// GetTodayStatsBatch 批量获取账号今日统计,优先走批量 SQL,失败时回退单账号查询。 +func (s *AccountUsageService) GetTodayStatsBatch(ctx context.Context, accountIDs []int64) (map[int64]*WindowStats, error) { + uniqueIDs := make([]int64, 0, len(accountIDs)) + seen := make(map[int64]struct{}, len(accountIDs)) + for _, accountID := range accountIDs { + if accountID <= 0 { + continue + } + if _, exists := seen[accountID]; exists { + continue + } + seen[accountID] = struct{}{} + uniqueIDs = append(uniqueIDs, accountID) + } + + result := make(map[int64]*WindowStats, len(uniqueIDs)) + if len(uniqueIDs) == 0 { + return result, nil + } + + startTime := timezone.Today() + if batchReader, ok := s.usageLogRepo.(accountWindowStatsBatchReader); ok { + statsByAccount, err := batchReader.GetAccountWindowStatsBatch(ctx, uniqueIDs, startTime) + if err == nil { + for _, accountID := range uniqueIDs { + result[accountID] = windowStatsFromAccountStats(statsByAccount[accountID]) + } + return result, nil + } + } + + var mu sync.Mutex + g, gctx := errgroup.WithContext(ctx) + g.SetLimit(8) + + for _, accountID := range uniqueIDs { + id := accountID + g.Go(func() error { + stats, err := s.usageLogRepo.GetAccountWindowStats(gctx, id, startTime) + if err != nil { + return nil + } + mu.Lock() + result[id] = windowStatsFromAccountStats(stats) + mu.Unlock() + return nil + }) + } + + _ = g.Wait() + + for _, accountID := range uniqueIDs { + if _, ok := result[accountID]; !ok { + result[accountID] = &WindowStats{} + } + } + return result, nil +} + +func windowStatsFromAccountStats(stats *usagestats.AccountStats) *WindowStats { + if stats == nil { + return &WindowStats{} + } + return &WindowStats{ + Requests: stats.Requests, + Tokens: stats.Tokens, + Cost: stats.Cost, + StandardCost: stats.StandardCost, + UserCost: stats.UserCost, + } +} + func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.AccountUsageStatsResponse, error) { stats, err := s.usageLogRepo.GetAccountUsageStats(ctx, accountID, startTime, endTime) if err != nil { diff --git a/backend/internal/service/account_wildcard_test.go b/backend/internal/service/account_wildcard_test.go index 6a9acc68..7782f948 100644 --- a/backend/internal/service/account_wildcard_test.go +++ b/backend/internal/service/account_wildcard_test.go @@ -314,3 +314,72 @@ func TestAccountGetModelMapping_AntigravityRespectsWildcardOverride(t *testing.T t.Fatalf("expected wildcard mapping to stay effective, got: %q", mapped) } } + +func TestAccountGetModelMapping_CacheInvalidatesOnCredentialsReplace(t *testing.T) { + account := &Account{ + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "claude-3-5-sonnet": "upstream-a", + }, + }, + } + + first := account.GetModelMapping() + if first["claude-3-5-sonnet"] != "upstream-a" { + t.Fatalf("unexpected first mapping: %v", first) + } + + account.Credentials = map[string]any{ + "model_mapping": map[string]any{ + "claude-3-5-sonnet": "upstream-b", + }, + } + second := account.GetModelMapping() + if second["claude-3-5-sonnet"] != "upstream-b" { + t.Fatalf("expected cache invalidated after credentials replace, got: %v", second) + } +} + +func TestAccountGetModelMapping_CacheInvalidatesOnMappingLenChange(t *testing.T) { + rawMapping := map[string]any{ + "claude-sonnet": "sonnet-a", + } + account := &Account{ + Credentials: map[string]any{ + "model_mapping": rawMapping, + }, + } + + first := account.GetModelMapping() + if len(first) != 1 { + t.Fatalf("unexpected first mapping length: %d", len(first)) + } + + rawMapping["claude-opus"] = "opus-b" + second := account.GetModelMapping() + if second["claude-opus"] != "opus-b" { + t.Fatalf("expected cache invalidated after mapping len change, got: %v", second) + } +} + +func TestAccountGetModelMapping_CacheInvalidatesOnInPlaceValueChange(t *testing.T) { + rawMapping := map[string]any{ + "claude-sonnet": "sonnet-a", + } + account := &Account{ + Credentials: map[string]any{ + "model_mapping": rawMapping, + }, + } + + first := account.GetModelMapping() + if first["claude-sonnet"] != "sonnet-a" { + t.Fatalf("unexpected first mapping: %v", first) + } + + rawMapping["claude-sonnet"] = "sonnet-b" + second := account.GetModelMapping() + if second["claude-sonnet"] != "sonnet-b" { + t.Fatalf("expected cache invalidated after in-place value change, got: %v", second) + } +} diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 47339661..52a07b01 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -83,13 +83,14 @@ type AdminService interface { // CreateUserInput represents input for creating a new user via admin operations. type CreateUserInput struct { - Email string - Password string - Username string - Notes string - Balance float64 - Concurrency int - AllowedGroups []int64 + Email string + Password string + Username string + Notes string + Balance float64 + Concurrency int + AllowedGroups []int64 + SoraStorageQuotaBytes int64 } type UpdateUserInput struct { @@ -103,7 +104,8 @@ type UpdateUserInput struct { AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组" // GroupRates 用户专属分组倍率配置 // map[groupID]*rate,nil 表示删除该分组的专属倍率 - GroupRates map[int64]*float64 + GroupRates map[int64]*float64 + SoraStorageQuotaBytes *int64 } type CreateGroupInput struct { @@ -135,6 +137,8 @@ type CreateGroupInput struct { MCPXMLInject *bool // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes []string + // Sora 存储配额 + SoraStorageQuotaBytes int64 // 从指定分组复制账号(创建分组后在同一事务内绑定) CopyAccountsFromGroupIDs []int64 } @@ -169,6 +173,8 @@ type UpdateGroupInput struct { MCPXMLInject *bool // 支持的模型系列(仅 antigravity 平台使用) SupportedModelScopes *[]string + // Sora 存储配额 + SoraStorageQuotaBytes *int64 // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 } @@ -402,6 +408,14 @@ type adminServiceImpl struct { authCacheInvalidator APIKeyAuthCacheInvalidator } +type userGroupRateBatchReader interface { + GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) +} + +type groupExistenceBatchReader interface { + ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error) +} + // NewAdminService creates a new AdminService func NewAdminService( userRepo UserRepository, @@ -442,18 +456,43 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi } // 批量加载用户专属分组倍率 if s.userGroupRateRepo != nil && len(users) > 0 { - for i := range users { - rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID) - if err != nil { - logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", users[i].ID, err) - continue + if batchRepo, ok := s.userGroupRateRepo.(userGroupRateBatchReader); ok { + userIDs := make([]int64, 0, len(users)) + for i := range users { + userIDs = append(userIDs, users[i].ID) } - users[i].GroupRates = rates + ratesByUser, err := batchRepo.GetByUserIDs(ctx, userIDs) + if err != nil { + logger.LegacyPrintf("service.admin", "failed to load user group rates in batch: err=%v", err) + s.loadUserGroupRatesOneByOne(ctx, users) + } else { + for i := range users { + if rates, ok := ratesByUser[users[i].ID]; ok { + users[i].GroupRates = rates + } + } + } + } else { + s.loadUserGroupRatesOneByOne(ctx, users) } } return users, result.Total, nil } +func (s *adminServiceImpl) loadUserGroupRatesOneByOne(ctx context.Context, users []User) { + if s.userGroupRateRepo == nil { + return + } + for i := range users { + rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID) + if err != nil { + logger.LegacyPrintf("service.admin", "failed to load user group rates: user_id=%d err=%v", users[i].ID, err) + continue + } + users[i].GroupRates = rates + } +} + func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) { user, err := s.userRepo.GetByID(ctx, id) if err != nil { @@ -473,14 +512,15 @@ func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) { user := &User{ - Email: input.Email, - Username: input.Username, - Notes: input.Notes, - Role: RoleUser, // Always create as regular user, never admin - Balance: input.Balance, - Concurrency: input.Concurrency, - Status: StatusActive, - AllowedGroups: input.AllowedGroups, + Email: input.Email, + Username: input.Username, + Notes: input.Notes, + Role: RoleUser, // Always create as regular user, never admin + Balance: input.Balance, + Concurrency: input.Concurrency, + Status: StatusActive, + AllowedGroups: input.AllowedGroups, + SoraStorageQuotaBytes: input.SoraStorageQuotaBytes, } if err := user.SetPassword(input.Password); err != nil { return nil, err @@ -534,6 +574,10 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda user.AllowedGroups = *input.AllowedGroups } + if input.SoraStorageQuotaBytes != nil { + user.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes + } + if err := s.userRepo.Update(ctx, user); err != nil { return nil, err } @@ -820,6 +864,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn ModelRouting: input.ModelRouting, MCPXMLInject: mcpXMLInject, SupportedModelScopes: input.SupportedModelScopes, + SoraStorageQuotaBytes: input.SoraStorageQuotaBytes, } if err := s.groupRepo.Create(ctx, group); err != nil { return nil, err @@ -982,6 +1027,9 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd if input.SoraVideoPricePerRequestHD != nil { group.SoraVideoPricePerRequestHD = normalizePrice(input.SoraVideoPricePerRequestHD) } + if input.SoraStorageQuotaBytes != nil { + group.SoraStorageQuotaBytes = *input.SoraStorageQuotaBytes + } // Claude Code 客户端限制 if input.ClaudeCodeOnly != nil { @@ -1188,6 +1236,18 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou } } + // Sora apikey 账号的 base_url 必填校验 + if input.Platform == PlatformSora && input.Type == AccountTypeAPIKey { + baseURL, _ := input.Credentials["base_url"].(string) + baseURL = strings.TrimSpace(baseURL) + if baseURL == "" { + return nil, errors.New("sora apikey 账号必须设置 base_url") + } + if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { + return nil, errors.New("base_url 必须以 http:// 或 https:// 开头") + } + } + account := &Account{ Name: input.Name, Notes: normalizeAccountNotes(input.Notes), @@ -1301,12 +1361,22 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U account.AutoPauseOnExpired = *input.AutoPauseOnExpired } + // Sora apikey 账号的 base_url 必填校验 + if account.Platform == PlatformSora && account.Type == AccountTypeAPIKey { + baseURL, _ := account.Credentials["base_url"].(string) + baseURL = strings.TrimSpace(baseURL) + if baseURL == "" { + return nil, errors.New("sora apikey 账号必须设置 base_url") + } + if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { + return nil, errors.New("base_url 必须以 http:// 或 https:// 开头") + } + } + // 先验证分组是否存在(在任何写操作之前) if input.GroupIDs != nil { - for _, groupID := range *input.GroupIDs { - if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil { - return nil, fmt.Errorf("get group: %w", err) - } + if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil { + return nil, err } // 检查混合渠道风险(除非用户已确认) @@ -1348,11 +1418,18 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp if len(input.AccountIDs) == 0 { return result, nil } + if input.GroupIDs != nil { + if err := s.validateGroupIDsExist(ctx, *input.GroupIDs); err != nil { + return nil, err + } + } needMixedChannelCheck := input.GroupIDs != nil && !input.SkipMixedChannelCheck // 预加载账号平台信息(混合渠道检查或 Sora 同步需要)。 platformByID := map[int64]string{} + groupAccountsByID := map[int64][]Account{} + groupNameByID := map[int64]string{} if needMixedChannelCheck { accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs) if err != nil { @@ -1366,6 +1443,13 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp } } } + + loadedAccounts, loadedNames, err := s.preloadMixedChannelRiskData(ctx, *input.GroupIDs) + if err != nil { + return nil, err + } + groupAccountsByID = loadedAccounts + groupNameByID = loadedNames } if input.RateMultiplier != nil { @@ -1409,11 +1493,12 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp // Handle group bindings per account (requires individual operations). for _, accountID := range input.AccountIDs { entry := BulkUpdateAccountResult{AccountID: accountID} + platform := "" if input.GroupIDs != nil { // 检查混合渠道风险(除非用户已确认) if !input.SkipMixedChannelCheck { - platform := platformByID[accountID] + platform = platformByID[accountID] if platform == "" { account, err := s.accountRepo.GetByID(ctx, accountID) if err != nil { @@ -1426,7 +1511,7 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp } platform = account.Platform } - if err := s.checkMixedChannelRisk(ctx, accountID, platform, *input.GroupIDs); err != nil { + if err := s.checkMixedChannelRiskWithPreloaded(accountID, platform, *input.GroupIDs, groupAccountsByID, groupNameByID); err != nil { entry.Success = false entry.Error = err.Error() result.Failed++ @@ -1444,6 +1529,9 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp result.Results = append(result.Results, entry) continue } + if !input.SkipMixedChannelCheck && platform != "" { + updateMixedChannelPreloadedAccounts(groupAccountsByID, *input.GroupIDs, accountID, platform) + } } entry.Success = true @@ -2115,6 +2203,135 @@ func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAcc return nil } +func (s *adminServiceImpl) preloadMixedChannelRiskData(ctx context.Context, groupIDs []int64) (map[int64][]Account, map[int64]string, error) { + accountsByGroup := make(map[int64][]Account) + groupNameByID := make(map[int64]string) + if len(groupIDs) == 0 { + return accountsByGroup, groupNameByID, nil + } + + seen := make(map[int64]struct{}, len(groupIDs)) + for _, groupID := range groupIDs { + if groupID <= 0 { + continue + } + if _, ok := seen[groupID]; ok { + continue + } + seen[groupID] = struct{}{} + + accounts, err := s.accountRepo.ListByGroup(ctx, groupID) + if err != nil { + return nil, nil, fmt.Errorf("get accounts in group %d: %w", groupID, err) + } + accountsByGroup[groupID] = accounts + + group, err := s.groupRepo.GetByID(ctx, groupID) + if err != nil { + continue + } + if group != nil { + groupNameByID[groupID] = group.Name + } + } + + return accountsByGroup, groupNameByID, nil +} + +func (s *adminServiceImpl) validateGroupIDsExist(ctx context.Context, groupIDs []int64) error { + if len(groupIDs) == 0 { + return nil + } + if s.groupRepo == nil { + return errors.New("group repository not configured") + } + + if batchReader, ok := s.groupRepo.(groupExistenceBatchReader); ok { + existsByID, err := batchReader.ExistsByIDs(ctx, groupIDs) + if err != nil { + return fmt.Errorf("check groups exists: %w", err) + } + for _, groupID := range groupIDs { + if groupID <= 0 || !existsByID[groupID] { + return fmt.Errorf("get group: %w", ErrGroupNotFound) + } + } + return nil + } + + for _, groupID := range groupIDs { + if _, err := s.groupRepo.GetByID(ctx, groupID); err != nil { + return fmt.Errorf("get group: %w", err) + } + } + return nil +} + +func (s *adminServiceImpl) checkMixedChannelRiskWithPreloaded(currentAccountID int64, currentAccountPlatform string, groupIDs []int64, accountsByGroup map[int64][]Account, groupNameByID map[int64]string) error { + currentPlatform := getAccountPlatform(currentAccountPlatform) + if currentPlatform == "" { + return nil + } + + for _, groupID := range groupIDs { + accounts := accountsByGroup[groupID] + for _, account := range accounts { + if currentAccountID > 0 && account.ID == currentAccountID { + continue + } + + otherPlatform := getAccountPlatform(account.Platform) + if otherPlatform == "" { + continue + } + + if currentPlatform != otherPlatform { + groupName := fmt.Sprintf("Group %d", groupID) + if name := strings.TrimSpace(groupNameByID[groupID]); name != "" { + groupName = name + } + + return &MixedChannelError{ + GroupID: groupID, + GroupName: groupName, + CurrentPlatform: currentPlatform, + OtherPlatform: otherPlatform, + } + } + } + } + + return nil +} + +func updateMixedChannelPreloadedAccounts(accountsByGroup map[int64][]Account, groupIDs []int64, accountID int64, platform string) { + if len(groupIDs) == 0 || accountID <= 0 || platform == "" { + return + } + for _, groupID := range groupIDs { + if groupID <= 0 { + continue + } + accounts := accountsByGroup[groupID] + found := false + for i := range accounts { + if accounts[i].ID != accountID { + continue + } + accounts[i].Platform = platform + found = true + break + } + if !found { + accounts = append(accounts, Account{ + ID: accountID, + Platform: platform, + }) + } + accountsByGroup[groupID] = accounts + } +} + // CheckMixedChannelRisk checks whether target groups contain mixed channels for the current account platform. func (s *adminServiceImpl) CheckMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error { return s.checkMixedChannelRisk(ctx, currentAccountID, currentAccountPlatform, groupIDs) diff --git a/backend/internal/service/admin_service_bulk_update_test.go b/backend/internal/service/admin_service_bulk_update_test.go index 0dccacbb..647a84a9 100644 --- a/backend/internal/service/admin_service_bulk_update_test.go +++ b/backend/internal/service/admin_service_bulk_update_test.go @@ -15,6 +15,7 @@ type accountRepoStubForBulkUpdate struct { bulkUpdateErr error bulkUpdateIDs []int64 bindGroupErrByID map[int64]error + bindGroupsCalls []int64 getByIDsAccounts []*Account getByIDsErr error getByIDsCalled bool @@ -22,6 +23,8 @@ type accountRepoStubForBulkUpdate struct { getByIDAccounts map[int64]*Account getByIDErrByID map[int64]error getByIDCalled []int64 + listByGroupData map[int64][]Account + listByGroupErr map[int64]error } func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) { @@ -33,6 +36,7 @@ func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64 } func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID int64, _ []int64) error { + s.bindGroupsCalls = append(s.bindGroupsCalls, accountID) if err, ok := s.bindGroupErrByID[accountID]; ok { return err } @@ -59,6 +63,16 @@ func (s *accountRepoStubForBulkUpdate) GetByID(_ context.Context, id int64) (*Ac return nil, errors.New("account not found") } +func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID int64) ([]Account, error) { + if err, ok := s.listByGroupErr[groupID]; ok { + return nil, err + } + if rows, ok := s.listByGroupData[groupID]; ok { + return rows, nil + } + return nil, nil +} + // TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。 func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) { repo := &accountRepoStubForBulkUpdate{} @@ -86,7 +100,10 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) { 2: errors.New("bind failed"), }, } - svc := &adminServiceImpl{accountRepo: repo} + svc := &adminServiceImpl{ + accountRepo: repo, + groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "g10"}}, + } groupIDs := []int64{10} schedulable := false @@ -105,3 +122,51 @@ func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) { require.ElementsMatch(t, []int64{2}, result.FailedIDs) require.Len(t, result.Results, 3) } + +func TestAdminService_BulkUpdateAccounts_NilGroupRepoReturnsError(t *testing.T) { + repo := &accountRepoStubForBulkUpdate{} + svc := &adminServiceImpl{accountRepo: repo} + + groupIDs := []int64{10} + input := &BulkUpdateAccountsInput{ + AccountIDs: []int64{1}, + GroupIDs: &groupIDs, + } + + result, err := svc.BulkUpdateAccounts(context.Background(), input) + require.Nil(t, result) + require.Error(t, err) + require.Contains(t, err.Error(), "group repository not configured") +} + +func TestAdminService_BulkUpdateAccounts_MixedChannelCheckUsesUpdatedSnapshot(t *testing.T) { + repo := &accountRepoStubForBulkUpdate{ + getByIDsAccounts: []*Account{ + {ID: 1, Platform: PlatformAnthropic}, + {ID: 2, Platform: PlatformAntigravity}, + }, + listByGroupData: map[int64][]Account{ + 10: {}, + }, + } + svc := &adminServiceImpl{ + accountRepo: repo, + groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "目标分组"}}, + } + + groupIDs := []int64{10} + input := &BulkUpdateAccountsInput{ + AccountIDs: []int64{1, 2}, + GroupIDs: &groupIDs, + } + + result, err := svc.BulkUpdateAccounts(context.Background(), input) + require.NoError(t, err) + require.Equal(t, 1, result.Success) + require.Equal(t, 1, result.Failed) + require.ElementsMatch(t, []int64{1}, result.SuccessIDs) + require.ElementsMatch(t, []int64{2}, result.FailedIDs) + require.Len(t, result.Results, 2) + require.Contains(t, result.Results[1].Error, "mixed channel") + require.Equal(t, []int64{1}, repo.bindGroupsCalls) +} diff --git a/backend/internal/service/admin_service_list_users_test.go b/backend/internal/service/admin_service_list_users_test.go new file mode 100644 index 00000000..8b50530a --- /dev/null +++ b/backend/internal/service/admin_service_list_users_test.go @@ -0,0 +1,106 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +type userRepoStubForListUsers struct { + userRepoStub + users []User + err error +} + +func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pagination.PaginationParams, _ UserListFilters) ([]User, *pagination.PaginationResult, error) { + if s.err != nil { + return nil, nil, s.err + } + out := make([]User, len(s.users)) + copy(out, s.users) + return out, &pagination.PaginationResult{ + Total: int64(len(out)), + Page: params.Page, + PageSize: params.PageSize, + }, nil +} + +type userGroupRateRepoStubForListUsers struct { + batchCalls int + singleCall []int64 + + batchErr error + batchData map[int64]map[int64]float64 + + singleErr map[int64]error + singleData map[int64]map[int64]float64 +} + +func (s *userGroupRateRepoStubForListUsers) GetByUserIDs(_ context.Context, _ []int64) (map[int64]map[int64]float64, error) { + s.batchCalls++ + if s.batchErr != nil { + return nil, s.batchErr + } + return s.batchData, nil +} + +func (s *userGroupRateRepoStubForListUsers) GetByUserID(_ context.Context, userID int64) (map[int64]float64, error) { + s.singleCall = append(s.singleCall, userID) + if err, ok := s.singleErr[userID]; ok { + return nil, err + } + if rates, ok := s.singleData[userID]; ok { + return rates, nil + } + return map[int64]float64{}, nil +} + +func (s *userGroupRateRepoStubForListUsers) GetByUserAndGroup(_ context.Context, userID, groupID int64) (*float64, error) { + panic("unexpected GetByUserAndGroup call") +} + +func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context, userID int64, rates map[int64]*float64) error { + panic("unexpected SyncUserGroupRates call") +} + +func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, groupID int64) error { + panic("unexpected DeleteByGroupID call") +} + +func (s *userGroupRateRepoStubForListUsers) DeleteByUserID(_ context.Context, userID int64) error { + panic("unexpected DeleteByUserID call") +} + +func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) { + userRepo := &userRepoStubForListUsers{ + users: []User{ + {ID: 101, Username: "u1"}, + {ID: 202, Username: "u2"}, + }, + } + rateRepo := &userGroupRateRepoStubForListUsers{ + batchErr: errors.New("batch unavailable"), + singleData: map[int64]map[int64]float64{ + 101: {11: 1.1}, + 202: {22: 2.2}, + }, + } + svc := &adminServiceImpl{ + userRepo: userRepo, + userGroupRateRepo: rateRepo, + } + + users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{}) + require.NoError(t, err) + require.Equal(t, int64(2), total) + require.Len(t, users, 2) + require.Equal(t, 1, rateRepo.batchCalls) + require.ElementsMatch(t, []int64{101, 202}, rateRepo.singleCall) + require.Equal(t, 1.1, users[0].GroupRates[11]) + require.Equal(t, 2.2, users[1].GroupRates[22]) +} diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 2bd6195a..96ff3354 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -21,7 +21,6 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" - "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/gin-gonic/gin" "github.com/google/uuid" @@ -2291,7 +2290,7 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool { // isSingleAccountRetry 检查 context 中是否设置了单账号退避重试标记 func isSingleAccountRetry(ctx context.Context) bool { - v, _ := ctx.Value(ctxkey.SingleAccountRetry).(bool) + v, _ := SingleAccountRetryFromContext(ctx) return v } diff --git a/backend/internal/service/api_key.go b/backend/internal/service/api_key.go index fe1b3a5d..07523597 100644 --- a/backend/internal/service/api_key.go +++ b/backend/internal/service/api_key.go @@ -1,6 +1,10 @@ package service -import "time" +import ( + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ip" +) // API Key status constants const ( @@ -19,11 +23,14 @@ type APIKey struct { Status string IPWhitelist []string IPBlacklist []string - LastUsedAt *time.Time - CreatedAt time.Time - UpdatedAt time.Time - User *User - Group *Group + // 预编译的 IP 规则,用于认证热路径避免重复 ParseIP/ParseCIDR。 + CompiledIPWhitelist *ip.CompiledIPRules `json:"-"` + CompiledIPBlacklist *ip.CompiledIPRules `json:"-"` + LastUsedAt *time.Time + CreatedAt time.Time + UpdatedAt time.Time + User *User + Group *Group // Quota fields Quota float64 // Quota limit in USD (0 = unlimited) diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index 77a75674..30eb8d74 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -298,5 +298,6 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho SupportedModelScopes: snapshot.Group.SupportedModelScopes, } } + s.compileAPIKeyIPRules(apiKey) return apiKey } diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index c5e1cfab..0d073077 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -158,6 +158,14 @@ func NewAPIKeyService( return svc } +func (s *APIKeyService) compileAPIKeyIPRules(apiKey *APIKey) { + if apiKey == nil { + return + } + apiKey.CompiledIPWhitelist = ip.CompileIPRules(apiKey.IPWhitelist) + apiKey.CompiledIPBlacklist = ip.CompileIPRules(apiKey.IPBlacklist) +} + // GenerateKey 生成随机API Key func (s *APIKeyService) GenerateKey() (string, error) { // 生成32字节随机数据 @@ -332,6 +340,7 @@ func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIK } s.InvalidateAuthCacheByKey(ctx, apiKey.Key) + s.compileAPIKeyIPRules(apiKey) return apiKey, nil } @@ -363,6 +372,7 @@ func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error) if err != nil { return nil, fmt.Errorf("get api key: %w", err) } + s.compileAPIKeyIPRules(apiKey) return apiKey, nil } @@ -375,6 +385,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro if err != nil { return nil, fmt.Errorf("get api key: %w", err) } + s.compileAPIKeyIPRules(apiKey) return apiKey, nil } } @@ -391,6 +402,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro if err != nil { return nil, fmt.Errorf("get api key: %w", err) } + s.compileAPIKeyIPRules(apiKey) return apiKey, nil } } else { @@ -402,6 +414,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro if err != nil { return nil, fmt.Errorf("get api key: %w", err) } + s.compileAPIKeyIPRules(apiKey) return apiKey, nil } } @@ -411,6 +424,7 @@ func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, erro return nil, fmt.Errorf("get api key: %w", err) } apiKey.Key = key + s.compileAPIKeyIPRules(apiKey) return apiKey, nil } @@ -510,6 +524,7 @@ func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req } s.InvalidateAuthCacheByKey(ctx, apiKey.Key) + s.compileAPIKeyIPRules(apiKey) return apiKey, nil } diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 73f59dd0..eae7bd53 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -308,6 +308,17 @@ func (s *AuthService) SendVerifyCodeAsync(ctx context.Context, email string) (*S }, nil } +// VerifyTurnstileForRegister 在注册场景下验证 Turnstile。 +// 当邮箱验证开启且已提交验证码时,说明验证码发送阶段已完成 Turnstile 校验, +// 此处跳过二次校验,避免一次性 token 在注册提交时重复使用导致误报失败。 +func (s *AuthService) VerifyTurnstileForRegister(ctx context.Context, token, remoteIP, verifyCode string) error { + if s.IsEmailVerifyEnabled(ctx) && strings.TrimSpace(verifyCode) != "" { + logger.LegacyPrintf("service.auth", "%s", "[Auth] Email verify flow detected, skip duplicate Turnstile check on register") + return nil + } + return s.VerifyTurnstile(ctx, token, remoteIP) +} + // VerifyTurnstile 验证Turnstile token func (s *AuthService) VerifyTurnstile(ctx context.Context, token string, remoteIP string) error { required := s.cfg != nil && s.cfg.Server.Mode == "release" && s.cfg.Turnstile.Required diff --git a/backend/internal/service/auth_service_turnstile_register_test.go b/backend/internal/service/auth_service_turnstile_register_test.go new file mode 100644 index 00000000..7dd9edca --- /dev/null +++ b/backend/internal/service/auth_service_turnstile_register_test.go @@ -0,0 +1,96 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type turnstileVerifierSpy struct { + called int + lastToken string + result *TurnstileVerifyResponse + err error +} + +func (s *turnstileVerifierSpy) VerifyToken(_ context.Context, _ string, token, _ string) (*TurnstileVerifyResponse, error) { + s.called++ + s.lastToken = token + if s.err != nil { + return nil, s.err + } + if s.result != nil { + return s.result, nil + } + return &TurnstileVerifyResponse{Success: true}, nil +} + +func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier TurnstileVerifier) *AuthService { + cfg := &config.Config{ + Server: config.ServerConfig{ + Mode: "release", + }, + Turnstile: config.TurnstileConfig{ + Required: true, + }, + } + + settingService := NewSettingService(&settingRepoStub{values: settings}, cfg) + turnstileService := NewTurnstileService(settingService, verifier) + + return NewAuthService( + &userRepoStub{}, + nil, // redeemRepo + nil, // refreshTokenCache + cfg, + settingService, + nil, // emailService + turnstileService, + nil, // emailQueueService + nil, // promoService + ) +} + +func TestAuthService_VerifyTurnstileForRegister_SkipWhenEmailVerifyCodeProvided(t *testing.T) { + verifier := &turnstileVerifierSpy{} + service := newAuthServiceForRegisterTurnstileTest(map[string]string{ + SettingKeyEmailVerifyEnabled: "true", + SettingKeyTurnstileEnabled: "true", + SettingKeyTurnstileSecretKey: "secret", + SettingKeyRegistrationEnabled: "true", + }, verifier) + + err := service.VerifyTurnstileForRegister(context.Background(), "", "127.0.0.1", "123456") + require.NoError(t, err) + require.Equal(t, 0, verifier.called) +} + +func TestAuthService_VerifyTurnstileForRegister_RequireWhenVerifyCodeMissing(t *testing.T) { + verifier := &turnstileVerifierSpy{} + service := newAuthServiceForRegisterTurnstileTest(map[string]string{ + SettingKeyEmailVerifyEnabled: "true", + SettingKeyTurnstileEnabled: "true", + SettingKeyTurnstileSecretKey: "secret", + }, verifier) + + err := service.VerifyTurnstileForRegister(context.Background(), "", "127.0.0.1", "") + require.ErrorIs(t, err, ErrTurnstileVerificationFailed) +} + +func TestAuthService_VerifyTurnstileForRegister_NoSkipWhenEmailVerifyDisabled(t *testing.T) { + verifier := &turnstileVerifierSpy{} + service := newAuthServiceForRegisterTurnstileTest(map[string]string{ + SettingKeyEmailVerifyEnabled: "false", + SettingKeyTurnstileEnabled: "true", + SettingKeyTurnstileSecretKey: "secret", + }, verifier) + + err := service.VerifyTurnstileForRegister(context.Background(), "turnstile-token", "127.0.0.1", "123456") + require.NoError(t, err) + require.Equal(t, 1, verifier.called) + require.Equal(t, "turnstile-token", verifier.lastToken) +} diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go index a560930b..1a76f5f6 100644 --- a/backend/internal/service/billing_cache_service.go +++ b/backend/internal/service/billing_cache_service.go @@ -3,6 +3,7 @@ package service import ( "context" "fmt" + "strconv" "sync" "sync/atomic" "time" @@ -10,6 +11,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "golang.org/x/sync/singleflight" ) // 错误定义 @@ -58,6 +60,7 @@ const ( cacheWriteBufferSize = 1000 // 任务队列缓冲大小 cacheWriteTimeout = 2 * time.Second // 单个写入操作超时 cacheWriteDropLogInterval = 5 * time.Second // 丢弃日志节流间隔 + balanceLoadTimeout = 3 * time.Second ) // cacheWriteTask 缓存写入任务 @@ -82,6 +85,9 @@ type BillingCacheService struct { cacheWriteChan chan cacheWriteTask cacheWriteWg sync.WaitGroup cacheWriteStopOnce sync.Once + cacheWriteMu sync.RWMutex + stopped atomic.Bool + balanceLoadSF singleflight.Group // 丢弃日志节流计数器(减少高负载下日志噪音) cacheWriteDropFullCount uint64 cacheWriteDropFullLastLog int64 @@ -105,35 +111,52 @@ func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo // Stop 关闭缓存写入工作池 func (s *BillingCacheService) Stop() { s.cacheWriteStopOnce.Do(func() { - if s.cacheWriteChan == nil { + s.stopped.Store(true) + + s.cacheWriteMu.Lock() + ch := s.cacheWriteChan + if ch != nil { + close(ch) + } + s.cacheWriteMu.Unlock() + + if ch == nil { return } - close(s.cacheWriteChan) s.cacheWriteWg.Wait() - s.cacheWriteChan = nil + + s.cacheWriteMu.Lock() + if s.cacheWriteChan == ch { + s.cacheWriteChan = nil + } + s.cacheWriteMu.Unlock() }) } func (s *BillingCacheService) startCacheWriteWorkers() { - s.cacheWriteChan = make(chan cacheWriteTask, cacheWriteBufferSize) + ch := make(chan cacheWriteTask, cacheWriteBufferSize) + s.cacheWriteChan = ch for i := 0; i < cacheWriteWorkerCount; i++ { s.cacheWriteWg.Add(1) - go s.cacheWriteWorker() + go s.cacheWriteWorker(ch) } } // enqueueCacheWrite 尝试将任务入队,队列满时返回 false(并记录告警)。 func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued bool) { - if s.cacheWriteChan == nil { + if s.stopped.Load() { + s.logCacheWriteDrop(task, "closed") return false } - defer func() { - if recovered := recover(); recovered != nil { - // 队列已关闭时可能触发 panic,记录后静默失败。 - s.logCacheWriteDrop(task, "closed") - enqueued = false - } - }() + + s.cacheWriteMu.RLock() + defer s.cacheWriteMu.RUnlock() + + if s.cacheWriteChan == nil { + s.logCacheWriteDrop(task, "closed") + return false + } + select { case s.cacheWriteChan <- task: return true @@ -144,9 +167,9 @@ func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued b } } -func (s *BillingCacheService) cacheWriteWorker() { +func (s *BillingCacheService) cacheWriteWorker(ch <-chan cacheWriteTask) { defer s.cacheWriteWg.Done() - for task := range s.cacheWriteChan { + for task := range ch { ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout) switch task.kind { case cacheWriteSetBalance: @@ -243,19 +266,31 @@ func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64) return balance, nil } - // 缓存未命中,从数据库读取 - balance, err = s.getUserBalanceFromDB(ctx, userID) + // 缓存未命中:singleflight 合并同一 userID 的并发回源请求。 + value, err, _ := s.balanceLoadSF.Do(strconv.FormatInt(userID, 10), func() (any, error) { + loadCtx, cancel := context.WithTimeout(context.Background(), balanceLoadTimeout) + defer cancel() + + balance, err := s.getUserBalanceFromDB(loadCtx, userID) + if err != nil { + return nil, err + } + + // 异步建立缓存 + _ = s.enqueueCacheWrite(cacheWriteTask{ + kind: cacheWriteSetBalance, + userID: userID, + balance: balance, + }) + return balance, nil + }) if err != nil { return 0, err } - - // 异步建立缓存 - _ = s.enqueueCacheWrite(cacheWriteTask{ - kind: cacheWriteSetBalance, - userID: userID, - balance: balance, - }) - + balance, ok := value.(float64) + if !ok { + return 0, fmt.Errorf("unexpected balance type: %T", value) + } return balance, nil } diff --git a/backend/internal/service/billing_cache_service_singleflight_test.go b/backend/internal/service/billing_cache_service_singleflight_test.go new file mode 100644 index 00000000..1b12c402 --- /dev/null +++ b/backend/internal/service/billing_cache_service_singleflight_test.go @@ -0,0 +1,115 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +type billingCacheMissStub struct { + setBalanceCalls atomic.Int64 +} + +func (s *billingCacheMissStub) GetUserBalance(ctx context.Context, userID int64) (float64, error) { + return 0, errors.New("cache miss") +} + +func (s *billingCacheMissStub) SetUserBalance(ctx context.Context, userID int64, balance float64) error { + s.setBalanceCalls.Add(1) + return nil +} + +func (s *billingCacheMissStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error { + return nil +} + +func (s *billingCacheMissStub) InvalidateUserBalance(ctx context.Context, userID int64) error { + return nil +} + +func (s *billingCacheMissStub) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) { + return nil, errors.New("cache miss") +} + +func (s *billingCacheMissStub) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error { + return nil +} + +func (s *billingCacheMissStub) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error { + return nil +} + +func (s *billingCacheMissStub) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error { + return nil +} + +type balanceLoadUserRepoStub struct { + mockUserRepo + calls atomic.Int64 + delay time.Duration + balance float64 +} + +func (s *balanceLoadUserRepoStub) GetByID(ctx context.Context, id int64) (*User, error) { + s.calls.Add(1) + if s.delay > 0 { + select { + case <-time.After(s.delay): + case <-ctx.Done(): + return nil, ctx.Err() + } + } + return &User{ID: id, Balance: s.balance}, nil +} + +func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) { + cache := &billingCacheMissStub{} + userRepo := &balanceLoadUserRepoStub{ + delay: 80 * time.Millisecond, + balance: 12.34, + } + svc := NewBillingCacheService(cache, userRepo, nil, &config.Config{}) + t.Cleanup(svc.Stop) + + const goroutines = 16 + start := make(chan struct{}) + var wg sync.WaitGroup + errCh := make(chan error, goroutines) + balCh := make(chan float64, goroutines) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-start + bal, err := svc.GetUserBalance(context.Background(), 99) + errCh <- err + balCh <- bal + }() + } + + close(start) + wg.Wait() + close(errCh) + close(balCh) + + for err := range errCh { + require.NoError(t, err) + } + for bal := range balCh { + require.Equal(t, 12.34, bal) + } + + require.Equal(t, int64(1), userRepo.calls.Load(), "并发穿透应被 singleflight 合并") + require.Eventually(t, func() bool { + return cache.setBalanceCalls.Load() >= 1 + }, time.Second, 10*time.Millisecond) +} diff --git a/backend/internal/service/billing_cache_service_test.go b/backend/internal/service/billing_cache_service_test.go index 445d5319..4e5f50e2 100644 --- a/backend/internal/service/billing_cache_service_test.go +++ b/backend/internal/service/billing_cache_service_test.go @@ -73,3 +73,16 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) { return atomic.LoadInt64(&cache.subscriptionUpdates) > 0 }, 2*time.Second, 10*time.Millisecond) } + +func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) { + cache := &billingCacheWorkerStub{} + svc := NewBillingCacheService(cache, nil, nil, &config.Config{}) + svc.Stop() + + enqueued := svc.enqueueCacheWrite(cacheWriteTask{ + kind: cacheWriteDeductBalance, + userID: 1, + amount: 1, + }) + require.False(t, enqueued) +} diff --git a/backend/internal/service/billing_service_image_test.go b/backend/internal/service/billing_service_image_test.go index 59125814..fa90f6bb 100644 --- a/backend/internal/service/billing_service_image_test.go +++ b/backend/internal/service/billing_service_image_test.go @@ -63,7 +63,7 @@ func TestCalculateImageCost_RateMultiplier(t *testing.T) { // 费率倍数 1.5x cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.5) - require.InDelta(t, 0.201, cost.TotalCost, 0.0001) // TotalCost = 0.134 * 1.5 + require.InDelta(t, 0.201, cost.TotalCost, 0.0001) // TotalCost = 0.134 * 1.5 require.InDelta(t, 0.3015, cost.ActualCost, 0.0001) // ActualCost = 0.201 * 1.5 // 费率倍数 2.0x diff --git a/backend/internal/service/claude_code_validator.go b/backend/internal/service/claude_code_validator.go index 6d06c83e..d3a4d119 100644 --- a/backend/internal/service/claude_code_validator.go +++ b/backend/internal/service/claude_code_validator.go @@ -78,7 +78,7 @@ func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) boo // Step 3: 检查 max_tokens=1 + haiku 探测请求绕过 // 这类请求用于 Claude Code 验证 API 连通性,不携带 system prompt - if isMaxTokensOneHaiku, ok := r.Context().Value(ctxkey.IsMaxTokensOneHaikuRequest).(bool); ok && isMaxTokensOneHaiku { + if isMaxTokensOneHaiku, ok := IsMaxTokensOneHaikuRequestFromContext(r.Context()); ok && isMaxTokensOneHaiku { return true // 绕过 system prompt 检查,UA 已在 Step 1 验证 } diff --git a/backend/internal/service/concurrency_service.go b/backend/internal/service/concurrency_service.go index 32b6d97c..4dcf84e0 100644 --- a/backend/internal/service/concurrency_service.go +++ b/backend/internal/service/concurrency_service.go @@ -3,8 +3,10 @@ package service import ( "context" "crypto/rand" - "encoding/hex" - "fmt" + "encoding/binary" + "os" + "strconv" + "sync/atomic" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" @@ -18,6 +20,7 @@ type ConcurrencyCache interface { AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) + GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) // 账号等待队列(账号级) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) @@ -42,15 +45,25 @@ type ConcurrencyCache interface { CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error } -// generateRequestID generates a unique request ID for concurrency slot tracking -// Uses 8 random bytes (16 hex chars) for uniqueness -func generateRequestID() string { +var ( + requestIDPrefix = initRequestIDPrefix() + requestIDCounter atomic.Uint64 +) + +func initRequestIDPrefix() string { b := make([]byte, 8) - if _, err := rand.Read(b); err != nil { - // Fallback to nanosecond timestamp (extremely rare case) - return fmt.Sprintf("%x", time.Now().UnixNano()) + if _, err := rand.Read(b); err == nil { + return "r" + strconv.FormatUint(binary.BigEndian.Uint64(b), 36) } - return hex.EncodeToString(b) + fallback := uint64(time.Now().UnixNano()) ^ (uint64(os.Getpid()) << 16) + return "r" + strconv.FormatUint(fallback, 36) +} + +// generateRequestID generates a unique request ID for concurrency slot tracking. +// Format: {process_random_prefix}-{base36_counter} +func generateRequestID() string { + seq := requestIDCounter.Add(1) + return requestIDPrefix + "-" + strconv.FormatUint(seq, 36) } const ( @@ -321,16 +334,15 @@ func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepositor // GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts // Returns a map of accountID -> current concurrency count func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { - result := make(map[int64]int) - - for _, accountID := range accountIDs { - count, err := s.cache.GetAccountConcurrency(ctx, accountID) - if err != nil { - // If key doesn't exist in Redis, count is 0 - count = 0 - } - result[accountID] = count + if len(accountIDs) == 0 { + return map[int64]int{}, nil } - - return result, nil + if s.cache == nil { + result := make(map[int64]int, len(accountIDs)) + for _, accountID := range accountIDs { + result[accountID] = 0 + } + return result, nil + } + return s.cache.GetAccountConcurrencyBatch(ctx, accountIDs) } diff --git a/backend/internal/service/concurrency_service_test.go b/backend/internal/service/concurrency_service_test.go index 33ce4cb9..9ba43d93 100644 --- a/backend/internal/service/concurrency_service_test.go +++ b/backend/internal/service/concurrency_service_test.go @@ -5,6 +5,8 @@ package service import ( "context" "errors" + "strconv" + "strings" "testing" "github.com/stretchr/testify/require" @@ -12,20 +14,20 @@ import ( // stubConcurrencyCacheForTest 用于并发服务单元测试的缓存桩 type stubConcurrencyCacheForTest struct { - acquireResult bool - acquireErr error - releaseErr error - concurrency int + acquireResult bool + acquireErr error + releaseErr error + concurrency int concurrencyErr error - waitAllowed bool - waitErr error - waitCount int - waitCountErr error - loadBatch map[int64]*AccountLoadInfo - loadBatchErr error + waitAllowed bool + waitErr error + waitCount int + waitCountErr error + loadBatch map[int64]*AccountLoadInfo + loadBatchErr error usersLoadBatch map[int64]*UserLoadInfo usersLoadErr error - cleanupErr error + cleanupErr error // 记录调用 releasedAccountIDs []int64 @@ -45,6 +47,16 @@ func (c *stubConcurrencyCacheForTest) ReleaseAccountSlot(_ context.Context, acco func (c *stubConcurrencyCacheForTest) GetAccountConcurrency(_ context.Context, _ int64) (int, error) { return c.concurrency, c.concurrencyErr } +func (c *stubConcurrencyCacheForTest) GetAccountConcurrencyBatch(_ context.Context, accountIDs []int64) (map[int64]int, error) { + result := make(map[int64]int, len(accountIDs)) + for _, accountID := range accountIDs { + if c.concurrencyErr != nil { + return nil, c.concurrencyErr + } + result[accountID] = c.concurrency + } + return result, nil +} func (c *stubConcurrencyCacheForTest) IncrementAccountWaitCount(_ context.Context, _ int64, _ int) (bool, error) { return c.waitAllowed, c.waitErr } @@ -155,6 +167,25 @@ func TestAcquireUserSlot_UnlimitedConcurrency(t *testing.T) { require.True(t, result.Acquired) } +func TestGenerateRequestID_UsesStablePrefixAndMonotonicCounter(t *testing.T) { + id1 := generateRequestID() + id2 := generateRequestID() + require.NotEmpty(t, id1) + require.NotEmpty(t, id2) + + p1 := strings.Split(id1, "-") + p2 := strings.Split(id2, "-") + require.Len(t, p1, 2) + require.Len(t, p2, 2) + require.Equal(t, p1[0], p2[0], "同一进程前缀应保持一致") + + n1, err := strconv.ParseUint(p1[1], 36, 64) + require.NoError(t, err) + n2, err := strconv.ParseUint(p2[1], 36, 64) + require.NoError(t, err) + require.Equal(t, n1+1, n2, "计数器应单调递增") +} + func TestGetAccountsLoadBatch_ReturnsCorrectData(t *testing.T) { expected := map[int64]*AccountLoadInfo{ 1: {AccountID: 1, CurrentConcurrency: 3, WaitingCount: 0, LoadRate: 60}, diff --git a/backend/internal/service/dashboard_service.go b/backend/internal/service/dashboard_service.go index 9aab10d2..4528def3 100644 --- a/backend/internal/service/dashboard_service.go +++ b/backend/internal/service/dashboard_service.go @@ -124,16 +124,16 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D return stats, nil } -func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) { - trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType) +func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) { + trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType) if err != nil { return nil, fmt.Errorf("get usage trend with filters: %w", err) } return trend, nil } -func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) { - stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType) +func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) { + stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType) if err != nil { return nil, fmt.Errorf("get model stats with filters: %w", err) } diff --git a/backend/internal/service/data_management_grpc.go b/backend/internal/service/data_management_grpc.go new file mode 100644 index 00000000..aeb3d529 --- /dev/null +++ b/backend/internal/service/data_management_grpc.go @@ -0,0 +1,252 @@ +package service + +import "context" + +type DataManagementPostgresConfig struct { + Host string `json:"host"` + Port int32 `json:"port"` + User string `json:"user"` + Password string `json:"password,omitempty"` + PasswordConfigured bool `json:"password_configured"` + Database string `json:"database"` + SSLMode string `json:"ssl_mode"` + ContainerName string `json:"container_name"` +} + +type DataManagementRedisConfig struct { + Addr string `json:"addr"` + Username string `json:"username"` + Password string `json:"password,omitempty"` + PasswordConfigured bool `json:"password_configured"` + DB int32 `json:"db"` + ContainerName string `json:"container_name"` +} + +type DataManagementS3Config struct { + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key,omitempty"` + SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + UseSSL bool `json:"use_ssl"` +} + +type DataManagementConfig struct { + SourceMode string `json:"source_mode"` + BackupRoot string `json:"backup_root"` + SQLitePath string `json:"sqlite_path,omitempty"` + RetentionDays int32 `json:"retention_days"` + KeepLast int32 `json:"keep_last"` + ActivePostgresID string `json:"active_postgres_profile_id"` + ActiveRedisID string `json:"active_redis_profile_id"` + Postgres DataManagementPostgresConfig `json:"postgres"` + Redis DataManagementRedisConfig `json:"redis"` + S3 DataManagementS3Config `json:"s3"` + ActiveS3ProfileID string `json:"active_s3_profile_id"` +} + +type DataManagementTestS3Result struct { + OK bool `json:"ok"` + Message string `json:"message"` +} + +type DataManagementCreateBackupJobInput struct { + BackupType string + UploadToS3 bool + TriggeredBy string + IdempotencyKey string + S3ProfileID string + PostgresID string + RedisID string +} + +type DataManagementListBackupJobsInput struct { + PageSize int32 + PageToken string + Status string + BackupType string +} + +type DataManagementArtifactInfo struct { + LocalPath string `json:"local_path"` + SizeBytes int64 `json:"size_bytes"` + SHA256 string `json:"sha256"` +} + +type DataManagementS3ObjectInfo struct { + Bucket string `json:"bucket"` + Key string `json:"key"` + ETag string `json:"etag"` +} + +type DataManagementBackupJob struct { + JobID string `json:"job_id"` + BackupType string `json:"backup_type"` + Status string `json:"status"` + TriggeredBy string `json:"triggered_by"` + IdempotencyKey string `json:"idempotency_key,omitempty"` + UploadToS3 bool `json:"upload_to_s3"` + S3ProfileID string `json:"s3_profile_id,omitempty"` + PostgresID string `json:"postgres_profile_id,omitempty"` + RedisID string `json:"redis_profile_id,omitempty"` + StartedAt string `json:"started_at,omitempty"` + FinishedAt string `json:"finished_at,omitempty"` + ErrorMessage string `json:"error_message,omitempty"` + Artifact DataManagementArtifactInfo `json:"artifact"` + S3Object DataManagementS3ObjectInfo `json:"s3"` +} + +type DataManagementSourceProfile struct { + SourceType string `json:"source_type"` + ProfileID string `json:"profile_id"` + Name string `json:"name"` + IsActive bool `json:"is_active"` + Config DataManagementSourceConfig `json:"config"` + PasswordConfigured bool `json:"password_configured"` + CreatedAt string `json:"created_at,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` +} + +type DataManagementSourceConfig struct { + Host string `json:"host"` + Port int32 `json:"port"` + User string `json:"user"` + Password string `json:"password,omitempty"` + Database string `json:"database"` + SSLMode string `json:"ssl_mode"` + Addr string `json:"addr"` + Username string `json:"username"` + DB int32 `json:"db"` + ContainerName string `json:"container_name"` +} + +type DataManagementCreateSourceProfileInput struct { + SourceType string + ProfileID string + Name string + Config DataManagementSourceConfig + SetActive bool +} + +type DataManagementUpdateSourceProfileInput struct { + SourceType string + ProfileID string + Name string + Config DataManagementSourceConfig +} + +type DataManagementS3Profile struct { + ProfileID string `json:"profile_id"` + Name string `json:"name"` + IsActive bool `json:"is_active"` + S3 DataManagementS3Config `json:"s3"` + SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` + CreatedAt string `json:"created_at,omitempty"` + UpdatedAt string `json:"updated_at,omitempty"` +} + +type DataManagementCreateS3ProfileInput struct { + ProfileID string + Name string + S3 DataManagementS3Config + SetActive bool +} + +type DataManagementUpdateS3ProfileInput struct { + ProfileID string + Name string + S3 DataManagementS3Config +} + +type DataManagementListBackupJobsResult struct { + Items []DataManagementBackupJob `json:"items"` + NextPageToken string `json:"next_page_token,omitempty"` +} + +func (s *DataManagementService) GetConfig(ctx context.Context) (DataManagementConfig, error) { + _ = ctx + return DataManagementConfig{}, s.deprecatedError() +} + +func (s *DataManagementService) UpdateConfig(ctx context.Context, cfg DataManagementConfig) (DataManagementConfig, error) { + _, _ = ctx, cfg + return DataManagementConfig{}, s.deprecatedError() +} + +func (s *DataManagementService) ListSourceProfiles(ctx context.Context, sourceType string) ([]DataManagementSourceProfile, error) { + _, _ = ctx, sourceType + return nil, s.deprecatedError() +} + +func (s *DataManagementService) CreateSourceProfile(ctx context.Context, input DataManagementCreateSourceProfileInput) (DataManagementSourceProfile, error) { + _, _ = ctx, input + return DataManagementSourceProfile{}, s.deprecatedError() +} + +func (s *DataManagementService) UpdateSourceProfile(ctx context.Context, input DataManagementUpdateSourceProfileInput) (DataManagementSourceProfile, error) { + _, _ = ctx, input + return DataManagementSourceProfile{}, s.deprecatedError() +} + +func (s *DataManagementService) DeleteSourceProfile(ctx context.Context, sourceType, profileID string) error { + _, _, _ = ctx, sourceType, profileID + return s.deprecatedError() +} + +func (s *DataManagementService) SetActiveSourceProfile(ctx context.Context, sourceType, profileID string) (DataManagementSourceProfile, error) { + _, _, _ = ctx, sourceType, profileID + return DataManagementSourceProfile{}, s.deprecatedError() +} + +func (s *DataManagementService) ValidateS3(ctx context.Context, cfg DataManagementS3Config) (DataManagementTestS3Result, error) { + _, _ = ctx, cfg + return DataManagementTestS3Result{}, s.deprecatedError() +} + +func (s *DataManagementService) ListS3Profiles(ctx context.Context) ([]DataManagementS3Profile, error) { + _ = ctx + return nil, s.deprecatedError() +} + +func (s *DataManagementService) CreateS3Profile(ctx context.Context, input DataManagementCreateS3ProfileInput) (DataManagementS3Profile, error) { + _, _ = ctx, input + return DataManagementS3Profile{}, s.deprecatedError() +} + +func (s *DataManagementService) UpdateS3Profile(ctx context.Context, input DataManagementUpdateS3ProfileInput) (DataManagementS3Profile, error) { + _, _ = ctx, input + return DataManagementS3Profile{}, s.deprecatedError() +} + +func (s *DataManagementService) DeleteS3Profile(ctx context.Context, profileID string) error { + _, _ = ctx, profileID + return s.deprecatedError() +} + +func (s *DataManagementService) SetActiveS3Profile(ctx context.Context, profileID string) (DataManagementS3Profile, error) { + _, _ = ctx, profileID + return DataManagementS3Profile{}, s.deprecatedError() +} + +func (s *DataManagementService) CreateBackupJob(ctx context.Context, input DataManagementCreateBackupJobInput) (DataManagementBackupJob, error) { + _, _ = ctx, input + return DataManagementBackupJob{}, s.deprecatedError() +} + +func (s *DataManagementService) ListBackupJobs(ctx context.Context, input DataManagementListBackupJobsInput) (DataManagementListBackupJobsResult, error) { + _, _ = ctx, input + return DataManagementListBackupJobsResult{}, s.deprecatedError() +} + +func (s *DataManagementService) GetBackupJob(ctx context.Context, jobID string) (DataManagementBackupJob, error) { + _, _ = ctx, jobID + return DataManagementBackupJob{}, s.deprecatedError() +} + +func (s *DataManagementService) deprecatedError() error { + return ErrDataManagementDeprecated.WithMetadata(map[string]string{"socket_path": s.SocketPath()}) +} diff --git a/backend/internal/service/data_management_grpc_test.go b/backend/internal/service/data_management_grpc_test.go new file mode 100644 index 00000000..286eb58d --- /dev/null +++ b/backend/internal/service/data_management_grpc_test.go @@ -0,0 +1,36 @@ +package service + +import ( + "context" + "path/filepath" + "testing" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/stretchr/testify/require" +) + +func TestDataManagementService_DeprecatedRPCMethods(t *testing.T) { + t.Parallel() + + socketPath := filepath.Join(t.TempDir(), "datamanagement.sock") + svc := NewDataManagementServiceWithOptions(socketPath, 0) + + _, err := svc.GetConfig(context.Background()) + assertDeprecatedDataManagementError(t, err, socketPath) + + _, err = svc.CreateBackupJob(context.Background(), DataManagementCreateBackupJobInput{BackupType: "full"}) + assertDeprecatedDataManagementError(t, err, socketPath) + + err = svc.DeleteS3Profile(context.Background(), "s3-default") + assertDeprecatedDataManagementError(t, err, socketPath) +} + +func assertDeprecatedDataManagementError(t *testing.T, err error, socketPath string) { + t.Helper() + + require.Error(t, err) + statusCode, status := infraerrors.ToHTTP(err) + require.Equal(t, 503, statusCode) + require.Equal(t, DataManagementDeprecatedReason, status.Reason) + require.Equal(t, socketPath, status.Metadata["socket_path"]) +} diff --git a/backend/internal/service/data_management_service.go b/backend/internal/service/data_management_service.go new file mode 100644 index 00000000..83e939f4 --- /dev/null +++ b/backend/internal/service/data_management_service.go @@ -0,0 +1,99 @@ +package service + +import ( + "context" + "strings" + "time" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" +) + +const ( + DefaultDataManagementAgentSocketPath = "/tmp/sub2api-datamanagement.sock" + LegacyBackupAgentSocketPath = "/tmp/sub2api-backup.sock" + + DataManagementDeprecatedReason = "DATA_MANAGEMENT_DEPRECATED" + DataManagementAgentSocketMissingReason = "DATA_MANAGEMENT_AGENT_SOCKET_MISSING" + DataManagementAgentUnavailableReason = "DATA_MANAGEMENT_AGENT_UNAVAILABLE" + + // Deprecated: keep old names for compatibility. + DefaultBackupAgentSocketPath = DefaultDataManagementAgentSocketPath + BackupAgentSocketMissingReason = DataManagementAgentSocketMissingReason + BackupAgentUnavailableReason = DataManagementAgentUnavailableReason +) + +var ( + ErrDataManagementDeprecated = infraerrors.ServiceUnavailable( + DataManagementDeprecatedReason, + "data management feature is deprecated", + ) + ErrDataManagementAgentSocketMissing = infraerrors.ServiceUnavailable( + DataManagementAgentSocketMissingReason, + "data management agent socket is missing", + ) + ErrDataManagementAgentUnavailable = infraerrors.ServiceUnavailable( + DataManagementAgentUnavailableReason, + "data management agent is unavailable", + ) + + // Deprecated: keep old names for compatibility. + ErrBackupAgentSocketMissing = ErrDataManagementAgentSocketMissing + ErrBackupAgentUnavailable = ErrDataManagementAgentUnavailable +) + +type DataManagementAgentHealth struct { + Enabled bool + Reason string + SocketPath string + Agent *DataManagementAgentInfo +} + +type DataManagementAgentInfo struct { + Status string + Version string + UptimeSeconds int64 +} + +type DataManagementService struct { + socketPath string + dialTimeout time.Duration +} + +func NewDataManagementService() *DataManagementService { + return NewDataManagementServiceWithOptions(DefaultDataManagementAgentSocketPath, 500*time.Millisecond) +} + +func NewDataManagementServiceWithOptions(socketPath string, dialTimeout time.Duration) *DataManagementService { + path := strings.TrimSpace(socketPath) + if path == "" { + path = DefaultDataManagementAgentSocketPath + } + if dialTimeout <= 0 { + dialTimeout = 500 * time.Millisecond + } + return &DataManagementService{ + socketPath: path, + dialTimeout: dialTimeout, + } +} + +func (s *DataManagementService) SocketPath() string { + if s == nil || strings.TrimSpace(s.socketPath) == "" { + return DefaultDataManagementAgentSocketPath + } + return s.socketPath +} + +func (s *DataManagementService) GetAgentHealth(ctx context.Context) DataManagementAgentHealth { + _ = ctx + return DataManagementAgentHealth{ + Enabled: false, + Reason: DataManagementDeprecatedReason, + SocketPath: s.SocketPath(), + } +} + +func (s *DataManagementService) EnsureAgentEnabled(ctx context.Context) error { + _ = ctx + return ErrDataManagementDeprecated.WithMetadata(map[string]string{"socket_path": s.SocketPath()}) +} diff --git a/backend/internal/service/data_management_service_test.go b/backend/internal/service/data_management_service_test.go new file mode 100644 index 00000000..65489d2e --- /dev/null +++ b/backend/internal/service/data_management_service_test.go @@ -0,0 +1,37 @@ +package service + +import ( + "context" + "path/filepath" + "testing" + + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/stretchr/testify/require" +) + +func TestDataManagementService_GetAgentHealth_Deprecated(t *testing.T) { + t.Parallel() + + socketPath := filepath.Join(t.TempDir(), "unused.sock") + svc := NewDataManagementServiceWithOptions(socketPath, 0) + health := svc.GetAgentHealth(context.Background()) + + require.False(t, health.Enabled) + require.Equal(t, DataManagementDeprecatedReason, health.Reason) + require.Equal(t, socketPath, health.SocketPath) + require.Nil(t, health.Agent) +} + +func TestDataManagementService_EnsureAgentEnabled_Deprecated(t *testing.T) { + t.Parallel() + + socketPath := filepath.Join(t.TempDir(), "unused.sock") + svc := NewDataManagementServiceWithOptions(socketPath, 100) + err := svc.EnsureAgentEnabled(context.Background()) + require.Error(t, err) + + statusCode, status := infraerrors.ToHTTP(err) + require.Equal(t, 503, statusCode) + require.Equal(t, DataManagementDeprecatedReason, status.Reason) + require.Equal(t, socketPath, status.Metadata["socket_path"]) +} diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index ceae443f..cc1a2721 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -104,6 +104,7 @@ const ( SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url" // OEM设置 + SettingKeySoraClientEnabled = "sora_client_enabled" // 是否启用 Sora 客户端(管理员手动控制) SettingKeySiteName = "site_name" // 网站名称 SettingKeySiteLogo = "site_logo" // 网站Logo (base64) SettingKeySiteSubtitle = "site_subtitle" // 网站副标题 @@ -170,6 +171,27 @@ const ( // SettingKeyStreamTimeoutSettings stores JSON config for stream timeout handling. SettingKeyStreamTimeoutSettings = "stream_timeout_settings" + + // ========================= + // Sora S3 存储配置 + // ========================= + + SettingKeySoraS3Enabled = "sora_s3_enabled" // 是否启用 Sora S3 存储 + SettingKeySoraS3Endpoint = "sora_s3_endpoint" // S3 端点地址 + SettingKeySoraS3Region = "sora_s3_region" // S3 区域 + SettingKeySoraS3Bucket = "sora_s3_bucket" // S3 存储桶名称 + SettingKeySoraS3AccessKeyID = "sora_s3_access_key_id" // S3 Access Key ID + SettingKeySoraS3SecretAccessKey = "sora_s3_secret_access_key" // S3 Secret Access Key(加密存储) + SettingKeySoraS3Prefix = "sora_s3_prefix" // S3 对象键前缀 + SettingKeySoraS3ForcePathStyle = "sora_s3_force_path_style" // 是否强制 Path Style(兼容 MinIO 等) + SettingKeySoraS3CDNURL = "sora_s3_cdn_url" // CDN 加速 URL(可选) + SettingKeySoraS3Profiles = "sora_s3_profiles" // Sora S3 多配置(JSON) + + // ========================= + // Sora 用户存储配额 + // ========================= + + SettingKeySoraDefaultStorageQuotaBytes = "sora_default_storage_quota_bytes" // 新用户默认 Sora 存储配额(字节) ) // AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys). diff --git a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go index e3dff6b8..f8c0ecda 100644 --- a/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go +++ b/backend/internal/service/gateway_anthropic_apikey_passthrough_test.go @@ -279,10 +279,10 @@ func TestGatewayService_AnthropicAPIKeyPassthrough_CountTokens404PassthroughNotE wantPassthrough: true, }, { - name: "404 generic not found passes through as 404", + name: "404 generic not found does not passthrough", statusCode: http.StatusNotFound, respBody: `{"error":{"message":"resource not found","type":"not_found_error"}}`, - wantPassthrough: true, + wantPassthrough: false, }, { name: "400 Invalid URL does not passthrough", diff --git a/backend/internal/service/gateway_beta_test.go b/backend/internal/service/gateway_beta_test.go index c682e286..21a1faa4 100644 --- a/backend/internal/service/gateway_beta_test.go +++ b/backend/internal/service/gateway_beta_test.go @@ -136,3 +136,67 @@ func TestDroppedBetaSet(t *testing.T) { require.Contains(t, extended, claude.BetaClaudeCode) require.Len(t, extended, len(claude.DroppedBetas)+1) } + +func TestBuildBetaTokenSet(t *testing.T) { + got := buildBetaTokenSet([]string{"foo", "", "bar", "foo"}) + require.Len(t, got, 2) + require.Contains(t, got, "foo") + require.Contains(t, got, "bar") + require.NotContains(t, got, "") + + empty := buildBetaTokenSet(nil) + require.Empty(t, empty) +} + +func TestStripBetaTokensWithSet_EmptyDropSet(t *testing.T) { + header := "oauth-2025-04-20,interleaved-thinking-2025-05-14" + got := stripBetaTokensWithSet(header, map[string]struct{}{}) + require.Equal(t, header, got) +} + +func TestIsCountTokensUnsupported404(t *testing.T) { + tests := []struct { + name string + statusCode int + body string + want bool + }{ + { + name: "exact endpoint not found", + statusCode: 404, + body: `{"error":{"message":"Not found: /v1/messages/count_tokens","type":"not_found_error"}}`, + want: true, + }, + { + name: "contains count_tokens and not found", + statusCode: 404, + body: `{"error":{"message":"count_tokens route not found","type":"not_found_error"}}`, + want: true, + }, + { + name: "generic 404", + statusCode: 404, + body: `{"error":{"message":"resource not found","type":"not_found_error"}}`, + want: false, + }, + { + name: "404 with empty error message", + statusCode: 404, + body: `{"error":{"message":"","type":"not_found_error"}}`, + want: false, + }, + { + name: "non-404 status", + statusCode: 400, + body: `{"error":{"message":"Not found: /v1/messages/count_tokens","type":"invalid_request_error"}}`, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isCountTokensUnsupported404(tt.statusCode, []byte(tt.body)) + require.Equal(t, tt.want, got) + }) + } +} diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 5055eec0..067a0e08 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -1892,6 +1892,14 @@ func (m *mockConcurrencyCache) GetAccountConcurrency(ctx context.Context, accoun return 0, nil } +func (m *mockConcurrencyCache) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { + result := make(map[int64]int, len(accountIDs)) + for _, accountID := range accountIDs { + result[accountID] = 0 + } + return result, nil +} + func (m *mockConcurrencyCache) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { return true, nil } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 3fabead0..0ba9e093 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -127,13 +127,26 @@ func WithForceCacheBilling(ctx context.Context) context.Context { } func (s *GatewayService) debugModelRoutingEnabled() bool { - v := strings.ToLower(strings.TrimSpace(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING"))) - return v == "1" || v == "true" || v == "yes" || v == "on" + if s == nil { + return false + } + return s.debugModelRouting.Load() } func (s *GatewayService) debugClaudeMimicEnabled() bool { - v := strings.ToLower(strings.TrimSpace(os.Getenv("SUB2API_DEBUG_CLAUDE_MIMIC"))) - return v == "1" || v == "true" || v == "yes" || v == "on" + if s == nil { + return false + } + return s.debugClaudeMimic.Load() +} + +func parseDebugEnvBool(raw string) bool { + switch strings.ToLower(strings.TrimSpace(raw)) { + case "1", "true", "yes", "on": + return true + default: + return false + } } func shortSessionHash(sessionHash string) string { @@ -374,37 +387,16 @@ func modelsListCacheKey(groupID *int64, platform string) string { } func prefetchedStickyGroupIDFromContext(ctx context.Context) (int64, bool) { - if ctx == nil { - return 0, false - } - v := ctx.Value(ctxkey.PrefetchedStickyGroupID) - switch t := v.(type) { - case int64: - return t, true - case int: - return int64(t), true - } - return 0, false + return PrefetchedStickyGroupIDFromContext(ctx) } func prefetchedStickyAccountIDFromContext(ctx context.Context, groupID *int64) int64 { - if ctx == nil { - return 0 - } prefetchedGroupID, ok := prefetchedStickyGroupIDFromContext(ctx) if !ok || prefetchedGroupID != derefGroupID(groupID) { return 0 } - v := ctx.Value(ctxkey.PrefetchedStickyAccountID) - switch t := v.(type) { - case int64: - if t > 0 { - return t - } - case int: - if t > 0 { - return int64(t) - } + if accountID, ok := PrefetchedStickyAccountIDFromContext(ctx); ok && accountID > 0 { + return accountID } return 0 } @@ -509,29 +501,32 @@ func (s *GatewayService) TempUnscheduleRetryableError(ctx context.Context, accou // GatewayService handles API gateway operations type GatewayService struct { - accountRepo AccountRepository - groupRepo GroupRepository - usageLogRepo UsageLogRepository - userRepo UserRepository - userSubRepo UserSubscriptionRepository - userGroupRateRepo UserGroupRateRepository - cache GatewayCache - digestStore *DigestSessionStore - cfg *config.Config - schedulerSnapshot *SchedulerSnapshotService - billingService *BillingService - rateLimitService *RateLimitService - billingCacheService *BillingCacheService - identityService *IdentityService - httpUpstream HTTPUpstream - deferredService *DeferredService - concurrencyService *ConcurrencyService - claudeTokenProvider *ClaudeTokenProvider - sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken) - userGroupRateCache *gocache.Cache - userGroupRateSF singleflight.Group - modelsListCache *gocache.Cache - modelsListCacheTTL time.Duration + accountRepo AccountRepository + groupRepo GroupRepository + usageLogRepo UsageLogRepository + userRepo UserRepository + userSubRepo UserSubscriptionRepository + userGroupRateRepo UserGroupRateRepository + cache GatewayCache + digestStore *DigestSessionStore + cfg *config.Config + schedulerSnapshot *SchedulerSnapshotService + billingService *BillingService + rateLimitService *RateLimitService + billingCacheService *BillingCacheService + identityService *IdentityService + httpUpstream HTTPUpstream + deferredService *DeferredService + concurrencyService *ConcurrencyService + claudeTokenProvider *ClaudeTokenProvider + sessionLimitCache SessionLimitCache // 会话数量限制缓存(仅 Anthropic OAuth/SetupToken) + userGroupRateCache *gocache.Cache + userGroupRateSF singleflight.Group + modelsListCache *gocache.Cache + modelsListCacheTTL time.Duration + responseHeaderFilter *responseheaders.CompiledHeaderFilter + debugModelRouting atomic.Bool + debugClaudeMimic atomic.Bool } // NewGatewayService creates a new GatewayService @@ -559,30 +554,34 @@ func NewGatewayService( userGroupRateTTL := resolveUserGroupRateCacheTTL(cfg) modelsListTTL := resolveModelsListCacheTTL(cfg) - return &GatewayService{ - accountRepo: accountRepo, - groupRepo: groupRepo, - usageLogRepo: usageLogRepo, - userRepo: userRepo, - userSubRepo: userSubRepo, - userGroupRateRepo: userGroupRateRepo, - cache: cache, - digestStore: digestStore, - cfg: cfg, - schedulerSnapshot: schedulerSnapshot, - concurrencyService: concurrencyService, - billingService: billingService, - rateLimitService: rateLimitService, - billingCacheService: billingCacheService, - identityService: identityService, - httpUpstream: httpUpstream, - deferredService: deferredService, - claudeTokenProvider: claudeTokenProvider, - sessionLimitCache: sessionLimitCache, - userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute), - modelsListCache: gocache.New(modelsListTTL, time.Minute), - modelsListCacheTTL: modelsListTTL, + svc := &GatewayService{ + accountRepo: accountRepo, + groupRepo: groupRepo, + usageLogRepo: usageLogRepo, + userRepo: userRepo, + userSubRepo: userSubRepo, + userGroupRateRepo: userGroupRateRepo, + cache: cache, + digestStore: digestStore, + cfg: cfg, + schedulerSnapshot: schedulerSnapshot, + concurrencyService: concurrencyService, + billingService: billingService, + rateLimitService: rateLimitService, + billingCacheService: billingCacheService, + identityService: identityService, + httpUpstream: httpUpstream, + deferredService: deferredService, + claudeTokenProvider: claudeTokenProvider, + sessionLimitCache: sessionLimitCache, + userGroupRateCache: gocache.New(userGroupRateTTL, time.Minute), + modelsListCache: gocache.New(modelsListTTL, time.Minute), + modelsListCacheTTL: modelsListTTL, + responseHeaderFilter: compileResponseHeaderFilter(cfg), } + svc.debugModelRouting.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING"))) + svc.debugClaudeMimic.Store(parseDebugEnvBool(os.Getenv("SUB2API_DEBUG_CLAUDE_MIMIC"))) + return svc } // GenerateSessionHash 从预解析请求计算粘性会话 hash @@ -1204,7 +1203,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro continue } account, ok := accountByID[routingAccountID] - if !ok || !account.IsSchedulable() { + if !ok || !s.isAccountSchedulableForSelection(account) { if !ok { filteredMissing++ } else { @@ -1220,7 +1219,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro filteredModelMapping++ continue } - if !account.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) { filteredModelScope++ modelScopeSkippedIDs = append(modelScopeSkippedIDs, account.ID) continue @@ -1249,10 +1248,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) { // 粘性账号在路由列表中,优先使用 if stickyAccount, ok := accountByID[stickyAccountID]; ok { - if stickyAccount.IsSchedulable() && + if s.isAccountSchedulableForSelection(stickyAccount) && s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) && - stickyAccount.IsSchedulableForModelWithContext(ctx, requestedModel) && + s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) { // 粘性会话窗口费用检查 result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency) if err == nil && result.Acquired { @@ -1406,7 +1405,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if !clearSticky && s.isAccountInGroup(account, groupID) && s.isAccountAllowedForPlatform(account, platform, useMixed) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && - account.IsSchedulableForModelWithContext(ctx, requestedModel) && + s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForWindowCost(ctx, account, true) { // 粘性会话窗口费用检查 result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if err == nil && result.Acquired { @@ -1457,7 +1456,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro // Scheduler snapshots can be temporarily stale (bucket rebuild is throttled); // re-check schedulability here so recently rate-limited/overloaded accounts // are not selected again before the bucket is rebuilt. - if !acc.IsSchedulable() { + if !s.isAccountSchedulableForSelection(acc) { continue } if !s.isAccountAllowedForPlatform(acc, platform, useMixed) { @@ -1466,7 +1465,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } - if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } // 窗口费用检查(非粘性会话路径) @@ -1737,6 +1736,9 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr } func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) { + if platform == PlatformSora { + return s.listSoraSchedulableAccounts(ctx, groupID) + } if s.schedulerSnapshot != nil { accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) if err == nil { @@ -1831,6 +1833,53 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i return accounts, useMixed, nil } +func (s *GatewayService) listSoraSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, bool, error) { + const useMixed = false + + var accounts []Account + var err error + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { + accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora) + } else if groupID != nil { + accounts, err = s.accountRepo.ListByGroup(ctx, *groupID) + } else { + accounts, err = s.accountRepo.ListByPlatform(ctx, PlatformSora) + } + if err != nil { + slog.Debug("account_scheduling_list_failed", + "group_id", derefGroupID(groupID), + "platform", PlatformSora, + "error", err) + return nil, useMixed, err + } + + filtered := make([]Account, 0, len(accounts)) + for _, acc := range accounts { + if acc.Platform != PlatformSora { + continue + } + if !s.isSoraAccountSchedulable(&acc) { + continue + } + filtered = append(filtered, acc) + } + slog.Debug("account_scheduling_list_sora", + "group_id", derefGroupID(groupID), + "platform", PlatformSora, + "raw_count", len(accounts), + "filtered_count", len(filtered)) + for _, acc := range filtered { + slog.Debug("account_scheduling_account_detail", + "account_id", acc.ID, + "name", acc.Name, + "platform", acc.Platform, + "type", acc.Type, + "status", acc.Status, + "tls_fingerprint", acc.IsTLSFingerprintEnabled()) + } + return filtered, useMixed, nil +} + // IsSingleAntigravityAccountGroup 检查指定分组是否只有一个 antigravity 平台的可调度账号。 // 用于 Handler 层在首次请求时提前设置 SingleAccountRetry context, // 避免单账号分组收到 503 时错误地设置模型限流标记导致后续请求连续快速失败。 @@ -1855,6 +1904,49 @@ func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform return account.Platform == platform } +func (s *GatewayService) isSoraAccountSchedulable(account *Account) bool { + return s.soraUnschedulableReason(account) == "" +} + +func (s *GatewayService) soraUnschedulableReason(account *Account) string { + if account == nil { + return "account_nil" + } + if account.Status != StatusActive { + return fmt.Sprintf("status=%s", account.Status) + } + if !account.Schedulable { + return "schedulable=false" + } + if account.TempUnschedulableUntil != nil && time.Now().Before(*account.TempUnschedulableUntil) { + return fmt.Sprintf("temp_unschedulable_until=%s", account.TempUnschedulableUntil.UTC().Format(time.RFC3339)) + } + return "" +} + +func (s *GatewayService) isAccountSchedulableForSelection(account *Account) bool { + if account == nil { + return false + } + if account.Platform == PlatformSora { + return s.isSoraAccountSchedulable(account) + } + return account.IsSchedulable() +} + +func (s *GatewayService) isAccountSchedulableForModelSelection(ctx context.Context, account *Account, requestedModel string) bool { + if account == nil { + return false + } + if account.Platform == PlatformSora { + if !s.isSoraAccountSchedulable(account) { + return false + } + return account.GetRateLimitRemainingTimeWithContext(ctx, requestedModel) <= 0 + } + return account.IsSchedulableForModelWithContext(ctx, requestedModel) +} + // isAccountInGroup checks if the account belongs to the specified group. // Returns true if groupID is nil (no group restriction) or account belongs to the group. func (s *GatewayService) isAccountInGroup(account *Account, groupID *int64) bool { @@ -2397,7 +2489,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) { if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) } @@ -2438,13 +2530,13 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, } // Scheduler snapshots can be temporarily stale; re-check schedulability here to // avoid selecting accounts that were recently rate-limited/overloaded. - if !acc.IsSchedulable() { + if !s.isAccountSchedulableForSelection(acc) { continue } if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } - if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } if selected == nil { @@ -2497,7 +2589,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) { return account, nil } } @@ -2527,13 +2619,13 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, } // Scheduler snapshots can be temporarily stale; re-check schedulability here to // avoid selecting accounts that were recently rate-limited/overloaded. - if !acc.IsSchedulable() { + if !s.isAccountSchedulableForSelection(acc) { continue } if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } - if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } if selected == nil { @@ -2561,8 +2653,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, } if selected == nil { + stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, platform, accounts, excludedIDs, false) if requestedModel != "" { - return nil, fmt.Errorf("no available accounts supporting model: %s", requestedModel) + return nil, fmt.Errorf("no available accounts supporting model: %s (%s)", requestedModel, summarizeSelectionFailureStats(stats)) } return nil, errors.New("no available accounts") } @@ -2604,7 +2697,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { if s.debugModelRoutingEnabled() { logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID) @@ -2643,7 +2736,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g } // Scheduler snapshots can be temporarily stale; re-check schedulability here to // avoid selecting accounts that were recently rate-limited/overloaded. - if !acc.IsSchedulable() { + if !s.isAccountSchedulableForSelection(acc) { continue } // 过滤:原生平台直接通过,antigravity 需要启用混合调度 @@ -2653,7 +2746,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } - if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } if selected == nil { @@ -2706,7 +2799,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if clearSticky { _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash) } - if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && account.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) { if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) { return account, nil } @@ -2734,7 +2827,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g } // Scheduler snapshots can be temporarily stale; re-check schedulability here to // avoid selecting accounts that were recently rate-limited/overloaded. - if !acc.IsSchedulable() { + if !s.isAccountSchedulableForSelection(acc) { continue } // 过滤:原生平台直接通过,antigravity 需要启用混合调度 @@ -2744,7 +2837,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { continue } - if !acc.IsSchedulableForModelWithContext(ctx, requestedModel) { + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { continue } if selected == nil { @@ -2772,8 +2865,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g } if selected == nil { + stats := s.logDetailedSelectionFailure(ctx, groupID, sessionHash, requestedModel, nativePlatform, accounts, excludedIDs, true) if requestedModel != "" { - return nil, fmt.Errorf("no available accounts supporting model: %s", requestedModel) + return nil, fmt.Errorf("no available accounts supporting model: %s (%s)", requestedModel, summarizeSelectionFailureStats(stats)) } return nil, errors.New("no available accounts") } @@ -2788,6 +2882,236 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g return selected, nil } +type selectionFailureStats struct { + Total int + Eligible int + Excluded int + Unschedulable int + PlatformFiltered int + ModelUnsupported int + ModelRateLimited int + SamplePlatformIDs []int64 + SampleMappingIDs []int64 + SampleRateLimitIDs []string +} + +type selectionFailureDiagnosis struct { + Category string + Detail string +} + +func (s *GatewayService) logDetailedSelectionFailure( + ctx context.Context, + groupID *int64, + sessionHash string, + requestedModel string, + platform string, + accounts []Account, + excludedIDs map[int64]struct{}, + allowMixedScheduling bool, +) selectionFailureStats { + stats := s.collectSelectionFailureStats(ctx, accounts, requestedModel, platform, excludedIDs, allowMixedScheduling) + logger.LegacyPrintf( + "service.gateway", + "[SelectAccountDetailed] group_id=%v model=%s platform=%s session=%s total=%d eligible=%d excluded=%d unschedulable=%d platform_filtered=%d model_unsupported=%d model_rate_limited=%d sample_platform_filtered=%v sample_model_unsupported=%v sample_model_rate_limited=%v", + derefGroupID(groupID), + requestedModel, + platform, + shortSessionHash(sessionHash), + stats.Total, + stats.Eligible, + stats.Excluded, + stats.Unschedulable, + stats.PlatformFiltered, + stats.ModelUnsupported, + stats.ModelRateLimited, + stats.SamplePlatformIDs, + stats.SampleMappingIDs, + stats.SampleRateLimitIDs, + ) + if platform == PlatformSora { + s.logSoraSelectionFailureDetails(ctx, groupID, sessionHash, requestedModel, accounts, excludedIDs, allowMixedScheduling) + } + return stats +} + +func (s *GatewayService) collectSelectionFailureStats( + ctx context.Context, + accounts []Account, + requestedModel string, + platform string, + excludedIDs map[int64]struct{}, + allowMixedScheduling bool, +) selectionFailureStats { + stats := selectionFailureStats{ + Total: len(accounts), + } + + for i := range accounts { + acc := &accounts[i] + diagnosis := s.diagnoseSelectionFailure(ctx, acc, requestedModel, platform, excludedIDs, allowMixedScheduling) + switch diagnosis.Category { + case "excluded": + stats.Excluded++ + case "unschedulable": + stats.Unschedulable++ + case "platform_filtered": + stats.PlatformFiltered++ + stats.SamplePlatformIDs = appendSelectionFailureSampleID(stats.SamplePlatformIDs, acc.ID) + case "model_unsupported": + stats.ModelUnsupported++ + stats.SampleMappingIDs = appendSelectionFailureSampleID(stats.SampleMappingIDs, acc.ID) + case "model_rate_limited": + stats.ModelRateLimited++ + remaining := acc.GetRateLimitRemainingTimeWithContext(ctx, requestedModel).Truncate(time.Second) + stats.SampleRateLimitIDs = appendSelectionFailureRateSample(stats.SampleRateLimitIDs, acc.ID, remaining) + default: + stats.Eligible++ + } + } + + return stats +} + +func (s *GatewayService) diagnoseSelectionFailure( + ctx context.Context, + acc *Account, + requestedModel string, + platform string, + excludedIDs map[int64]struct{}, + allowMixedScheduling bool, +) selectionFailureDiagnosis { + if acc == nil { + return selectionFailureDiagnosis{Category: "unschedulable", Detail: "account_nil"} + } + if _, excluded := excludedIDs[acc.ID]; excluded { + return selectionFailureDiagnosis{Category: "excluded"} + } + if !s.isAccountSchedulableForSelection(acc) { + detail := "generic_unschedulable" + if acc.Platform == PlatformSora { + detail = s.soraUnschedulableReason(acc) + } + return selectionFailureDiagnosis{Category: "unschedulable", Detail: detail} + } + if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) { + return selectionFailureDiagnosis{ + Category: "platform_filtered", + Detail: fmt.Sprintf("account_platform=%s requested_platform=%s", acc.Platform, strings.TrimSpace(platform)), + } + } + if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) { + return selectionFailureDiagnosis{ + Category: "model_unsupported", + Detail: fmt.Sprintf("model=%s", requestedModel), + } + } + if !s.isAccountSchedulableForModelSelection(ctx, acc, requestedModel) { + remaining := acc.GetRateLimitRemainingTimeWithContext(ctx, requestedModel).Truncate(time.Second) + return selectionFailureDiagnosis{ + Category: "model_rate_limited", + Detail: fmt.Sprintf("remaining=%s", remaining), + } + } + return selectionFailureDiagnosis{Category: "eligible"} +} + +func (s *GatewayService) logSoraSelectionFailureDetails( + ctx context.Context, + groupID *int64, + sessionHash string, + requestedModel string, + accounts []Account, + excludedIDs map[int64]struct{}, + allowMixedScheduling bool, +) { + const maxLines = 30 + logged := 0 + + for i := range accounts { + if logged >= maxLines { + break + } + acc := &accounts[i] + diagnosis := s.diagnoseSelectionFailure(ctx, acc, requestedModel, PlatformSora, excludedIDs, allowMixedScheduling) + if diagnosis.Category == "eligible" { + continue + } + detail := diagnosis.Detail + if detail == "" { + detail = "-" + } + logger.LegacyPrintf( + "service.gateway", + "[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s account_id=%d account_platform=%s category=%s detail=%s", + derefGroupID(groupID), + requestedModel, + shortSessionHash(sessionHash), + acc.ID, + acc.Platform, + diagnosis.Category, + detail, + ) + logged++ + } + if len(accounts) > maxLines { + logger.LegacyPrintf( + "service.gateway", + "[SelectAccountDetailed:Sora] group_id=%v model=%s session=%s truncated=true total=%d logged=%d", + derefGroupID(groupID), + requestedModel, + shortSessionHash(sessionHash), + len(accounts), + logged, + ) + } +} + +func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool { + if acc == nil { + return true + } + if allowMixedScheduling { + if acc.Platform == PlatformAntigravity { + return !acc.IsMixedSchedulingEnabled() + } + return acc.Platform != platform + } + if strings.TrimSpace(platform) == "" { + return false + } + return acc.Platform != platform +} + +func appendSelectionFailureSampleID(samples []int64, id int64) []int64 { + const limit = 5 + if len(samples) >= limit { + return samples + } + return append(samples, id) +} + +func appendSelectionFailureRateSample(samples []string, accountID int64, remaining time.Duration) []string { + const limit = 5 + if len(samples) >= limit { + return samples + } + return append(samples, fmt.Sprintf("%d(%s)", accountID, remaining)) +} + +func summarizeSelectionFailureStats(stats selectionFailureStats) string { + return fmt.Sprintf( + "total=%d eligible=%d excluded=%d unschedulable=%d platform_filtered=%d model_unsupported=%d model_rate_limited=%d", + stats.Total, + stats.Eligible, + stats.Excluded, + stats.Unschedulable, + stats.PlatformFiltered, + stats.ModelUnsupported, + stats.ModelRateLimited, + ) +} + // isModelSupportedByAccountWithContext 根据账户平台检查模型支持(带 context) // 对于 Antigravity 平台,会先获取映射后的最终模型名(包括 thinking 后缀)再检查支持 func (s *GatewayService) isModelSupportedByAccountWithContext(ctx context.Context, account *Account, requestedModel string) bool { @@ -2801,7 +3125,7 @@ func (s *GatewayService) isModelSupportedByAccountWithContext(ctx context.Contex return false } // 应用 thinking 后缀后检查最终模型是否在账号映射中 - if enabled, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok { + if enabled, ok := ThinkingEnabledFromContext(ctx); ok { finalModel := applyThinkingModelSuffix(mapped, enabled) if finalModel == mapped { return true // thinking 后缀未改变模型名,映射已通过 @@ -2821,6 +3145,9 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo } return mapAntigravityModel(account, requestedModel) != "" } + if account.Platform == PlatformSora { + return s.isSoraModelSupportedByAccount(account, requestedModel) + } // OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID) if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey { requestedModel = claude.NormalizeModelID(requestedModel) @@ -2829,6 +3156,143 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo return account.IsModelSupported(requestedModel) } +func (s *GatewayService) isSoraModelSupportedByAccount(account *Account, requestedModel string) bool { + if account == nil { + return false + } + if strings.TrimSpace(requestedModel) == "" { + return true + } + + // 先走原始精确/通配符匹配。 + mapping := account.GetModelMapping() + if len(mapping) == 0 || account.IsModelSupported(requestedModel) { + return true + } + + aliases := buildSoraModelAliases(requestedModel) + if len(aliases) == 0 { + return false + } + + hasSoraSelector := false + for pattern := range mapping { + if !isSoraModelSelector(pattern) { + continue + } + hasSoraSelector = true + if matchPatternAnyAlias(pattern, aliases) { + return true + } + } + + // 兼容旧账号:mapping 存在但未配置任何 Sora 选择器(例如只含 gpt-*), + // 此时不应误拦截 Sora 模型请求。 + if !hasSoraSelector { + return true + } + + return false +} + +func matchPatternAnyAlias(pattern string, aliases []string) bool { + normalizedPattern := strings.ToLower(strings.TrimSpace(pattern)) + if normalizedPattern == "" { + return false + } + for _, alias := range aliases { + if matchWildcard(normalizedPattern, alias) { + return true + } + } + return false +} + +func isSoraModelSelector(pattern string) bool { + p := strings.ToLower(strings.TrimSpace(pattern)) + if p == "" { + return false + } + + switch { + case strings.HasPrefix(p, "sora"), + strings.HasPrefix(p, "gpt-image"), + strings.HasPrefix(p, "prompt-enhance"), + strings.HasPrefix(p, "sy_"): + return true + } + + return p == "video" || p == "image" +} + +func buildSoraModelAliases(requestedModel string) []string { + modelID := strings.ToLower(strings.TrimSpace(requestedModel)) + if modelID == "" { + return nil + } + + aliases := make([]string, 0, 8) + addAlias := func(value string) { + v := strings.ToLower(strings.TrimSpace(value)) + if v == "" { + return + } + for _, existing := range aliases { + if existing == v { + return + } + } + aliases = append(aliases, v) + } + + addAlias(modelID) + cfg, ok := GetSoraModelConfig(modelID) + if ok { + addAlias(cfg.Model) + switch cfg.Type { + case "video": + addAlias("video") + addAlias("sora") + addAlias(soraVideoFamilyAlias(modelID)) + case "image": + addAlias("image") + addAlias("gpt-image") + case "prompt_enhance": + addAlias("prompt-enhance") + } + return aliases + } + + switch { + case strings.HasPrefix(modelID, "sora"): + addAlias("video") + addAlias("sora") + addAlias(soraVideoFamilyAlias(modelID)) + case strings.HasPrefix(modelID, "gpt-image"): + addAlias("image") + addAlias("gpt-image") + case strings.HasPrefix(modelID, "prompt-enhance"): + addAlias("prompt-enhance") + default: + return nil + } + + return aliases +} + +func soraVideoFamilyAlias(modelID string) string { + switch { + case strings.HasPrefix(modelID, "sora2pro-hd"): + return "sora2pro-hd" + case strings.HasPrefix(modelID, "sora2pro"): + return "sora2pro" + case strings.HasPrefix(modelID, "sora2"): + return "sora2" + default: + return "" + } +} + // GetAccessToken 获取账号凭证 func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) { switch account.Type { @@ -4012,7 +4476,7 @@ func (s *GatewayService) handleStreamingResponseAnthropicAPIKeyPassthrough( s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) } - writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg) + writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) contentType := strings.TrimSpace(resp.Header.Get("Content-Type")) if contentType == "" { @@ -4308,7 +4772,7 @@ func (s *GatewayService) handleNonStreamingResponseAnthropicAPIKeyPassthrough( usage := parseClaudeUsageFromResponseBody(body) - writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg) + writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) contentType := strings.TrimSpace(resp.Header.Get("Content-Type")) if contentType == "" { contentType = "application/json" @@ -4317,12 +4781,12 @@ func (s *GatewayService) handleNonStreamingResponseAnthropicAPIKeyPassthrough( return usage, nil } -func writeAnthropicPassthroughResponseHeaders(dst http.Header, src http.Header, cfg *config.Config) { +func writeAnthropicPassthroughResponseHeaders(dst http.Header, src http.Header, filter *responseheaders.CompiledHeaderFilter) { if dst == nil || src == nil { return } - if cfg != nil { - responseheaders.WriteFilteredHeaders(dst, src, cfg.Security.ResponseHeaders) + if filter != nil { + responseheaders.WriteFilteredHeaders(dst, src, filter) return } if v := strings.TrimSpace(src.Get("Content-Type")); v != "" { @@ -4425,12 +4889,11 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex // messages requests typically use only oauth + interleaved-thinking. // Also drop claude-code beta if a downstream client added it. requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking} - drop := droppedBetaSet(claude.BetaClaudeCode) - req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, drop)) + req.Header.Set("anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, droppedBetasWithClaudeCodeSet)) } else { // Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta clientBetaHeader := req.Header.Get("anthropic-beta") - req.Header.Set("anthropic-beta", stripBetaTokens(s.getBetaHeader(modelID, clientBetaHeader), claude.DroppedBetas)) + req.Header.Set("anthropic-beta", stripBetaTokensWithSet(s.getBetaHeader(modelID, clientBetaHeader), defaultDroppedBetasSet)) } } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { // API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭) @@ -4589,9 +5052,12 @@ func stripBetaTokens(header string, tokens []string) string { if header == "" || len(tokens) == 0 { return header } - drop := make(map[string]struct{}, len(tokens)) - for _, t := range tokens { - drop[t] = struct{}{} + return stripBetaTokensWithSet(header, buildBetaTokenSet(tokens)) +} + +func stripBetaTokensWithSet(header string, drop map[string]struct{}) string { + if header == "" || len(drop) == 0 { + return header } parts := strings.Split(header, ",") out := make([]string, 0, len(parts)) @@ -4613,8 +5079,8 @@ func stripBetaTokens(header string, tokens []string) string { // droppedBetaSet returns claude.DroppedBetas as a set, with optional extra tokens. func droppedBetaSet(extra ...string) map[string]struct{} { - m := make(map[string]struct{}, len(claude.DroppedBetas)+len(extra)) - for _, t := range claude.DroppedBetas { + m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(extra)) + for t := range defaultDroppedBetasSet { m[t] = struct{}{} } for _, t := range extra { @@ -4623,6 +5089,22 @@ func droppedBetaSet(extra ...string) map[string]struct{} { return m } +func buildBetaTokenSet(tokens []string) map[string]struct{} { + m := make(map[string]struct{}, len(tokens)) + for _, t := range tokens { + if t == "" { + continue + } + m[t] = struct{}{} + } + return m +} + +var ( + defaultDroppedBetasSet = buildBetaTokenSet(claude.DroppedBetas) + droppedBetasWithClaudeCodeSet = droppedBetaSet(claude.BetaClaudeCode) +) + // applyClaudeCodeMimicHeaders forces "Claude Code-like" request headers. // This mirrors opencode-anthropic-auth behavior: do not trust downstream // headers when using Claude Code-scoped OAuth credentials. @@ -4752,6 +5234,20 @@ func extractUpstreamErrorMessage(body []byte) string { return gjson.GetBytes(body, "message").String() } +func isCountTokensUnsupported404(statusCode int, body []byte) bool { + if statusCode != http.StatusNotFound { + return false + } + msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(body))) + if msg == "" { + return false + } + if strings.Contains(msg, "/v1/messages/count_tokens") { + return true + } + return strings.Contains(msg, "count_tokens") && strings.Contains(msg, "not found") +} + func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) @@ -5029,8 +5525,8 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http // 更新5h窗口状态 s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header) - if s.cfg != nil { - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) } // 设置SSE响应头 @@ -5124,9 +5620,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http pendingEventLines := make([]string, 0, 4) - processSSEEvent := func(lines []string) ([]string, string, error) { + processSSEEvent := func(lines []string) ([]string, string, *sseUsagePatch, error) { if len(lines) == 0 { - return nil, "", nil + return nil, "", nil, nil } eventName := "" @@ -5143,11 +5639,11 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } if eventName == "error" { - return nil, dataLine, errors.New("have error in stream") + return nil, dataLine, nil, errors.New("have error in stream") } if dataLine == "" { - return []string{strings.Join(lines, "\n") + "\n\n"}, "", nil + return []string{strings.Join(lines, "\n") + "\n\n"}, "", nil, nil } if dataLine == "[DONE]" { @@ -5156,7 +5652,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http block = "event: " + eventName + "\n" } block += "data: " + dataLine + "\n\n" - return []string{block}, dataLine, nil + return []string{block}, dataLine, nil, nil } var event map[string]any @@ -5167,25 +5663,26 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http block = "event: " + eventName + "\n" } block += "data: " + dataLine + "\n\n" - return []string{block}, dataLine, nil + return []string{block}, dataLine, nil, nil } eventType, _ := event["type"].(string) if eventName == "" { eventName = eventType } + eventChanged := false // 兼容 Kimi cached_tokens → cache_read_input_tokens if eventType == "message_start" { if msg, ok := event["message"].(map[string]any); ok { if u, ok := msg["usage"].(map[string]any); ok { - reconcileCachedTokens(u) + eventChanged = reconcileCachedTokens(u) || eventChanged } } } if eventType == "message_delta" { if u, ok := event["usage"].(map[string]any); ok { - reconcileCachedTokens(u) + eventChanged = reconcileCachedTokens(u) || eventChanged } } @@ -5195,13 +5692,13 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http if eventType == "message_start" { if msg, ok := event["message"].(map[string]any); ok { if u, ok := msg["usage"].(map[string]any); ok { - rewriteCacheCreationJSON(u, overrideTarget) + eventChanged = rewriteCacheCreationJSON(u, overrideTarget) || eventChanged } } } if eventType == "message_delta" { if u, ok := event["usage"].(map[string]any); ok { - rewriteCacheCreationJSON(u, overrideTarget) + eventChanged = rewriteCacheCreationJSON(u, overrideTarget) || eventChanged } } } @@ -5210,10 +5707,21 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http if msg, ok := event["message"].(map[string]any); ok { if model, ok := msg["model"].(string); ok && model == mappedModel { msg["model"] = originalModel + eventChanged = true } } } + usagePatch := s.extractSSEUsagePatch(event) + if !eventChanged { + block := "" + if eventName != "" { + block = "event: " + eventName + "\n" + } + block += "data: " + dataLine + "\n\n" + return []string{block}, dataLine, usagePatch, nil + } + newData, err := json.Marshal(event) if err != nil { // 序列化失败,直接透传原始数据 @@ -5222,7 +5730,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http block = "event: " + eventName + "\n" } block += "data: " + dataLine + "\n\n" - return []string{block}, dataLine, nil + return []string{block}, dataLine, usagePatch, nil } block := "" @@ -5230,7 +5738,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http block = "event: " + eventName + "\n" } block += "data: " + string(newData) + "\n\n" - return []string{block}, string(newData), nil + return []string{block}, string(newData), usagePatch, nil } for { @@ -5268,7 +5776,7 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http continue } - outputBlocks, data, err := processSSEEvent(pendingEventLines) + outputBlocks, data, usagePatch, err := processSSEEvent(pendingEventLines) pendingEventLines = pendingEventLines[:0] if err != nil { if clientDisconnected { @@ -5291,7 +5799,9 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms } - s.parseSSEUsage(data, usage) + if usagePatch != nil { + mergeSSEUsagePatch(usage, usagePatch) + } } } continue @@ -5322,64 +5832,163 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http } func (s *GatewayService) parseSSEUsage(data string, usage *ClaudeUsage) { - // 解析message_start获取input tokens(标准Claude API格式) - var msgStart struct { - Type string `json:"type"` - Message struct { - Usage ClaudeUsage `json:"usage"` - } `json:"message"` - } - if json.Unmarshal([]byte(data), &msgStart) == nil && msgStart.Type == "message_start" { - usage.InputTokens = msgStart.Message.Usage.InputTokens - usage.CacheCreationInputTokens = msgStart.Message.Usage.CacheCreationInputTokens - usage.CacheReadInputTokens = msgStart.Message.Usage.CacheReadInputTokens - - // 解析嵌套的 cache_creation 对象中的 5m/1h 明细 - cc5m := gjson.Get(data, "message.usage.cache_creation.ephemeral_5m_input_tokens") - cc1h := gjson.Get(data, "message.usage.cache_creation.ephemeral_1h_input_tokens") - if cc5m.Exists() || cc1h.Exists() { - usage.CacheCreation5mTokens = int(cc5m.Int()) - usage.CacheCreation1hTokens = int(cc1h.Int()) - } + if usage == nil { + return } - // 解析message_delta获取tokens(兼容GLM等把所有usage放在delta中的API) - var msgDelta struct { - Type string `json:"type"` - Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - CacheCreationInputTokens int `json:"cache_creation_input_tokens"` - CacheReadInputTokens int `json:"cache_read_input_tokens"` - } `json:"usage"` + var event map[string]any + if err := json.Unmarshal([]byte(data), &event); err != nil { + return } - if json.Unmarshal([]byte(data), &msgDelta) == nil && msgDelta.Type == "message_delta" { - // message_delta 仅覆盖存在且非0的字段 - // 避免覆盖 message_start 中已有的值(如 input_tokens) - // Claude API 的 message_delta 通常只包含 output_tokens - if msgDelta.Usage.InputTokens > 0 { - usage.InputTokens = msgDelta.Usage.InputTokens - } - if msgDelta.Usage.OutputTokens > 0 { - usage.OutputTokens = msgDelta.Usage.OutputTokens - } - if msgDelta.Usage.CacheCreationInputTokens > 0 { - usage.CacheCreationInputTokens = msgDelta.Usage.CacheCreationInputTokens - } - if msgDelta.Usage.CacheReadInputTokens > 0 { - usage.CacheReadInputTokens = msgDelta.Usage.CacheReadInputTokens + + if patch := s.extractSSEUsagePatch(event); patch != nil { + mergeSSEUsagePatch(usage, patch) + } +} + +type sseUsagePatch struct { + inputTokens int + hasInputTokens bool + outputTokens int + hasOutputTokens bool + cacheCreationInputTokens int + hasCacheCreationInput bool + cacheReadInputTokens int + hasCacheReadInput bool + cacheCreation5mTokens int + hasCacheCreation5m bool + cacheCreation1hTokens int + hasCacheCreation1h bool +} + +func (s *GatewayService) extractSSEUsagePatch(event map[string]any) *sseUsagePatch { + if len(event) == 0 { + return nil + } + + eventType, _ := event["type"].(string) + switch eventType { + case "message_start": + msg, _ := event["message"].(map[string]any) + usageObj, _ := msg["usage"].(map[string]any) + if len(usageObj) == 0 { + return nil } - // 解析嵌套的 cache_creation 对象中的 5m/1h 明细 - cc5m := gjson.Get(data, "usage.cache_creation.ephemeral_5m_input_tokens") - cc1h := gjson.Get(data, "usage.cache_creation.ephemeral_1h_input_tokens") - if cc5m.Exists() && cc5m.Int() > 0 { - usage.CacheCreation5mTokens = int(cc5m.Int()) + patch := &sseUsagePatch{} + patch.hasInputTokens = true + if v, ok := parseSSEUsageInt(usageObj["input_tokens"]); ok { + patch.inputTokens = v } - if cc1h.Exists() && cc1h.Int() > 0 { - usage.CacheCreation1hTokens = int(cc1h.Int()) + patch.hasCacheCreationInput = true + if v, ok := parseSSEUsageInt(usageObj["cache_creation_input_tokens"]); ok { + patch.cacheCreationInputTokens = v + } + patch.hasCacheReadInput = true + if v, ok := parseSSEUsageInt(usageObj["cache_read_input_tokens"]); ok { + patch.cacheReadInputTokens = v + } + if cc, ok := usageObj["cache_creation"].(map[string]any); ok { + if v, exists := parseSSEUsageInt(cc["ephemeral_5m_input_tokens"]); exists { + patch.cacheCreation5mTokens = v + patch.hasCacheCreation5m = true + } + if v, exists := parseSSEUsageInt(cc["ephemeral_1h_input_tokens"]); exists { + patch.cacheCreation1hTokens = v + patch.hasCacheCreation1h = true + } + } + return patch + + case "message_delta": + usageObj, _ := event["usage"].(map[string]any) + if len(usageObj) == 0 { + return nil + } + + patch := &sseUsagePatch{} + if v, ok := parseSSEUsageInt(usageObj["input_tokens"]); ok && v > 0 { + patch.inputTokens = v + patch.hasInputTokens = true + } + if v, ok := parseSSEUsageInt(usageObj["output_tokens"]); ok && v > 0 { + patch.outputTokens = v + patch.hasOutputTokens = true + } + if v, ok := parseSSEUsageInt(usageObj["cache_creation_input_tokens"]); ok && v > 0 { + patch.cacheCreationInputTokens = v + patch.hasCacheCreationInput = true + } + if v, ok := parseSSEUsageInt(usageObj["cache_read_input_tokens"]); ok && v > 0 { + patch.cacheReadInputTokens = v + patch.hasCacheReadInput = true + } + if cc, ok := usageObj["cache_creation"].(map[string]any); ok { + if v, exists := parseSSEUsageInt(cc["ephemeral_5m_input_tokens"]); exists && v > 0 { + patch.cacheCreation5mTokens = v + patch.hasCacheCreation5m = true + } + if v, exists := parseSSEUsageInt(cc["ephemeral_1h_input_tokens"]); exists && v > 0 { + patch.cacheCreation1hTokens = v + patch.hasCacheCreation1h = true + } + } + return patch + } + + return nil +} + +func mergeSSEUsagePatch(usage *ClaudeUsage, patch *sseUsagePatch) { + if usage == nil || patch == nil { + return + } + + if patch.hasInputTokens { + usage.InputTokens = patch.inputTokens + } + if patch.hasCacheCreationInput { + usage.CacheCreationInputTokens = patch.cacheCreationInputTokens + } + if patch.hasCacheReadInput { + usage.CacheReadInputTokens = patch.cacheReadInputTokens + } + if patch.hasOutputTokens { + usage.OutputTokens = patch.outputTokens + } + if patch.hasCacheCreation5m { + usage.CacheCreation5mTokens = patch.cacheCreation5mTokens + } + if patch.hasCacheCreation1h { + usage.CacheCreation1hTokens = patch.cacheCreation1hTokens + } +} + +func parseSSEUsageInt(value any) (int, bool) { + switch v := value.(type) { + case float64: + return int(v), true + case float32: + return int(v), true + case int: + return v, true + case int64: + return int(v), true + case int32: + return int(v), true + case json.Number: + if i, err := v.Int64(); err == nil { + return int(i), true + } + if f, err := v.Float64(); err == nil { + return int(f), true + } + case string: + if parsed, err := strconv.Atoi(strings.TrimSpace(v)); err == nil { + return parsed, true } } + return 0, false } // applyCacheTTLOverride 将所有 cache creation tokens 归入指定的 TTL 类型。 @@ -5413,25 +6022,32 @@ func applyCacheTTLOverride(usage *ClaudeUsage, target string) bool { // rewriteCacheCreationJSON 在 JSON usage 对象中重写 cache_creation 嵌套对象的 TTL 分类。 // usageObj 是 usage JSON 对象(map[string]any)。 -func rewriteCacheCreationJSON(usageObj map[string]any, target string) { +func rewriteCacheCreationJSON(usageObj map[string]any, target string) bool { ccObj, ok := usageObj["cache_creation"].(map[string]any) if !ok { - return + return false } - v5m, _ := ccObj["ephemeral_5m_input_tokens"].(float64) - v1h, _ := ccObj["ephemeral_1h_input_tokens"].(float64) + v5m, _ := parseSSEUsageInt(ccObj["ephemeral_5m_input_tokens"]) + v1h, _ := parseSSEUsageInt(ccObj["ephemeral_1h_input_tokens"]) total := v5m + v1h if total == 0 { - return + return false } switch target { case "1h": - ccObj["ephemeral_1h_input_tokens"] = total + if v1h == total { + return false + } + ccObj["ephemeral_1h_input_tokens"] = float64(total) ccObj["ephemeral_5m_input_tokens"] = float64(0) default: // "5m" - ccObj["ephemeral_5m_input_tokens"] = total + if v5m == total { + return false + } + ccObj["ephemeral_5m_input_tokens"] = float64(total) ccObj["ephemeral_1h_input_tokens"] = float64(0) } + return true } func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*ClaudeUsage, error) { @@ -5500,7 +6116,7 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h body = s.replaceModelInResponseBody(body, mappedModel, originalModel) } - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) contentType := "application/json" if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled { @@ -6224,8 +6840,9 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) // 中转站不支持 count_tokens 端点时(404),返回 404 让客户端 fallback 到本地估算。 + // 仅在错误消息明确指向 count_tokens endpoint 不存在时生效,避免误吞其他 404(如错误 base_url)。 // 返回 nil 避免 handler 层记录为错误,也不设置 ops 上游错误上下文。 - if resp.StatusCode == http.StatusNotFound { + if isCountTokensUnsupported404(resp.StatusCode, respBody) { logger.LegacyPrintf("service.gateway", "[count_tokens] Upstream does not support count_tokens (404), returning 404: account=%d name=%s msg=%s", account.ID, account.Name, truncateString(upstreamMsg, 512)) @@ -6268,7 +6885,7 @@ func (s *GatewayService) forwardCountTokensAnthropicAPIKeyPassthrough(ctx contex return fmt.Errorf("upstream error: %d message=%s", resp.StatusCode, upstreamMsg) } - writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg) + writeAnthropicPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) contentType := strings.TrimSpace(resp.Header.Get("Content-Type")) if contentType == "" { contentType = "application/json" @@ -6420,7 +7037,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con if !strings.Contains(beta, claude.BetaTokenCounting) { beta = beta + "," + claude.BetaTokenCounting } - req.Header.Set("anthropic-beta", stripBetaTokens(beta, claude.DroppedBetas)) + req.Header.Set("anthropic-beta", stripBetaTokensWithSet(beta, defaultDroppedBetasSet)) } } } else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" { diff --git a/backend/internal/service/gateway_service_selection_failure_stats_test.go b/backend/internal/service/gateway_service_selection_failure_stats_test.go new file mode 100644 index 00000000..743d70bb --- /dev/null +++ b/backend/internal/service/gateway_service_selection_failure_stats_test.go @@ -0,0 +1,141 @@ +package service + +import ( + "context" + "strings" + "testing" + "time" +) + +func TestCollectSelectionFailureStats(t *testing.T) { + svc := &GatewayService{} + model := "sora2-landscape-10s" + resetAt := time.Now().Add(2 * time.Minute).Format(time.RFC3339) + + accounts := []Account{ + // excluded + { + ID: 1, + Platform: PlatformSora, + Status: StatusActive, + Schedulable: true, + }, + // unschedulable + { + ID: 2, + Platform: PlatformSora, + Status: StatusActive, + Schedulable: false, + }, + // platform filtered + { + ID: 3, + Platform: PlatformOpenAI, + Status: StatusActive, + Schedulable: true, + }, + // model unsupported + { + ID: 4, + Platform: PlatformSora, + Status: StatusActive, + Schedulable: true, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-image": "gpt-image", + }, + }, + }, + // model rate limited + { + ID: 5, + Platform: PlatformSora, + Status: StatusActive, + Schedulable: true, + Extra: map[string]any{ + "model_rate_limits": map[string]any{ + model: map[string]any{ + "rate_limit_reset_at": resetAt, + }, + }, + }, + }, + // eligible + { + ID: 6, + Platform: PlatformSora, + Status: StatusActive, + Schedulable: true, + }, + } + + excluded := map[int64]struct{}{1: {}} + stats := svc.collectSelectionFailureStats(context.Background(), accounts, model, PlatformSora, excluded, false) + + if stats.Total != 6 { + t.Fatalf("total=%d want=6", stats.Total) + } + if stats.Excluded != 1 { + t.Fatalf("excluded=%d want=1", stats.Excluded) + } + if stats.Unschedulable != 1 { + t.Fatalf("unschedulable=%d want=1", stats.Unschedulable) + } + if stats.PlatformFiltered != 1 { + t.Fatalf("platform_filtered=%d want=1", stats.PlatformFiltered) + } + if stats.ModelUnsupported != 1 { + t.Fatalf("model_unsupported=%d want=1", stats.ModelUnsupported) + } + if stats.ModelRateLimited != 1 { + t.Fatalf("model_rate_limited=%d want=1", stats.ModelRateLimited) + } + if stats.Eligible != 1 { + t.Fatalf("eligible=%d want=1", stats.Eligible) + } +} + +func TestDiagnoseSelectionFailure_SoraUnschedulableDetail(t *testing.T) { + svc := &GatewayService{} + acc := &Account{ + ID: 7, + Platform: PlatformSora, + Status: StatusActive, + Schedulable: false, + } + + diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, "sora2-landscape-10s", PlatformSora, map[int64]struct{}{}, false) + if diagnosis.Category != "unschedulable" { + t.Fatalf("category=%s want=unschedulable", diagnosis.Category) + } + if diagnosis.Detail != "schedulable=false" { + t.Fatalf("detail=%s want=schedulable=false", diagnosis.Detail) + } +} + +func TestDiagnoseSelectionFailure_SoraModelRateLimitedDetail(t *testing.T) { + svc := &GatewayService{} + model := "sora2-landscape-10s" + resetAt := time.Now().Add(2 * time.Minute).UTC().Format(time.RFC3339) + acc := &Account{ + ID: 8, + Platform: PlatformSora, + Status: StatusActive, + Schedulable: true, + Extra: map[string]any{ + "model_rate_limits": map[string]any{ + model: map[string]any{ + "rate_limit_reset_at": resetAt, + }, + }, + }, + } + + diagnosis := svc.diagnoseSelectionFailure(context.Background(), acc, model, PlatformSora, map[int64]struct{}{}, false) + if diagnosis.Category != "model_rate_limited" { + t.Fatalf("category=%s want=model_rate_limited", diagnosis.Category) + } + if !strings.Contains(diagnosis.Detail, "remaining=") { + t.Fatalf("detail=%s want contains remaining=", diagnosis.Detail) + } +} diff --git a/backend/internal/service/gateway_service_sora_model_support_test.go b/backend/internal/service/gateway_service_sora_model_support_test.go new file mode 100644 index 00000000..8ee2a960 --- /dev/null +++ b/backend/internal/service/gateway_service_sora_model_support_test.go @@ -0,0 +1,79 @@ +package service + +import "testing" + +func TestGatewayServiceIsModelSupportedByAccount_SoraNoMappingAllowsAll(t *testing.T) { + svc := &GatewayService{} + account := &Account{ + Platform: PlatformSora, + Credentials: map[string]any{}, + } + + if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") { + t.Fatalf("expected sora model to be supported when model_mapping is empty") + } +} + +func TestGatewayServiceIsModelSupportedByAccount_SoraLegacyNonSoraMappingDoesNotBlock(t *testing.T) { + svc := &GatewayService{} + account := &Account{ + Platform: PlatformSora, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-4o": "gpt-4o", + }, + }, + } + + if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") { + t.Fatalf("expected sora model to be supported when mapping has no sora selectors") + } +} + +func TestGatewayServiceIsModelSupportedByAccount_SoraFamilyAlias(t *testing.T) { + svc := &GatewayService{} + account := &Account{ + Platform: PlatformSora, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "sora2": "sora2", + }, + }, + } + + if !svc.isModelSupportedByAccount(account, "sora2-landscape-15s") { + t.Fatalf("expected family selector sora2 to support sora2-landscape-15s") + } +} + +func TestGatewayServiceIsModelSupportedByAccount_SoraUnderlyingModelAlias(t *testing.T) { + svc := &GatewayService{} + account := &Account{ + Platform: PlatformSora, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "sy_8": "sy_8", + }, + }, + } + + if !svc.isModelSupportedByAccount(account, "sora2-landscape-10s") { + t.Fatalf("expected underlying model selector sy_8 to support sora2-landscape-10s") + } +} + +func TestGatewayServiceIsModelSupportedByAccount_SoraExplicitImageSelectorBlocksVideo(t *testing.T) { + svc := &GatewayService{} + account := &Account{ + Platform: PlatformSora, + Credentials: map[string]any{ + "model_mapping": map[string]any{ + "gpt-image": "gpt-image", + }, + }, + } + + if svc.isModelSupportedByAccount(account, "sora2-landscape-10s") { + t.Fatalf("expected video model to be blocked when mapping explicitly only allows gpt-image") + } +} diff --git a/backend/internal/service/gateway_service_sora_scheduling_test.go b/backend/internal/service/gateway_service_sora_scheduling_test.go new file mode 100644 index 00000000..5178e68e --- /dev/null +++ b/backend/internal/service/gateway_service_sora_scheduling_test.go @@ -0,0 +1,89 @@ +package service + +import ( + "context" + "testing" + "time" +) + +func TestGatewayServiceIsAccountSchedulableForSelectionSoraIgnoresGenericWindows(t *testing.T) { + svc := &GatewayService{} + now := time.Now() + past := now.Add(-1 * time.Minute) + future := now.Add(5 * time.Minute) + + acc := &Account{ + Platform: PlatformSora, + Status: StatusActive, + Schedulable: true, + AutoPauseOnExpired: true, + ExpiresAt: &past, + OverloadUntil: &future, + RateLimitResetAt: &future, + } + + if !svc.isAccountSchedulableForSelection(acc) { + t.Fatalf("expected sora account to ignore generic expiry/overload/rate-limit windows") + } +} + +func TestGatewayServiceIsAccountSchedulableForSelectionNonSoraKeepsGenericLogic(t *testing.T) { + svc := &GatewayService{} + future := time.Now().Add(5 * time.Minute) + + acc := &Account{ + Platform: PlatformAnthropic, + Status: StatusActive, + Schedulable: true, + RateLimitResetAt: &future, + } + + if svc.isAccountSchedulableForSelection(acc) { + t.Fatalf("expected non-sora account to keep generic schedulable checks") + } +} + +func TestGatewayServiceIsAccountSchedulableForModelSelectionSoraChecksModelScopeOnly(t *testing.T) { + svc := &GatewayService{} + model := "sora2-landscape-10s" + resetAt := time.Now().Add(2 * time.Minute).UTC().Format(time.RFC3339) + globalResetAt := time.Now().Add(2 * time.Minute) + + acc := &Account{ + Platform: PlatformSora, + Status: StatusActive, + Schedulable: true, + RateLimitResetAt: &globalResetAt, + Extra: map[string]any{ + "model_rate_limits": map[string]any{ + model: map[string]any{ + "rate_limit_reset_at": resetAt, + }, + }, + }, + } + + if svc.isAccountSchedulableForModelSelection(context.Background(), acc, model) { + t.Fatalf("expected sora account to be blocked by model scope rate limit") + } +} + +func TestCollectSelectionFailureStatsSoraIgnoresGenericUnschedulableWindows(t *testing.T) { + svc := &GatewayService{} + future := time.Now().Add(3 * time.Minute) + + accounts := []Account{ + { + ID: 1, + Platform: PlatformSora, + Status: StatusActive, + Schedulable: true, + RateLimitResetAt: &future, + }, + } + + stats := svc.collectSelectionFailureStats(context.Background(), accounts, "sora2-landscape-10s", PlatformSora, map[int64]struct{}{}, false) + if stats.Unschedulable != 0 || stats.Eligible != 1 { + t.Fatalf("unexpected stats: unschedulable=%d eligible=%d", stats.Unschedulable, stats.Eligible) + } +} diff --git a/backend/internal/service/gateway_waiting_queue_test.go b/backend/internal/service/gateway_waiting_queue_test.go index 0ed95c87..0c53323e 100644 --- a/backend/internal/service/gateway_waiting_queue_test.go +++ b/backend/internal/service/gateway_waiting_queue_test.go @@ -105,12 +105,12 @@ func TestCalculateMaxWait_Scenarios(t *testing.T) { concurrency int expected int }{ - {5, 25}, // 5 + 20 - {10, 30}, // 10 + 20 - {1, 21}, // 1 + 20 - {0, 21}, // min(1) + 20 - {-1, 21}, // min(1) + 20 - {-10, 21}, // min(1) + 20 + {5, 25}, // 5 + 20 + {10, 30}, // 10 + 20 + {1, 21}, // 1 + 20 + {0, 21}, // min(1) + 20 + {-1, 21}, // min(1) + 20 + {-10, 21}, // min(1) + 20 {100, 120}, // 100 + 20 } for _, tt := range tests { diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 8670f99a..1c38b6c2 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -53,6 +53,7 @@ type GeminiMessagesCompatService struct { httpUpstream HTTPUpstream antigravityGatewayService *AntigravityGatewayService cfg *config.Config + responseHeaderFilter *responseheaders.CompiledHeaderFilter } func NewGeminiMessagesCompatService( @@ -76,6 +77,7 @@ func NewGeminiMessagesCompatService( httpUpstream: httpUpstream, antigravityGatewayService: antigravityGatewayService, cfg: cfg, + responseHeaderFilter: compileResponseHeaderFilter(cfg), } } @@ -229,6 +231,16 @@ func (s *GeminiMessagesCompatService) isAccountUsableForRequest( account *Account, requestedModel, platform string, useMixedScheduling bool, +) bool { + return s.isAccountUsableForRequestWithPrecheck(ctx, account, requestedModel, platform, useMixedScheduling, nil) +} + +func (s *GeminiMessagesCompatService) isAccountUsableForRequestWithPrecheck( + ctx context.Context, + account *Account, + requestedModel, platform string, + useMixedScheduling bool, + precheckResult map[int64]bool, ) bool { // 检查模型调度能力 // Check model scheduling capability @@ -250,7 +262,7 @@ func (s *GeminiMessagesCompatService) isAccountUsableForRequest( // 速率限制预检 // Rate limit precheck - if !s.passesRateLimitPreCheck(ctx, account, requestedModel) { + if !s.passesRateLimitPreCheckWithCache(ctx, account, requestedModel, precheckResult) { return false } @@ -272,15 +284,17 @@ func (s *GeminiMessagesCompatService) isAccountValidForPlatform(account *Account return false } -// passesRateLimitPreCheck 执行速率限制预检。 -// 返回 true 表示通过预检或无需预检。 -// -// passesRateLimitPreCheck performs rate limit precheck. -// Returns true if passed or precheck not required. -func (s *GeminiMessagesCompatService) passesRateLimitPreCheck(ctx context.Context, account *Account, requestedModel string) bool { +func (s *GeminiMessagesCompatService) passesRateLimitPreCheckWithCache(ctx context.Context, account *Account, requestedModel string, precheckResult map[int64]bool) bool { if s.rateLimitService == nil || requestedModel == "" { return true } + + if precheckResult != nil { + if ok, exists := precheckResult[account.ID]; exists { + return ok + } + } + ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel) if err != nil { logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini PreCheck] Account %d precheck error: %v", account.ID, err) @@ -302,6 +316,7 @@ func (s *GeminiMessagesCompatService) selectBestGeminiAccount( useMixedScheduling bool, ) *Account { var selected *Account + precheckResult := s.buildPreCheckUsageResultMap(ctx, accounts, requestedModel) for i := range accounts { acc := &accounts[i] @@ -312,7 +327,7 @@ func (s *GeminiMessagesCompatService) selectBestGeminiAccount( } // 检查账号是否可用于当前请求 - if !s.isAccountUsableForRequest(ctx, acc, requestedModel, platform, useMixedScheduling) { + if !s.isAccountUsableForRequestWithPrecheck(ctx, acc, requestedModel, platform, useMixedScheduling, precheckResult) { continue } @@ -330,6 +345,23 @@ func (s *GeminiMessagesCompatService) selectBestGeminiAccount( return selected } +func (s *GeminiMessagesCompatService) buildPreCheckUsageResultMap(ctx context.Context, accounts []Account, requestedModel string) map[int64]bool { + if s.rateLimitService == nil || requestedModel == "" || len(accounts) == 0 { + return nil + } + + candidates := make([]*Account, 0, len(accounts)) + for i := range accounts { + candidates = append(candidates, &accounts[i]) + } + + result, err := s.rateLimitService.PreCheckUsageBatch(ctx, candidates, requestedModel) + if err != nil { + logger.LegacyPrintf("service.gemini_messages_compat", "[Gemini PreCheckBatch] failed: %v", err) + } + return result +} + // isBetterGeminiAccount 判断 candidate 是否比 current 更优。 // 规则:优先级更高(数值更小)优先;同优先级时,未使用过的优先(OAuth > 非 OAuth),其次是最久未使用的。 // @@ -2390,7 +2422,7 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co } } - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) contentType := resp.Header.Get("Content-Type") if contentType == "" { @@ -2415,8 +2447,8 @@ func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Conte logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ====================================================") } - if s.cfg != nil { - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) } c.Status(resp.StatusCode) @@ -2557,7 +2589,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac body, _ := io.ReadAll(io.LimitReader(resp.Body, 8<<20)) wwwAuthenticate := resp.Header.Get("Www-Authenticate") - filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.cfg.Security.ResponseHeaders) + filteredHeaders := responseheaders.FilterHeaders(resp.Header, s.responseHeaderFilter) if wwwAuthenticate != "" { filteredHeaders.Set("Www-Authenticate", wwwAuthenticate) } diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index 86ece03f..6990caca 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -32,6 +32,9 @@ type Group struct { SoraVideoPricePerRequest *float64 SoraVideoPricePerRequestHD *float64 + // Sora 存储配额 + SoraStorageQuotaBytes int64 + // Claude Code 客户端限制 ClaudeCodeOnly bool FallbackGroupID *int64 diff --git a/backend/internal/service/model_rate_limit.go b/backend/internal/service/model_rate_limit.go index ff4b5977..c45615cc 100644 --- a/backend/internal/service/model_rate_limit.go +++ b/backend/internal/service/model_rate_limit.go @@ -4,8 +4,6 @@ import ( "context" "strings" "time" - - "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" ) const modelRateLimitsKey = "model_rate_limits" @@ -73,7 +71,7 @@ func resolveFinalAntigravityModelKey(ctx context.Context, account *Account, requ return "" } // thinking 会影响 Antigravity 最终模型名(例如 claude-sonnet-4-5 -> claude-sonnet-4-5-thinking) - if enabled, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok { + if enabled, ok := ThinkingEnabledFromContext(ctx); ok { modelKey = applyThinkingModelSuffix(modelKey, enabled) } return modelKey diff --git a/backend/internal/service/oauth_service.go b/backend/internal/service/oauth_service.go index 6f6261d8..0931f9ce 100644 --- a/backend/internal/service/oauth_service.go +++ b/backend/internal/service/oauth_service.go @@ -12,7 +12,7 @@ import ( // OpenAIOAuthClient interface for OpenAI OAuth operations type OpenAIOAuthClient interface { - ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) + ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) } diff --git a/backend/internal/service/oauth_service_test.go b/backend/internal/service/oauth_service_test.go index 72de4b8c..78f39dc5 100644 --- a/backend/internal/service/oauth_service_test.go +++ b/backend/internal/service/oauth_service_test.go @@ -14,10 +14,10 @@ import ( // --- mock: ClaudeOAuthClient --- type mockClaudeOAuthClient struct { - getOrgUUIDFunc func(ctx context.Context, sessionKey, proxyURL string) (string, error) - getAuthCodeFunc func(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) - exchangeCodeFunc func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) - refreshTokenFunc func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) + getOrgUUIDFunc func(ctx context.Context, sessionKey, proxyURL string) (string, error) + getAuthCodeFunc func(ctx context.Context, sessionKey, orgUUID, scope, codeChallenge, state, proxyURL string) (string, error) + exchangeCodeFunc func(ctx context.Context, code, codeVerifier, state, proxyURL string, isSetupToken bool) (*oauth.TokenResponse, error) + refreshTokenFunc func(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) } func (m *mockClaudeOAuthClient) GetOrganizationUUID(ctx context.Context, sessionKey, proxyURL string) (string, error) { @@ -437,9 +437,9 @@ func TestOAuthService_RefreshAccountToken_NoRefreshToken(t *testing.T) { // 无 refresh_token 的账号 account := &Account{ - ID: 1, - Platform: PlatformAnthropic, - Type: AccountTypeOAuth, + ID: 1, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, Credentials: map[string]any{ "access_token": "some-token", }, @@ -460,9 +460,9 @@ func TestOAuthService_RefreshAccountToken_EmptyRefreshToken(t *testing.T) { defer svc.Stop() account := &Account{ - ID: 2, - Platform: PlatformAnthropic, - Type: AccountTypeOAuth, + ID: 2, + Platform: PlatformAnthropic, + Type: AccountTypeOAuth, Credentials: map[string]any{ "access_token": "some-token", "refresh_token": "", diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go new file mode 100644 index 00000000..99013ce5 --- /dev/null +++ b/backend/internal/service/openai_account_scheduler.go @@ -0,0 +1,909 @@ +package service + +import ( + "container/heap" + "context" + "errors" + "hash/fnv" + "math" + "sort" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" +) + +const ( + openAIAccountScheduleLayerPreviousResponse = "previous_response_id" + openAIAccountScheduleLayerSessionSticky = "session_hash" + openAIAccountScheduleLayerLoadBalance = "load_balance" +) + +type OpenAIAccountScheduleRequest struct { + GroupID *int64 + SessionHash string + StickyAccountID int64 + PreviousResponseID string + RequestedModel string + RequiredTransport OpenAIUpstreamTransport + ExcludedIDs map[int64]struct{} +} + +type OpenAIAccountScheduleDecision struct { + Layer string + StickyPreviousHit bool + StickySessionHit bool + CandidateCount int + TopK int + LatencyMs int64 + LoadSkew float64 + SelectedAccountID int64 + SelectedAccountType string +} + +type OpenAIAccountSchedulerMetricsSnapshot struct { + SelectTotal int64 + StickyPreviousHitTotal int64 + StickySessionHitTotal int64 + LoadBalanceSelectTotal int64 + AccountSwitchTotal int64 + SchedulerLatencyMsTotal int64 + SchedulerLatencyMsAvg float64 + StickyHitRatio float64 + AccountSwitchRate float64 + LoadSkewAvg float64 + RuntimeStatsAccountCount int +} + +type OpenAIAccountScheduler interface { + Select(ctx context.Context, req OpenAIAccountScheduleRequest) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) + ReportResult(accountID int64, success bool, firstTokenMs *int) + ReportSwitch() + SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot +} + +type openAIAccountSchedulerMetrics struct { + selectTotal atomic.Int64 + stickyPreviousHitTotal atomic.Int64 + stickySessionHitTotal atomic.Int64 + loadBalanceSelectTotal atomic.Int64 + accountSwitchTotal atomic.Int64 + latencyMsTotal atomic.Int64 + loadSkewMilliTotal atomic.Int64 +} + +func (m *openAIAccountSchedulerMetrics) recordSelect(decision OpenAIAccountScheduleDecision) { + if m == nil { + return + } + m.selectTotal.Add(1) + m.latencyMsTotal.Add(decision.LatencyMs) + m.loadSkewMilliTotal.Add(int64(math.Round(decision.LoadSkew * 1000))) + if decision.StickyPreviousHit { + m.stickyPreviousHitTotal.Add(1) + } + if decision.StickySessionHit { + m.stickySessionHitTotal.Add(1) + } + if decision.Layer == openAIAccountScheduleLayerLoadBalance { + m.loadBalanceSelectTotal.Add(1) + } +} + +func (m *openAIAccountSchedulerMetrics) recordSwitch() { + if m == nil { + return + } + m.accountSwitchTotal.Add(1) +} + +type openAIAccountRuntimeStats struct { + accounts sync.Map + accountCount atomic.Int64 +} + +type openAIAccountRuntimeStat struct { + errorRateEWMABits atomic.Uint64 + ttftEWMABits atomic.Uint64 +} + +func newOpenAIAccountRuntimeStats() *openAIAccountRuntimeStats { + return &openAIAccountRuntimeStats{} +} + +func (s *openAIAccountRuntimeStats) loadOrCreate(accountID int64) *openAIAccountRuntimeStat { + if value, ok := s.accounts.Load(accountID); ok { + stat, _ := value.(*openAIAccountRuntimeStat) + if stat != nil { + return stat + } + } + + stat := &openAIAccountRuntimeStat{} + stat.ttftEWMABits.Store(math.Float64bits(math.NaN())) + actual, loaded := s.accounts.LoadOrStore(accountID, stat) + if !loaded { + s.accountCount.Add(1) + return stat + } + existing, _ := actual.(*openAIAccountRuntimeStat) + if existing != nil { + return existing + } + return stat +} + +func updateEWMAAtomic(target *atomic.Uint64, sample float64, alpha float64) { + for { + oldBits := target.Load() + oldValue := math.Float64frombits(oldBits) + newValue := alpha*sample + (1-alpha)*oldValue + if target.CompareAndSwap(oldBits, math.Float64bits(newValue)) { + return + } + } +} + +func (s *openAIAccountRuntimeStats) report(accountID int64, success bool, firstTokenMs *int) { + if s == nil || accountID <= 0 { + return + } + const alpha = 0.2 + stat := s.loadOrCreate(accountID) + + errorSample := 1.0 + if success { + errorSample = 0.0 + } + updateEWMAAtomic(&stat.errorRateEWMABits, errorSample, alpha) + + if firstTokenMs != nil && *firstTokenMs > 0 { + ttft := float64(*firstTokenMs) + ttftBits := math.Float64bits(ttft) + for { + oldBits := stat.ttftEWMABits.Load() + oldValue := math.Float64frombits(oldBits) + if math.IsNaN(oldValue) { + if stat.ttftEWMABits.CompareAndSwap(oldBits, ttftBits) { + break + } + continue + } + newValue := alpha*ttft + (1-alpha)*oldValue + if stat.ttftEWMABits.CompareAndSwap(oldBits, math.Float64bits(newValue)) { + break + } + } + } +} + +func (s *openAIAccountRuntimeStats) snapshot(accountID int64) (errorRate float64, ttft float64, hasTTFT bool) { + if s == nil || accountID <= 0 { + return 0, 0, false + } + value, ok := s.accounts.Load(accountID) + if !ok { + return 0, 0, false + } + stat, _ := value.(*openAIAccountRuntimeStat) + if stat == nil { + return 0, 0, false + } + errorRate = clamp01(math.Float64frombits(stat.errorRateEWMABits.Load())) + ttftValue := math.Float64frombits(stat.ttftEWMABits.Load()) + if math.IsNaN(ttftValue) { + return errorRate, 0, false + } + return errorRate, ttftValue, true +} + +func (s *openAIAccountRuntimeStats) size() int { + if s == nil { + return 0 + } + return int(s.accountCount.Load()) +} + +type defaultOpenAIAccountScheduler struct { + service *OpenAIGatewayService + metrics openAIAccountSchedulerMetrics + stats *openAIAccountRuntimeStats +} + +func newDefaultOpenAIAccountScheduler(service *OpenAIGatewayService, stats *openAIAccountRuntimeStats) OpenAIAccountScheduler { + if stats == nil { + stats = newOpenAIAccountRuntimeStats() + } + return &defaultOpenAIAccountScheduler{ + service: service, + stats: stats, + } +} + +func (s *defaultOpenAIAccountScheduler) Select( + ctx context.Context, + req OpenAIAccountScheduleRequest, +) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { + decision := OpenAIAccountScheduleDecision{} + start := time.Now() + defer func() { + decision.LatencyMs = time.Since(start).Milliseconds() + s.metrics.recordSelect(decision) + }() + + previousResponseID := strings.TrimSpace(req.PreviousResponseID) + if previousResponseID != "" { + selection, err := s.service.SelectAccountByPreviousResponseID( + ctx, + req.GroupID, + previousResponseID, + req.RequestedModel, + req.ExcludedIDs, + ) + if err != nil { + return nil, decision, err + } + if selection != nil && selection.Account != nil { + if !s.isAccountTransportCompatible(selection.Account, req.RequiredTransport) { + selection = nil + } + } + if selection != nil && selection.Account != nil { + decision.Layer = openAIAccountScheduleLayerPreviousResponse + decision.StickyPreviousHit = true + decision.SelectedAccountID = selection.Account.ID + decision.SelectedAccountType = selection.Account.Type + if req.SessionHash != "" { + _ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, selection.Account.ID) + } + return selection, decision, nil + } + } + + selection, err := s.selectBySessionHash(ctx, req) + if err != nil { + return nil, decision, err + } + if selection != nil && selection.Account != nil { + decision.Layer = openAIAccountScheduleLayerSessionSticky + decision.StickySessionHit = true + decision.SelectedAccountID = selection.Account.ID + decision.SelectedAccountType = selection.Account.Type + return selection, decision, nil + } + + selection, candidateCount, topK, loadSkew, err := s.selectByLoadBalance(ctx, req) + decision.Layer = openAIAccountScheduleLayerLoadBalance + decision.CandidateCount = candidateCount + decision.TopK = topK + decision.LoadSkew = loadSkew + if err != nil { + return nil, decision, err + } + if selection != nil && selection.Account != nil { + decision.SelectedAccountID = selection.Account.ID + decision.SelectedAccountType = selection.Account.Type + } + return selection, decision, nil +} + +func (s *defaultOpenAIAccountScheduler) selectBySessionHash( + ctx context.Context, + req OpenAIAccountScheduleRequest, +) (*AccountSelectionResult, error) { + sessionHash := strings.TrimSpace(req.SessionHash) + if sessionHash == "" || s == nil || s.service == nil || s.service.cache == nil { + return nil, nil + } + + accountID := req.StickyAccountID + if accountID <= 0 { + var err error + accountID, err = s.service.getStickySessionAccountID(ctx, req.GroupID, sessionHash) + if err != nil || accountID <= 0 { + return nil, nil + } + } + if accountID <= 0 { + return nil, nil + } + if req.ExcludedIDs != nil { + if _, excluded := req.ExcludedIDs[accountID]; excluded { + return nil, nil + } + } + + account, err := s.service.getSchedulableAccount(ctx, accountID) + if err != nil || account == nil { + _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) + return nil, nil + } + if shouldClearStickySession(account, req.RequestedModel) || !account.IsOpenAI() { + _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) + return nil, nil + } + if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) { + return nil, nil + } + if !s.isAccountTransportCompatible(account, req.RequiredTransport) { + _ = s.service.deleteStickySessionAccountID(ctx, req.GroupID, sessionHash) + return nil, nil + } + + result, acquireErr := s.service.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) + if acquireErr == nil && result.Acquired { + _ = s.service.refreshStickySessionTTL(ctx, req.GroupID, sessionHash, s.service.openAIWSSessionStickyTTL()) + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + + cfg := s.service.schedulingConfig() + if s.service.concurrencyService != nil { + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: accountID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } + return nil, nil +} + +type openAIAccountCandidateScore struct { + account *Account + loadInfo *AccountLoadInfo + score float64 + errorRate float64 + ttft float64 + hasTTFT bool +} + +type openAIAccountCandidateHeap []openAIAccountCandidateScore + +func (h openAIAccountCandidateHeap) Len() int { + return len(h) +} + +func (h openAIAccountCandidateHeap) Less(i, j int) bool { + // 最小堆根节点保存“最差”候选,便于 O(log k) 维护 topK。 + return isOpenAIAccountCandidateBetter(h[j], h[i]) +} + +func (h openAIAccountCandidateHeap) Swap(i, j int) { + h[i], h[j] = h[j], h[i] +} + +func (h *openAIAccountCandidateHeap) Push(x any) { + candidate, ok := x.(openAIAccountCandidateScore) + if !ok { + panic("openAIAccountCandidateHeap: invalid element type") + } + *h = append(*h, candidate) +} + +func (h *openAIAccountCandidateHeap) Pop() any { + old := *h + n := len(old) + last := old[n-1] + *h = old[:n-1] + return last +} + +func isOpenAIAccountCandidateBetter(left openAIAccountCandidateScore, right openAIAccountCandidateScore) bool { + if left.score != right.score { + return left.score > right.score + } + if left.account.Priority != right.account.Priority { + return left.account.Priority < right.account.Priority + } + if left.loadInfo.LoadRate != right.loadInfo.LoadRate { + return left.loadInfo.LoadRate < right.loadInfo.LoadRate + } + if left.loadInfo.WaitingCount != right.loadInfo.WaitingCount { + return left.loadInfo.WaitingCount < right.loadInfo.WaitingCount + } + return left.account.ID < right.account.ID +} + +func selectTopKOpenAICandidates(candidates []openAIAccountCandidateScore, topK int) []openAIAccountCandidateScore { + if len(candidates) == 0 { + return nil + } + if topK <= 0 { + topK = 1 + } + if topK >= len(candidates) { + ranked := append([]openAIAccountCandidateScore(nil), candidates...) + sort.Slice(ranked, func(i, j int) bool { + return isOpenAIAccountCandidateBetter(ranked[i], ranked[j]) + }) + return ranked + } + + best := make(openAIAccountCandidateHeap, 0, topK) + for _, candidate := range candidates { + if len(best) < topK { + heap.Push(&best, candidate) + continue + } + if isOpenAIAccountCandidateBetter(candidate, best[0]) { + best[0] = candidate + heap.Fix(&best, 0) + } + } + + ranked := make([]openAIAccountCandidateScore, len(best)) + copy(ranked, best) + sort.Slice(ranked, func(i, j int) bool { + return isOpenAIAccountCandidateBetter(ranked[i], ranked[j]) + }) + return ranked +} + +type openAISelectionRNG struct { + state uint64 +} + +func newOpenAISelectionRNG(seed uint64) openAISelectionRNG { + if seed == 0 { + seed = 0x9e3779b97f4a7c15 + } + return openAISelectionRNG{state: seed} +} + +func (r *openAISelectionRNG) nextUint64() uint64 { + // xorshift64* + x := r.state + x ^= x >> 12 + x ^= x << 25 + x ^= x >> 27 + r.state = x + return x * 2685821657736338717 +} + +func (r *openAISelectionRNG) nextFloat64() float64 { + // [0,1) + return float64(r.nextUint64()>>11) / (1 << 53) +} + +func deriveOpenAISelectionSeed(req OpenAIAccountScheduleRequest) uint64 { + hasher := fnv.New64a() + writeValue := func(value string) { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return + } + _, _ = hasher.Write([]byte(trimmed)) + _, _ = hasher.Write([]byte{0}) + } + + writeValue(req.SessionHash) + writeValue(req.PreviousResponseID) + writeValue(req.RequestedModel) + if req.GroupID != nil { + _, _ = hasher.Write([]byte(strconv.FormatInt(*req.GroupID, 10))) + } + + seed := hasher.Sum64() + // 对“无会话锚点”的纯负载均衡请求引入时间熵,避免固定命中同一账号。 + if strings.TrimSpace(req.SessionHash) == "" && strings.TrimSpace(req.PreviousResponseID) == "" { + seed ^= uint64(time.Now().UnixNano()) + } + if seed == 0 { + seed = uint64(time.Now().UnixNano()) ^ 0x9e3779b97f4a7c15 + } + return seed +} + +func buildOpenAIWeightedSelectionOrder( + candidates []openAIAccountCandidateScore, + req OpenAIAccountScheduleRequest, +) []openAIAccountCandidateScore { + if len(candidates) <= 1 { + return append([]openAIAccountCandidateScore(nil), candidates...) + } + + pool := append([]openAIAccountCandidateScore(nil), candidates...) + weights := make([]float64, len(pool)) + minScore := pool[0].score + for i := 1; i < len(pool); i++ { + if pool[i].score < minScore { + minScore = pool[i].score + } + } + for i := range pool { + // 将 top-K 分值平移到正区间,避免“单一最高分账号”长期垄断。 + weight := (pool[i].score - minScore) + 1.0 + if math.IsNaN(weight) || math.IsInf(weight, 0) || weight <= 0 { + weight = 1.0 + } + weights[i] = weight + } + + order := make([]openAIAccountCandidateScore, 0, len(pool)) + rng := newOpenAISelectionRNG(deriveOpenAISelectionSeed(req)) + for len(pool) > 0 { + total := 0.0 + for _, w := range weights { + total += w + } + + selectedIdx := 0 + if total > 0 { + r := rng.nextFloat64() * total + acc := 0.0 + for i, w := range weights { + acc += w + if r <= acc { + selectedIdx = i + break + } + } + } else { + selectedIdx = int(rng.nextUint64() % uint64(len(pool))) + } + + order = append(order, pool[selectedIdx]) + pool = append(pool[:selectedIdx], pool[selectedIdx+1:]...) + weights = append(weights[:selectedIdx], weights[selectedIdx+1:]...) + } + return order +} + +func (s *defaultOpenAIAccountScheduler) selectByLoadBalance( + ctx context.Context, + req OpenAIAccountScheduleRequest, +) (*AccountSelectionResult, int, int, float64, error) { + accounts, err := s.service.listSchedulableAccounts(ctx, req.GroupID) + if err != nil { + return nil, 0, 0, 0, err + } + if len(accounts) == 0 { + return nil, 0, 0, 0, errors.New("no available OpenAI accounts") + } + + filtered := make([]*Account, 0, len(accounts)) + loadReq := make([]AccountWithConcurrency, 0, len(accounts)) + for i := range accounts { + account := &accounts[i] + if req.ExcludedIDs != nil { + if _, excluded := req.ExcludedIDs[account.ID]; excluded { + continue + } + } + if !account.IsSchedulable() || !account.IsOpenAI() { + continue + } + if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) { + continue + } + if !s.isAccountTransportCompatible(account, req.RequiredTransport) { + continue + } + filtered = append(filtered, account) + loadReq = append(loadReq, AccountWithConcurrency{ + ID: account.ID, + MaxConcurrency: account.Concurrency, + }) + } + if len(filtered) == 0 { + return nil, 0, 0, 0, errors.New("no available OpenAI accounts") + } + + loadMap := map[int64]*AccountLoadInfo{} + if s.service.concurrencyService != nil { + if batchLoad, loadErr := s.service.concurrencyService.GetAccountsLoadBatch(ctx, loadReq); loadErr == nil { + loadMap = batchLoad + } + } + + minPriority, maxPriority := filtered[0].Priority, filtered[0].Priority + maxWaiting := 1 + loadRateSum := 0.0 + loadRateSumSquares := 0.0 + minTTFT, maxTTFT := 0.0, 0.0 + hasTTFTSample := false + candidates := make([]openAIAccountCandidateScore, 0, len(filtered)) + for _, account := range filtered { + loadInfo := loadMap[account.ID] + if loadInfo == nil { + loadInfo = &AccountLoadInfo{AccountID: account.ID} + } + if account.Priority < minPriority { + minPriority = account.Priority + } + if account.Priority > maxPriority { + maxPriority = account.Priority + } + if loadInfo.WaitingCount > maxWaiting { + maxWaiting = loadInfo.WaitingCount + } + errorRate, ttft, hasTTFT := s.stats.snapshot(account.ID) + if hasTTFT && ttft > 0 { + if !hasTTFTSample { + minTTFT, maxTTFT = ttft, ttft + hasTTFTSample = true + } else { + if ttft < minTTFT { + minTTFT = ttft + } + if ttft > maxTTFT { + maxTTFT = ttft + } + } + } + loadRate := float64(loadInfo.LoadRate) + loadRateSum += loadRate + loadRateSumSquares += loadRate * loadRate + candidates = append(candidates, openAIAccountCandidateScore{ + account: account, + loadInfo: loadInfo, + errorRate: errorRate, + ttft: ttft, + hasTTFT: hasTTFT, + }) + } + loadSkew := calcLoadSkewByMoments(loadRateSum, loadRateSumSquares, len(candidates)) + + weights := s.service.openAIWSSchedulerWeights() + for i := range candidates { + item := &candidates[i] + priorityFactor := 1.0 + if maxPriority > minPriority { + priorityFactor = 1 - float64(item.account.Priority-minPriority)/float64(maxPriority-minPriority) + } + loadFactor := 1 - clamp01(float64(item.loadInfo.LoadRate)/100.0) + queueFactor := 1 - clamp01(float64(item.loadInfo.WaitingCount)/float64(maxWaiting)) + errorFactor := 1 - clamp01(item.errorRate) + ttftFactor := 0.5 + if item.hasTTFT && hasTTFTSample && maxTTFT > minTTFT { + ttftFactor = 1 - clamp01((item.ttft-minTTFT)/(maxTTFT-minTTFT)) + } + + item.score = weights.Priority*priorityFactor + + weights.Load*loadFactor + + weights.Queue*queueFactor + + weights.ErrorRate*errorFactor + + weights.TTFT*ttftFactor + } + + topK := s.service.openAIWSLBTopK() + if topK > len(candidates) { + topK = len(candidates) + } + if topK <= 0 { + topK = 1 + } + rankedCandidates := selectTopKOpenAICandidates(candidates, topK) + selectionOrder := buildOpenAIWeightedSelectionOrder(rankedCandidates, req) + + for i := 0; i < len(selectionOrder); i++ { + candidate := selectionOrder[i] + result, acquireErr := s.service.tryAcquireAccountSlot(ctx, candidate.account.ID, candidate.account.Concurrency) + if acquireErr != nil { + return nil, len(candidates), topK, loadSkew, acquireErr + } + if result != nil && result.Acquired { + if req.SessionHash != "" { + _ = s.service.BindStickySession(ctx, req.GroupID, req.SessionHash, candidate.account.ID) + } + return &AccountSelectionResult{ + Account: candidate.account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, len(candidates), topK, loadSkew, nil + } + } + + cfg := s.service.schedulingConfig() + candidate := selectionOrder[0] + return &AccountSelectionResult{ + Account: candidate.account, + WaitPlan: &AccountWaitPlan{ + AccountID: candidate.account.ID, + MaxConcurrency: candidate.account.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }, + }, len(candidates), topK, loadSkew, nil +} + +func (s *defaultOpenAIAccountScheduler) isAccountTransportCompatible(account *Account, requiredTransport OpenAIUpstreamTransport) bool { + // HTTP 入站可回退到 HTTP 线路,不需要在账号选择阶段做传输协议强过滤。 + if requiredTransport == OpenAIUpstreamTransportAny || requiredTransport == OpenAIUpstreamTransportHTTPSSE { + return true + } + if s == nil || s.service == nil || account == nil { + return false + } + return s.service.getOpenAIWSProtocolResolver().Resolve(account).Transport == requiredTransport +} + +func (s *defaultOpenAIAccountScheduler) ReportResult(accountID int64, success bool, firstTokenMs *int) { + if s == nil || s.stats == nil { + return + } + s.stats.report(accountID, success, firstTokenMs) +} + +func (s *defaultOpenAIAccountScheduler) ReportSwitch() { + if s == nil { + return + } + s.metrics.recordSwitch() +} + +func (s *defaultOpenAIAccountScheduler) SnapshotMetrics() OpenAIAccountSchedulerMetricsSnapshot { + if s == nil { + return OpenAIAccountSchedulerMetricsSnapshot{} + } + + selectTotal := s.metrics.selectTotal.Load() + prevHit := s.metrics.stickyPreviousHitTotal.Load() + sessionHit := s.metrics.stickySessionHitTotal.Load() + switchTotal := s.metrics.accountSwitchTotal.Load() + latencyTotal := s.metrics.latencyMsTotal.Load() + loadSkewTotal := s.metrics.loadSkewMilliTotal.Load() + + snapshot := OpenAIAccountSchedulerMetricsSnapshot{ + SelectTotal: selectTotal, + StickyPreviousHitTotal: prevHit, + StickySessionHitTotal: sessionHit, + LoadBalanceSelectTotal: s.metrics.loadBalanceSelectTotal.Load(), + AccountSwitchTotal: switchTotal, + SchedulerLatencyMsTotal: latencyTotal, + RuntimeStatsAccountCount: s.stats.size(), + } + if selectTotal > 0 { + snapshot.SchedulerLatencyMsAvg = float64(latencyTotal) / float64(selectTotal) + snapshot.StickyHitRatio = float64(prevHit+sessionHit) / float64(selectTotal) + snapshot.AccountSwitchRate = float64(switchTotal) / float64(selectTotal) + snapshot.LoadSkewAvg = float64(loadSkewTotal) / 1000 / float64(selectTotal) + } + return snapshot +} + +func (s *OpenAIGatewayService) getOpenAIAccountScheduler() OpenAIAccountScheduler { + if s == nil { + return nil + } + s.openaiSchedulerOnce.Do(func() { + if s.openaiAccountStats == nil { + s.openaiAccountStats = newOpenAIAccountRuntimeStats() + } + if s.openaiScheduler == nil { + s.openaiScheduler = newDefaultOpenAIAccountScheduler(s, s.openaiAccountStats) + } + }) + return s.openaiScheduler +} + +func (s *OpenAIGatewayService) SelectAccountWithScheduler( + ctx context.Context, + groupID *int64, + previousResponseID string, + sessionHash string, + requestedModel string, + excludedIDs map[int64]struct{}, + requiredTransport OpenAIUpstreamTransport, +) (*AccountSelectionResult, OpenAIAccountScheduleDecision, error) { + decision := OpenAIAccountScheduleDecision{} + scheduler := s.getOpenAIAccountScheduler() + if scheduler == nil { + selection, err := s.SelectAccountWithLoadAwareness(ctx, groupID, sessionHash, requestedModel, excludedIDs) + decision.Layer = openAIAccountScheduleLayerLoadBalance + return selection, decision, err + } + + var stickyAccountID int64 + if sessionHash != "" && s.cache != nil { + if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil && accountID > 0 { + stickyAccountID = accountID + } + } + + return scheduler.Select(ctx, OpenAIAccountScheduleRequest{ + GroupID: groupID, + SessionHash: sessionHash, + StickyAccountID: stickyAccountID, + PreviousResponseID: previousResponseID, + RequestedModel: requestedModel, + RequiredTransport: requiredTransport, + ExcludedIDs: excludedIDs, + }) +} + +func (s *OpenAIGatewayService) ReportOpenAIAccountScheduleResult(accountID int64, success bool, firstTokenMs *int) { + scheduler := s.getOpenAIAccountScheduler() + if scheduler == nil { + return + } + scheduler.ReportResult(accountID, success, firstTokenMs) +} + +func (s *OpenAIGatewayService) RecordOpenAIAccountSwitch() { + scheduler := s.getOpenAIAccountScheduler() + if scheduler == nil { + return + } + scheduler.ReportSwitch() +} + +func (s *OpenAIGatewayService) SnapshotOpenAIAccountSchedulerMetrics() OpenAIAccountSchedulerMetricsSnapshot { + scheduler := s.getOpenAIAccountScheduler() + if scheduler == nil { + return OpenAIAccountSchedulerMetricsSnapshot{} + } + return scheduler.SnapshotMetrics() +} + +func (s *OpenAIGatewayService) openAIWSSessionStickyTTL() time.Duration { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds > 0 { + return time.Duration(s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds) * time.Second + } + return openaiStickySessionTTL +} + +func (s *OpenAIGatewayService) openAIWSLBTopK() int { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.LBTopK > 0 { + return s.cfg.Gateway.OpenAIWS.LBTopK + } + return 7 +} + +func (s *OpenAIGatewayService) openAIWSSchedulerWeights() GatewayOpenAIWSSchedulerScoreWeightsView { + if s != nil && s.cfg != nil { + return GatewayOpenAIWSSchedulerScoreWeightsView{ + Priority: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority, + Load: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load, + Queue: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue, + ErrorRate: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate, + TTFT: s.cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT, + } + } + return GatewayOpenAIWSSchedulerScoreWeightsView{ + Priority: 1.0, + Load: 1.0, + Queue: 0.7, + ErrorRate: 0.8, + TTFT: 0.5, + } +} + +type GatewayOpenAIWSSchedulerScoreWeightsView struct { + Priority float64 + Load float64 + Queue float64 + ErrorRate float64 + TTFT float64 +} + +func clamp01(value float64) float64 { + switch { + case value < 0: + return 0 + case value > 1: + return 1 + default: + return value + } +} + +func calcLoadSkewByMoments(sum float64, sumSquares float64, count int) float64 { + if count <= 1 { + return 0 + } + mean := sum / float64(count) + variance := sumSquares/float64(count) - mean*mean + if variance < 0 { + variance = 0 + } + return math.Sqrt(variance) +} diff --git a/backend/internal/service/openai_account_scheduler_benchmark_test.go b/backend/internal/service/openai_account_scheduler_benchmark_test.go new file mode 100644 index 00000000..897be5b0 --- /dev/null +++ b/backend/internal/service/openai_account_scheduler_benchmark_test.go @@ -0,0 +1,83 @@ +package service + +import ( + "sort" + "testing" +) + +func buildOpenAISchedulerBenchmarkCandidates(size int) []openAIAccountCandidateScore { + if size <= 0 { + return nil + } + candidates := make([]openAIAccountCandidateScore, 0, size) + for i := 0; i < size; i++ { + accountID := int64(10_000 + i) + candidates = append(candidates, openAIAccountCandidateScore{ + account: &Account{ + ID: accountID, + Priority: i % 7, + }, + loadInfo: &AccountLoadInfo{ + AccountID: accountID, + LoadRate: (i * 17) % 100, + WaitingCount: (i * 11) % 13, + }, + score: float64((i*29)%1000) / 100, + errorRate: float64((i * 5) % 100 / 100), + ttft: float64(30 + (i*3)%500), + hasTTFT: i%3 != 0, + }) + } + return candidates +} + +func selectTopKOpenAICandidatesBySortBenchmark(candidates []openAIAccountCandidateScore, topK int) []openAIAccountCandidateScore { + if len(candidates) == 0 { + return nil + } + if topK <= 0 { + topK = 1 + } + ranked := append([]openAIAccountCandidateScore(nil), candidates...) + sort.Slice(ranked, func(i, j int) bool { + return isOpenAIAccountCandidateBetter(ranked[i], ranked[j]) + }) + if topK > len(ranked) { + topK = len(ranked) + } + return ranked[:topK] +} + +func BenchmarkOpenAIAccountSchedulerSelectTopK(b *testing.B) { + cases := []struct { + name string + size int + topK int + }{ + {name: "n_16_k_3", size: 16, topK: 3}, + {name: "n_64_k_3", size: 64, topK: 3}, + {name: "n_256_k_5", size: 256, topK: 5}, + } + + for _, tc := range cases { + candidates := buildOpenAISchedulerBenchmarkCandidates(tc.size) + b.Run(tc.name+"/heap_topk", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + result := selectTopKOpenAICandidates(candidates, tc.topK) + if len(result) == 0 { + b.Fatal("unexpected empty result") + } + } + }) + b.Run(tc.name+"/full_sort", func(b *testing.B) { + b.ReportAllocs() + for i := 0; i < b.N; i++ { + result := selectTopKOpenAICandidatesBySortBenchmark(candidates, tc.topK) + if len(result) == 0 { + b.Fatal("unexpected empty result") + } + } + }) + } +} diff --git a/backend/internal/service/openai_account_scheduler_test.go b/backend/internal/service/openai_account_scheduler_test.go new file mode 100644 index 00000000..7f6f1b66 --- /dev/null +++ b/backend/internal/service/openai_account_scheduler_test.go @@ -0,0 +1,841 @@ +package service + +import ( + "context" + "fmt" + "math" + "sync" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestOpenAIGatewayService_SelectAccountWithScheduler_PreviousResponseSticky(t *testing.T) { + ctx := context.Background() + groupID := int64(9) + account := Account{ + ID: 1001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 2, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + cache := &stubGatewayCache{} + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 1800 + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + store := svc.getOpenAIWSStateStore() + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_001", account.ID, time.Hour)) + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "resp_prev_001", + "session_hash_001", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, account.ID, selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerPreviousResponse, decision.Layer) + require.True(t, decision.StickyPreviousHit) + require.Equal(t, account.ID, cache.sessionBindings["openai:session_hash_001"]) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky(t *testing.T) { + ctx := context.Background() + groupID := int64(10) + account := Account{ + ID: 2001, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + } + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{ + "openai:session_hash_abc": account.ID, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + cache: cache, + cfg: &config.Config{}, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "session_hash_abc", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, account.ID, selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer) + require.True(t, decision.StickySessionHit) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionStickyBusyKeepsSticky(t *testing.T) { + ctx := context.Background() + groupID := int64(10100) + accounts := []Account{ + { + ID: 21001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + { + ID: 21002, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 9, + }, + } + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{ + "openai:session_hash_sticky_busy": 21001, + }, + } + cfg := &config.Config{} + cfg.Gateway.Scheduling.StickySessionMaxWaiting = 2 + cfg.Gateway.Scheduling.StickySessionWaitTimeout = 45 * time.Second + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + + concurrencyCache := stubConcurrencyCache{ + acquireResults: map[int64]bool{ + 21001: false, // sticky 账号已满 + 21002: true, // 若回退负载均衡会命中该账号(本测试要求不能切换) + }, + waitCounts: map[int64]int{ + 21001: 999, + }, + loadMap: map[int64]*AccountLoadInfo{ + 21001: {AccountID: 21001, LoadRate: 90, WaitingCount: 9}, + 21002: {AccountID: 21002, LoadRate: 1, WaitingCount: 0}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "session_hash_sticky_busy", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(21001), selection.Account.ID, "busy sticky account should remain selected") + require.False(t, selection.Acquired) + require.NotNil(t, selection.WaitPlan) + require.Equal(t, int64(21001), selection.WaitPlan.AccountID) + require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer) + require.True(t, decision.StickySessionHit) +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_SessionSticky_ForceHTTP(t *testing.T) { + ctx := context.Background() + groupID := int64(1010) + account := Account{ + ID: 2101, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Extra: map[string]any{ + "openai_ws_force_http": true, + }, + } + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{ + "openai:session_hash_force_http": account.ID, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + cache: cache, + cfg: &config.Config{}, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "session_hash_force_http", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, account.ID, selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerSessionSticky, decision.Layer) + require.True(t, decision.StickySessionHit) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_SkipsStickyHTTPAccount(t *testing.T) { + ctx := context.Background() + groupID := int64(1011) + accounts := []Account{ + { + ID: 2201, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + { + ID: 2202, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 5, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + }, + } + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{ + "openai:session_hash_ws_only": 2201, + }, + } + cfg := newOpenAIWSV2TestConfig() + + // 构造“HTTP-only 账号负载更低”的场景,验证 required transport 会强制过滤。 + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 2201: {AccountID: 2201, LoadRate: 0, WaitingCount: 0}, + 2202: {AccountID: 2202, LoadRate: 90, WaitingCount: 5}, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "session_hash_ws_only", + "gpt-5.1", + nil, + OpenAIUpstreamTransportResponsesWebsocketV2, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(2202), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + require.False(t, decision.StickySessionHit) + require.Equal(t, 1, decision.CandidateCount) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_RequiredWSV2_NoAvailableAccount(t *testing.T) { + ctx := context.Background() + groupID := int64(1012) + accounts := []Account{ + { + ID: 2301, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: &stubGatewayCache{}, + cfg: newOpenAIWSV2TestConfig(), + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "", + "gpt-5.1", + nil, + OpenAIUpstreamTransportResponsesWebsocketV2, + ) + require.Error(t, err) + require.Nil(t, selection) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + require.Equal(t, 0, decision.CandidateCount) +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceTopKFallback(t *testing.T) { + ctx := context.Background() + groupID := int64(11) + accounts := []Account{ + { + ID: 3001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + { + ID: 3002, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + { + ID: 3003, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + }, + } + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.LBTopK = 2 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.4 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1.0 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.2 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.1 + + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 3001: {AccountID: 3001, LoadRate: 95, WaitingCount: 8}, + 3002: {AccountID: 3002, LoadRate: 20, WaitingCount: 1}, + 3003: {AccountID: 3003, LoadRate: 10, WaitingCount: 0}, + }, + acquireResults: map[int64]bool{ + 3003: false, // top1 失败,必须回退到 top-K 的下一候选 + 3002: true, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: &stubGatewayCache{}, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + "", + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(3002), selection.Account.ID) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + require.Equal(t, 3, decision.CandidateCount) + require.Equal(t, 2, decision.TopK) + require.Greater(t, decision.LoadSkew, 0.0) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestOpenAIGatewayService_OpenAIAccountSchedulerMetrics(t *testing.T) { + ctx := context.Background() + groupID := int64(12) + account := Account{ + ID: 4001, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + } + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{ + "openai:session_hash_metrics": account.ID, + }, + } + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + cache: cache, + cfg: &config.Config{}, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + } + + selection, _, err := svc.SelectAccountWithScheduler(ctx, &groupID, "", "session_hash_metrics", "gpt-5.1", nil, OpenAIUpstreamTransportAny) + require.NoError(t, err) + require.NotNil(t, selection) + svc.ReportOpenAIAccountScheduleResult(account.ID, true, intPtrForTest(120)) + svc.RecordOpenAIAccountSwitch() + + snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics() + require.GreaterOrEqual(t, snapshot.SelectTotal, int64(1)) + require.GreaterOrEqual(t, snapshot.StickySessionHitTotal, int64(1)) + require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1)) + require.GreaterOrEqual(t, snapshot.SchedulerLatencyMsAvg, float64(0)) + require.GreaterOrEqual(t, snapshot.StickyHitRatio, 0.0) + require.GreaterOrEqual(t, snapshot.RuntimeStatsAccountCount, 1) +} + +func intPtrForTest(v int) *int { + return &v +} + +func TestOpenAIAccountRuntimeStats_ReportAndSnapshot(t *testing.T) { + stats := newOpenAIAccountRuntimeStats() + stats.report(1001, true, nil) + firstTTFT := 100 + stats.report(1001, false, &firstTTFT) + secondTTFT := 200 + stats.report(1001, false, &secondTTFT) + + errorRate, ttft, hasTTFT := stats.snapshot(1001) + require.True(t, hasTTFT) + require.InDelta(t, 0.36, errorRate, 1e-9) + require.InDelta(t, 120.0, ttft, 1e-9) + require.Equal(t, 1, stats.size()) +} + +func TestOpenAIAccountRuntimeStats_ReportConcurrent(t *testing.T) { + stats := newOpenAIAccountRuntimeStats() + + const ( + accountCount = 4 + workers = 16 + iterations = 800 + ) + var wg sync.WaitGroup + wg.Add(workers) + for worker := 0; worker < workers; worker++ { + worker := worker + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + accountID := int64(i%accountCount + 1) + success := (i+worker)%3 != 0 + ttft := 80 + (i+worker)%40 + stats.report(accountID, success, &ttft) + } + }() + } + wg.Wait() + + require.Equal(t, accountCount, stats.size()) + for accountID := int64(1); accountID <= accountCount; accountID++ { + errorRate, ttft, hasTTFT := stats.snapshot(accountID) + require.GreaterOrEqual(t, errorRate, 0.0) + require.LessOrEqual(t, errorRate, 1.0) + require.True(t, hasTTFT) + require.Greater(t, ttft, 0.0) + } +} + +func TestSelectTopKOpenAICandidates(t *testing.T) { + candidates := []openAIAccountCandidateScore{ + { + account: &Account{ID: 11, Priority: 2}, + loadInfo: &AccountLoadInfo{LoadRate: 10, WaitingCount: 1}, + score: 10.0, + }, + { + account: &Account{ID: 12, Priority: 1}, + loadInfo: &AccountLoadInfo{LoadRate: 20, WaitingCount: 1}, + score: 9.5, + }, + { + account: &Account{ID: 13, Priority: 1}, + loadInfo: &AccountLoadInfo{LoadRate: 30, WaitingCount: 0}, + score: 10.0, + }, + { + account: &Account{ID: 14, Priority: 0}, + loadInfo: &AccountLoadInfo{LoadRate: 40, WaitingCount: 0}, + score: 8.0, + }, + } + + top2 := selectTopKOpenAICandidates(candidates, 2) + require.Len(t, top2, 2) + require.Equal(t, int64(13), top2[0].account.ID) + require.Equal(t, int64(11), top2[1].account.ID) + + topAll := selectTopKOpenAICandidates(candidates, 8) + require.Len(t, topAll, len(candidates)) + require.Equal(t, int64(13), topAll[0].account.ID) + require.Equal(t, int64(11), topAll[1].account.ID) + require.Equal(t, int64(12), topAll[2].account.ID) + require.Equal(t, int64(14), topAll[3].account.ID) +} + +func TestBuildOpenAIWeightedSelectionOrder_DeterministicBySessionSeed(t *testing.T) { + candidates := []openAIAccountCandidateScore{ + { + account: &Account{ID: 101}, + loadInfo: &AccountLoadInfo{LoadRate: 10, WaitingCount: 0}, + score: 4.2, + }, + { + account: &Account{ID: 102}, + loadInfo: &AccountLoadInfo{LoadRate: 30, WaitingCount: 1}, + score: 3.5, + }, + { + account: &Account{ID: 103}, + loadInfo: &AccountLoadInfo{LoadRate: 50, WaitingCount: 2}, + score: 2.1, + }, + } + req := OpenAIAccountScheduleRequest{ + GroupID: int64PtrForTest(99), + SessionHash: "session_seed_fixed", + RequestedModel: "gpt-5.1", + } + + first := buildOpenAIWeightedSelectionOrder(candidates, req) + second := buildOpenAIWeightedSelectionOrder(candidates, req) + require.Len(t, first, len(candidates)) + require.Len(t, second, len(candidates)) + for i := range first { + require.Equal(t, first[i].account.ID, second[i].account.ID) + } +} + +func TestOpenAIGatewayService_SelectAccountWithScheduler_LoadBalanceDistributesAcrossSessions(t *testing.T) { + ctx := context.Background() + groupID := int64(15) + accounts := []Account{ + { + ID: 5101, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 3, + Priority: 0, + }, + { + ID: 5102, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 3, + Priority: 0, + }, + { + ID: 5103, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 3, + Priority: 0, + }, + } + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.LBTopK = 3 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 1 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 1 + + concurrencyCache := stubConcurrencyCache{ + loadMap: map[int64]*AccountLoadInfo{ + 5101: {AccountID: 5101, LoadRate: 20, WaitingCount: 1}, + 5102: {AccountID: 5102, LoadRate: 20, WaitingCount: 1}, + 5103: {AccountID: 5103, LoadRate: 20, WaitingCount: 1}, + }, + } + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: &stubGatewayCache{sessionBindings: map[string]int64{}}, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + } + + selected := make(map[int64]int, len(accounts)) + for i := 0; i < 60; i++ { + sessionHash := fmt.Sprintf("session_hash_lb_%d", i) + selection, decision, err := svc.SelectAccountWithScheduler( + ctx, + &groupID, + "", + sessionHash, + "gpt-5.1", + nil, + OpenAIUpstreamTransportAny, + ) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, openAIAccountScheduleLayerLoadBalance, decision.Layer) + selected[selection.Account.ID]++ + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } + } + + // 多 session 应该能打散到多个账号,避免“恒定单账号命中”。 + require.GreaterOrEqual(t, len(selected), 2) +} + +func TestDeriveOpenAISelectionSeed_NoAffinityAddsEntropy(t *testing.T) { + req := OpenAIAccountScheduleRequest{ + RequestedModel: "gpt-5.1", + } + seed1 := deriveOpenAISelectionSeed(req) + time.Sleep(1 * time.Millisecond) + seed2 := deriveOpenAISelectionSeed(req) + require.NotZero(t, seed1) + require.NotZero(t, seed2) + require.NotEqual(t, seed1, seed2) +} + +func TestBuildOpenAIWeightedSelectionOrder_HandlesInvalidScores(t *testing.T) { + candidates := []openAIAccountCandidateScore{ + { + account: &Account{ID: 901}, + loadInfo: &AccountLoadInfo{LoadRate: 5, WaitingCount: 0}, + score: math.NaN(), + }, + { + account: &Account{ID: 902}, + loadInfo: &AccountLoadInfo{LoadRate: 5, WaitingCount: 0}, + score: math.Inf(1), + }, + { + account: &Account{ID: 903}, + loadInfo: &AccountLoadInfo{LoadRate: 5, WaitingCount: 0}, + score: -1, + }, + } + req := OpenAIAccountScheduleRequest{ + SessionHash: "seed_invalid_scores", + } + + order := buildOpenAIWeightedSelectionOrder(candidates, req) + require.Len(t, order, len(candidates)) + seen := map[int64]struct{}{} + for _, item := range order { + seen[item.account.ID] = struct{}{} + } + require.Len(t, seen, len(candidates)) +} + +func TestOpenAISelectionRNG_SeedZeroStillWorks(t *testing.T) { + rng := newOpenAISelectionRNG(0) + v1 := rng.nextUint64() + v2 := rng.nextUint64() + require.NotEqual(t, v1, v2) + require.GreaterOrEqual(t, rng.nextFloat64(), 0.0) + require.Less(t, rng.nextFloat64(), 1.0) +} + +func TestOpenAIAccountCandidateHeap_PushPopAndInvalidType(t *testing.T) { + h := openAIAccountCandidateHeap{} + h.Push(openAIAccountCandidateScore{ + account: &Account{ID: 7001}, + loadInfo: &AccountLoadInfo{LoadRate: 0, WaitingCount: 0}, + score: 1.0, + }) + require.Equal(t, 1, h.Len()) + popped, ok := h.Pop().(openAIAccountCandidateScore) + require.True(t, ok) + require.Equal(t, int64(7001), popped.account.ID) + require.Equal(t, 0, h.Len()) + + require.Panics(t, func() { + h.Push("bad_element_type") + }) +} + +func TestClamp01_AllBranches(t *testing.T) { + require.Equal(t, 0.0, clamp01(-0.2)) + require.Equal(t, 1.0, clamp01(1.3)) + require.Equal(t, 0.5, clamp01(0.5)) +} + +func TestCalcLoadSkewByMoments_Branches(t *testing.T) { + require.Equal(t, 0.0, calcLoadSkewByMoments(1, 1, 1)) + // variance < 0 分支:sumSquares/count - mean^2 为负值时应钳制为 0。 + require.Equal(t, 0.0, calcLoadSkewByMoments(1, 0, 2)) + require.GreaterOrEqual(t, calcLoadSkewByMoments(6, 20, 3), 0.0) +} + +func TestDefaultOpenAIAccountScheduler_ReportSwitchAndSnapshot(t *testing.T) { + schedulerAny := newDefaultOpenAIAccountScheduler(&OpenAIGatewayService{}, nil) + scheduler, ok := schedulerAny.(*defaultOpenAIAccountScheduler) + require.True(t, ok) + + ttft := 100 + scheduler.ReportResult(1001, true, &ttft) + scheduler.ReportSwitch() + scheduler.metrics.recordSelect(OpenAIAccountScheduleDecision{ + Layer: openAIAccountScheduleLayerLoadBalance, + LatencyMs: 8, + LoadSkew: 0.5, + StickyPreviousHit: true, + }) + scheduler.metrics.recordSelect(OpenAIAccountScheduleDecision{ + Layer: openAIAccountScheduleLayerSessionSticky, + LatencyMs: 6, + LoadSkew: 0.2, + StickySessionHit: true, + }) + + snapshot := scheduler.SnapshotMetrics() + require.Equal(t, int64(2), snapshot.SelectTotal) + require.Equal(t, int64(1), snapshot.StickyPreviousHitTotal) + require.Equal(t, int64(1), snapshot.StickySessionHitTotal) + require.Equal(t, int64(1), snapshot.LoadBalanceSelectTotal) + require.Equal(t, int64(1), snapshot.AccountSwitchTotal) + require.Greater(t, snapshot.SchedulerLatencyMsAvg, 0.0) + require.Greater(t, snapshot.StickyHitRatio, 0.0) + require.Greater(t, snapshot.LoadSkewAvg, 0.0) +} + +func TestOpenAIGatewayService_SchedulerWrappersAndDefaults(t *testing.T) { + svc := &OpenAIGatewayService{} + ttft := 120 + svc.ReportOpenAIAccountScheduleResult(10, true, &ttft) + svc.RecordOpenAIAccountSwitch() + snapshot := svc.SnapshotOpenAIAccountSchedulerMetrics() + require.GreaterOrEqual(t, snapshot.AccountSwitchTotal, int64(1)) + require.Equal(t, 7, svc.openAIWSLBTopK()) + require.Equal(t, openaiStickySessionTTL, svc.openAIWSSessionStickyTTL()) + + defaultWeights := svc.openAIWSSchedulerWeights() + require.Equal(t, 1.0, defaultWeights.Priority) + require.Equal(t, 1.0, defaultWeights.Load) + require.Equal(t, 0.7, defaultWeights.Queue) + require.Equal(t, 0.8, defaultWeights.ErrorRate) + require.Equal(t, 0.5, defaultWeights.TTFT) + + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.LBTopK = 9 + cfg.Gateway.OpenAIWS.StickySessionTTLSeconds = 180 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Priority = 0.2 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Load = 0.3 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.Queue = 0.4 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate = 0.5 + cfg.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT = 0.6 + svcWithCfg := &OpenAIGatewayService{cfg: cfg} + + require.Equal(t, 9, svcWithCfg.openAIWSLBTopK()) + require.Equal(t, 180*time.Second, svcWithCfg.openAIWSSessionStickyTTL()) + customWeights := svcWithCfg.openAIWSSchedulerWeights() + require.Equal(t, 0.2, customWeights.Priority) + require.Equal(t, 0.3, customWeights.Load) + require.Equal(t, 0.4, customWeights.Queue) + require.Equal(t, 0.5, customWeights.ErrorRate) + require.Equal(t, 0.6, customWeights.TTFT) +} + +func TestDefaultOpenAIAccountScheduler_IsAccountTransportCompatible_Branches(t *testing.T) { + scheduler := &defaultOpenAIAccountScheduler{} + require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportAny)) + require.True(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportHTTPSSE)) + require.False(t, scheduler.isAccountTransportCompatible(nil, OpenAIUpstreamTransportResponsesWebsocketV2)) + + cfg := newOpenAIWSV2TestConfig() + scheduler.service = &OpenAIGatewayService{cfg: cfg} + account := &Account{ + ID: 8801, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + require.True(t, scheduler.isAccountTransportCompatible(account, OpenAIUpstreamTransportResponsesWebsocketV2)) +} + +func int64PtrForTest(v int64) *int64 { + return &v +} diff --git a/backend/internal/service/openai_client_transport.go b/backend/internal/service/openai_client_transport.go new file mode 100644 index 00000000..c9cf3246 --- /dev/null +++ b/backend/internal/service/openai_client_transport.go @@ -0,0 +1,71 @@ +package service + +import ( + "strings" + + "github.com/gin-gonic/gin" +) + +// OpenAIClientTransport 表示客户端入站协议类型。 +type OpenAIClientTransport string + +const ( + OpenAIClientTransportUnknown OpenAIClientTransport = "" + OpenAIClientTransportHTTP OpenAIClientTransport = "http" + OpenAIClientTransportWS OpenAIClientTransport = "ws" +) + +const openAIClientTransportContextKey = "openai_client_transport" + +// SetOpenAIClientTransport 标记当前请求的客户端入站协议。 +func SetOpenAIClientTransport(c *gin.Context, transport OpenAIClientTransport) { + if c == nil { + return + } + normalized := normalizeOpenAIClientTransport(transport) + if normalized == OpenAIClientTransportUnknown { + return + } + c.Set(openAIClientTransportContextKey, string(normalized)) +} + +// GetOpenAIClientTransport 读取当前请求的客户端入站协议。 +func GetOpenAIClientTransport(c *gin.Context) OpenAIClientTransport { + if c == nil { + return OpenAIClientTransportUnknown + } + raw, ok := c.Get(openAIClientTransportContextKey) + if !ok || raw == nil { + return OpenAIClientTransportUnknown + } + + switch v := raw.(type) { + case OpenAIClientTransport: + return normalizeOpenAIClientTransport(v) + case string: + return normalizeOpenAIClientTransport(OpenAIClientTransport(v)) + default: + return OpenAIClientTransportUnknown + } +} + +func normalizeOpenAIClientTransport(transport OpenAIClientTransport) OpenAIClientTransport { + switch strings.ToLower(strings.TrimSpace(string(transport))) { + case string(OpenAIClientTransportHTTP), "http_sse", "sse": + return OpenAIClientTransportHTTP + case string(OpenAIClientTransportWS), "websocket": + return OpenAIClientTransportWS + default: + return OpenAIClientTransportUnknown + } +} + +func resolveOpenAIWSDecisionByClientTransport( + decision OpenAIWSProtocolDecision, + clientTransport OpenAIClientTransport, +) OpenAIWSProtocolDecision { + if clientTransport == OpenAIClientTransportHTTP { + return openAIWSHTTPDecision("client_protocol_http") + } + return decision +} diff --git a/backend/internal/service/openai_client_transport_test.go b/backend/internal/service/openai_client_transport_test.go new file mode 100644 index 00000000..ef90e614 --- /dev/null +++ b/backend/internal/service/openai_client_transport_test.go @@ -0,0 +1,107 @@ +package service + +import ( + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestOpenAIClientTransport_SetAndGet(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + require.Equal(t, OpenAIClientTransportUnknown, GetOpenAIClientTransport(c)) + + SetOpenAIClientTransport(c, OpenAIClientTransportHTTP) + require.Equal(t, OpenAIClientTransportHTTP, GetOpenAIClientTransport(c)) + + SetOpenAIClientTransport(c, OpenAIClientTransportWS) + require.Equal(t, OpenAIClientTransportWS, GetOpenAIClientTransport(c)) +} + +func TestOpenAIClientTransport_GetNormalizesRawContextValue(t *testing.T) { + gin.SetMode(gin.TestMode) + + tests := []struct { + name string + rawValue any + want OpenAIClientTransport + }{ + { + name: "type_value_ws", + rawValue: OpenAIClientTransportWS, + want: OpenAIClientTransportWS, + }, + { + name: "http_sse_alias", + rawValue: "http_sse", + want: OpenAIClientTransportHTTP, + }, + { + name: "sse_alias", + rawValue: "sSe", + want: OpenAIClientTransportHTTP, + }, + { + name: "websocket_alias", + rawValue: "WebSocket", + want: OpenAIClientTransportWS, + }, + { + name: "invalid_string", + rawValue: "tcp", + want: OpenAIClientTransportUnknown, + }, + { + name: "invalid_type", + rawValue: 123, + want: OpenAIClientTransportUnknown, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Set(openAIClientTransportContextKey, tt.rawValue) + require.Equal(t, tt.want, GetOpenAIClientTransport(c)) + }) + } +} + +func TestOpenAIClientTransport_NilAndUnknownInput(t *testing.T) { + SetOpenAIClientTransport(nil, OpenAIClientTransportHTTP) + require.Equal(t, OpenAIClientTransportUnknown, GetOpenAIClientTransport(nil)) + + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + SetOpenAIClientTransport(c, OpenAIClientTransportUnknown) + _, exists := c.Get(openAIClientTransportContextKey) + require.False(t, exists) + + SetOpenAIClientTransport(c, OpenAIClientTransport(" ")) + _, exists = c.Get(openAIClientTransportContextKey) + require.False(t, exists) +} + +func TestResolveOpenAIWSDecisionByClientTransport(t *testing.T) { + base := OpenAIWSProtocolDecision{ + Transport: OpenAIUpstreamTransportResponsesWebsocketV2, + Reason: "ws_v2_enabled", + } + + httpDecision := resolveOpenAIWSDecisionByClientTransport(base, OpenAIClientTransportHTTP) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, httpDecision.Transport) + require.Equal(t, "client_protocol_http", httpDecision.Reason) + + wsDecision := resolveOpenAIWSDecisionByClientTransport(base, OpenAIClientTransportWS) + require.Equal(t, base, wsDecision) + + unknownDecision := resolveOpenAIWSDecisionByClientTransport(base, OpenAIClientTransportUnknown) + require.Equal(t, base, unknownDecision) +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index f26ce03f..f624d92a 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -10,10 +10,12 @@ import ( "errors" "fmt" "io" + "math/rand" "net/http" "sort" "strconv" "strings" + "sync" "sync/atomic" "time" @@ -34,35 +36,46 @@ const ( // OpenAI Platform API for API Key accounts (fallback) openaiPlatformAPIURL = "https://api.openai.com/v1/responses" openaiStickySessionTTL = time.Hour // 粘性会话TTL - codexCLIUserAgent = "codex_cli_rs/0.98.0" + codexCLIUserAgent = "codex_cli_rs/0.104.0" // codex_cli_only 拒绝时单个请求头日志长度上限(字符) codexCLIOnlyHeaderValueMaxBytes = 256 // OpenAIParsedRequestBodyKey 缓存 handler 侧已解析的请求体,避免重复解析。 OpenAIParsedRequestBodyKey = "openai_parsed_request_body" + // OpenAI WS Mode 失败后的重连次数上限(不含首次尝试)。 + // 与 Codex 客户端保持一致:失败后最多重连 5 次。 + openAIWSReconnectRetryLimit = 5 + // OpenAI WS Mode 重连退避默认值(可由配置覆盖)。 + openAIWSRetryBackoffInitialDefault = 120 * time.Millisecond + openAIWSRetryBackoffMaxDefault = 2 * time.Second + openAIWSRetryJitterRatioDefault = 0.2 ) // OpenAI allowed headers whitelist (for non-passthrough). var openaiAllowedHeaders = map[string]bool{ - "accept-language": true, - "content-type": true, - "conversation_id": true, - "user-agent": true, - "originator": true, - "session_id": true, + "accept-language": true, + "content-type": true, + "conversation_id": true, + "user-agent": true, + "originator": true, + "session_id": true, + "x-codex-turn-state": true, + "x-codex-turn-metadata": true, } // OpenAI passthrough allowed headers whitelist. // 透传模式下仅放行这些低风险请求头,避免将非标准/环境噪声头传给上游触发风控。 var openaiPassthroughAllowedHeaders = map[string]bool{ - "accept": true, - "accept-language": true, - "content-type": true, - "conversation_id": true, - "openai-beta": true, - "user-agent": true, - "originator": true, - "session_id": true, + "accept": true, + "accept-language": true, + "content-type": true, + "conversation_id": true, + "openai-beta": true, + "user-agent": true, + "originator": true, + "session_id": true, + "x-codex-turn-state": true, + "x-codex-turn-metadata": true, } // codex_cli_only 拒绝时记录的请求头白名单(仅用于诊断日志,不参与上游透传) @@ -196,10 +209,40 @@ type OpenAIForwardResult struct { // Stored for usage records display; nil means not provided / not applicable. ReasoningEffort *string Stream bool + OpenAIWSMode bool Duration time.Duration FirstTokenMs *int } +type OpenAIWSRetryMetricsSnapshot struct { + RetryAttemptsTotal int64 `json:"retry_attempts_total"` + RetryBackoffMsTotal int64 `json:"retry_backoff_ms_total"` + RetryExhaustedTotal int64 `json:"retry_exhausted_total"` + NonRetryableFastFallbackTotal int64 `json:"non_retryable_fast_fallback_total"` +} + +type OpenAICompatibilityFallbackMetricsSnapshot struct { + SessionHashLegacyReadFallbackTotal int64 `json:"session_hash_legacy_read_fallback_total"` + SessionHashLegacyReadFallbackHit int64 `json:"session_hash_legacy_read_fallback_hit"` + SessionHashLegacyDualWriteTotal int64 `json:"session_hash_legacy_dual_write_total"` + SessionHashLegacyReadHitRate float64 `json:"session_hash_legacy_read_hit_rate"` + + MetadataLegacyFallbackIsMaxTokensOneHaikuTotal int64 `json:"metadata_legacy_fallback_is_max_tokens_one_haiku_total"` + MetadataLegacyFallbackThinkingEnabledTotal int64 `json:"metadata_legacy_fallback_thinking_enabled_total"` + MetadataLegacyFallbackPrefetchedStickyAccount int64 `json:"metadata_legacy_fallback_prefetched_sticky_account_total"` + MetadataLegacyFallbackPrefetchedStickyGroup int64 `json:"metadata_legacy_fallback_prefetched_sticky_group_total"` + MetadataLegacyFallbackSingleAccountRetryTotal int64 `json:"metadata_legacy_fallback_single_account_retry_total"` + MetadataLegacyFallbackAccountSwitchCountTotal int64 `json:"metadata_legacy_fallback_account_switch_count_total"` + MetadataLegacyFallbackTotal int64 `json:"metadata_legacy_fallback_total"` +} + +type openAIWSRetryMetrics struct { + retryAttempts atomic.Int64 + retryBackoffMs atomic.Int64 + retryExhausted atomic.Int64 + nonRetryableFastFallback atomic.Int64 +} + // OpenAIGatewayService handles OpenAI API gateway operations type OpenAIGatewayService struct { accountRepo AccountRepository @@ -218,6 +261,19 @@ type OpenAIGatewayService struct { deferredService *DeferredService openAITokenProvider *OpenAITokenProvider toolCorrector *CodexToolCorrector + openaiWSResolver OpenAIWSProtocolResolver + + openaiWSPoolOnce sync.Once + openaiWSStateStoreOnce sync.Once + openaiSchedulerOnce sync.Once + openaiWSPool *openAIWSConnPool + openaiWSStateStore OpenAIWSStateStore + openaiScheduler OpenAIAccountScheduler + openaiAccountStats *openAIAccountRuntimeStats + + openaiWSFallbackUntil sync.Map // key: int64(accountID), value: time.Time + openaiWSRetryMetrics openAIWSRetryMetrics + responseHeaderFilter *responseheaders.CompiledHeaderFilter } // NewOpenAIGatewayService creates a new OpenAIGatewayService @@ -237,24 +293,61 @@ func NewOpenAIGatewayService( deferredService *DeferredService, openAITokenProvider *OpenAITokenProvider, ) *OpenAIGatewayService { - return &OpenAIGatewayService{ - accountRepo: accountRepo, - usageLogRepo: usageLogRepo, - userRepo: userRepo, - userSubRepo: userSubRepo, - cache: cache, - cfg: cfg, - codexDetector: NewOpenAICodexClientRestrictionDetector(cfg), - schedulerSnapshot: schedulerSnapshot, - concurrencyService: concurrencyService, - billingService: billingService, - rateLimitService: rateLimitService, - billingCacheService: billingCacheService, - httpUpstream: httpUpstream, - deferredService: deferredService, - openAITokenProvider: openAITokenProvider, - toolCorrector: NewCodexToolCorrector(), + svc := &OpenAIGatewayService{ + accountRepo: accountRepo, + usageLogRepo: usageLogRepo, + userRepo: userRepo, + userSubRepo: userSubRepo, + cache: cache, + cfg: cfg, + codexDetector: NewOpenAICodexClientRestrictionDetector(cfg), + schedulerSnapshot: schedulerSnapshot, + concurrencyService: concurrencyService, + billingService: billingService, + rateLimitService: rateLimitService, + billingCacheService: billingCacheService, + httpUpstream: httpUpstream, + deferredService: deferredService, + openAITokenProvider: openAITokenProvider, + toolCorrector: NewCodexToolCorrector(), + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + responseHeaderFilter: compileResponseHeaderFilter(cfg), } + svc.logOpenAIWSModeBootstrap() + return svc +} + +// CloseOpenAIWSPool 关闭 OpenAI WebSocket 连接池的后台 worker 和空闲连接。 +// 应在应用优雅关闭时调用。 +func (s *OpenAIGatewayService) CloseOpenAIWSPool() { + if s != nil && s.openaiWSPool != nil { + s.openaiWSPool.Close() + } +} + +func (s *OpenAIGatewayService) logOpenAIWSModeBootstrap() { + if s == nil || s.cfg == nil { + return + } + wsCfg := s.cfg.Gateway.OpenAIWS + logOpenAIWSModeInfo( + "bootstrap enabled=%v oauth_enabled=%v apikey_enabled=%v force_http=%v responses_websockets_v2=%v responses_websockets=%v payload_log_sample_rate=%.3f event_flush_batch_size=%d event_flush_interval_ms=%d prewarm_cooldown_ms=%d retry_backoff_initial_ms=%d retry_backoff_max_ms=%d retry_jitter_ratio=%.3f retry_total_budget_ms=%d ws_read_limit_bytes=%d", + wsCfg.Enabled, + wsCfg.OAuthEnabled, + wsCfg.APIKeyEnabled, + wsCfg.ForceHTTP, + wsCfg.ResponsesWebsocketsV2, + wsCfg.ResponsesWebsockets, + wsCfg.PayloadLogSampleRate, + wsCfg.EventFlushBatchSize, + wsCfg.EventFlushIntervalMS, + wsCfg.PrewarmCooldownMS, + wsCfg.RetryBackoffInitialMS, + wsCfg.RetryBackoffMaxMS, + wsCfg.RetryJitterRatio, + wsCfg.RetryTotalBudgetMS, + openAIWSMessageReadLimitBytes, + ) } func (s *OpenAIGatewayService) getCodexClientRestrictionDetector() CodexClientRestrictionDetector { @@ -268,6 +361,317 @@ func (s *OpenAIGatewayService) getCodexClientRestrictionDetector() CodexClientRe return NewOpenAICodexClientRestrictionDetector(cfg) } +func (s *OpenAIGatewayService) getOpenAIWSProtocolResolver() OpenAIWSProtocolResolver { + if s != nil && s.openaiWSResolver != nil { + return s.openaiWSResolver + } + var cfg *config.Config + if s != nil { + cfg = s.cfg + } + return NewOpenAIWSProtocolResolver(cfg) +} + +func classifyOpenAIWSReconnectReason(err error) (string, bool) { + if err == nil { + return "", false + } + var fallbackErr *openAIWSFallbackError + if !errors.As(err, &fallbackErr) || fallbackErr == nil { + return "", false + } + reason := strings.TrimSpace(fallbackErr.Reason) + if reason == "" { + return "", false + } + + baseReason := strings.TrimPrefix(reason, "prewarm_") + + switch baseReason { + case "policy_violation", + "message_too_big", + "upgrade_required", + "ws_unsupported", + "auth_failed", + "previous_response_not_found": + return reason, false + } + + switch baseReason { + case "read_event", + "write_request", + "write", + "acquire_timeout", + "acquire_conn", + "conn_queue_full", + "dial_failed", + "upstream_5xx", + "event_error", + "error_event", + "upstream_error_event", + "ws_connection_limit_reached", + "missing_final_response": + return reason, true + default: + return reason, false + } +} + +func resolveOpenAIWSFallbackErrorResponse(err error) (statusCode int, errType string, clientMessage string, upstreamMessage string, ok bool) { + if err == nil { + return 0, "", "", "", false + } + var fallbackErr *openAIWSFallbackError + if !errors.As(err, &fallbackErr) || fallbackErr == nil { + return 0, "", "", "", false + } + + reason := strings.TrimSpace(fallbackErr.Reason) + reason = strings.TrimPrefix(reason, "prewarm_") + if reason == "" { + return 0, "", "", "", false + } + + var dialErr *openAIWSDialError + if fallbackErr.Err != nil && errors.As(fallbackErr.Err, &dialErr) && dialErr != nil { + if dialErr.StatusCode > 0 { + statusCode = dialErr.StatusCode + } + if dialErr.Err != nil { + upstreamMessage = sanitizeUpstreamErrorMessage(strings.TrimSpace(dialErr.Err.Error())) + } + } + + switch reason { + case "previous_response_not_found": + if statusCode == 0 { + statusCode = http.StatusBadRequest + } + errType = "invalid_request_error" + if upstreamMessage == "" { + upstreamMessage = "previous response not found" + } + case "upgrade_required": + if statusCode == 0 { + statusCode = http.StatusUpgradeRequired + } + case "ws_unsupported": + if statusCode == 0 { + statusCode = http.StatusBadRequest + } + case "auth_failed": + if statusCode == 0 { + statusCode = http.StatusUnauthorized + } + case "upstream_rate_limited": + if statusCode == 0 { + statusCode = http.StatusTooManyRequests + } + default: + if statusCode == 0 { + return 0, "", "", "", false + } + } + + if upstreamMessage == "" && fallbackErr.Err != nil { + upstreamMessage = sanitizeUpstreamErrorMessage(strings.TrimSpace(fallbackErr.Err.Error())) + } + if upstreamMessage == "" { + switch reason { + case "upgrade_required": + upstreamMessage = "upstream websocket upgrade required" + case "ws_unsupported": + upstreamMessage = "upstream websocket not supported" + case "auth_failed": + upstreamMessage = "upstream authentication failed" + case "upstream_rate_limited": + upstreamMessage = "upstream rate limit exceeded, please retry later" + default: + upstreamMessage = "Upstream request failed" + } + } + + if errType == "" { + if statusCode == http.StatusTooManyRequests { + errType = "rate_limit_error" + } else { + errType = "upstream_error" + } + } + clientMessage = upstreamMessage + return statusCode, errType, clientMessage, upstreamMessage, true +} + +func (s *OpenAIGatewayService) writeOpenAIWSFallbackErrorResponse(c *gin.Context, account *Account, wsErr error) bool { + if c == nil || c.Writer == nil || c.Writer.Written() { + return false + } + statusCode, errType, clientMessage, upstreamMessage, ok := resolveOpenAIWSFallbackErrorResponse(wsErr) + if !ok { + return false + } + if strings.TrimSpace(clientMessage) == "" { + clientMessage = "Upstream request failed" + } + if strings.TrimSpace(upstreamMessage) == "" { + upstreamMessage = clientMessage + } + + setOpsUpstreamError(c, statusCode, upstreamMessage, "") + if account != nil { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: statusCode, + Kind: "ws_error", + Message: upstreamMessage, + }) + } + c.JSON(statusCode, gin.H{ + "error": gin.H{ + "type": errType, + "message": clientMessage, + }, + }) + return true +} + +func (s *OpenAIGatewayService) openAIWSRetryBackoff(attempt int) time.Duration { + if attempt <= 0 { + return 0 + } + + initial := openAIWSRetryBackoffInitialDefault + maxBackoff := openAIWSRetryBackoffMaxDefault + jitterRatio := openAIWSRetryJitterRatioDefault + if s != nil && s.cfg != nil { + wsCfg := s.cfg.Gateway.OpenAIWS + if wsCfg.RetryBackoffInitialMS > 0 { + initial = time.Duration(wsCfg.RetryBackoffInitialMS) * time.Millisecond + } + if wsCfg.RetryBackoffMaxMS > 0 { + maxBackoff = time.Duration(wsCfg.RetryBackoffMaxMS) * time.Millisecond + } + if wsCfg.RetryJitterRatio >= 0 { + jitterRatio = wsCfg.RetryJitterRatio + } + } + if initial <= 0 { + return 0 + } + if maxBackoff <= 0 { + maxBackoff = initial + } + if maxBackoff < initial { + maxBackoff = initial + } + if jitterRatio < 0 { + jitterRatio = 0 + } + if jitterRatio > 1 { + jitterRatio = 1 + } + + shift := attempt - 1 + if shift < 0 { + shift = 0 + } + backoff := initial + if shift > 0 { + backoff = initial * time.Duration(1< maxBackoff { + backoff = maxBackoff + } + if jitterRatio <= 0 { + return backoff + } + jitter := time.Duration(float64(backoff) * jitterRatio) + if jitter <= 0 { + return backoff + } + delta := time.Duration(rand.Int63n(int64(jitter)*2+1)) - jitter + withJitter := backoff + delta + if withJitter < 0 { + return 0 + } + return withJitter +} + +func (s *OpenAIGatewayService) openAIWSRetryTotalBudget() time.Duration { + if s != nil && s.cfg != nil { + ms := s.cfg.Gateway.OpenAIWS.RetryTotalBudgetMS + if ms <= 0 { + return 0 + } + return time.Duration(ms) * time.Millisecond + } + return 0 +} + +func (s *OpenAIGatewayService) recordOpenAIWSRetryAttempt(backoff time.Duration) { + if s == nil { + return + } + s.openaiWSRetryMetrics.retryAttempts.Add(1) + if backoff > 0 { + s.openaiWSRetryMetrics.retryBackoffMs.Add(backoff.Milliseconds()) + } +} + +func (s *OpenAIGatewayService) recordOpenAIWSRetryExhausted() { + if s == nil { + return + } + s.openaiWSRetryMetrics.retryExhausted.Add(1) +} + +func (s *OpenAIGatewayService) recordOpenAIWSNonRetryableFastFallback() { + if s == nil { + return + } + s.openaiWSRetryMetrics.nonRetryableFastFallback.Add(1) +} + +func (s *OpenAIGatewayService) SnapshotOpenAIWSRetryMetrics() OpenAIWSRetryMetricsSnapshot { + if s == nil { + return OpenAIWSRetryMetricsSnapshot{} + } + return OpenAIWSRetryMetricsSnapshot{ + RetryAttemptsTotal: s.openaiWSRetryMetrics.retryAttempts.Load(), + RetryBackoffMsTotal: s.openaiWSRetryMetrics.retryBackoffMs.Load(), + RetryExhaustedTotal: s.openaiWSRetryMetrics.retryExhausted.Load(), + NonRetryableFastFallbackTotal: s.openaiWSRetryMetrics.nonRetryableFastFallback.Load(), + } +} + +func SnapshotOpenAICompatibilityFallbackMetrics() OpenAICompatibilityFallbackMetricsSnapshot { + legacyReadFallbackTotal, legacyReadFallbackHit, legacyDualWriteTotal := openAIStickyCompatStats() + isMaxTokensOneHaiku, thinkingEnabled, prefetchedStickyAccount, prefetchedStickyGroup, singleAccountRetry, accountSwitchCount := RequestMetadataFallbackStats() + + readHitRate := float64(0) + if legacyReadFallbackTotal > 0 { + readHitRate = float64(legacyReadFallbackHit) / float64(legacyReadFallbackTotal) + } + metadataFallbackTotal := isMaxTokensOneHaiku + thinkingEnabled + prefetchedStickyAccount + prefetchedStickyGroup + singleAccountRetry + accountSwitchCount + + return OpenAICompatibilityFallbackMetricsSnapshot{ + SessionHashLegacyReadFallbackTotal: legacyReadFallbackTotal, + SessionHashLegacyReadFallbackHit: legacyReadFallbackHit, + SessionHashLegacyDualWriteTotal: legacyDualWriteTotal, + SessionHashLegacyReadHitRate: readHitRate, + + MetadataLegacyFallbackIsMaxTokensOneHaikuTotal: isMaxTokensOneHaiku, + MetadataLegacyFallbackThinkingEnabledTotal: thinkingEnabled, + MetadataLegacyFallbackPrefetchedStickyAccount: prefetchedStickyAccount, + MetadataLegacyFallbackPrefetchedStickyGroup: prefetchedStickyGroup, + MetadataLegacyFallbackSingleAccountRetryTotal: singleAccountRetry, + MetadataLegacyFallbackAccountSwitchCountTotal: accountSwitchCount, + MetadataLegacyFallbackTotal: metadataFallbackTotal, + } +} + func (s *OpenAIGatewayService) detectCodexClientRestriction(c *gin.Context, account *Account) CodexClientRestrictionDetectionResult { return s.getCodexClientRestrictionDetector().Detect(c, account) } @@ -494,8 +898,28 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, body []byte) return "" } - hash := sha256.Sum256([]byte(sessionID)) - return hex.EncodeToString(hash[:]) + currentHash, legacyHash := deriveOpenAISessionHashes(sessionID) + attachOpenAILegacySessionHashToGin(c, legacyHash) + return currentHash +} + +// GenerateSessionHashWithFallback 先按常规信号生成会话哈希; +// 当未携带 session_id/conversation_id/prompt_cache_key 时,使用 fallbackSeed 生成稳定哈希。 +// 该方法用于 WS ingress,避免会话信号缺失时发生跨账号漂移。 +func (s *OpenAIGatewayService) GenerateSessionHashWithFallback(c *gin.Context, body []byte, fallbackSeed string) string { + sessionHash := s.GenerateSessionHash(c, body) + if sessionHash != "" { + return sessionHash + } + + seed := strings.TrimSpace(fallbackSeed) + if seed == "" { + return "" + } + + currentHash, legacyHash := deriveOpenAISessionHashes(seed) + attachOpenAILegacySessionHashToGin(c, legacyHash) + return currentHash } // BindStickySession sets session -> account binding with standard TTL. @@ -503,7 +927,11 @@ func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, groupID *i if sessionHash == "" || accountID <= 0 { return nil } - return s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, accountID, openaiStickySessionTTL) + ttl := openaiStickySessionTTL + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds > 0 { + ttl = time.Duration(s.cfg.Gateway.OpenAIWS.StickySessionTTLSeconds) * time.Second + } + return s.setStickySessionAccountID(ctx, groupID, sessionHash, accountID, ttl) } // SelectAccount selects an OpenAI account with sticky session support @@ -519,11 +947,13 @@ func (s *OpenAIGatewayService) SelectAccountForModel(ctx context.Context, groupI // SelectAccountForModelWithExclusions selects an account supporting the requested model while excluding specified accounts. // SelectAccountForModelWithExclusions 选择支持指定模型的账号,同时排除指定的账号。 func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*Account, error) { - cacheKey := "openai:" + sessionHash + return s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, 0) +} +func (s *OpenAIGatewayService) selectAccountForModelWithExclusions(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) (*Account, error) { // 1. 尝试粘性会话命中 // Try sticky session hit - if account := s.tryStickySessionHit(ctx, groupID, sessionHash, cacheKey, requestedModel, excludedIDs); account != nil { + if account := s.tryStickySessionHit(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID); account != nil { return account, nil } @@ -548,7 +978,7 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C // 4. 设置粘性会话绑定 // Set sticky session binding if sessionHash != "" { - _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), cacheKey, selected.ID, openaiStickySessionTTL) + _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, selected.ID, openaiStickySessionTTL) } return selected, nil @@ -559,14 +989,18 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C // // tryStickySessionHit attempts to get account from sticky session. // Returns account if hit and usable; clears session and returns nil if account is unavailable. -func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, cacheKey, requestedModel string, excludedIDs map[int64]struct{}) *Account { +func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID *int64, sessionHash, requestedModel string, excludedIDs map[int64]struct{}, stickyAccountID int64) *Account { if sessionHash == "" { return nil } - accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), cacheKey) - if err != nil || accountID <= 0 { - return nil + accountID := stickyAccountID + if accountID <= 0 { + var err error + accountID, err = s.getStickySessionAccountID(ctx, groupID, sessionHash) + if err != nil || accountID <= 0 { + return nil + } } if _, excluded := excludedIDs[accountID]; excluded { @@ -581,7 +1015,7 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID // 检查账号是否需要清理粘性会话 // Check if sticky session should be cleared if shouldClearStickySession(account, requestedModel) { - _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), cacheKey) + _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) return nil } @@ -596,7 +1030,7 @@ func (s *OpenAIGatewayService) tryStickySessionHit(ctx context.Context, groupID // 刷新会话 TTL 并返回账号 // Refresh session TTL and return account - _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), cacheKey, openaiStickySessionTTL) + _ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL) return account } @@ -682,12 +1116,12 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex cfg := s.schedulingConfig() var stickyAccountID int64 if sessionHash != "" && s.cache != nil { - if accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash); err == nil { + if accountID, err := s.getStickySessionAccountID(ctx, groupID, sessionHash); err == nil { stickyAccountID = accountID } } if s.concurrencyService == nil || !cfg.LoadBatchEnabled { - account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs) + account, err := s.selectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs, stickyAccountID) if err != nil { return nil, err } @@ -742,19 +1176,19 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex // ============ Layer 1: Sticky session ============ if sessionHash != "" { - accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash) - if err == nil && accountID > 0 && !isExcluded(accountID) { + accountID := stickyAccountID + if accountID > 0 && !isExcluded(accountID) { account, err := s.getSchedulableAccount(ctx, accountID) if err == nil { clearSticky := shouldClearStickySession(account, requestedModel) if clearSticky { - _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash) + _ = s.deleteStickySessionAccountID(ctx, groupID, sessionHash) } if !clearSticky && account.IsSchedulable() && account.IsOpenAI() && (requestedModel == "" || account.IsModelSupported(requestedModel)) { result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if err == nil && result.Acquired { - _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), "openai:"+sessionHash, openaiStickySessionTTL) + _ = s.refreshStickySessionTTL(ctx, groupID, sessionHash, openaiStickySessionTTL) return &AccountSelectionResult{ Account: account, Acquired: true, @@ -818,7 +1252,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) if err == nil && result.Acquired { if sessionHash != "" { - _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, acc.ID, openaiStickySessionTTL) + _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, acc.ID, openaiStickySessionTTL) } return &AccountSelectionResult{ Account: acc, @@ -868,7 +1302,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) if err == nil && result.Acquired { if sessionHash != "" { - _ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), "openai:"+sessionHash, item.account.ID, openaiStickySessionTTL) + _ = s.setStickySessionAccountID(ctx, groupID, sessionHash, item.account.ID, openaiStickySessionTTL) } return &AccountSelectionResult{ Account: item.account, @@ -1010,6 +1444,37 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco originalModel := reqModel isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) + wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account) + clientTransport := GetOpenAIClientTransport(c) + // 仅允许 WS 入站请求走 WS 上游,避免出现 HTTP -> WS 协议混用。 + wsDecision = resolveOpenAIWSDecisionByClientTransport(wsDecision, clientTransport) + if c != nil { + c.Set("openai_ws_transport_decision", string(wsDecision.Transport)) + c.Set("openai_ws_transport_reason", wsDecision.Reason) + } + if wsDecision.Transport == OpenAIUpstreamTransportResponsesWebsocketV2 { + logOpenAIWSModeDebug( + "selected account_id=%d account_type=%s transport=%s reason=%s model=%s stream=%v", + account.ID, + account.Type, + normalizeOpenAIWSLogValue(string(wsDecision.Transport)), + normalizeOpenAIWSLogValue(wsDecision.Reason), + reqModel, + reqStream, + ) + } + // 当前仅支持 WSv2;WSv1 命中时直接返回错误,避免出现“配置可开但行为不确定”。 + if wsDecision.Transport == OpenAIUpstreamTransportResponsesWebsocket { + if c != nil { + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "type": "invalid_request_error", + "message": "OpenAI WSv1 is temporarily unsupported. Please enable responses_websockets_v2.", + }, + }) + } + return nil, errors.New("openai ws v1 is temporarily unsupported; use ws v2") + } passthroughEnabled := account.IsOpenAIPassthroughEnabled() if passthroughEnabled { // 透传分支只需要轻量提取字段,避免热路径全量 Unmarshal。 @@ -1037,12 +1502,61 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco // Track if body needs re-serialization bodyModified := false + // 单字段补丁快速路径:只要整个变更集最终可归约为同一路径的 set/delete,就避免全量 Marshal。 + patchDisabled := false + patchHasOp := false + patchDelete := false + patchPath := "" + var patchValue any + markPatchSet := func(path string, value any) { + if strings.TrimSpace(path) == "" { + patchDisabled = true + return + } + if patchDisabled { + return + } + if !patchHasOp { + patchHasOp = true + patchDelete = false + patchPath = path + patchValue = value + return + } + if patchDelete || patchPath != path { + patchDisabled = true + return + } + patchValue = value + } + markPatchDelete := func(path string) { + if strings.TrimSpace(path) == "" { + patchDisabled = true + return + } + if patchDisabled { + return + } + if !patchHasOp { + patchHasOp = true + patchDelete = true + patchPath = path + return + } + if !patchDelete || patchPath != path { + patchDisabled = true + } + } + disablePatch := func() { + patchDisabled = true + } // 非透传模式下,保持历史行为:非 Codex CLI 请求在 instructions 为空时注入默认指令。 if !isCodexCLI && isInstructionsEmpty(reqBody) { if instructions := strings.TrimSpace(GetOpenCodeInstructions()); instructions != "" { reqBody["instructions"] = instructions bodyModified = true + markPatchSet("instructions", instructions) } } @@ -1052,6 +1566,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Model mapping applied: %s -> %s (account: %s, isCodexCLI: %v)", reqModel, mappedModel, account.Name, isCodexCLI) reqBody["model"] = mappedModel bodyModified = true + markPatchSet("model", mappedModel) } // 针对所有 OpenAI 账号执行 Codex 模型名规范化,确保上游识别一致。 @@ -1063,6 +1578,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco reqBody["model"] = normalizedModel mappedModel = normalizedModel bodyModified = true + markPatchSet("model", normalizedModel) } } @@ -1071,6 +1587,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco if effort, ok := reasoning["effort"].(string); ok && effort == "minimal" { reasoning["effort"] = "none" bodyModified = true + markPatchSet("reasoning.effort", "none") logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized reasoning.effort: minimal -> none (account: %s)", account.Name) } } @@ -1079,6 +1596,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco codexResult := applyCodexOAuthTransform(reqBody, isCodexCLI) if codexResult.Modified { bodyModified = true + disablePatch() } if codexResult.NormalizedModel != "" { mappedModel = codexResult.NormalizedModel @@ -1098,22 +1616,27 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco if account.Type == AccountTypeAPIKey { delete(reqBody, "max_output_tokens") bodyModified = true + markPatchDelete("max_output_tokens") } case PlatformAnthropic: // For Anthropic (Claude), convert to max_tokens delete(reqBody, "max_output_tokens") + markPatchDelete("max_output_tokens") if _, hasMaxTokens := reqBody["max_tokens"]; !hasMaxTokens { reqBody["max_tokens"] = maxOutputTokens + disablePatch() } bodyModified = true case PlatformGemini: // For Gemini, remove (will be handled by Gemini-specific transform) delete(reqBody, "max_output_tokens") bodyModified = true + markPatchDelete("max_output_tokens") default: // For unknown platforms, remove to be safe delete(reqBody, "max_output_tokens") bodyModified = true + markPatchDelete("max_output_tokens") } } @@ -1122,24 +1645,51 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco if account.Type == AccountTypeAPIKey || account.Platform != PlatformOpenAI { delete(reqBody, "max_completion_tokens") bodyModified = true + markPatchDelete("max_completion_tokens") } } // Remove unsupported fields (not supported by upstream OpenAI API) - for _, unsupportedField := range []string{"prompt_cache_retention", "safety_identifier", "previous_response_id"} { + unsupportedFields := []string{"prompt_cache_retention", "safety_identifier"} + for _, unsupportedField := range unsupportedFields { if _, has := reqBody[unsupportedField]; has { delete(reqBody, unsupportedField) bodyModified = true + markPatchDelete(unsupportedField) } } } + // 仅在 WSv2 模式保留 previous_response_id,其他模式(HTTP/WSv1)统一过滤。 + // 注意:该规则同样适用于 Codex CLI 请求,避免 WSv1 向上游透传不支持字段。 + if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { + if _, has := reqBody["previous_response_id"]; has { + delete(reqBody, "previous_response_id") + bodyModified = true + markPatchDelete("previous_response_id") + } + } + // Re-serialize body only if modified if bodyModified { - var err error - body, err = json.Marshal(reqBody) - if err != nil { - return nil, fmt.Errorf("serialize request body: %w", err) + serializedByPatch := false + if !patchDisabled && patchHasOp { + var patchErr error + if patchDelete { + body, patchErr = sjson.DeleteBytes(body, patchPath) + } else { + body, patchErr = sjson.SetBytes(body, patchPath, patchValue) + } + if patchErr == nil { + serializedByPatch = true + } + } + if !serializedByPatch { + var marshalErr error + body, marshalErr = json.Marshal(reqBody) + if marshalErr != nil { + return nil, fmt.Errorf("serialize request body: %w", marshalErr) + } } } @@ -1149,6 +1699,184 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco return nil, err } + // Capture upstream request body for ops retry of this attempt. + setOpsUpstreamRequestBody(c, body) + + // 命中 WS 时仅走 WebSocket Mode;不再自动回退 HTTP。 + if wsDecision.Transport == OpenAIUpstreamTransportResponsesWebsocketV2 { + wsReqBody := reqBody + if len(reqBody) > 0 { + wsReqBody = make(map[string]any, len(reqBody)) + for k, v := range reqBody { + wsReqBody[k] = v + } + } + _, hasPreviousResponseID := wsReqBody["previous_response_id"] + logOpenAIWSModeDebug( + "forward_start account_id=%d account_type=%s model=%s stream=%v has_previous_response_id=%v", + account.ID, + account.Type, + mappedModel, + reqStream, + hasPreviousResponseID, + ) + maxAttempts := openAIWSReconnectRetryLimit + 1 + wsAttempts := 0 + var wsResult *OpenAIForwardResult + var wsErr error + wsLastFailureReason := "" + wsPrevResponseRecoveryTried := false + recoverPrevResponseNotFound := func(attempt int) bool { + if wsPrevResponseRecoveryTried { + return false + } + previousResponseID := openAIWSPayloadString(wsReqBody, "previous_response_id") + if previousResponseID == "" { + logOpenAIWSModeInfo( + "reconnect_prev_response_recovery_skip account_id=%d attempt=%d reason=missing_previous_response_id previous_response_id_present=false", + account.ID, + attempt, + ) + return false + } + if HasFunctionCallOutput(wsReqBody) { + logOpenAIWSModeInfo( + "reconnect_prev_response_recovery_skip account_id=%d attempt=%d reason=has_function_call_output previous_response_id_present=true", + account.ID, + attempt, + ) + return false + } + delete(wsReqBody, "previous_response_id") + wsPrevResponseRecoveryTried = true + logOpenAIWSModeInfo( + "reconnect_prev_response_recovery account_id=%d attempt=%d action=drop_previous_response_id retry=1 previous_response_id=%s previous_response_id_kind=%s", + account.ID, + attempt, + truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(ClassifyOpenAIPreviousResponseIDKind(previousResponseID)), + ) + return true + } + retryBudget := s.openAIWSRetryTotalBudget() + retryStartedAt := time.Now() + wsRetryLoop: + for attempt := 1; attempt <= maxAttempts; attempt++ { + wsAttempts = attempt + wsResult, wsErr = s.forwardOpenAIWSV2( + ctx, + c, + account, + wsReqBody, + token, + wsDecision, + isCodexCLI, + reqStream, + originalModel, + mappedModel, + startTime, + attempt, + wsLastFailureReason, + ) + if wsErr == nil { + break + } + if c != nil && c.Writer != nil && c.Writer.Written() { + break + } + + reason, retryable := classifyOpenAIWSReconnectReason(wsErr) + if reason != "" { + wsLastFailureReason = reason + } + // previous_response_not_found 说明续链锚点不可用: + // 对非 function_call_output 场景,允许一次“去掉 previous_response_id 后重放”。 + if reason == "previous_response_not_found" && recoverPrevResponseNotFound(attempt) { + continue + } + if retryable && attempt < maxAttempts { + backoff := s.openAIWSRetryBackoff(attempt) + if retryBudget > 0 && time.Since(retryStartedAt)+backoff > retryBudget { + s.recordOpenAIWSRetryExhausted() + logOpenAIWSModeInfo( + "reconnect_budget_exhausted account_id=%d attempts=%d max_retries=%d reason=%s elapsed_ms=%d budget_ms=%d", + account.ID, + attempt, + openAIWSReconnectRetryLimit, + normalizeOpenAIWSLogValue(reason), + time.Since(retryStartedAt).Milliseconds(), + retryBudget.Milliseconds(), + ) + break + } + s.recordOpenAIWSRetryAttempt(backoff) + logOpenAIWSModeInfo( + "reconnect_retry account_id=%d retry=%d max_retries=%d reason=%s backoff_ms=%d", + account.ID, + attempt, + openAIWSReconnectRetryLimit, + normalizeOpenAIWSLogValue(reason), + backoff.Milliseconds(), + ) + if backoff > 0 { + timer := time.NewTimer(backoff) + select { + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + wsErr = wrapOpenAIWSFallback("retry_backoff_canceled", ctx.Err()) + break wsRetryLoop + case <-timer.C: + } + } + continue + } + if retryable { + s.recordOpenAIWSRetryExhausted() + logOpenAIWSModeInfo( + "reconnect_exhausted account_id=%d attempts=%d max_retries=%d reason=%s", + account.ID, + attempt, + openAIWSReconnectRetryLimit, + normalizeOpenAIWSLogValue(reason), + ) + } else if reason != "" { + s.recordOpenAIWSNonRetryableFastFallback() + logOpenAIWSModeInfo( + "reconnect_stop account_id=%d attempt=%d reason=%s", + account.ID, + attempt, + normalizeOpenAIWSLogValue(reason), + ) + } + break + } + if wsErr == nil { + firstTokenMs := int64(0) + hasFirstTokenMs := wsResult != nil && wsResult.FirstTokenMs != nil + if hasFirstTokenMs { + firstTokenMs = int64(*wsResult.FirstTokenMs) + } + requestID := "" + if wsResult != nil { + requestID = strings.TrimSpace(wsResult.RequestID) + } + logOpenAIWSModeDebug( + "forward_succeeded account_id=%d request_id=%s stream=%v has_first_token_ms=%v first_token_ms=%d ws_attempts=%d", + account.ID, + requestID, + reqStream, + hasFirstTokenMs, + firstTokenMs, + wsAttempts, + ) + return wsResult, nil + } + s.writeOpenAIWSFallbackErrorResponse(c, account, wsErr) + return nil, wsErr + } + // Build upstream request upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, reqStream, promptCacheKey, isCodexCLI) if err != nil { @@ -1161,9 +1889,6 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco proxyURL = account.Proxy.URL() } - // Capture upstream request body for ops retry of this attempt. - setOpsUpstreamRequestBody(c, body) - // Send request upstreamStart := time.Now() resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) @@ -1260,6 +1985,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco Model: originalModel, ReasoningEffort: reasoningEffort, Stream: reqStream, + OpenAIWSMode: false, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, }, nil @@ -1413,6 +2139,7 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough( Model: reqModel, ReasoningEffort: reasoningEffort, Stream: reqStream, + OpenAIWSMode: false, Duration: time.Since(startTime), FirstTokenMs: firstTokenMs, }, nil @@ -1576,7 +2303,7 @@ func (s *OpenAIGatewayService) handleErrorResponsePassthrough( UpstreamResponseBody: upstreamDetail, }) - writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg) + writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) contentType := resp.Header.Get("Content-Type") if contentType == "" { contentType = "application/json" @@ -1643,7 +2370,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( account *Account, startTime time.Time, ) (*openaiStreamingResultPassthrough, error) { - writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg) + writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) // SSE headers c.Header("Content-Type", "text/event-stream") @@ -1678,6 +2405,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( for scanner.Scan() { line := scanner.Text() if data, ok := extractOpenAISSEDataLine(line); ok { + dataBytes := []byte(data) trimmedData := strings.TrimSpace(data) if trimmedData == "[DONE]" { sawDone = true @@ -1686,7 +2414,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms } - s.parseSSEUsage(data, usage) + s.parseSSEUsageBytes(dataBytes, usage) } if !clientDisconnected { @@ -1759,19 +2487,8 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough( usage := &OpenAIUsage{} usageParsed := false if len(body) > 0 { - var response struct { - Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - InputTokenDetails struct { - CachedTokens int `json:"cached_tokens"` - } `json:"input_tokens_details"` - } `json:"usage"` - } - if json.Unmarshal(body, &response) == nil { - usage.InputTokens = response.Usage.InputTokens - usage.OutputTokens = response.Usage.OutputTokens - usage.CacheReadInputTokens = response.Usage.InputTokenDetails.CachedTokens + if parsedUsage, ok := extractOpenAIUsageFromJSONBytes(body); ok { + *usage = parsedUsage usageParsed = true } } @@ -1780,7 +2497,7 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough( usage = s.parseSSEUsageFromBody(string(body)) } - writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.cfg) + writeOpenAIPassthroughResponseHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) contentType := resp.Header.Get("Content-Type") if contentType == "" { @@ -1790,12 +2507,12 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough( return usage, nil } -func writeOpenAIPassthroughResponseHeaders(dst http.Header, src http.Header, cfg *config.Config) { +func writeOpenAIPassthroughResponseHeaders(dst http.Header, src http.Header, filter *responseheaders.CompiledHeaderFilter) { if dst == nil || src == nil { return } - if cfg != nil { - responseheaders.WriteFilteredHeaders(dst, src, cfg.Security.ResponseHeaders) + if filter != nil { + responseheaders.WriteFilteredHeaders(dst, src, filter) } else { // 兜底:尽量保留最基础的 content-type if v := strings.TrimSpace(src.Get("Content-Type")); v != "" { @@ -2074,8 +2791,8 @@ type openaiStreamingResult struct { } func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, startTime time.Time, originalModel, mappedModel string) (*openaiStreamingResult, error) { - if s.cfg != nil { - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) } // Set SSE response headers @@ -2094,6 +2811,14 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp if !ok { return nil, errors.New("streaming not supported") } + bufferedWriter := bufio.NewWriterSize(w, 4*1024) + flushBuffered := func() error { + if err := bufferedWriter.Flush(); err != nil { + return err + } + flusher.Flush() + return nil + } usage := &OpenAIUsage{} var firstTokenMs *int @@ -2105,38 +2830,6 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp scanBuf := getSSEScannerBuf64K() scanner.Buffer(scanBuf[:0], maxLineSize) - type scanEvent struct { - line string - err error - } - // 独立 goroutine 读取上游,避免读取阻塞影响 keepalive/超时处理 - events := make(chan scanEvent, 16) - done := make(chan struct{}) - sendEvent := func(ev scanEvent) bool { - select { - case events <- ev: - return true - case <-done: - return false - } - } - var lastReadAt int64 - atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - go func(scanBuf *sseScannerBuf64K) { - defer putSSEScannerBuf64K(scanBuf) - defer close(events) - for scanner.Scan() { - atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) - if !sendEvent(scanEvent{line: scanner.Text()}) { - return - } - } - if err := scanner.Err(); err != nil { - _ = sendEvent(scanEvent{err: err}) - } - }(scanBuf) - defer close(done) - streamInterval := time.Duration(0) if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 { streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second @@ -2179,94 +2872,178 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp return } errorEventSent = true - payload := map[string]any{ - "type": "error", - "sequence_number": 0, - "error": map[string]any{ - "type": "upstream_error", - "message": reason, - "code": reason, - }, + payload := `{"type":"error","sequence_number":0,"error":{"type":"upstream_error","message":` + strconv.Quote(reason) + `,"code":` + strconv.Quote(reason) + `}}` + if err := flushBuffered(); err != nil { + clientDisconnected = true + return } - if b, err := json.Marshal(payload); err == nil { - _, _ = fmt.Fprintf(w, "data: %s\n\n", b) - flusher.Flush() + if _, err := bufferedWriter.WriteString("data: " + payload + "\n\n"); err != nil { + clientDisconnected = true + return + } + if err := flushBuffered(); err != nil { + clientDisconnected = true } } needModelReplace := originalModel != mappedModel + resultWithUsage := func() *openaiStreamingResult { + return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs} + } + finalizeStream := func() (*openaiStreamingResult, error) { + if !clientDisconnected { + if err := flushBuffered(); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during final flush, returning collected usage") + } + } + return resultWithUsage(), nil + } + handleScanErr := func(scanErr error) (*openaiStreamingResult, error, bool) { + if scanErr == nil { + return nil, nil, false + } + // 客户端断开/取消请求时,上游读取往往会返回 context canceled。 + // /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。 + if errors.Is(scanErr, context.Canceled) || errors.Is(scanErr, context.DeadlineExceeded) { + logger.LegacyPrintf("service.openai_gateway", "Context canceled during streaming, returning collected usage") + return resultWithUsage(), nil, true + } + // 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage + if clientDisconnected { + logger.LegacyPrintf("service.openai_gateway", "Upstream read error after client disconnect: %v, returning collected usage", scanErr) + return resultWithUsage(), nil, true + } + if errors.Is(scanErr, bufio.ErrTooLong) { + logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, scanErr) + sendErrorEvent("response_too_large") + return resultWithUsage(), scanErr, true + } + sendErrorEvent("stream_read_error") + return resultWithUsage(), fmt.Errorf("stream read error: %w", scanErr), true + } + processSSELine := func(line string, queueDrained bool) { + lastDataAt = time.Now() + + // Extract data from SSE line (supports both "data: " and "data:" formats) + if data, ok := extractOpenAISSEDataLine(line); ok { + + // Replace model in response if needed. + // Fast path: most events do not contain model field values. + if needModelReplace && mappedModel != "" && strings.Contains(data, mappedModel) { + line = s.replaceModelInSSELine(line, mappedModel, originalModel) + } + + dataBytes := []byte(data) + + // Correct Codex tool calls if needed (apply_patch -> edit, etc.) + if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEBytes(dataBytes); corrected { + dataBytes = correctedData + data = string(correctedData) + line = "data: " + data + } + + // 写入客户端(客户端断开后继续 drain 上游) + if !clientDisconnected { + shouldFlush := queueDrained + if firstTokenMs == nil && data != "" && data != "[DONE]" { + // 保证首个 token 事件尽快出站,避免影响 TTFT。 + shouldFlush = true + } + if _, err := bufferedWriter.WriteString(line); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") + } else if _, err := bufferedWriter.WriteString("\n"); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") + } else if shouldFlush { + if err := flushBuffered(); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing") + } + } + } + + // Record first token time + if firstTokenMs == nil && data != "" && data != "[DONE]" { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + s.parseSSEUsageBytes(dataBytes, usage) + return + } + + // Forward non-data lines as-is + if !clientDisconnected { + if _, err := bufferedWriter.WriteString(line); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") + } else if _, err := bufferedWriter.WriteString("\n"); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") + } else if queueDrained { + if err := flushBuffered(); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming flush, continuing to drain upstream for billing") + } + } + } + } + + // 无超时/无 keepalive 的常见路径走同步扫描,减少 goroutine 与 channel 开销。 + if streamInterval <= 0 && keepaliveInterval <= 0 { + defer putSSEScannerBuf64K(scanBuf) + for scanner.Scan() { + processSSELine(scanner.Text(), true) + } + if result, err, done := handleScanErr(scanner.Err()); done { + return result, err + } + return finalizeStream() + } + + type scanEvent struct { + line string + err error + } + // 独立 goroutine 读取上游,避免读取阻塞影响 keepalive/超时处理 + events := make(chan scanEvent, 16) + done := make(chan struct{}) + sendEvent := func(ev scanEvent) bool { + select { + case events <- ev: + return true + case <-done: + return false + } + } + var lastReadAt int64 + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + go func(scanBuf *sseScannerBuf64K) { + defer putSSEScannerBuf64K(scanBuf) + defer close(events) + for scanner.Scan() { + atomic.StoreInt64(&lastReadAt, time.Now().UnixNano()) + if !sendEvent(scanEvent{line: scanner.Text()}) { + return + } + } + if err := scanner.Err(); err != nil { + _ = sendEvent(scanEvent{err: err}) + } + }(scanBuf) + defer close(done) for { select { case ev, ok := <-events: if !ok { - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil + return finalizeStream() } - if ev.err != nil { - // 客户端断开/取消请求时,上游读取往往会返回 context canceled。 - // /v1/responses 的 SSE 事件必须符合 OpenAI 协议;这里不注入自定义 error event,避免下游 SDK 解析失败。 - if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) { - logger.LegacyPrintf("service.openai_gateway", "Context canceled during streaming, returning collected usage") - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil - } - // 客户端已断开时,上游出错仅影响体验,不影响计费;返回已收集 usage - if clientDisconnected { - logger.LegacyPrintf("service.openai_gateway", "Upstream read error after client disconnect: %v, returning collected usage", ev.err) - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil - } - if errors.Is(ev.err, bufio.ErrTooLong) { - logger.LegacyPrintf("service.openai_gateway", "SSE line too long: account=%d max_size=%d error=%v", account.ID, maxLineSize, ev.err) - sendErrorEvent("response_too_large") - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, ev.err - } - sendErrorEvent("stream_read_error") - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream read error: %w", ev.err) - } - - line := ev.line - lastDataAt = time.Now() - - // Extract data from SSE line (supports both "data: " and "data:" formats) - if data, ok := extractOpenAISSEDataLine(line); ok { - - // Replace model in response if needed - if needModelReplace { - line = s.replaceModelInSSELine(line, mappedModel, originalModel) - } - - // Correct Codex tool calls if needed (apply_patch -> edit, etc.) - if correctedData, corrected := s.toolCorrector.CorrectToolCallsInSSEData(data); corrected { - data = correctedData - line = "data: " + correctedData - } - - // 写入客户端(客户端断开后继续 drain 上游) - if !clientDisconnected { - if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { - clientDisconnected = true - logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") - } else { - flusher.Flush() - } - } - - // Record first token time - if firstTokenMs == nil && data != "" && data != "[DONE]" { - ms := int(time.Since(startTime).Milliseconds()) - firstTokenMs = &ms - } - s.parseSSEUsage(data, usage) - } else { - // Forward non-data lines as-is - if !clientDisconnected { - if _, err := fmt.Fprintf(w, "%s\n", line); err != nil { - clientDisconnected = true - logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") - } else { - flusher.Flush() - } - } + if result, err, done := handleScanErr(ev.err); done { + return result, err } + processSSELine(ev.line, len(events) == 0) case <-intervalCh: lastRead := time.Unix(0, atomic.LoadInt64(&lastReadAt)) @@ -2275,7 +3052,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp } if clientDisconnected { logger.LegacyPrintf("service.openai_gateway", "Upstream timeout after client disconnect, returning collected usage") - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, nil + return resultWithUsage(), nil } logger.LegacyPrintf("service.openai_gateway", "Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval) // 处理流超时,可能标记账户为临时不可调度或错误状态 @@ -2283,7 +3060,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel) } sendErrorEvent("stream_timeout") - return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout") + return resultWithUsage(), fmt.Errorf("stream data interval timeout") case <-keepaliveCh: if clientDisconnected { @@ -2292,12 +3069,15 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp if time.Since(lastDataAt) < keepaliveInterval { continue } - if _, err := fmt.Fprint(w, ":\n\n"); err != nil { + if _, err := bufferedWriter.WriteString(":\n\n"); err != nil { clientDisconnected = true logger.LegacyPrintf("service.openai_gateway", "Client disconnected during streaming, continuing to drain upstream for billing") continue } - flusher.Flush() + if err := flushBuffered(); err != nil { + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "Client disconnected during keepalive flush, continuing to drain upstream for billing") + } } } @@ -2355,29 +3135,49 @@ func (s *OpenAIGatewayService) correctToolCallsInResponseBody(body []byte) []byt return body } - bodyStr := string(body) - corrected, changed := s.toolCorrector.CorrectToolCallsInSSEData(bodyStr) + corrected, changed := s.toolCorrector.CorrectToolCallsInSSEBytes(body) if changed { - return []byte(corrected) + return corrected } return body } func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) { - if usage == nil || data == "" || data == "[DONE]" { + s.parseSSEUsageBytes([]byte(data), usage) +} + +func (s *OpenAIGatewayService) parseSSEUsageBytes(data []byte, usage *OpenAIUsage) { + if usage == nil || len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) { return } // 选择性解析:仅在数据中包含 completed 事件标识时才进入字段提取。 - if !strings.Contains(data, `"response.completed"`) { + if len(data) < 80 || !bytes.Contains(data, []byte(`"response.completed"`)) { return } - if gjson.Get(data, "type").String() != "response.completed" { + if gjson.GetBytes(data, "type").String() != "response.completed" { return } - usage.InputTokens = int(gjson.Get(data, "response.usage.input_tokens").Int()) - usage.OutputTokens = int(gjson.Get(data, "response.usage.output_tokens").Int()) - usage.CacheReadInputTokens = int(gjson.Get(data, "response.usage.input_tokens_details.cached_tokens").Int()) + usage.InputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens").Int()) + usage.OutputTokens = int(gjson.GetBytes(data, "response.usage.output_tokens").Int()) + usage.CacheReadInputTokens = int(gjson.GetBytes(data, "response.usage.input_tokens_details.cached_tokens").Int()) +} + +func extractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) { + if len(body) == 0 || !gjson.ValidBytes(body) { + return OpenAIUsage{}, false + } + values := gjson.GetManyBytes( + body, + "usage.input_tokens", + "usage.output_tokens", + "usage.input_tokens_details.cached_tokens", + ) + return OpenAIUsage{ + InputTokens: int(values[0].Int()), + OutputTokens: int(values[1].Int()), + CacheReadInputTokens: int(values[2].Int()), + }, true } func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) { @@ -2403,32 +3203,18 @@ func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, r } } - // Parse usage - var response struct { - Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - InputTokenDetails struct { - CachedTokens int `json:"cached_tokens"` - } `json:"input_tokens_details"` - } `json:"usage"` - } - if err := json.Unmarshal(body, &response); err != nil { - return nil, fmt.Errorf("parse response: %w", err) - } - - usage := &OpenAIUsage{ - InputTokens: response.Usage.InputTokens, - OutputTokens: response.Usage.OutputTokens, - CacheReadInputTokens: response.Usage.InputTokenDetails.CachedTokens, + usageValue, usageOK := extractOpenAIUsageFromJSONBytes(body) + if !usageOK { + return nil, fmt.Errorf("parse response: invalid json response") } + usage := &usageValue // Replace model in response if needed if originalModel != mappedModel { body = s.replaceModelInResponseBody(body, mappedModel, originalModel) } - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) contentType := "application/json" if s.cfg != nil && !s.cfg.Security.ResponseHeaders.Enabled { @@ -2453,19 +3239,8 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin. usage := &OpenAIUsage{} if ok { - var response struct { - Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - InputTokenDetails struct { - CachedTokens int `json:"cached_tokens"` - } `json:"input_tokens_details"` - } `json:"usage"` - } - if err := json.Unmarshal(finalResponse, &response); err == nil { - usage.InputTokens = response.Usage.InputTokens - usage.OutputTokens = response.Usage.OutputTokens - usage.CacheReadInputTokens = response.Usage.InputTokenDetails.CachedTokens + if parsedUsage, parsed := extractOpenAIUsageFromJSONBytes(finalResponse); parsed { + *usage = parsedUsage } body = finalResponse if originalModel != mappedModel { @@ -2481,7 +3256,7 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin. body = []byte(bodyText) } - responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders) + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) contentType := "application/json; charset=utf-8" if !ok { @@ -2505,16 +3280,10 @@ func extractCodexFinalResponse(body string) ([]byte, bool) { if data == "" || data == "[DONE]" { continue } - var event struct { - Type string `json:"type"` - Response json.RawMessage `json:"response"` - } - if json.Unmarshal([]byte(data), &event) != nil { - continue - } - if event.Type == "response.done" || event.Type == "response.completed" { - if len(event.Response) > 0 { - return event.Response, true + eventType := gjson.Get(data, "type").String() + if eventType == "response.done" || eventType == "response.completed" { + if response := gjson.Get(data, "response"); response.Exists() && response.Type == gjson.JSON && response.Raw != "" { + return []byte(response.Raw), true } } } @@ -2532,7 +3301,7 @@ func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage { if data == "" || data == "[DONE]" { continue } - s.parseSSEUsage(data, usage) + s.parseSSEUsageBytes([]byte(data), usage) } return usage } @@ -2671,6 +3440,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec AccountRateMultiplier: &accountRateMultiplier, BillingType: billingType, Stream: result.Stream, + OpenAIWSMode: result.OpenAIWSMode, DurationMs: &durationMs, FirstTokenMs: result.FirstTokenMs, CreatedAt: time.Now(), @@ -3047,6 +3817,9 @@ func getOpenAIRequestBodyMap(c *gin.Context, body []byte) (map[string]any, error if err := json.Unmarshal(body, &reqBody); err != nil { return nil, fmt.Errorf("parse request: %w", err) } + if c != nil { + c.Set(OpenAIParsedRequestBodyKey, reqBody) + } return reqBody, nil } diff --git a/backend/internal/service/openai_gateway_service_hotpath_test.go b/backend/internal/service/openai_gateway_service_hotpath_test.go index 6b11831f..f73c06c5 100644 --- a/backend/internal/service/openai_gateway_service_hotpath_test.go +++ b/backend/internal/service/openai_gateway_service_hotpath_test.go @@ -123,3 +123,19 @@ func TestGetOpenAIRequestBodyMap_ParseErrorWithoutCache(t *testing.T) { require.Error(t, err) require.Contains(t, err.Error(), "parse request") } + +func TestGetOpenAIRequestBodyMap_WriteBackContextCache(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + got, err := getOpenAIRequestBodyMap(c, []byte(`{"model":"gpt-5","stream":true}`)) + require.NoError(t, err) + require.Equal(t, "gpt-5", got["model"]) + + cached, ok := c.Get(OpenAIParsedRequestBodyKey) + require.True(t, ok) + cachedMap, ok := cached.(map[string]any) + require.True(t, ok) + require.Equal(t, got, cachedMap) +} diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 226648e4..89443b69 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -5,6 +5,7 @@ import ( "bytes" "context" "errors" + "fmt" "io" "net/http" "net/http/httptest" @@ -13,6 +14,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/cespare/xxhash/v2" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" ) @@ -166,6 +168,54 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) { } } +func TestOpenAIGatewayService_GenerateSessionHash_UsesXXHash64(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + + c.Request.Header.Set("session_id", "sess-fixed-value") + svc := &OpenAIGatewayService{} + + got := svc.GenerateSessionHash(c, nil) + want := fmt.Sprintf("%016x", xxhash.Sum64String("sess-fixed-value")) + require.Equal(t, want, got) +} + +func TestOpenAIGatewayService_GenerateSessionHash_AttachesLegacyHashToContext(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + + c.Request.Header.Set("session_id", "sess-legacy-check") + svc := &OpenAIGatewayService{} + + sessionHash := svc.GenerateSessionHash(c, nil) + require.NotEmpty(t, sessionHash) + require.NotNil(t, c.Request) + require.NotNil(t, c.Request.Context()) + require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context())) +} + +func TestOpenAIGatewayService_GenerateSessionHashWithFallback(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + + svc := &OpenAIGatewayService{} + seed := "openai_ws_ingress:9:100:200" + + got := svc.GenerateSessionHashWithFallback(c, []byte(`{}`), seed) + want := fmt.Sprintf("%016x", xxhash.Sum64String(seed)) + require.Equal(t, want, got) + require.NotEmpty(t, openAILegacySessionHashFromContext(c.Request.Context())) + + empty := svc.GenerateSessionHashWithFallback(c, []byte(`{}`), " ") + require.Equal(t, "", empty) +} + func (c stubConcurrencyCache) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { if c.waitCounts != nil { if count, ok := c.waitCounts[accountID]; ok { diff --git a/backend/internal/service/openai_json_optimization_benchmark_test.go b/backend/internal/service/openai_json_optimization_benchmark_test.go new file mode 100644 index 00000000..1737804b --- /dev/null +++ b/backend/internal/service/openai_json_optimization_benchmark_test.go @@ -0,0 +1,357 @@ +package service + +import ( + "encoding/json" + "strconv" + "strings" + "testing" + + "github.com/tidwall/gjson" +) + +var ( + benchmarkToolContinuationBoolSink bool + benchmarkWSParseStringSink string + benchmarkWSParseMapSink map[string]any + benchmarkUsageSink OpenAIUsage +) + +func BenchmarkToolContinuationValidationLegacy(b *testing.B) { + reqBody := benchmarkToolContinuationRequestBody() + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchmarkToolContinuationBoolSink = legacyValidateFunctionCallOutputContext(reqBody) + } +} + +func BenchmarkToolContinuationValidationOptimized(b *testing.B) { + reqBody := benchmarkToolContinuationRequestBody() + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchmarkToolContinuationBoolSink = optimizedValidateFunctionCallOutputContext(reqBody) + } +} + +func BenchmarkWSIngressPayloadParseLegacy(b *testing.B) { + raw := benchmarkWSIngressPayloadBytes() + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + eventType, model, promptCacheKey, previousResponseID, payload, err := legacyParseWSIngressPayload(raw) + if err == nil { + benchmarkWSParseStringSink = eventType + model + promptCacheKey + previousResponseID + benchmarkWSParseMapSink = payload + } + } +} + +func BenchmarkWSIngressPayloadParseOptimized(b *testing.B) { + raw := benchmarkWSIngressPayloadBytes() + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + eventType, model, promptCacheKey, previousResponseID, payload, err := optimizedParseWSIngressPayload(raw) + if err == nil { + benchmarkWSParseStringSink = eventType + model + promptCacheKey + previousResponseID + benchmarkWSParseMapSink = payload + } + } +} + +func BenchmarkOpenAIUsageExtractLegacy(b *testing.B) { + body := benchmarkOpenAIUsageJSONBytes() + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + usage, ok := legacyExtractOpenAIUsageFromJSONBytes(body) + if ok { + benchmarkUsageSink = usage + } + } +} + +func BenchmarkOpenAIUsageExtractOptimized(b *testing.B) { + body := benchmarkOpenAIUsageJSONBytes() + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + usage, ok := extractOpenAIUsageFromJSONBytes(body) + if ok { + benchmarkUsageSink = usage + } + } +} + +func benchmarkToolContinuationRequestBody() map[string]any { + input := make([]any, 0, 64) + for i := 0; i < 24; i++ { + input = append(input, map[string]any{ + "type": "text", + "text": "benchmark text", + }) + } + for i := 0; i < 10; i++ { + callID := "call_" + strconv.Itoa(i) + input = append(input, map[string]any{ + "type": "tool_call", + "call_id": callID, + }) + input = append(input, map[string]any{ + "type": "function_call_output", + "call_id": callID, + }) + input = append(input, map[string]any{ + "type": "item_reference", + "id": callID, + }) + } + return map[string]any{ + "model": "gpt-5.3-codex", + "input": input, + } +} + +func benchmarkWSIngressPayloadBytes() []byte { + return []byte(`{"type":"response.create","model":"gpt-5.3-codex","prompt_cache_key":"cache_bench","previous_response_id":"resp_prev_bench","input":[{"type":"message","role":"user","content":[{"type":"input_text","text":"hello"}]}]}`) +} + +func benchmarkOpenAIUsageJSONBytes() []byte { + return []byte(`{"id":"resp_bench","object":"response","model":"gpt-5.3-codex","usage":{"input_tokens":3210,"output_tokens":987,"input_tokens_details":{"cached_tokens":456}}}`) +} + +func legacyValidateFunctionCallOutputContext(reqBody map[string]any) bool { + if !legacyHasFunctionCallOutput(reqBody) { + return true + } + previousResponseID, _ := reqBody["previous_response_id"].(string) + if strings.TrimSpace(previousResponseID) != "" { + return true + } + if legacyHasToolCallContext(reqBody) { + return true + } + if legacyHasFunctionCallOutputMissingCallID(reqBody) { + return false + } + callIDs := legacyFunctionCallOutputCallIDs(reqBody) + return legacyHasItemReferenceForCallIDs(reqBody, callIDs) +} + +func optimizedValidateFunctionCallOutputContext(reqBody map[string]any) bool { + validation := ValidateFunctionCallOutputContext(reqBody) + if !validation.HasFunctionCallOutput { + return true + } + previousResponseID, _ := reqBody["previous_response_id"].(string) + if strings.TrimSpace(previousResponseID) != "" { + return true + } + if validation.HasToolCallContext { + return true + } + if validation.HasFunctionCallOutputMissingCallID { + return false + } + return validation.HasItemReferenceForAllCallIDs +} + +func legacyHasFunctionCallOutput(reqBody map[string]any) bool { + if reqBody == nil { + return false + } + input, ok := reqBody["input"].([]any) + if !ok { + return false + } + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType == "function_call_output" { + return true + } + } + return false +} + +func legacyHasToolCallContext(reqBody map[string]any) bool { + if reqBody == nil { + return false + } + input, ok := reqBody["input"].([]any) + if !ok { + return false + } + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType != "tool_call" && itemType != "function_call" { + continue + } + if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" { + return true + } + } + return false +} + +func legacyFunctionCallOutputCallIDs(reqBody map[string]any) []string { + if reqBody == nil { + return nil + } + input, ok := reqBody["input"].([]any) + if !ok { + return nil + } + ids := make(map[string]struct{}) + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType != "function_call_output" { + continue + } + if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" { + ids[callID] = struct{}{} + } + } + if len(ids) == 0 { + return nil + } + callIDs := make([]string, 0, len(ids)) + for id := range ids { + callIDs = append(callIDs, id) + } + return callIDs +} + +func legacyHasFunctionCallOutputMissingCallID(reqBody map[string]any) bool { + if reqBody == nil { + return false + } + input, ok := reqBody["input"].([]any) + if !ok { + return false + } + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType != "function_call_output" { + continue + } + callID, _ := itemMap["call_id"].(string) + if strings.TrimSpace(callID) == "" { + return true + } + } + return false +} + +func legacyHasItemReferenceForCallIDs(reqBody map[string]any, callIDs []string) bool { + if reqBody == nil || len(callIDs) == 0 { + return false + } + input, ok := reqBody["input"].([]any) + if !ok { + return false + } + referenceIDs := make(map[string]struct{}) + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType != "item_reference" { + continue + } + idValue, _ := itemMap["id"].(string) + idValue = strings.TrimSpace(idValue) + if idValue == "" { + continue + } + referenceIDs[idValue] = struct{}{} + } + if len(referenceIDs) == 0 { + return false + } + for _, callID := range callIDs { + if _, ok := referenceIDs[callID]; !ok { + return false + } + } + return true +} + +func legacyParseWSIngressPayload(raw []byte) (eventType, model, promptCacheKey, previousResponseID string, payload map[string]any, err error) { + values := gjson.GetManyBytes(raw, "type", "model", "prompt_cache_key", "previous_response_id") + eventType = strings.TrimSpace(values[0].String()) + if eventType == "" { + eventType = "response.create" + } + model = strings.TrimSpace(values[1].String()) + promptCacheKey = strings.TrimSpace(values[2].String()) + previousResponseID = strings.TrimSpace(values[3].String()) + payload = make(map[string]any) + if err = json.Unmarshal(raw, &payload); err != nil { + return "", "", "", "", nil, err + } + if _, exists := payload["type"]; !exists { + payload["type"] = "response.create" + } + return eventType, model, promptCacheKey, previousResponseID, payload, nil +} + +func optimizedParseWSIngressPayload(raw []byte) (eventType, model, promptCacheKey, previousResponseID string, payload map[string]any, err error) { + payload = make(map[string]any) + if err = json.Unmarshal(raw, &payload); err != nil { + return "", "", "", "", nil, err + } + eventType = openAIWSPayloadString(payload, "type") + if eventType == "" { + eventType = "response.create" + payload["type"] = eventType + } + model = openAIWSPayloadString(payload, "model") + promptCacheKey = openAIWSPayloadString(payload, "prompt_cache_key") + previousResponseID = openAIWSPayloadString(payload, "previous_response_id") + return eventType, model, promptCacheKey, previousResponseID, payload, nil +} + +func legacyExtractOpenAIUsageFromJSONBytes(body []byte) (OpenAIUsage, bool) { + var response struct { + Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + InputTokenDetails struct { + CachedTokens int `json:"cached_tokens"` + } `json:"input_tokens_details"` + } `json:"usage"` + } + if err := json.Unmarshal(body, &response); err != nil { + return OpenAIUsage{}, false + } + return OpenAIUsage{ + InputTokens: response.Usage.InputTokens, + OutputTokens: response.Usage.OutputTokens, + CacheReadInputTokens: response.Usage.InputTokenDetails.CachedTokens, + }, true +} diff --git a/backend/internal/service/openai_oauth_passthrough_test.go b/backend/internal/service/openai_oauth_passthrough_test.go index 7a996c26..0840d3b1 100644 --- a/backend/internal/service/openai_oauth_passthrough_test.go +++ b/backend/internal/service/openai_oauth_passthrough_test.go @@ -515,7 +515,7 @@ func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *te require.NoError(t, err) require.Equal(t, false, gjson.GetBytes(upstream.lastBody, "store").Bool()) require.Equal(t, true, gjson.GetBytes(upstream.lastBody, "stream").Bool()) - require.Equal(t, "codex_cli_rs/0.98.0", upstream.lastReq.Header.Get("User-Agent")) + require.Equal(t, "codex_cli_rs/0.104.0", upstream.lastReq.Header.Get("User-Agent")) } func TestOpenAIGatewayService_CodexCLIOnly_RejectsNonCodexClient(t *testing.T) { diff --git a/backend/internal/service/openai_oauth_service.go b/backend/internal/service/openai_oauth_service.go index 087ad4ec..07cb5472 100644 --- a/backend/internal/service/openai_oauth_service.go +++ b/backend/internal/service/openai_oauth_service.go @@ -5,8 +5,12 @@ import ( "crypto/subtle" "encoding/json" "io" + "log/slog" "net/http" "net/url" + "regexp" + "sort" + "strconv" "strings" "time" @@ -16,6 +20,13 @@ import ( var openAISoraSessionAuthURL = "https://sora.chatgpt.com/api/auth/session" +var soraSessionCookiePattern = regexp.MustCompile(`(?i)(?:^|[\n\r;])\s*(?:(?:set-cookie|cookie)\s*:\s*)?__Secure-(?:next-auth|authjs)\.session-token(?:\.(\d+))?=([^;\r\n]+)`) + +type soraSessionChunk struct { + index int + value string +} + // OpenAIOAuthService handles OpenAI OAuth authentication flows type OpenAIOAuthService struct { sessionStore *openai.SessionStore @@ -39,7 +50,7 @@ type OpenAIAuthURLResult struct { } // GenerateAuthURL generates an OpenAI OAuth authorization URL -func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI string) (*OpenAIAuthURLResult, error) { +func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI, platform string) (*OpenAIAuthURLResult, error) { // Generate PKCE values state, err := openai.GenerateState() if err != nil { @@ -75,11 +86,14 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64 if redirectURI == "" { redirectURI = openai.DefaultRedirectURI } + normalizedPlatform := normalizeOpenAIOAuthPlatform(platform) + clientID, _ := openai.OAuthClientConfigByPlatform(normalizedPlatform) // Store session session := &openai.OAuthSession{ State: state, CodeVerifier: codeVerifier, + ClientID: clientID, RedirectURI: redirectURI, ProxyURL: proxyURL, CreatedAt: time.Now(), @@ -87,7 +101,7 @@ func (s *OpenAIOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64 s.sessionStore.Set(sessionID, session) // Build authorization URL - authURL := openai.BuildAuthorizationURL(state, codeChallenge, redirectURI) + authURL := openai.BuildAuthorizationURLForPlatform(state, codeChallenge, redirectURI, normalizedPlatform) return &OpenAIAuthURLResult{ AuthURL: authURL, @@ -111,6 +125,7 @@ type OpenAITokenInfo struct { IDToken string `json:"id_token,omitempty"` ExpiresIn int64 `json:"expires_in"` ExpiresAt int64 `json:"expires_at"` + ClientID string `json:"client_id,omitempty"` Email string `json:"email,omitempty"` ChatGPTAccountID string `json:"chatgpt_account_id,omitempty"` ChatGPTUserID string `json:"chatgpt_user_id,omitempty"` @@ -148,9 +163,13 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch if input.RedirectURI != "" { redirectURI = input.RedirectURI } + clientID := strings.TrimSpace(session.ClientID) + if clientID == "" { + clientID = openai.ClientID + } // Exchange code for token - tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL) + tokenResp, err := s.oauthClient.ExchangeCode(ctx, input.Code, session.CodeVerifier, redirectURI, proxyURL, clientID) if err != nil { return nil, err } @@ -158,8 +177,10 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch // Parse ID token to get user info var userInfo *openai.UserInfo if tokenResp.IDToken != "" { - claims, err := openai.ParseIDToken(tokenResp.IDToken) - if err == nil { + claims, parseErr := openai.ParseIDToken(tokenResp.IDToken) + if parseErr != nil { + slog.Warn("openai_oauth_id_token_parse_failed", "error", parseErr) + } else { userInfo = claims.GetUserInfo() } } @@ -173,6 +194,7 @@ func (s *OpenAIOAuthService) ExchangeCode(ctx context.Context, input *OpenAIExch IDToken: tokenResp.IDToken, ExpiresIn: int64(tokenResp.ExpiresIn), ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn), + ClientID: clientID, } if userInfo != nil { @@ -200,8 +222,10 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre // Parse ID token to get user info var userInfo *openai.UserInfo if tokenResp.IDToken != "" { - claims, err := openai.ParseIDToken(tokenResp.IDToken) - if err == nil { + claims, parseErr := openai.ParseIDToken(tokenResp.IDToken) + if parseErr != nil { + slog.Warn("openai_oauth_id_token_parse_failed", "error", parseErr) + } else { userInfo = claims.GetUserInfo() } } @@ -213,6 +237,9 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre ExpiresIn: int64(tokenResp.ExpiresIn), ExpiresAt: time.Now().Unix() + int64(tokenResp.ExpiresIn), } + if trimmed := strings.TrimSpace(clientID); trimmed != "" { + tokenInfo.ClientID = trimmed + } if userInfo != nil { tokenInfo.Email = userInfo.Email @@ -226,6 +253,7 @@ func (s *OpenAIOAuthService) RefreshTokenWithClientID(ctx context.Context, refre // ExchangeSoraSessionToken exchanges Sora session_token to access_token. func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessionToken string, proxyID *int64) (*OpenAITokenInfo, error) { + sessionToken = normalizeSoraSessionTokenInput(sessionToken) if strings.TrimSpace(sessionToken) == "" { return nil, infraerrors.New(http.StatusBadRequest, "SORA_SESSION_TOKEN_REQUIRED", "session_token is required") } @@ -287,10 +315,141 @@ func (s *OpenAIOAuthService) ExchangeSoraSessionToken(ctx context.Context, sessi AccessToken: strings.TrimSpace(sessionResp.AccessToken), ExpiresIn: expiresIn, ExpiresAt: expiresAt, + ClientID: openai.SoraClientID, Email: strings.TrimSpace(sessionResp.User.Email), }, nil } +func normalizeSoraSessionTokenInput(raw string) string { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return "" + } + + matches := soraSessionCookiePattern.FindAllStringSubmatch(trimmed, -1) + if len(matches) == 0 { + return sanitizeSessionToken(trimmed) + } + + chunkMatches := make([]soraSessionChunk, 0, len(matches)) + singleValues := make([]string, 0, len(matches)) + + for _, match := range matches { + if len(match) < 3 { + continue + } + + value := sanitizeSessionToken(match[2]) + if value == "" { + continue + } + + if strings.TrimSpace(match[1]) == "" { + singleValues = append(singleValues, value) + continue + } + + idx, err := strconv.Atoi(strings.TrimSpace(match[1])) + if err != nil || idx < 0 { + continue + } + chunkMatches = append(chunkMatches, soraSessionChunk{ + index: idx, + value: value, + }) + } + + if merged := mergeLatestSoraSessionChunks(chunkMatches); merged != "" { + return merged + } + + if len(singleValues) > 0 { + return singleValues[len(singleValues)-1] + } + + return "" +} + +func mergeSoraSessionChunkSegment(chunks []soraSessionChunk, requiredMaxIndex int, requireComplete bool) string { + if len(chunks) == 0 { + return "" + } + + byIndex := make(map[int]string, len(chunks)) + for _, chunk := range chunks { + byIndex[chunk.index] = chunk.value + } + + if _, ok := byIndex[0]; !ok { + return "" + } + if requireComplete { + for idx := 0; idx <= requiredMaxIndex; idx++ { + if _, ok := byIndex[idx]; !ok { + return "" + } + } + } + + orderedIndexes := make([]int, 0, len(byIndex)) + for idx := range byIndex { + orderedIndexes = append(orderedIndexes, idx) + } + sort.Ints(orderedIndexes) + + var builder strings.Builder + for _, idx := range orderedIndexes { + if _, err := builder.WriteString(byIndex[idx]); err != nil { + return "" + } + } + return sanitizeSessionToken(builder.String()) +} + +func mergeLatestSoraSessionChunks(chunks []soraSessionChunk) string { + if len(chunks) == 0 { + return "" + } + + requiredMaxIndex := 0 + for _, chunk := range chunks { + if chunk.index > requiredMaxIndex { + requiredMaxIndex = chunk.index + } + } + + groupStarts := make([]int, 0, len(chunks)) + for idx, chunk := range chunks { + if chunk.index == 0 { + groupStarts = append(groupStarts, idx) + } + } + + if len(groupStarts) == 0 { + return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false) + } + + for i := len(groupStarts) - 1; i >= 0; i-- { + start := groupStarts[i] + end := len(chunks) + if i+1 < len(groupStarts) { + end = groupStarts[i+1] + } + if merged := mergeSoraSessionChunkSegment(chunks[start:end], requiredMaxIndex, true); merged != "" { + return merged + } + } + + return mergeSoraSessionChunkSegment(chunks, requiredMaxIndex, false) +} + +func sanitizeSessionToken(raw string) string { + token := strings.TrimSpace(raw) + token = strings.Trim(token, "\"'`") + token = strings.TrimSuffix(token, ";") + return strings.TrimSpace(token) +} + // RefreshAccountToken refreshes token for an OpenAI/Sora OAuth account func (s *OpenAIOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*OpenAITokenInfo, error) { if account.Platform != PlatformOpenAI && account.Platform != PlatformSora { @@ -322,9 +481,12 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo) expiresAt := time.Unix(tokenInfo.ExpiresAt, 0).Format(time.RFC3339) creds := map[string]any{ - "access_token": tokenInfo.AccessToken, - "refresh_token": tokenInfo.RefreshToken, - "expires_at": expiresAt, + "access_token": tokenInfo.AccessToken, + "expires_at": expiresAt, + } + // 仅在刷新响应返回了新的 refresh_token 时才更新,防止用空值覆盖已有令牌 + if strings.TrimSpace(tokenInfo.RefreshToken) != "" { + creds["refresh_token"] = tokenInfo.RefreshToken } if tokenInfo.IDToken != "" { @@ -342,6 +504,9 @@ func (s *OpenAIOAuthService) BuildAccountCredentials(tokenInfo *OpenAITokenInfo) if tokenInfo.OrganizationID != "" { creds["organization_id"] = tokenInfo.OrganizationID } + if strings.TrimSpace(tokenInfo.ClientID) != "" { + creds["client_id"] = strings.TrimSpace(tokenInfo.ClientID) + } return creds } @@ -377,3 +542,12 @@ func newOpenAIOAuthHTTPClient(proxyURL string) *http.Client { Transport: transport, } } + +func normalizeOpenAIOAuthPlatform(platform string) string { + switch strings.ToLower(strings.TrimSpace(platform)) { + case PlatformSora: + return openai.OAuthPlatformSora + default: + return openai.OAuthPlatformOpenAI + } +} diff --git a/backend/internal/service/openai_oauth_service_auth_url_test.go b/backend/internal/service/openai_oauth_service_auth_url_test.go new file mode 100644 index 00000000..5f26903d --- /dev/null +++ b/backend/internal/service/openai_oauth_service_auth_url_test.go @@ -0,0 +1,67 @@ +package service + +import ( + "context" + "errors" + "net/url" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/stretchr/testify/require" +) + +type openaiOAuthClientAuthURLStub struct{} + +func (s *openaiOAuthClientAuthURLStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func (s *openaiOAuthClientAuthURLStub) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func (s *openaiOAuthClientAuthURLStub) RefreshTokenWithClientID(ctx context.Context, refreshToken, proxyURL string, clientID string) (*openai.TokenResponse, error) { + return nil, errors.New("not implemented") +} + +func TestOpenAIOAuthService_GenerateAuthURL_OpenAIKeepsCodexFlow(t *testing.T) { + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientAuthURLStub{}) + defer svc.Stop() + + result, err := svc.GenerateAuthURL(context.Background(), nil, "", PlatformOpenAI) + require.NoError(t, err) + require.NotEmpty(t, result.AuthURL) + require.NotEmpty(t, result.SessionID) + + parsed, err := url.Parse(result.AuthURL) + require.NoError(t, err) + q := parsed.Query() + require.Equal(t, openai.ClientID, q.Get("client_id")) + require.Equal(t, "true", q.Get("codex_cli_simplified_flow")) + + session, ok := svc.sessionStore.Get(result.SessionID) + require.True(t, ok) + require.Equal(t, openai.ClientID, session.ClientID) +} + +// TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient 验证 Sora 平台复用 Codex CLI 的 +// client_id(支持 localhost redirect_uri),但不启用 codex_cli_simplified_flow。 +func TestOpenAIOAuthService_GenerateAuthURL_SoraUsesCodexClient(t *testing.T) { + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientAuthURLStub{}) + defer svc.Stop() + + result, err := svc.GenerateAuthURL(context.Background(), nil, "", PlatformSora) + require.NoError(t, err) + require.NotEmpty(t, result.AuthURL) + require.NotEmpty(t, result.SessionID) + + parsed, err := url.Parse(result.AuthURL) + require.NoError(t, err) + q := parsed.Query() + require.Equal(t, openai.ClientID, q.Get("client_id")) + require.Empty(t, q.Get("codex_cli_simplified_flow")) + + session, ok := svc.sessionStore.Get(result.SessionID) + require.True(t, ok) + require.Equal(t, openai.ClientID, session.ClientID) +} diff --git a/backend/internal/service/openai_oauth_service_sora_session_test.go b/backend/internal/service/openai_oauth_service_sora_session_test.go index fb76f6c1..08da8557 100644 --- a/backend/internal/service/openai_oauth_service_sora_session_test.go +++ b/backend/internal/service/openai_oauth_service_sora_session_test.go @@ -5,6 +5,7 @@ import ( "errors" "net/http" "net/http/httptest" + "strings" "testing" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" @@ -13,7 +14,7 @@ import ( type openaiOAuthClientNoopStub struct{} -func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) { +func (s *openaiOAuthClientNoopStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) { return nil, errors.New("not implemented") } @@ -67,3 +68,106 @@ func TestOpenAIOAuthService_ExchangeSoraSessionToken_MissingAccessToken(t *testi require.Error(t, err) require.Contains(t, err.Error(), "missing access token") } + +func TestOpenAIOAuthService_ExchangeSoraSessionToken_AcceptsSetCookieLine(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=st-cookie-value") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`)) + })) + defer server.Close() + + origin := openAISoraSessionAuthURL + openAISoraSessionAuthURL = server.URL + defer func() { openAISoraSessionAuthURL = origin }() + + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) + defer svc.Stop() + + raw := "__Secure-next-auth.session-token.0=st-cookie-value; Domain=.chatgpt.com; Path=/; HttpOnly; Secure; SameSite=Lax" + info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil) + require.NoError(t, err) + require.Equal(t, "at-token", info.AccessToken) +} + +func TestOpenAIOAuthService_ExchangeSoraSessionToken_MergesChunkedSetCookieLines(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=chunk-0chunk-1") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`)) + })) + defer server.Close() + + origin := openAISoraSessionAuthURL + openAISoraSessionAuthURL = server.URL + defer func() { openAISoraSessionAuthURL = origin }() + + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) + defer svc.Stop() + + raw := strings.Join([]string{ + "Set-Cookie: __Secure-next-auth.session-token.1=chunk-1; Path=/; HttpOnly", + "Set-Cookie: __Secure-next-auth.session-token.0=chunk-0; Path=/; HttpOnly", + }, "\n") + info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil) + require.NoError(t, err) + require.Equal(t, "at-token", info.AccessToken) +} + +func TestOpenAIOAuthService_ExchangeSoraSessionToken_PrefersLatestDuplicateChunks(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=new-0new-1") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`)) + })) + defer server.Close() + + origin := openAISoraSessionAuthURL + openAISoraSessionAuthURL = server.URL + defer func() { openAISoraSessionAuthURL = origin }() + + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) + defer svc.Stop() + + raw := strings.Join([]string{ + "Set-Cookie: __Secure-next-auth.session-token.0=old-0; Path=/; HttpOnly", + "Set-Cookie: __Secure-next-auth.session-token.1=old-1; Path=/; HttpOnly", + "Set-Cookie: __Secure-next-auth.session-token.0=new-0; Path=/; HttpOnly", + "Set-Cookie: __Secure-next-auth.session-token.1=new-1; Path=/; HttpOnly", + }, "\n") + info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil) + require.NoError(t, err) + require.Equal(t, "at-token", info.AccessToken) +} + +func TestOpenAIOAuthService_ExchangeSoraSessionToken_UsesLatestCompleteChunkGroup(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, http.MethodGet, r.Method) + require.Contains(t, r.Header.Get("Cookie"), "__Secure-next-auth.session-token=ok-0ok-1") + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"accessToken":"at-token","expires":"2099-01-01T00:00:00Z","user":{"email":"demo@example.com"}}`)) + })) + defer server.Close() + + origin := openAISoraSessionAuthURL + openAISoraSessionAuthURL = server.URL + defer func() { openAISoraSessionAuthURL = origin }() + + svc := NewOpenAIOAuthService(nil, &openaiOAuthClientNoopStub{}) + defer svc.Stop() + + raw := strings.Join([]string{ + "set-cookie", + "__Secure-next-auth.session-token.0=ok-0; Domain=.chatgpt.com; Path=/", + "set-cookie", + "__Secure-next-auth.session-token.1=ok-1; Domain=.chatgpt.com; Path=/", + "set-cookie", + "__Secure-next-auth.session-token.0=partial-0; Domain=.chatgpt.com; Path=/", + }, "\n") + info, err := svc.ExchangeSoraSessionToken(context.Background(), raw, nil) + require.NoError(t, err) + require.Equal(t, "at-token", info.AccessToken) +} diff --git a/backend/internal/service/openai_oauth_service_state_test.go b/backend/internal/service/openai_oauth_service_state_test.go index 0a2a195f..29252328 100644 --- a/backend/internal/service/openai_oauth_service_state_test.go +++ b/backend/internal/service/openai_oauth_service_state_test.go @@ -13,10 +13,12 @@ import ( type openaiOAuthClientStateStub struct { exchangeCalled int32 + lastClientID string } -func (s *openaiOAuthClientStateStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL string) (*openai.TokenResponse, error) { +func (s *openaiOAuthClientStateStub) ExchangeCode(ctx context.Context, code, codeVerifier, redirectURI, proxyURL, clientID string) (*openai.TokenResponse, error) { atomic.AddInt32(&s.exchangeCalled, 1) + s.lastClientID = clientID return &openai.TokenResponse{ AccessToken: "at", RefreshToken: "rt", @@ -95,6 +97,8 @@ func TestOpenAIOAuthService_ExchangeCode_StateMatch(t *testing.T) { require.NoError(t, err) require.NotNil(t, info) require.Equal(t, "at", info.AccessToken) + require.Equal(t, openai.ClientID, info.ClientID) + require.Equal(t, openai.ClientID, client.lastClientID) require.Equal(t, int32(1), atomic.LoadInt32(&client.exchangeCalled)) _, ok := svc.sessionStore.Get("sid") diff --git a/backend/internal/service/openai_previous_response_id.go b/backend/internal/service/openai_previous_response_id.go new file mode 100644 index 00000000..95865086 --- /dev/null +++ b/backend/internal/service/openai_previous_response_id.go @@ -0,0 +1,37 @@ +package service + +import ( + "regexp" + "strings" +) + +const ( + OpenAIPreviousResponseIDKindEmpty = "empty" + OpenAIPreviousResponseIDKindResponseID = "response_id" + OpenAIPreviousResponseIDKindMessageID = "message_id" + OpenAIPreviousResponseIDKindUnknown = "unknown" +) + +var ( + openAIResponseIDPattern = regexp.MustCompile(`^resp_[A-Za-z0-9_-]{1,256}$`) + openAIMessageIDPattern = regexp.MustCompile(`^(msg|message|item|chatcmpl)_[A-Za-z0-9_-]{1,256}$`) +) + +// ClassifyOpenAIPreviousResponseIDKind classifies previous_response_id to improve diagnostics. +func ClassifyOpenAIPreviousResponseIDKind(id string) string { + trimmed := strings.TrimSpace(id) + if trimmed == "" { + return OpenAIPreviousResponseIDKindEmpty + } + if openAIResponseIDPattern.MatchString(trimmed) { + return OpenAIPreviousResponseIDKindResponseID + } + if openAIMessageIDPattern.MatchString(strings.ToLower(trimmed)) { + return OpenAIPreviousResponseIDKindMessageID + } + return OpenAIPreviousResponseIDKindUnknown +} + +func IsOpenAIPreviousResponseIDLikelyMessageID(id string) bool { + return ClassifyOpenAIPreviousResponseIDKind(id) == OpenAIPreviousResponseIDKindMessageID +} diff --git a/backend/internal/service/openai_previous_response_id_test.go b/backend/internal/service/openai_previous_response_id_test.go new file mode 100644 index 00000000..7867b864 --- /dev/null +++ b/backend/internal/service/openai_previous_response_id_test.go @@ -0,0 +1,34 @@ +package service + +import "testing" + +func TestClassifyOpenAIPreviousResponseIDKind(t *testing.T) { + tests := []struct { + name string + id string + want string + }{ + {name: "empty", id: " ", want: OpenAIPreviousResponseIDKindEmpty}, + {name: "response_id", id: "resp_0906a621bc423a8d0169a108637ef88197b74b0e2f37ba358f", want: OpenAIPreviousResponseIDKindResponseID}, + {name: "message_id", id: "msg_123456", want: OpenAIPreviousResponseIDKindMessageID}, + {name: "item_id", id: "item_abcdef", want: OpenAIPreviousResponseIDKindMessageID}, + {name: "unknown", id: "foo_123456", want: OpenAIPreviousResponseIDKindUnknown}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := ClassifyOpenAIPreviousResponseIDKind(tc.id); got != tc.want { + t.Fatalf("ClassifyOpenAIPreviousResponseIDKind(%q)=%q want=%q", tc.id, got, tc.want) + } + }) + } +} + +func TestIsOpenAIPreviousResponseIDLikelyMessageID(t *testing.T) { + if !IsOpenAIPreviousResponseIDLikelyMessageID("msg_123") { + t.Fatal("expected msg_123 to be identified as message id") + } + if IsOpenAIPreviousResponseIDLikelyMessageID("resp_123") { + t.Fatal("expected resp_123 not to be identified as message id") + } +} diff --git a/backend/internal/service/openai_sticky_compat.go b/backend/internal/service/openai_sticky_compat.go new file mode 100644 index 00000000..e897debc --- /dev/null +++ b/backend/internal/service/openai_sticky_compat.go @@ -0,0 +1,214 @@ +package service + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "strings" + "sync/atomic" + "time" + + "github.com/cespare/xxhash/v2" + "github.com/gin-gonic/gin" +) + +type openAILegacySessionHashContextKey struct{} + +var openAILegacySessionHashKey = openAILegacySessionHashContextKey{} + +var ( + openAIStickyLegacyReadFallbackTotal atomic.Int64 + openAIStickyLegacyReadFallbackHit atomic.Int64 + openAIStickyLegacyDualWriteTotal atomic.Int64 +) + +func openAIStickyCompatStats() (legacyReadFallbackTotal, legacyReadFallbackHit, legacyDualWriteTotal int64) { + return openAIStickyLegacyReadFallbackTotal.Load(), + openAIStickyLegacyReadFallbackHit.Load(), + openAIStickyLegacyDualWriteTotal.Load() +} + +func deriveOpenAISessionHashes(sessionID string) (currentHash string, legacyHash string) { + normalized := strings.TrimSpace(sessionID) + if normalized == "" { + return "", "" + } + + currentHash = fmt.Sprintf("%016x", xxhash.Sum64String(normalized)) + sum := sha256.Sum256([]byte(normalized)) + legacyHash = hex.EncodeToString(sum[:]) + return currentHash, legacyHash +} + +func withOpenAILegacySessionHash(ctx context.Context, legacyHash string) context.Context { + if ctx == nil { + return nil + } + trimmed := strings.TrimSpace(legacyHash) + if trimmed == "" { + return ctx + } + return context.WithValue(ctx, openAILegacySessionHashKey, trimmed) +} + +func openAILegacySessionHashFromContext(ctx context.Context) string { + if ctx == nil { + return "" + } + value, _ := ctx.Value(openAILegacySessionHashKey).(string) + return strings.TrimSpace(value) +} + +func attachOpenAILegacySessionHashToGin(c *gin.Context, legacyHash string) { + if c == nil || c.Request == nil { + return + } + c.Request = c.Request.WithContext(withOpenAILegacySessionHash(c.Request.Context(), legacyHash)) +} + +func (s *OpenAIGatewayService) openAISessionHashReadOldFallbackEnabled() bool { + if s == nil || s.cfg == nil { + return true + } + return s.cfg.Gateway.OpenAIWS.SessionHashReadOldFallback +} + +func (s *OpenAIGatewayService) openAISessionHashDualWriteOldEnabled() bool { + if s == nil || s.cfg == nil { + return true + } + return s.cfg.Gateway.OpenAIWS.SessionHashDualWriteOld +} + +func (s *OpenAIGatewayService) openAISessionCacheKey(sessionHash string) string { + normalized := strings.TrimSpace(sessionHash) + if normalized == "" { + return "" + } + return "openai:" + normalized +} + +func (s *OpenAIGatewayService) openAILegacySessionCacheKey(ctx context.Context, sessionHash string) string { + legacyHash := openAILegacySessionHashFromContext(ctx) + if legacyHash == "" { + return "" + } + legacyKey := "openai:" + legacyHash + if legacyKey == s.openAISessionCacheKey(sessionHash) { + return "" + } + return legacyKey +} + +func (s *OpenAIGatewayService) openAIStickyLegacyTTL(ttl time.Duration) time.Duration { + legacyTTL := ttl + if legacyTTL <= 0 { + legacyTTL = openaiStickySessionTTL + } + if legacyTTL > 10*time.Minute { + return 10 * time.Minute + } + return legacyTTL +} + +func (s *OpenAIGatewayService) getStickySessionAccountID(ctx context.Context, groupID *int64, sessionHash string) (int64, error) { + if s == nil || s.cache == nil { + return 0, nil + } + + primaryKey := s.openAISessionCacheKey(sessionHash) + if primaryKey == "" { + return 0, nil + } + + accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), primaryKey) + if err == nil && accountID > 0 { + return accountID, nil + } + if !s.openAISessionHashReadOldFallbackEnabled() { + return accountID, err + } + + legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash) + if legacyKey == "" { + return accountID, err + } + + openAIStickyLegacyReadFallbackTotal.Add(1) + legacyAccountID, legacyErr := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), legacyKey) + if legacyErr == nil && legacyAccountID > 0 { + openAIStickyLegacyReadFallbackHit.Add(1) + return legacyAccountID, nil + } + return accountID, err +} + +func (s *OpenAIGatewayService) setStickySessionAccountID(ctx context.Context, groupID *int64, sessionHash string, accountID int64, ttl time.Duration) error { + if s == nil || s.cache == nil || accountID <= 0 { + return nil + } + primaryKey := s.openAISessionCacheKey(sessionHash) + if primaryKey == "" { + return nil + } + + if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), primaryKey, accountID, ttl); err != nil { + return err + } + + if !s.openAISessionHashDualWriteOldEnabled() { + return nil + } + legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash) + if legacyKey == "" { + return nil + } + if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), legacyKey, accountID, s.openAIStickyLegacyTTL(ttl)); err != nil { + return err + } + openAIStickyLegacyDualWriteTotal.Add(1) + return nil +} + +func (s *OpenAIGatewayService) refreshStickySessionTTL(ctx context.Context, groupID *int64, sessionHash string, ttl time.Duration) error { + if s == nil || s.cache == nil { + return nil + } + primaryKey := s.openAISessionCacheKey(sessionHash) + if primaryKey == "" { + return nil + } + + err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), primaryKey, ttl) + if !s.openAISessionHashReadOldFallbackEnabled() && !s.openAISessionHashDualWriteOldEnabled() { + return err + } + + legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash) + if legacyKey != "" { + _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), legacyKey, s.openAIStickyLegacyTTL(ttl)) + } + return err +} + +func (s *OpenAIGatewayService) deleteStickySessionAccountID(ctx context.Context, groupID *int64, sessionHash string) error { + if s == nil || s.cache == nil { + return nil + } + primaryKey := s.openAISessionCacheKey(sessionHash) + if primaryKey == "" { + return nil + } + + err := s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), primaryKey) + if !s.openAISessionHashReadOldFallbackEnabled() && !s.openAISessionHashDualWriteOldEnabled() { + return err + } + + legacyKey := s.openAILegacySessionCacheKey(ctx, sessionHash) + if legacyKey != "" { + _ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), legacyKey) + } + return err +} diff --git a/backend/internal/service/openai_sticky_compat_test.go b/backend/internal/service/openai_sticky_compat_test.go new file mode 100644 index 00000000..9f57c358 --- /dev/null +++ b/backend/internal/service/openai_sticky_compat_test.go @@ -0,0 +1,96 @@ +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/stretchr/testify/require" +) + +func TestGetStickySessionAccountID_FallbackToLegacyKey(t *testing.T) { + beforeFallbackTotal, beforeFallbackHit, _ := openAIStickyCompatStats() + + cache := &stubGatewayCache{ + sessionBindings: map[string]int64{ + "openai:legacy-hash": 42, + }, + } + svc := &OpenAIGatewayService{ + cache: cache, + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + OpenAIWS: config.GatewayOpenAIWSConfig{ + SessionHashReadOldFallback: true, + }, + }, + }, + } + + ctx := withOpenAILegacySessionHash(context.Background(), "legacy-hash") + accountID, err := svc.getStickySessionAccountID(ctx, nil, "new-hash") + require.NoError(t, err) + require.Equal(t, int64(42), accountID) + + afterFallbackTotal, afterFallbackHit, _ := openAIStickyCompatStats() + require.Equal(t, beforeFallbackTotal+1, afterFallbackTotal) + require.Equal(t, beforeFallbackHit+1, afterFallbackHit) +} + +func TestSetStickySessionAccountID_DualWriteOldEnabled(t *testing.T) { + _, _, beforeDualWriteTotal := openAIStickyCompatStats() + + cache := &stubGatewayCache{sessionBindings: map[string]int64{}} + svc := &OpenAIGatewayService{ + cache: cache, + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + OpenAIWS: config.GatewayOpenAIWSConfig{ + SessionHashDualWriteOld: true, + }, + }, + }, + } + + ctx := withOpenAILegacySessionHash(context.Background(), "legacy-hash") + err := svc.setStickySessionAccountID(ctx, nil, "new-hash", 9, openaiStickySessionTTL) + require.NoError(t, err) + require.Equal(t, int64(9), cache.sessionBindings["openai:new-hash"]) + require.Equal(t, int64(9), cache.sessionBindings["openai:legacy-hash"]) + + _, _, afterDualWriteTotal := openAIStickyCompatStats() + require.Equal(t, beforeDualWriteTotal+1, afterDualWriteTotal) +} + +func TestSetStickySessionAccountID_DualWriteOldDisabled(t *testing.T) { + cache := &stubGatewayCache{sessionBindings: map[string]int64{}} + svc := &OpenAIGatewayService{ + cache: cache, + cfg: &config.Config{ + Gateway: config.GatewayConfig{ + OpenAIWS: config.GatewayOpenAIWSConfig{ + SessionHashDualWriteOld: false, + }, + }, + }, + } + + ctx := withOpenAILegacySessionHash(context.Background(), "legacy-hash") + err := svc.setStickySessionAccountID(ctx, nil, "new-hash", 9, openaiStickySessionTTL) + require.NoError(t, err) + require.Equal(t, int64(9), cache.sessionBindings["openai:new-hash"]) + _, exists := cache.sessionBindings["openai:legacy-hash"] + require.False(t, exists) +} + +func TestSnapshotOpenAICompatibilityFallbackMetrics(t *testing.T) { + before := SnapshotOpenAICompatibilityFallbackMetrics() + + ctx := context.WithValue(context.Background(), ctxkey.ThinkingEnabled, true) + _, _ = ThinkingEnabledFromContext(ctx) + + after := SnapshotOpenAICompatibilityFallbackMetrics() + require.GreaterOrEqual(t, after.MetadataLegacyFallbackTotal, before.MetadataLegacyFallbackTotal+1) + require.GreaterOrEqual(t, after.MetadataLegacyFallbackThinkingEnabledTotal, before.MetadataLegacyFallbackThinkingEnabledTotal+1) +} diff --git a/backend/internal/service/openai_tool_continuation.go b/backend/internal/service/openai_tool_continuation.go index e59082b2..dea3c172 100644 --- a/backend/internal/service/openai_tool_continuation.go +++ b/backend/internal/service/openai_tool_continuation.go @@ -2,6 +2,24 @@ package service import "strings" +// ToolContinuationSignals 聚合工具续链相关信号,避免重复遍历 input。 +type ToolContinuationSignals struct { + HasFunctionCallOutput bool + HasFunctionCallOutputMissingCallID bool + HasToolCallContext bool + HasItemReference bool + HasItemReferenceForAllCallIDs bool + FunctionCallOutputCallIDs []string +} + +// FunctionCallOutputValidation 汇总 function_call_output 关联性校验结果。 +type FunctionCallOutputValidation struct { + HasFunctionCallOutput bool + HasToolCallContext bool + HasFunctionCallOutputMissingCallID bool + HasItemReferenceForAllCallIDs bool +} + // NeedsToolContinuation 判定请求是否需要工具调用续链处理。 // 满足以下任一信号即视为续链:previous_response_id、input 内包含 function_call_output/item_reference、 // 或显式声明 tools/tool_choice。 @@ -18,107 +36,191 @@ func NeedsToolContinuation(reqBody map[string]any) bool { if hasToolChoiceSignal(reqBody) { return true } - if inputHasType(reqBody, "function_call_output") { - return true + input, ok := reqBody["input"].([]any) + if !ok { + return false } - if inputHasType(reqBody, "item_reference") { - return true + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + if itemType == "function_call_output" || itemType == "item_reference" { + return true + } } return false } +// AnalyzeToolContinuationSignals 单次遍历 input,提取 function_call_output/tool_call/item_reference 相关信号。 +func AnalyzeToolContinuationSignals(reqBody map[string]any) ToolContinuationSignals { + signals := ToolContinuationSignals{} + if reqBody == nil { + return signals + } + input, ok := reqBody["input"].([]any) + if !ok { + return signals + } + + var callIDs map[string]struct{} + var referenceIDs map[string]struct{} + + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + switch itemType { + case "tool_call", "function_call": + callID, _ := itemMap["call_id"].(string) + if strings.TrimSpace(callID) != "" { + signals.HasToolCallContext = true + } + case "function_call_output": + signals.HasFunctionCallOutput = true + callID, _ := itemMap["call_id"].(string) + callID = strings.TrimSpace(callID) + if callID == "" { + signals.HasFunctionCallOutputMissingCallID = true + continue + } + if callIDs == nil { + callIDs = make(map[string]struct{}) + } + callIDs[callID] = struct{}{} + case "item_reference": + signals.HasItemReference = true + idValue, _ := itemMap["id"].(string) + idValue = strings.TrimSpace(idValue) + if idValue == "" { + continue + } + if referenceIDs == nil { + referenceIDs = make(map[string]struct{}) + } + referenceIDs[idValue] = struct{}{} + } + } + + if len(callIDs) == 0 { + return signals + } + signals.FunctionCallOutputCallIDs = make([]string, 0, len(callIDs)) + allReferenced := len(referenceIDs) > 0 + for callID := range callIDs { + signals.FunctionCallOutputCallIDs = append(signals.FunctionCallOutputCallIDs, callID) + if allReferenced { + if _, ok := referenceIDs[callID]; !ok { + allReferenced = false + } + } + } + signals.HasItemReferenceForAllCallIDs = allReferenced + return signals +} + +// ValidateFunctionCallOutputContext 为 handler 提供低开销校验结果: +// 1) 无 function_call_output 直接返回 +// 2) 若已存在 tool_call/function_call 上下文则提前返回 +// 3) 仅在无工具上下文时才构建 call_id / item_reference 集合 +func ValidateFunctionCallOutputContext(reqBody map[string]any) FunctionCallOutputValidation { + result := FunctionCallOutputValidation{} + if reqBody == nil { + return result + } + input, ok := reqBody["input"].([]any) + if !ok { + return result + } + + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + switch itemType { + case "function_call_output": + result.HasFunctionCallOutput = true + case "tool_call", "function_call": + callID, _ := itemMap["call_id"].(string) + if strings.TrimSpace(callID) != "" { + result.HasToolCallContext = true + } + } + if result.HasFunctionCallOutput && result.HasToolCallContext { + return result + } + } + + if !result.HasFunctionCallOutput || result.HasToolCallContext { + return result + } + + callIDs := make(map[string]struct{}) + referenceIDs := make(map[string]struct{}) + for _, item := range input { + itemMap, ok := item.(map[string]any) + if !ok { + continue + } + itemType, _ := itemMap["type"].(string) + switch itemType { + case "function_call_output": + callID, _ := itemMap["call_id"].(string) + callID = strings.TrimSpace(callID) + if callID == "" { + result.HasFunctionCallOutputMissingCallID = true + continue + } + callIDs[callID] = struct{}{} + case "item_reference": + idValue, _ := itemMap["id"].(string) + idValue = strings.TrimSpace(idValue) + if idValue == "" { + continue + } + referenceIDs[idValue] = struct{}{} + } + } + + if len(callIDs) == 0 || len(referenceIDs) == 0 { + return result + } + allReferenced := true + for callID := range callIDs { + if _, ok := referenceIDs[callID]; !ok { + allReferenced = false + break + } + } + result.HasItemReferenceForAllCallIDs = allReferenced + return result +} + // HasFunctionCallOutput 判断 input 是否包含 function_call_output,用于触发续链校验。 func HasFunctionCallOutput(reqBody map[string]any) bool { - if reqBody == nil { - return false - } - return inputHasType(reqBody, "function_call_output") + return AnalyzeToolContinuationSignals(reqBody).HasFunctionCallOutput } // HasToolCallContext 判断 input 是否包含带 call_id 的 tool_call/function_call, // 用于判断 function_call_output 是否具备可关联的上下文。 func HasToolCallContext(reqBody map[string]any) bool { - if reqBody == nil { - return false - } - input, ok := reqBody["input"].([]any) - if !ok { - return false - } - for _, item := range input { - itemMap, ok := item.(map[string]any) - if !ok { - continue - } - itemType, _ := itemMap["type"].(string) - if itemType != "tool_call" && itemType != "function_call" { - continue - } - if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" { - return true - } - } - return false + return AnalyzeToolContinuationSignals(reqBody).HasToolCallContext } // FunctionCallOutputCallIDs 提取 input 中 function_call_output 的 call_id 集合。 // 仅返回非空 call_id,用于与 item_reference.id 做匹配校验。 func FunctionCallOutputCallIDs(reqBody map[string]any) []string { - if reqBody == nil { - return nil - } - input, ok := reqBody["input"].([]any) - if !ok { - return nil - } - ids := make(map[string]struct{}) - for _, item := range input { - itemMap, ok := item.(map[string]any) - if !ok { - continue - } - itemType, _ := itemMap["type"].(string) - if itemType != "function_call_output" { - continue - } - if callID, ok := itemMap["call_id"].(string); ok && strings.TrimSpace(callID) != "" { - ids[callID] = struct{}{} - } - } - if len(ids) == 0 { - return nil - } - result := make([]string, 0, len(ids)) - for id := range ids { - result = append(result, id) - } - return result + return AnalyzeToolContinuationSignals(reqBody).FunctionCallOutputCallIDs } // HasFunctionCallOutputMissingCallID 判断是否存在缺少 call_id 的 function_call_output。 func HasFunctionCallOutputMissingCallID(reqBody map[string]any) bool { - if reqBody == nil { - return false - } - input, ok := reqBody["input"].([]any) - if !ok { - return false - } - for _, item := range input { - itemMap, ok := item.(map[string]any) - if !ok { - continue - } - itemType, _ := itemMap["type"].(string) - if itemType != "function_call_output" { - continue - } - callID, _ := itemMap["call_id"].(string) - if strings.TrimSpace(callID) == "" { - return true - } - } - return false + return AnalyzeToolContinuationSignals(reqBody).HasFunctionCallOutputMissingCallID } // HasItemReferenceForCallIDs 判断 item_reference.id 是否覆盖所有 call_id。 @@ -152,32 +254,13 @@ func HasItemReferenceForCallIDs(reqBody map[string]any, callIDs []string) bool { return false } for _, callID := range callIDs { - if _, ok := referenceIDs[callID]; !ok { + if _, ok := referenceIDs[strings.TrimSpace(callID)]; !ok { return false } } return true } -// inputHasType 判断 input 中是否存在指定类型的 item。 -func inputHasType(reqBody map[string]any, want string) bool { - input, ok := reqBody["input"].([]any) - if !ok { - return false - } - for _, item := range input { - itemMap, ok := item.(map[string]any) - if !ok { - continue - } - itemType, _ := itemMap["type"].(string) - if itemType == want { - return true - } - } - return false -} - // hasNonEmptyString 判断字段是否为非空字符串。 func hasNonEmptyString(value any) bool { stringValue, ok := value.(string) diff --git a/backend/internal/service/openai_tool_corrector.go b/backend/internal/service/openai_tool_corrector.go index deec80fa..348723a6 100644 --- a/backend/internal/service/openai_tool_corrector.go +++ b/backend/internal/service/openai_tool_corrector.go @@ -1,11 +1,15 @@ package service import ( - "encoding/json" + "bytes" "fmt" + "strconv" + "strings" "sync" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" ) // codexToolNameMapping 定义 Codex 原生工具名称到 OpenCode 工具名称的映射 @@ -62,169 +66,201 @@ func (c *CodexToolCorrector) CorrectToolCallsInSSEData(data string) (string, boo if data == "" || data == "\n" { return data, false } + correctedBytes, corrected := c.CorrectToolCallsInSSEBytes([]byte(data)) + if !corrected { + return data, false + } + return string(correctedBytes), true +} - // 尝试解析 JSON - var payload map[string]any - if err := json.Unmarshal([]byte(data), &payload); err != nil { - // 不是有效的 JSON,直接返回原数据 +// CorrectToolCallsInSSEBytes 修正 SSE JSON 数据中的工具调用(字节路径)。 +// 返回修正后的数据和是否进行了修正。 +func (c *CodexToolCorrector) CorrectToolCallsInSSEBytes(data []byte) ([]byte, bool) { + if len(bytes.TrimSpace(data)) == 0 { + return data, false + } + if !mayContainToolCallPayload(data) { + return data, false + } + if !gjson.ValidBytes(data) { + // 不是有效 JSON,直接返回原数据 return data, false } + updated := data corrected := false - - // 处理 tool_calls 数组 - if toolCalls, ok := payload["tool_calls"].([]any); ok { - if c.correctToolCallsArray(toolCalls) { + collect := func(changed bool, next []byte) { + if changed { corrected = true + updated = next } } - // 处理 function_call 对象 - if functionCall, ok := payload["function_call"].(map[string]any); ok { - if c.correctFunctionCall(functionCall) { - corrected = true - } + if next, changed := c.correctToolCallsArrayAtPath(updated, "tool_calls"); changed { + collect(changed, next) + } + if next, changed := c.correctFunctionAtPath(updated, "function_call"); changed { + collect(changed, next) + } + if next, changed := c.correctToolCallsArrayAtPath(updated, "delta.tool_calls"); changed { + collect(changed, next) + } + if next, changed := c.correctFunctionAtPath(updated, "delta.function_call"); changed { + collect(changed, next) } - // 处理 delta.tool_calls - if delta, ok := payload["delta"].(map[string]any); ok { - if toolCalls, ok := delta["tool_calls"].([]any); ok { - if c.correctToolCallsArray(toolCalls) { - corrected = true - } + choicesCount := int(gjson.GetBytes(updated, "choices.#").Int()) + for i := 0; i < choicesCount; i++ { + prefix := "choices." + strconv.Itoa(i) + if next, changed := c.correctToolCallsArrayAtPath(updated, prefix+".message.tool_calls"); changed { + collect(changed, next) } - if functionCall, ok := delta["function_call"].(map[string]any); ok { - if c.correctFunctionCall(functionCall) { - corrected = true - } + if next, changed := c.correctFunctionAtPath(updated, prefix+".message.function_call"); changed { + collect(changed, next) } - } - - // 处理 choices[].message.tool_calls 和 choices[].delta.tool_calls - if choices, ok := payload["choices"].([]any); ok { - for _, choice := range choices { - if choiceMap, ok := choice.(map[string]any); ok { - // 处理 message 中的工具调用 - if message, ok := choiceMap["message"].(map[string]any); ok { - if toolCalls, ok := message["tool_calls"].([]any); ok { - if c.correctToolCallsArray(toolCalls) { - corrected = true - } - } - if functionCall, ok := message["function_call"].(map[string]any); ok { - if c.correctFunctionCall(functionCall) { - corrected = true - } - } - } - // 处理 delta 中的工具调用 - if delta, ok := choiceMap["delta"].(map[string]any); ok { - if toolCalls, ok := delta["tool_calls"].([]any); ok { - if c.correctToolCallsArray(toolCalls) { - corrected = true - } - } - if functionCall, ok := delta["function_call"].(map[string]any); ok { - if c.correctFunctionCall(functionCall) { - corrected = true - } - } - } - } + if next, changed := c.correctToolCallsArrayAtPath(updated, prefix+".delta.tool_calls"); changed { + collect(changed, next) + } + if next, changed := c.correctFunctionAtPath(updated, prefix+".delta.function_call"); changed { + collect(changed, next) } } if !corrected { return data, false } + return updated, true +} - // 序列化回 JSON - correctedBytes, err := json.Marshal(payload) - if err != nil { - logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Failed to marshal corrected data: %v", err) +func mayContainToolCallPayload(data []byte) bool { + // 快速路径:多数 token / 文本事件不包含工具字段,避免进入 JSON 解析热路径。 + return bytes.Contains(data, []byte(`"tool_calls"`)) || + bytes.Contains(data, []byte(`"function_call"`)) || + bytes.Contains(data, []byte(`"function":{"name"`)) +} + +// correctToolCallsArrayAtPath 修正指定路径下 tool_calls 数组中的工具名称。 +func (c *CodexToolCorrector) correctToolCallsArrayAtPath(data []byte, toolCallsPath string) ([]byte, bool) { + count := int(gjson.GetBytes(data, toolCallsPath+".#").Int()) + if count <= 0 { return data, false } - - return string(correctedBytes), true -} - -// correctToolCallsArray 修正工具调用数组中的工具名称 -func (c *CodexToolCorrector) correctToolCallsArray(toolCalls []any) bool { + updated := data corrected := false - for _, toolCall := range toolCalls { - if toolCallMap, ok := toolCall.(map[string]any); ok { - if function, ok := toolCallMap["function"].(map[string]any); ok { - if c.correctFunctionCall(function) { - corrected = true - } - } + for i := 0; i < count; i++ { + functionPath := toolCallsPath + "." + strconv.Itoa(i) + ".function" + if next, changed := c.correctFunctionAtPath(updated, functionPath); changed { + updated = next + corrected = true } } - return corrected + return updated, corrected } -// correctFunctionCall 修正单个函数调用的工具名称和参数 -func (c *CodexToolCorrector) correctFunctionCall(functionCall map[string]any) bool { - name, ok := functionCall["name"].(string) - if !ok || name == "" { - return false +// correctFunctionAtPath 修正指定路径下单个函数调用的工具名称和参数。 +func (c *CodexToolCorrector) correctFunctionAtPath(data []byte, functionPath string) ([]byte, bool) { + namePath := functionPath + ".name" + nameResult := gjson.GetBytes(data, namePath) + if !nameResult.Exists() || nameResult.Type != gjson.String { + return data, false } - + name := strings.TrimSpace(nameResult.Str) + if name == "" { + return data, false + } + updated := data corrected := false // 查找并修正工具名称 if correctName, found := codexToolNameMapping[name]; found { - functionCall["name"] = correctName - c.recordCorrection(name, correctName) - corrected = true - name = correctName // 使用修正后的名称进行参数修正 + if next, err := sjson.SetBytes(updated, namePath, correctName); err == nil { + updated = next + c.recordCorrection(name, correctName) + corrected = true + name = correctName // 使用修正后的名称进行参数修正 + } } // 修正工具参数(基于工具名称) - if c.correctToolParameters(name, functionCall) { + if next, changed := c.correctToolParametersAtPath(updated, functionPath+".arguments", name); changed { + updated = next corrected = true } - - return corrected + return updated, corrected } -// correctToolParameters 修正工具参数以符合 OpenCode 规范 -func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall map[string]any) bool { - arguments, ok := functionCall["arguments"] - if !ok { - return false +// correctToolParametersAtPath 修正指定路径下 arguments 参数。 +func (c *CodexToolCorrector) correctToolParametersAtPath(data []byte, argumentsPath, toolName string) ([]byte, bool) { + if toolName != "bash" && toolName != "edit" { + return data, false } - // arguments 可能是字符串(JSON)或已解析的 map - var argsMap map[string]any - switch v := arguments.(type) { - case string: - // 解析 JSON 字符串 - if err := json.Unmarshal([]byte(v), &argsMap); err != nil { - return false + args := gjson.GetBytes(data, argumentsPath) + if !args.Exists() { + return data, false + } + + switch args.Type { + case gjson.String: + argsJSON := strings.TrimSpace(args.Str) + if !gjson.Valid(argsJSON) { + return data, false } - case map[string]any: - argsMap = v + if !gjson.Parse(argsJSON).IsObject() { + return data, false + } + nextArgsJSON, corrected := c.correctToolArgumentsJSON(argsJSON, toolName) + if !corrected { + return data, false + } + next, err := sjson.SetBytes(data, argumentsPath, nextArgsJSON) + if err != nil { + return data, false + } + return next, true + case gjson.JSON: + if !args.IsObject() || !gjson.Valid(args.Raw) { + return data, false + } + nextArgsJSON, corrected := c.correctToolArgumentsJSON(args.Raw, toolName) + if !corrected { + return data, false + } + next, err := sjson.SetRawBytes(data, argumentsPath, []byte(nextArgsJSON)) + if err != nil { + return data, false + } + return next, true default: - return false + return data, false + } +} + +// correctToolArgumentsJSON 修正工具参数 JSON(对象字符串),返回修正后的 JSON 与是否变更。 +func (c *CodexToolCorrector) correctToolArgumentsJSON(argsJSON, toolName string) (string, bool) { + if !gjson.Valid(argsJSON) { + return argsJSON, false + } + if !gjson.Parse(argsJSON).IsObject() { + return argsJSON, false } + updated := argsJSON corrected := false // 根据工具名称应用特定的参数修正规则 switch toolName { case "bash": // OpenCode bash 支持 workdir;有些来源会输出 work_dir。 - if _, hasWorkdir := argsMap["workdir"]; !hasWorkdir { - if workDir, exists := argsMap["work_dir"]; exists { - argsMap["workdir"] = workDir - delete(argsMap, "work_dir") + if !gjson.Get(updated, "workdir").Exists() { + if next, changed := moveJSONField(updated, "work_dir", "workdir"); changed { + updated = next corrected = true logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'work_dir' to 'workdir' in bash tool") } } else { - if _, exists := argsMap["work_dir"]; exists { - delete(argsMap, "work_dir") + if next, changed := deleteJSONField(updated, "work_dir"); changed { + updated = next corrected = true logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Removed duplicate 'work_dir' parameter from bash tool") } @@ -232,67 +268,71 @@ func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall case "edit": // OpenCode edit 参数为 filePath/oldString/newString(camelCase)。 - if _, exists := argsMap["filePath"]; !exists { - if filePath, exists := argsMap["file_path"]; exists { - argsMap["filePath"] = filePath - delete(argsMap, "file_path") + if !gjson.Get(updated, "filePath").Exists() { + if next, changed := moveJSONField(updated, "file_path", "filePath"); changed { + updated = next corrected = true logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'file_path' to 'filePath' in edit tool") - } else if filePath, exists := argsMap["path"]; exists { - argsMap["filePath"] = filePath - delete(argsMap, "path") + } else if next, changed := moveJSONField(updated, "path", "filePath"); changed { + updated = next corrected = true logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'path' to 'filePath' in edit tool") - } else if filePath, exists := argsMap["file"]; exists { - argsMap["filePath"] = filePath - delete(argsMap, "file") + } else if next, changed := moveJSONField(updated, "file", "filePath"); changed { + updated = next corrected = true logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'file' to 'filePath' in edit tool") } } - if _, exists := argsMap["oldString"]; !exists { - if oldString, exists := argsMap["old_string"]; exists { - argsMap["oldString"] = oldString - delete(argsMap, "old_string") - corrected = true - logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'old_string' to 'oldString' in edit tool") - } + if next, changed := moveJSONField(updated, "old_string", "oldString"); changed { + updated = next + corrected = true + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'old_string' to 'oldString' in edit tool") } - if _, exists := argsMap["newString"]; !exists { - if newString, exists := argsMap["new_string"]; exists { - argsMap["newString"] = newString - delete(argsMap, "new_string") - corrected = true - logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'new_string' to 'newString' in edit tool") - } + if next, changed := moveJSONField(updated, "new_string", "newString"); changed { + updated = next + corrected = true + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'new_string' to 'newString' in edit tool") } - if _, exists := argsMap["replaceAll"]; !exists { - if replaceAll, exists := argsMap["replace_all"]; exists { - argsMap["replaceAll"] = replaceAll - delete(argsMap, "replace_all") - corrected = true - logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'replace_all' to 'replaceAll' in edit tool") - } + if next, changed := moveJSONField(updated, "replace_all", "replaceAll"); changed { + updated = next + corrected = true + logger.LegacyPrintf("service.openai_tool_corrector", "[CodexToolCorrector] Renamed 'replace_all' to 'replaceAll' in edit tool") } } + return updated, corrected +} - // 如果修正了参数,需要重新序列化 - if corrected { - if _, wasString := arguments.(string); wasString { - // 原本是字符串,序列化回字符串 - if newArgsJSON, err := json.Marshal(argsMap); err == nil { - functionCall["arguments"] = string(newArgsJSON) - } - } else { - // 原本是 map,直接赋值 - functionCall["arguments"] = argsMap - } +func moveJSONField(input, from, to string) (string, bool) { + if gjson.Get(input, to).Exists() { + return input, false } + src := gjson.Get(input, from) + if !src.Exists() { + return input, false + } + next, err := sjson.SetRaw(input, to, src.Raw) + if err != nil { + return input, false + } + next, err = sjson.Delete(next, from) + if err != nil { + return input, false + } + return next, true +} - return corrected +func deleteJSONField(input, path string) (string, bool) { + if !gjson.Get(input, path).Exists() { + return input, false + } + next, err := sjson.Delete(input, path) + if err != nil { + return input, false + } + return next, true } // recordCorrection 记录一次工具名称修正 diff --git a/backend/internal/service/openai_tool_corrector_test.go b/backend/internal/service/openai_tool_corrector_test.go index ff518ea6..7c83de9e 100644 --- a/backend/internal/service/openai_tool_corrector_test.go +++ b/backend/internal/service/openai_tool_corrector_test.go @@ -5,6 +5,15 @@ import ( "testing" ) +func TestMayContainToolCallPayload(t *testing.T) { + if mayContainToolCallPayload([]byte(`{"type":"response.output_text.delta","delta":"hello"}`)) { + t.Fatalf("plain text event should not trigger tool-call parsing") + } + if !mayContainToolCallPayload([]byte(`{"tool_calls":[{"function":{"name":"apply_patch"}}]}`)) { + t.Fatalf("tool_calls event should trigger tool-call parsing") + } +} + func TestCorrectToolCallsInSSEData(t *testing.T) { corrector := NewCodexToolCorrector() diff --git a/backend/internal/service/openai_ws_account_sticky_test.go b/backend/internal/service/openai_ws_account_sticky_test.go new file mode 100644 index 00000000..3fe08179 --- /dev/null +++ b/backend/internal/service/openai_ws_account_sticky_test.go @@ -0,0 +1,190 @@ +package service + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Hit(t *testing.T) { + ctx := context.Background() + groupID := int64(23) + account := Account{ + ID: 2, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 2, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + cache := &stubGatewayCache{} + store := NewOpenAIWSStateStore(cache) + cfg := newOpenAIWSV2TestConfig() + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + openaiWSStateStore: store, + } + + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_1", account.ID, time.Hour)) + + selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_1", "gpt-5.1", nil) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, account.ID, selection.Account.ID) + require.True(t, selection.Acquired) + if selection.ReleaseFunc != nil { + selection.ReleaseFunc() + } +} + +func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_Excluded(t *testing.T) { + ctx := context.Background() + groupID := int64(23) + account := Account{ + ID: 8, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + cache := &stubGatewayCache{} + store := NewOpenAIWSStateStore(cache) + cfg := newOpenAIWSV2TestConfig() + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + openaiWSStateStore: store, + } + + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_2", account.ID, time.Hour)) + + selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_2", "gpt-5.1", map[int64]struct{}{account.ID: {}}) + require.NoError(t, err) + require.Nil(t, selection) +} + +func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_ForceHTTPIgnored(t *testing.T) { + ctx := context.Background() + groupID := int64(23) + account := Account{ + ID: 11, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Extra: map[string]any{ + "openai_ws_force_http": true, + "responses_websockets_v2_enabled": true, + }, + } + cache := &stubGatewayCache{} + store := NewOpenAIWSStateStore(cache) + cfg := newOpenAIWSV2TestConfig() + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: []Account{account}}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(stubConcurrencyCache{}), + openaiWSStateStore: store, + } + + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_force_http", account.ID, time.Hour)) + + selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_force_http", "gpt-5.1", nil) + require.NoError(t, err) + require.Nil(t, selection, "force_http 场景应忽略 previous_response_id 粘连") +} + +func TestOpenAIGatewayService_SelectAccountByPreviousResponseID_BusyKeepsSticky(t *testing.T) { + ctx := context.Background() + groupID := int64(23) + accounts := []Account{ + { + ID: 21, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 0, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + }, + { + ID: 22, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Priority: 9, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + }, + } + + cache := &stubGatewayCache{} + store := NewOpenAIWSStateStore(cache) + cfg := newOpenAIWSV2TestConfig() + cfg.Gateway.Scheduling.StickySessionMaxWaiting = 2 + cfg.Gateway.Scheduling.StickySessionWaitTimeout = 30 * time.Second + + concurrencyCache := stubConcurrencyCache{ + acquireResults: map[int64]bool{ + 21: false, // previous_response 命中的账号繁忙 + 22: true, // 次优账号可用(若回退会命中) + }, + waitCounts: map[int64]int{ + 21: 999, + }, + } + + svc := &OpenAIGatewayService{ + accountRepo: stubOpenAIAccountRepo{accounts: accounts}, + cache: cache, + cfg: cfg, + concurrencyService: NewConcurrencyService(concurrencyCache), + openaiWSStateStore: store, + } + + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_prev_busy", 21, time.Hour)) + + selection, err := svc.SelectAccountByPreviousResponseID(ctx, &groupID, "resp_prev_busy", "gpt-5.1", nil) + require.NoError(t, err) + require.NotNil(t, selection) + require.NotNil(t, selection.Account) + require.Equal(t, int64(21), selection.Account.ID, "busy previous_response sticky account should remain selected") + require.False(t, selection.Acquired) + require.NotNil(t, selection.WaitPlan) + require.Equal(t, int64(21), selection.WaitPlan.AccountID) +} + +func newOpenAIWSV2TestConfig() *config.Config { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 + return cfg +} diff --git a/backend/internal/service/openai_ws_client.go b/backend/internal/service/openai_ws_client.go new file mode 100644 index 00000000..9f3c47b7 --- /dev/null +++ b/backend/internal/service/openai_ws_client.go @@ -0,0 +1,285 @@ +package service + +import ( + "context" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + "sync" + "sync/atomic" + "time" + + coderws "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" +) + +const openAIWSMessageReadLimitBytes int64 = 16 * 1024 * 1024 +const ( + openAIWSProxyTransportMaxIdleConns = 128 + openAIWSProxyTransportMaxIdleConnsPerHost = 64 + openAIWSProxyTransportIdleConnTimeout = 90 * time.Second + openAIWSProxyClientCacheMaxEntries = 256 + openAIWSProxyClientCacheIdleTTL = 15 * time.Minute +) + +type OpenAIWSTransportMetricsSnapshot struct { + ProxyClientCacheHits int64 `json:"proxy_client_cache_hits"` + ProxyClientCacheMisses int64 `json:"proxy_client_cache_misses"` + TransportReuseRatio float64 `json:"transport_reuse_ratio"` +} + +// openAIWSClientConn 抽象 WS 客户端连接,便于替换底层实现。 +type openAIWSClientConn interface { + WriteJSON(ctx context.Context, value any) error + ReadMessage(ctx context.Context) ([]byte, error) + Ping(ctx context.Context) error + Close() error +} + +// openAIWSClientDialer 抽象 WS 建连器。 +type openAIWSClientDialer interface { + Dial(ctx context.Context, wsURL string, headers http.Header, proxyURL string) (openAIWSClientConn, int, http.Header, error) +} + +type openAIWSTransportMetricsDialer interface { + SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot +} + +func newDefaultOpenAIWSClientDialer() openAIWSClientDialer { + return &coderOpenAIWSClientDialer{ + proxyClients: make(map[string]*openAIWSProxyClientEntry), + } +} + +type coderOpenAIWSClientDialer struct { + proxyMu sync.Mutex + proxyClients map[string]*openAIWSProxyClientEntry + proxyHits atomic.Int64 + proxyMisses atomic.Int64 +} + +type openAIWSProxyClientEntry struct { + client *http.Client + lastUsedUnixNano int64 +} + +func (d *coderOpenAIWSClientDialer) Dial( + ctx context.Context, + wsURL string, + headers http.Header, + proxyURL string, +) (openAIWSClientConn, int, http.Header, error) { + targetURL := strings.TrimSpace(wsURL) + if targetURL == "" { + return nil, 0, nil, errors.New("ws url is empty") + } + + opts := &coderws.DialOptions{ + HTTPHeader: cloneHeader(headers), + CompressionMode: coderws.CompressionContextTakeover, + } + if proxy := strings.TrimSpace(proxyURL); proxy != "" { + proxyClient, err := d.proxyHTTPClient(proxy) + if err != nil { + return nil, 0, nil, err + } + opts.HTTPClient = proxyClient + } + + conn, resp, err := coderws.Dial(ctx, targetURL, opts) + if err != nil { + status := 0 + respHeaders := http.Header(nil) + if resp != nil { + status = resp.StatusCode + respHeaders = cloneHeader(resp.Header) + } + return nil, status, respHeaders, err + } + // coder/websocket 默认单消息读取上限为 32KB,Codex WS 事件(如 rate_limits/大 delta) + // 可能超过该阈值,需显式提高上限,避免本地 read_fail(message too big)。 + conn.SetReadLimit(openAIWSMessageReadLimitBytes) + respHeaders := http.Header(nil) + if resp != nil { + respHeaders = cloneHeader(resp.Header) + } + return &coderOpenAIWSClientConn{conn: conn}, 0, respHeaders, nil +} + +func (d *coderOpenAIWSClientDialer) proxyHTTPClient(proxy string) (*http.Client, error) { + if d == nil { + return nil, errors.New("openai ws dialer is nil") + } + normalizedProxy := strings.TrimSpace(proxy) + if normalizedProxy == "" { + return nil, errors.New("proxy url is empty") + } + parsedProxyURL, err := url.Parse(normalizedProxy) + if err != nil { + return nil, fmt.Errorf("invalid proxy url: %w", err) + } + now := time.Now().UnixNano() + + d.proxyMu.Lock() + defer d.proxyMu.Unlock() + if entry, ok := d.proxyClients[normalizedProxy]; ok && entry != nil && entry.client != nil { + entry.lastUsedUnixNano = now + d.proxyHits.Add(1) + return entry.client, nil + } + d.cleanupProxyClientsLocked(now) + transport := &http.Transport{ + Proxy: http.ProxyURL(parsedProxyURL), + MaxIdleConns: openAIWSProxyTransportMaxIdleConns, + MaxIdleConnsPerHost: openAIWSProxyTransportMaxIdleConnsPerHost, + IdleConnTimeout: openAIWSProxyTransportIdleConnTimeout, + TLSHandshakeTimeout: 10 * time.Second, + ForceAttemptHTTP2: true, + } + client := &http.Client{Transport: transport} + d.proxyClients[normalizedProxy] = &openAIWSProxyClientEntry{ + client: client, + lastUsedUnixNano: now, + } + d.ensureProxyClientCapacityLocked() + d.proxyMisses.Add(1) + return client, nil +} + +func (d *coderOpenAIWSClientDialer) cleanupProxyClientsLocked(nowUnixNano int64) { + if d == nil || len(d.proxyClients) == 0 { + return + } + idleTTL := openAIWSProxyClientCacheIdleTTL + if idleTTL <= 0 { + return + } + now := time.Unix(0, nowUnixNano) + for key, entry := range d.proxyClients { + if entry == nil || entry.client == nil { + delete(d.proxyClients, key) + continue + } + lastUsed := time.Unix(0, entry.lastUsedUnixNano) + if now.Sub(lastUsed) > idleTTL { + closeOpenAIWSProxyClient(entry.client) + delete(d.proxyClients, key) + } + } +} + +func (d *coderOpenAIWSClientDialer) ensureProxyClientCapacityLocked() { + if d == nil { + return + } + maxEntries := openAIWSProxyClientCacheMaxEntries + if maxEntries <= 0 { + return + } + for len(d.proxyClients) > maxEntries { + var oldestKey string + var oldestLastUsed int64 + hasOldest := false + for key, entry := range d.proxyClients { + lastUsed := int64(0) + if entry != nil { + lastUsed = entry.lastUsedUnixNano + } + if !hasOldest || lastUsed < oldestLastUsed { + hasOldest = true + oldestKey = key + oldestLastUsed = lastUsed + } + } + if !hasOldest { + return + } + if entry := d.proxyClients[oldestKey]; entry != nil { + closeOpenAIWSProxyClient(entry.client) + } + delete(d.proxyClients, oldestKey) + } +} + +func closeOpenAIWSProxyClient(client *http.Client) { + if client == nil || client.Transport == nil { + return + } + if transport, ok := client.Transport.(*http.Transport); ok && transport != nil { + transport.CloseIdleConnections() + } +} + +func (d *coderOpenAIWSClientDialer) SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot { + if d == nil { + return OpenAIWSTransportMetricsSnapshot{} + } + hits := d.proxyHits.Load() + misses := d.proxyMisses.Load() + total := hits + misses + reuseRatio := 0.0 + if total > 0 { + reuseRatio = float64(hits) / float64(total) + } + return OpenAIWSTransportMetricsSnapshot{ + ProxyClientCacheHits: hits, + ProxyClientCacheMisses: misses, + TransportReuseRatio: reuseRatio, + } +} + +type coderOpenAIWSClientConn struct { + conn *coderws.Conn +} + +func (c *coderOpenAIWSClientConn) WriteJSON(ctx context.Context, value any) error { + if c == nil || c.conn == nil { + return errOpenAIWSConnClosed + } + if ctx == nil { + ctx = context.Background() + } + return wsjson.Write(ctx, c.conn, value) +} + +func (c *coderOpenAIWSClientConn) ReadMessage(ctx context.Context) ([]byte, error) { + if c == nil || c.conn == nil { + return nil, errOpenAIWSConnClosed + } + if ctx == nil { + ctx = context.Background() + } + + msgType, payload, err := c.conn.Read(ctx) + if err != nil { + return nil, err + } + switch msgType { + case coderws.MessageText, coderws.MessageBinary: + return payload, nil + default: + return nil, errOpenAIWSConnClosed + } +} + +func (c *coderOpenAIWSClientConn) Ping(ctx context.Context) error { + if c == nil || c.conn == nil { + return errOpenAIWSConnClosed + } + if ctx == nil { + ctx = context.Background() + } + return c.conn.Ping(ctx) +} + +func (c *coderOpenAIWSClientConn) Close() error { + if c == nil || c.conn == nil { + return nil + } + // Close 为幂等,忽略重复关闭错误。 + _ = c.conn.Close(coderws.StatusNormalClosure, "") + _ = c.conn.CloseNow() + return nil +} diff --git a/backend/internal/service/openai_ws_client_test.go b/backend/internal/service/openai_ws_client_test.go new file mode 100644 index 00000000..a88d6266 --- /dev/null +++ b/backend/internal/service/openai_ws_client_test.go @@ -0,0 +1,112 @@ +package service + +import ( + "fmt" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestCoderOpenAIWSClientDialer_ProxyHTTPClientReuse(t *testing.T) { + dialer := newDefaultOpenAIWSClientDialer() + impl, ok := dialer.(*coderOpenAIWSClientDialer) + require.True(t, ok) + + c1, err := impl.proxyHTTPClient("http://127.0.0.1:8080") + require.NoError(t, err) + c2, err := impl.proxyHTTPClient("http://127.0.0.1:8080") + require.NoError(t, err) + require.Same(t, c1, c2, "同一代理地址应复用同一个 HTTP 客户端") + + c3, err := impl.proxyHTTPClient("http://127.0.0.1:8081") + require.NoError(t, err) + require.NotSame(t, c1, c3, "不同代理地址应分离客户端") +} + +func TestCoderOpenAIWSClientDialer_ProxyHTTPClientInvalidURL(t *testing.T) { + dialer := newDefaultOpenAIWSClientDialer() + impl, ok := dialer.(*coderOpenAIWSClientDialer) + require.True(t, ok) + + _, err := impl.proxyHTTPClient("://bad") + require.Error(t, err) +} + +func TestCoderOpenAIWSClientDialer_TransportMetricsSnapshot(t *testing.T) { + dialer := newDefaultOpenAIWSClientDialer() + impl, ok := dialer.(*coderOpenAIWSClientDialer) + require.True(t, ok) + + _, err := impl.proxyHTTPClient("http://127.0.0.1:18080") + require.NoError(t, err) + _, err = impl.proxyHTTPClient("http://127.0.0.1:18080") + require.NoError(t, err) + _, err = impl.proxyHTTPClient("http://127.0.0.1:18081") + require.NoError(t, err) + + snapshot := impl.SnapshotTransportMetrics() + require.Equal(t, int64(1), snapshot.ProxyClientCacheHits) + require.Equal(t, int64(2), snapshot.ProxyClientCacheMisses) + require.InDelta(t, 1.0/3.0, snapshot.TransportReuseRatio, 0.0001) +} + +func TestCoderOpenAIWSClientDialer_ProxyClientCacheCapacity(t *testing.T) { + dialer := newDefaultOpenAIWSClientDialer() + impl, ok := dialer.(*coderOpenAIWSClientDialer) + require.True(t, ok) + + total := openAIWSProxyClientCacheMaxEntries + 32 + for i := 0; i < total; i++ { + _, err := impl.proxyHTTPClient(fmt.Sprintf("http://127.0.0.1:%d", 20000+i)) + require.NoError(t, err) + } + + impl.proxyMu.Lock() + cacheSize := len(impl.proxyClients) + impl.proxyMu.Unlock() + + require.LessOrEqual(t, cacheSize, openAIWSProxyClientCacheMaxEntries, "代理客户端缓存应受容量上限约束") +} + +func TestCoderOpenAIWSClientDialer_ProxyClientCacheIdleTTL(t *testing.T) { + dialer := newDefaultOpenAIWSClientDialer() + impl, ok := dialer.(*coderOpenAIWSClientDialer) + require.True(t, ok) + + oldProxy := "http://127.0.0.1:28080" + _, err := impl.proxyHTTPClient(oldProxy) + require.NoError(t, err) + + impl.proxyMu.Lock() + oldEntry := impl.proxyClients[oldProxy] + require.NotNil(t, oldEntry) + oldEntry.lastUsedUnixNano = time.Now().Add(-openAIWSProxyClientCacheIdleTTL - time.Minute).UnixNano() + impl.proxyMu.Unlock() + + // 触发一次新的代理获取,驱动 TTL 清理。 + _, err = impl.proxyHTTPClient("http://127.0.0.1:28081") + require.NoError(t, err) + + impl.proxyMu.Lock() + _, exists := impl.proxyClients[oldProxy] + impl.proxyMu.Unlock() + + require.False(t, exists, "超过空闲 TTL 的代理客户端应被回收") +} + +func TestCoderOpenAIWSClientDialer_ProxyTransportTLSHandshakeTimeout(t *testing.T) { + dialer := newDefaultOpenAIWSClientDialer() + impl, ok := dialer.(*coderOpenAIWSClientDialer) + require.True(t, ok) + + client, err := impl.proxyHTTPClient("http://127.0.0.1:38080") + require.NoError(t, err) + require.NotNil(t, client) + + transport, ok := client.Transport.(*http.Transport) + require.True(t, ok) + require.NotNil(t, transport) + require.Equal(t, 10*time.Second, transport.TLSHandshakeTimeout) +} diff --git a/backend/internal/service/openai_ws_fallback_test.go b/backend/internal/service/openai_ws_fallback_test.go new file mode 100644 index 00000000..ce06f6a2 --- /dev/null +++ b/backend/internal/service/openai_ws_fallback_test.go @@ -0,0 +1,251 @@ +package service + +import ( + "context" + "errors" + "net/http" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + coderws "github.com/coder/websocket" + "github.com/stretchr/testify/require" +) + +func TestClassifyOpenAIWSAcquireError(t *testing.T) { + t.Run("dial_426_upgrade_required", func(t *testing.T) { + err := &openAIWSDialError{StatusCode: 426, Err: errors.New("upgrade required")} + require.Equal(t, "upgrade_required", classifyOpenAIWSAcquireError(err)) + }) + + t.Run("queue_full", func(t *testing.T) { + require.Equal(t, "conn_queue_full", classifyOpenAIWSAcquireError(errOpenAIWSConnQueueFull)) + }) + + t.Run("preferred_conn_unavailable", func(t *testing.T) { + require.Equal(t, "preferred_conn_unavailable", classifyOpenAIWSAcquireError(errOpenAIWSPreferredConnUnavailable)) + }) + + t.Run("acquire_timeout", func(t *testing.T) { + require.Equal(t, "acquire_timeout", classifyOpenAIWSAcquireError(context.DeadlineExceeded)) + }) + + t.Run("auth_failed_401", func(t *testing.T) { + err := &openAIWSDialError{StatusCode: 401, Err: errors.New("unauthorized")} + require.Equal(t, "auth_failed", classifyOpenAIWSAcquireError(err)) + }) + + t.Run("upstream_rate_limited", func(t *testing.T) { + err := &openAIWSDialError{StatusCode: 429, Err: errors.New("rate limited")} + require.Equal(t, "upstream_rate_limited", classifyOpenAIWSAcquireError(err)) + }) + + t.Run("upstream_5xx", func(t *testing.T) { + err := &openAIWSDialError{StatusCode: 502, Err: errors.New("bad gateway")} + require.Equal(t, "upstream_5xx", classifyOpenAIWSAcquireError(err)) + }) + + t.Run("dial_failed_other_status", func(t *testing.T) { + err := &openAIWSDialError{StatusCode: 418, Err: errors.New("teapot")} + require.Equal(t, "dial_failed", classifyOpenAIWSAcquireError(err)) + }) + + t.Run("other", func(t *testing.T) { + require.Equal(t, "acquire_conn", classifyOpenAIWSAcquireError(errors.New("x"))) + }) + + t.Run("nil", func(t *testing.T) { + require.Equal(t, "acquire_conn", classifyOpenAIWSAcquireError(nil)) + }) +} + +func TestClassifyOpenAIWSDialError(t *testing.T) { + t.Run("handshake_not_finished", func(t *testing.T) { + err := &openAIWSDialError{ + StatusCode: http.StatusBadGateway, + Err: errors.New("WebSocket protocol error: Handshake not finished"), + } + require.Equal(t, "handshake_not_finished", classifyOpenAIWSDialError(err)) + }) + + t.Run("context_deadline", func(t *testing.T) { + err := &openAIWSDialError{ + StatusCode: 0, + Err: context.DeadlineExceeded, + } + require.Equal(t, "ctx_deadline_exceeded", classifyOpenAIWSDialError(err)) + }) +} + +func TestSummarizeOpenAIWSDialError(t *testing.T) { + err := &openAIWSDialError{ + StatusCode: http.StatusBadGateway, + ResponseHeaders: http.Header{ + "Server": []string{"cloudflare"}, + "Via": []string{"1.1 example"}, + "Cf-Ray": []string{"abcd1234"}, + "X-Request-Id": []string{"req_123"}, + }, + Err: errors.New("WebSocket protocol error: Handshake not finished"), + } + + status, class, closeStatus, closeReason, server, via, cfRay, reqID := summarizeOpenAIWSDialError(err) + require.Equal(t, http.StatusBadGateway, status) + require.Equal(t, "handshake_not_finished", class) + require.Equal(t, "-", closeStatus) + require.Equal(t, "-", closeReason) + require.Equal(t, "cloudflare", server) + require.Equal(t, "1.1 example", via) + require.Equal(t, "abcd1234", cfRay) + require.Equal(t, "req_123", reqID) +} + +func TestClassifyOpenAIWSErrorEvent(t *testing.T) { + reason, recoverable := classifyOpenAIWSErrorEvent([]byte(`{"type":"error","error":{"code":"upgrade_required","message":"Upgrade required"}}`)) + require.Equal(t, "upgrade_required", reason) + require.True(t, recoverable) + + reason, recoverable = classifyOpenAIWSErrorEvent([]byte(`{"type":"error","error":{"code":"previous_response_not_found","message":"not found"}}`)) + require.Equal(t, "previous_response_not_found", reason) + require.True(t, recoverable) +} + +func TestClassifyOpenAIWSReconnectReason(t *testing.T) { + reason, retryable := classifyOpenAIWSReconnectReason(wrapOpenAIWSFallback("policy_violation", errors.New("policy"))) + require.Equal(t, "policy_violation", reason) + require.False(t, retryable) + + reason, retryable = classifyOpenAIWSReconnectReason(wrapOpenAIWSFallback("read_event", errors.New("io"))) + require.Equal(t, "read_event", reason) + require.True(t, retryable) +} + +func TestOpenAIWSErrorHTTPStatus(t *testing.T) { + require.Equal(t, http.StatusBadRequest, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`))) + require.Equal(t, http.StatusUnauthorized, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"authentication_error","code":"invalid_api_key","message":"auth failed"}}`))) + require.Equal(t, http.StatusForbidden, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"permission_error","code":"forbidden","message":"forbidden"}}`))) + require.Equal(t, http.StatusTooManyRequests, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"rate_limit_error","code":"rate_limit_exceeded","message":"rate limited"}}`))) + require.Equal(t, http.StatusBadGateway, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"server_error","code":"server_error","message":"server"}}`))) +} + +func TestResolveOpenAIWSFallbackErrorResponse(t *testing.T) { + t.Run("previous_response_not_found", func(t *testing.T) { + statusCode, errType, clientMessage, upstreamMessage, ok := resolveOpenAIWSFallbackErrorResponse( + wrapOpenAIWSFallback("previous_response_not_found", errors.New("previous response not found")), + ) + require.True(t, ok) + require.Equal(t, http.StatusBadRequest, statusCode) + require.Equal(t, "invalid_request_error", errType) + require.Equal(t, "previous response not found", clientMessage) + require.Equal(t, "previous response not found", upstreamMessage) + }) + + t.Run("auth_failed_uses_dial_status", func(t *testing.T) { + statusCode, errType, clientMessage, upstreamMessage, ok := resolveOpenAIWSFallbackErrorResponse( + wrapOpenAIWSFallback("auth_failed", &openAIWSDialError{ + StatusCode: http.StatusForbidden, + Err: errors.New("forbidden"), + }), + ) + require.True(t, ok) + require.Equal(t, http.StatusForbidden, statusCode) + require.Equal(t, "upstream_error", errType) + require.Equal(t, "forbidden", clientMessage) + require.Equal(t, "forbidden", upstreamMessage) + }) + + t.Run("non_fallback_error_not_resolved", func(t *testing.T) { + _, _, _, _, ok := resolveOpenAIWSFallbackErrorResponse(errors.New("plain error")) + require.False(t, ok) + }) +} + +func TestOpenAIWSFallbackCooling(t *testing.T) { + svc := &OpenAIGatewayService{cfg: &config.Config{}} + svc.cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + require.False(t, svc.isOpenAIWSFallbackCooling(1)) + svc.markOpenAIWSFallbackCooling(1, "upgrade_required") + require.True(t, svc.isOpenAIWSFallbackCooling(1)) + + svc.clearOpenAIWSFallbackCooling(1) + require.False(t, svc.isOpenAIWSFallbackCooling(1)) + + svc.markOpenAIWSFallbackCooling(2, "x") + time.Sleep(1200 * time.Millisecond) + require.False(t, svc.isOpenAIWSFallbackCooling(2)) +} + +func TestOpenAIWSRetryBackoff(t *testing.T) { + svc := &OpenAIGatewayService{cfg: &config.Config{}} + svc.cfg.Gateway.OpenAIWS.RetryBackoffInitialMS = 100 + svc.cfg.Gateway.OpenAIWS.RetryBackoffMaxMS = 400 + svc.cfg.Gateway.OpenAIWS.RetryJitterRatio = 0 + + require.Equal(t, time.Duration(100)*time.Millisecond, svc.openAIWSRetryBackoff(1)) + require.Equal(t, time.Duration(200)*time.Millisecond, svc.openAIWSRetryBackoff(2)) + require.Equal(t, time.Duration(400)*time.Millisecond, svc.openAIWSRetryBackoff(3)) + require.Equal(t, time.Duration(400)*time.Millisecond, svc.openAIWSRetryBackoff(4)) +} + +func TestOpenAIWSRetryTotalBudget(t *testing.T) { + svc := &OpenAIGatewayService{cfg: &config.Config{}} + svc.cfg.Gateway.OpenAIWS.RetryTotalBudgetMS = 1200 + require.Equal(t, 1200*time.Millisecond, svc.openAIWSRetryTotalBudget()) + + svc.cfg.Gateway.OpenAIWS.RetryTotalBudgetMS = 0 + require.Equal(t, time.Duration(0), svc.openAIWSRetryTotalBudget()) +} + +func TestClassifyOpenAIWSReadFallbackReason(t *testing.T) { + require.Equal(t, "policy_violation", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusPolicyViolation})) + require.Equal(t, "message_too_big", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusMessageTooBig})) + require.Equal(t, "read_event", classifyOpenAIWSReadFallbackReason(errors.New("io"))) +} + +func TestOpenAIWSStoreDisabledConnMode(t *testing.T) { + svc := &OpenAIGatewayService{cfg: &config.Config{}} + svc.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = true + require.Equal(t, openAIWSStoreDisabledConnModeStrict, svc.openAIWSStoreDisabledConnMode()) + + svc.cfg.Gateway.OpenAIWS.StoreDisabledConnMode = "adaptive" + require.Equal(t, openAIWSStoreDisabledConnModeAdaptive, svc.openAIWSStoreDisabledConnMode()) + + svc.cfg.Gateway.OpenAIWS.StoreDisabledConnMode = "" + svc.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = false + require.Equal(t, openAIWSStoreDisabledConnModeOff, svc.openAIWSStoreDisabledConnMode()) +} + +func TestShouldForceNewConnOnStoreDisabled(t *testing.T) { + require.True(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeStrict, "")) + require.False(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeOff, "policy_violation")) + + require.True(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeAdaptive, "policy_violation")) + require.True(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeAdaptive, "prewarm_message_too_big")) + require.False(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeAdaptive, "read_event")) +} + +func TestOpenAIWSRetryMetricsSnapshot(t *testing.T) { + svc := &OpenAIGatewayService{} + svc.recordOpenAIWSRetryAttempt(150 * time.Millisecond) + svc.recordOpenAIWSRetryAttempt(0) + svc.recordOpenAIWSRetryExhausted() + svc.recordOpenAIWSNonRetryableFastFallback() + + snapshot := svc.SnapshotOpenAIWSRetryMetrics() + require.Equal(t, int64(2), snapshot.RetryAttemptsTotal) + require.Equal(t, int64(150), snapshot.RetryBackoffMsTotal) + require.Equal(t, int64(1), snapshot.RetryExhaustedTotal) + require.Equal(t, int64(1), snapshot.NonRetryableFastFallbackTotal) +} + +func TestShouldLogOpenAIWSPayloadSchema(t *testing.T) { + svc := &OpenAIGatewayService{cfg: &config.Config{}} + + svc.cfg.Gateway.OpenAIWS.PayloadLogSampleRate = 0 + require.True(t, svc.shouldLogOpenAIWSPayloadSchema(1), "首次尝试应始终记录 payload_schema") + require.False(t, svc.shouldLogOpenAIWSPayloadSchema(2)) + + svc.cfg.Gateway.OpenAIWS.PayloadLogSampleRate = 1 + require.True(t, svc.shouldLogOpenAIWSPayloadSchema(2)) +} diff --git a/backend/internal/service/openai_ws_forwarder.go b/backend/internal/service/openai_ws_forwarder.go new file mode 100644 index 00000000..74ba472f --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder.go @@ -0,0 +1,3955 @@ +package service + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "math/rand" + "net" + "net/http" + "net/url" + "sort" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/pkg/openai" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + coderws "github.com/coder/websocket" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "go.uber.org/zap" +) + +const ( + openAIWSBetaV1Value = "responses_websockets=2026-02-04" + openAIWSBetaV2Value = "responses_websockets=2026-02-06" + + openAIWSTurnStateHeader = "x-codex-turn-state" + openAIWSTurnMetadataHeader = "x-codex-turn-metadata" + + openAIWSLogValueMaxLen = 160 + openAIWSHeaderValueMaxLen = 120 + openAIWSIDValueMaxLen = 64 + openAIWSEventLogHeadLimit = 20 + openAIWSEventLogEveryN = 50 + openAIWSBufferLogHeadLimit = 8 + openAIWSBufferLogEveryN = 20 + openAIWSPrewarmEventLogHead = 10 + openAIWSPayloadKeySizeTopN = 6 + + openAIWSPayloadSizeEstimateDepth = 3 + openAIWSPayloadSizeEstimateMaxBytes = 64 * 1024 + openAIWSPayloadSizeEstimateMaxItems = 16 + + openAIWSEventFlushBatchSizeDefault = 4 + openAIWSEventFlushIntervalDefault = 25 * time.Millisecond + openAIWSPayloadLogSampleDefault = 0.2 + + openAIWSStoreDisabledConnModeStrict = "strict" + openAIWSStoreDisabledConnModeAdaptive = "adaptive" + openAIWSStoreDisabledConnModeOff = "off" + + openAIWSIngressStagePreviousResponseNotFound = "previous_response_not_found" + openAIWSMaxPrevResponseIDDeletePasses = 8 +) + +var openAIWSLogValueReplacer = strings.NewReplacer( + "error", "err", + "fallback", "fb", + "warning", "warnx", + "failed", "fail", +) + +var openAIWSIngressPreflightPingIdle = 20 * time.Second + +// openAIWSFallbackError 表示可安全回退到 HTTP 的 WS 错误(尚未写下游)。 +type openAIWSFallbackError struct { + Reason string + Err error +} + +func (e *openAIWSFallbackError) Error() string { + if e == nil { + return "" + } + if e.Err == nil { + return fmt.Sprintf("openai ws fallback: %s", strings.TrimSpace(e.Reason)) + } + return fmt.Sprintf("openai ws fallback: %s: %v", strings.TrimSpace(e.Reason), e.Err) +} + +func (e *openAIWSFallbackError) Unwrap() error { + if e == nil { + return nil + } + return e.Err +} + +func wrapOpenAIWSFallback(reason string, err error) error { + return &openAIWSFallbackError{Reason: strings.TrimSpace(reason), Err: err} +} + +// OpenAIWSClientCloseError 表示应以指定 WebSocket close code 主动关闭客户端连接的错误。 +type OpenAIWSClientCloseError struct { + statusCode coderws.StatusCode + reason string + err error +} + +type openAIWSIngressTurnError struct { + stage string + cause error + wroteDownstream bool +} + +func (e *openAIWSIngressTurnError) Error() string { + if e == nil { + return "" + } + if e.cause == nil { + return strings.TrimSpace(e.stage) + } + return e.cause.Error() +} + +func (e *openAIWSIngressTurnError) Unwrap() error { + if e == nil { + return nil + } + return e.cause +} + +func wrapOpenAIWSIngressTurnError(stage string, cause error, wroteDownstream bool) error { + if cause == nil { + return nil + } + return &openAIWSIngressTurnError{ + stage: strings.TrimSpace(stage), + cause: cause, + wroteDownstream: wroteDownstream, + } +} + +func isOpenAIWSIngressTurnRetryable(err error) bool { + var turnErr *openAIWSIngressTurnError + if !errors.As(err, &turnErr) || turnErr == nil { + return false + } + if errors.Is(turnErr.cause, context.Canceled) || errors.Is(turnErr.cause, context.DeadlineExceeded) { + return false + } + if turnErr.wroteDownstream { + return false + } + switch turnErr.stage { + case "write_upstream", "read_upstream": + return true + default: + return false + } +} + +func openAIWSIngressTurnRetryReason(err error) string { + var turnErr *openAIWSIngressTurnError + if !errors.As(err, &turnErr) || turnErr == nil { + return "unknown" + } + if turnErr.stage == "" { + return "unknown" + } + return turnErr.stage +} + +func isOpenAIWSIngressPreviousResponseNotFound(err error) bool { + var turnErr *openAIWSIngressTurnError + if !errors.As(err, &turnErr) || turnErr == nil { + return false + } + if strings.TrimSpace(turnErr.stage) != openAIWSIngressStagePreviousResponseNotFound { + return false + } + return !turnErr.wroteDownstream +} + +// NewOpenAIWSClientCloseError 创建一个客户端 WS 关闭错误。 +func NewOpenAIWSClientCloseError(statusCode coderws.StatusCode, reason string, err error) error { + return &OpenAIWSClientCloseError{ + statusCode: statusCode, + reason: strings.TrimSpace(reason), + err: err, + } +} + +func (e *OpenAIWSClientCloseError) Error() string { + if e == nil { + return "" + } + if e.err == nil { + return fmt.Sprintf("openai ws client close: %d %s", int(e.statusCode), strings.TrimSpace(e.reason)) + } + return fmt.Sprintf("openai ws client close: %d %s: %v", int(e.statusCode), strings.TrimSpace(e.reason), e.err) +} + +func (e *OpenAIWSClientCloseError) Unwrap() error { + if e == nil { + return nil + } + return e.err +} + +func (e *OpenAIWSClientCloseError) StatusCode() coderws.StatusCode { + if e == nil { + return coderws.StatusInternalError + } + return e.statusCode +} + +func (e *OpenAIWSClientCloseError) Reason() string { + if e == nil { + return "" + } + return strings.TrimSpace(e.reason) +} + +// OpenAIWSIngressHooks 定义入站 WS 每个 turn 的生命周期回调。 +type OpenAIWSIngressHooks struct { + BeforeTurn func(turn int) error + AfterTurn func(turn int, result *OpenAIForwardResult, turnErr error) +} + +func normalizeOpenAIWSLogValue(value string) string { + trimmed := strings.TrimSpace(value) + if trimmed == "" { + return "-" + } + return openAIWSLogValueReplacer.Replace(trimmed) +} + +func truncateOpenAIWSLogValue(value string, maxLen int) string { + normalized := normalizeOpenAIWSLogValue(value) + if normalized == "-" || maxLen <= 0 { + return normalized + } + if len(normalized) <= maxLen { + return normalized + } + return normalized[:maxLen] + "..." +} + +func openAIWSHeaderValueForLog(headers http.Header, key string) string { + if headers == nil { + return "-" + } + return truncateOpenAIWSLogValue(headers.Get(key), openAIWSHeaderValueMaxLen) +} + +func hasOpenAIWSHeader(headers http.Header, key string) bool { + if headers == nil { + return false + } + return strings.TrimSpace(headers.Get(key)) != "" +} + +type openAIWSSessionHeaderResolution struct { + SessionID string + ConversationID string + SessionSource string + ConversationSource string +} + +func resolveOpenAIWSSessionHeaders(c *gin.Context, promptCacheKey string) openAIWSSessionHeaderResolution { + resolution := openAIWSSessionHeaderResolution{ + SessionSource: "none", + ConversationSource: "none", + } + if c != nil && c.Request != nil { + if sessionID := strings.TrimSpace(c.Request.Header.Get("session_id")); sessionID != "" { + resolution.SessionID = sessionID + resolution.SessionSource = "header_session_id" + } + if conversationID := strings.TrimSpace(c.Request.Header.Get("conversation_id")); conversationID != "" { + resolution.ConversationID = conversationID + resolution.ConversationSource = "header_conversation_id" + if resolution.SessionID == "" { + resolution.SessionID = conversationID + resolution.SessionSource = "header_conversation_id" + } + } + } + + cacheKey := strings.TrimSpace(promptCacheKey) + if cacheKey != "" { + if resolution.SessionID == "" { + resolution.SessionID = cacheKey + resolution.SessionSource = "prompt_cache_key" + } + } + return resolution +} + +func shouldLogOpenAIWSEvent(idx int, eventType string) bool { + if idx <= openAIWSEventLogHeadLimit { + return true + } + if openAIWSEventLogEveryN > 0 && idx%openAIWSEventLogEveryN == 0 { + return true + } + if eventType == "error" || isOpenAIWSTerminalEvent(eventType) { + return true + } + return false +} + +func shouldLogOpenAIWSBufferedEvent(idx int) bool { + if idx <= openAIWSBufferLogHeadLimit { + return true + } + if openAIWSBufferLogEveryN > 0 && idx%openAIWSBufferLogEveryN == 0 { + return true + } + return false +} + +func openAIWSEventMayContainModel(eventType string) bool { + switch eventType { + case "response.created", + "response.in_progress", + "response.completed", + "response.done", + "response.failed", + "response.incomplete", + "response.cancelled", + "response.canceled": + return true + default: + trimmed := strings.TrimSpace(eventType) + if trimmed == eventType { + return false + } + switch trimmed { + case "response.created", + "response.in_progress", + "response.completed", + "response.done", + "response.failed", + "response.incomplete", + "response.cancelled", + "response.canceled": + return true + default: + return false + } + } +} + +func openAIWSEventMayContainToolCalls(eventType string) bool { + eventType = strings.TrimSpace(eventType) + if eventType == "" { + return false + } + if strings.Contains(eventType, "function_call") || strings.Contains(eventType, "tool_call") { + return true + } + switch eventType { + case "response.output_item.added", "response.output_item.done", "response.completed", "response.done": + return true + default: + return false + } +} + +func openAIWSEventShouldParseUsage(eventType string) bool { + return eventType == "response.completed" || strings.TrimSpace(eventType) == "response.completed" +} + +func parseOpenAIWSEventEnvelope(message []byte) (eventType string, responseID string, response gjson.Result) { + if len(message) == 0 { + return "", "", gjson.Result{} + } + values := gjson.GetManyBytes(message, "type", "response.id", "id", "response") + eventType = strings.TrimSpace(values[0].String()) + if id := strings.TrimSpace(values[1].String()); id != "" { + responseID = id + } else { + responseID = strings.TrimSpace(values[2].String()) + } + return eventType, responseID, values[3] +} + +func openAIWSMessageLikelyContainsToolCalls(message []byte) bool { + if len(message) == 0 { + return false + } + return bytes.Contains(message, []byte(`"tool_calls"`)) || + bytes.Contains(message, []byte(`"tool_call"`)) || + bytes.Contains(message, []byte(`"function_call"`)) +} + +func parseOpenAIWSResponseUsageFromCompletedEvent(message []byte, usage *OpenAIUsage) { + if usage == nil || len(message) == 0 { + return + } + values := gjson.GetManyBytes( + message, + "response.usage.input_tokens", + "response.usage.output_tokens", + "response.usage.input_tokens_details.cached_tokens", + ) + usage.InputTokens = int(values[0].Int()) + usage.OutputTokens = int(values[1].Int()) + usage.CacheReadInputTokens = int(values[2].Int()) +} + +func parseOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) { + if len(message) == 0 { + return "", "", "" + } + values := gjson.GetManyBytes(message, "error.code", "error.type", "error.message") + return strings.TrimSpace(values[0].String()), strings.TrimSpace(values[1].String()), strings.TrimSpace(values[2].String()) +} + +func summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMessageRaw string) (code string, errType string, errMessage string) { + code = truncateOpenAIWSLogValue(codeRaw, openAIWSLogValueMaxLen) + errType = truncateOpenAIWSLogValue(errTypeRaw, openAIWSLogValueMaxLen) + errMessage = truncateOpenAIWSLogValue(errMessageRaw, openAIWSLogValueMaxLen) + return code, errType, errMessage +} + +func summarizeOpenAIWSErrorEventFields(message []byte) (code string, errType string, errMessage string) { + if len(message) == 0 { + return "-", "-", "-" + } + return summarizeOpenAIWSErrorEventFieldsFromRaw(parseOpenAIWSErrorEventFields(message)) +} + +func summarizeOpenAIWSPayloadKeySizes(payload map[string]any, topN int) string { + if len(payload) == 0 { + return "-" + } + type keySize struct { + Key string + Size int + } + sizes := make([]keySize, 0, len(payload)) + for key, value := range payload { + size := estimateOpenAIWSPayloadValueSize(value, openAIWSPayloadSizeEstimateDepth) + sizes = append(sizes, keySize{Key: key, Size: size}) + } + sort.Slice(sizes, func(i, j int) bool { + if sizes[i].Size == sizes[j].Size { + return sizes[i].Key < sizes[j].Key + } + return sizes[i].Size > sizes[j].Size + }) + + if topN <= 0 || topN > len(sizes) { + topN = len(sizes) + } + parts := make([]string, 0, topN) + for idx := 0; idx < topN; idx++ { + item := sizes[idx] + parts = append(parts, fmt.Sprintf("%s:%d", item.Key, item.Size)) + } + return strings.Join(parts, ",") +} + +func estimateOpenAIWSPayloadValueSize(value any, depth int) int { + if depth <= 0 { + return -1 + } + switch v := value.(type) { + case nil: + return 0 + case string: + return len(v) + case []byte: + return len(v) + case bool: + return 1 + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: + return 8 + case float32, float64: + return 8 + case map[string]any: + if len(v) == 0 { + return 2 + } + total := 2 + count := 0 + for key, item := range v { + count++ + if count > openAIWSPayloadSizeEstimateMaxItems { + return -1 + } + itemSize := estimateOpenAIWSPayloadValueSize(item, depth-1) + if itemSize < 0 { + return -1 + } + total += len(key) + itemSize + 3 + if total > openAIWSPayloadSizeEstimateMaxBytes { + return -1 + } + } + return total + case []any: + if len(v) == 0 { + return 2 + } + total := 2 + limit := len(v) + if limit > openAIWSPayloadSizeEstimateMaxItems { + return -1 + } + for i := 0; i < limit; i++ { + itemSize := estimateOpenAIWSPayloadValueSize(v[i], depth-1) + if itemSize < 0 { + return -1 + } + total += itemSize + 1 + if total > openAIWSPayloadSizeEstimateMaxBytes { + return -1 + } + } + return total + default: + raw, err := json.Marshal(v) + if err != nil { + return -1 + } + if len(raw) > openAIWSPayloadSizeEstimateMaxBytes { + return -1 + } + return len(raw) + } +} + +func openAIWSPayloadString(payload map[string]any, key string) string { + if len(payload) == 0 { + return "" + } + raw, ok := payload[key] + if !ok { + return "" + } + switch v := raw.(type) { + case nil: + return "" + case string: + return strings.TrimSpace(v) + case []byte: + return strings.TrimSpace(string(v)) + default: + return "" + } +} + +func openAIWSPayloadStringFromRaw(payload []byte, key string) string { + if len(payload) == 0 || strings.TrimSpace(key) == "" { + return "" + } + return strings.TrimSpace(gjson.GetBytes(payload, key).String()) +} + +func openAIWSPayloadBoolFromRaw(payload []byte, key string, defaultValue bool) bool { + if len(payload) == 0 || strings.TrimSpace(key) == "" { + return defaultValue + } + value := gjson.GetBytes(payload, key) + if !value.Exists() { + return defaultValue + } + if value.Type != gjson.True && value.Type != gjson.False { + return defaultValue + } + return value.Bool() +} + +func openAIWSSessionHashesFromID(sessionID string) (string, string) { + return deriveOpenAISessionHashes(sessionID) +} + +func extractOpenAIWSImageURL(value any) string { + switch v := value.(type) { + case string: + return strings.TrimSpace(v) + case map[string]any: + if raw, ok := v["url"].(string); ok { + return strings.TrimSpace(raw) + } + } + return "" +} + +func summarizeOpenAIWSInput(input any) string { + items, ok := input.([]any) + if !ok || len(items) == 0 { + return "-" + } + + itemCount := len(items) + textChars := 0 + imageDataURLs := 0 + imageDataURLChars := 0 + imageRemoteURLs := 0 + + handleContentItem := func(contentItem map[string]any) { + contentType, _ := contentItem["type"].(string) + switch strings.TrimSpace(contentType) { + case "input_text", "output_text", "text": + if text, ok := contentItem["text"].(string); ok { + textChars += len(text) + } + case "input_image": + imageURL := extractOpenAIWSImageURL(contentItem["image_url"]) + if imageURL == "" { + return + } + if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") { + imageDataURLs++ + imageDataURLChars += len(imageURL) + return + } + imageRemoteURLs++ + } + } + + handleInputItem := func(inputItem map[string]any) { + if content, ok := inputItem["content"].([]any); ok { + for _, rawContent := range content { + contentItem, ok := rawContent.(map[string]any) + if !ok { + continue + } + handleContentItem(contentItem) + } + return + } + + itemType, _ := inputItem["type"].(string) + switch strings.TrimSpace(itemType) { + case "input_text", "output_text", "text": + if text, ok := inputItem["text"].(string); ok { + textChars += len(text) + } + case "input_image": + imageURL := extractOpenAIWSImageURL(inputItem["image_url"]) + if imageURL == "" { + return + } + if strings.HasPrefix(strings.ToLower(imageURL), "data:image/") { + imageDataURLs++ + imageDataURLChars += len(imageURL) + return + } + imageRemoteURLs++ + } + } + + for _, rawItem := range items { + inputItem, ok := rawItem.(map[string]any) + if !ok { + continue + } + handleInputItem(inputItem) + } + + return fmt.Sprintf( + "items=%d,text_chars=%d,image_data_urls=%d,image_data_url_chars=%d,image_remote_urls=%d", + itemCount, + textChars, + imageDataURLs, + imageDataURLChars, + imageRemoteURLs, + ) +} + +func dropOpenAIWSPayloadKey(payload map[string]any, key string, removed *[]string) { + if len(payload) == 0 || strings.TrimSpace(key) == "" { + return + } + if _, exists := payload[key]; !exists { + return + } + delete(payload, key) + *removed = append(*removed, key) +} + +// applyOpenAIWSRetryPayloadStrategy 在 WS 连续失败时仅移除无语义字段, +// 避免重试成功却改变原始请求语义。 +// 注意:prompt_cache_key 不应在重试中移除;它常用于会话稳定标识(session_id 兜底)。 +func applyOpenAIWSRetryPayloadStrategy(payload map[string]any, attempt int) (strategy string, removedKeys []string) { + if len(payload) == 0 { + return "empty", nil + } + if attempt <= 1 { + return "full", nil + } + + removed := make([]string, 0, 2) + if attempt >= 2 { + dropOpenAIWSPayloadKey(payload, "include", &removed) + } + + if len(removed) == 0 { + return "full", nil + } + sort.Strings(removed) + return "trim_optional_fields", removed +} + +func logOpenAIWSModeInfo(format string, args ...any) { + logger.LegacyPrintf("service.openai_gateway", "[OpenAI WS Mode][openai_ws_mode=true] "+format, args...) +} + +func isOpenAIWSModeDebugEnabled() bool { + return logger.L().Core().Enabled(zap.DebugLevel) +} + +func logOpenAIWSModeDebug(format string, args ...any) { + if !isOpenAIWSModeDebugEnabled() { + return + } + logger.LegacyPrintf("service.openai_gateway", "[debug] [OpenAI WS Mode][openai_ws_mode=true] "+format, args...) +} + +func logOpenAIWSBindResponseAccountWarn(groupID, accountID int64, responseID string, err error) { + if err == nil { + return + } + logger.L().Warn( + "openai.ws_bind_response_account_failed", + zap.Int64("group_id", groupID), + zap.Int64("account_id", accountID), + zap.String("response_id", truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen)), + zap.Error(err), + ) +} + +func summarizeOpenAIWSReadCloseError(err error) (status string, reason string) { + if err == nil { + return "-", "-" + } + statusCode := coderws.CloseStatus(err) + if statusCode == -1 { + return "-", "-" + } + closeStatus := fmt.Sprintf("%d(%s)", int(statusCode), statusCode.String()) + closeReason := "-" + var closeErr coderws.CloseError + if errors.As(err, &closeErr) { + reasonText := strings.TrimSpace(closeErr.Reason) + if reasonText != "" { + closeReason = normalizeOpenAIWSLogValue(reasonText) + } + } + return normalizeOpenAIWSLogValue(closeStatus), closeReason +} + +func unwrapOpenAIWSDialBaseError(err error) error { + if err == nil { + return nil + } + var dialErr *openAIWSDialError + if errors.As(err, &dialErr) && dialErr != nil && dialErr.Err != nil { + return dialErr.Err + } + return err +} + +func openAIWSDialRespHeaderForLog(err error, key string) string { + var dialErr *openAIWSDialError + if !errors.As(err, &dialErr) || dialErr == nil || dialErr.ResponseHeaders == nil { + return "-" + } + return truncateOpenAIWSLogValue(dialErr.ResponseHeaders.Get(key), openAIWSHeaderValueMaxLen) +} + +func classifyOpenAIWSDialError(err error) string { + if err == nil { + return "-" + } + baseErr := unwrapOpenAIWSDialBaseError(err) + if baseErr == nil { + return "-" + } + if errors.Is(baseErr, context.DeadlineExceeded) { + return "ctx_deadline_exceeded" + } + if errors.Is(baseErr, context.Canceled) { + return "ctx_canceled" + } + var netErr net.Error + if errors.As(baseErr, &netErr) && netErr.Timeout() { + return "net_timeout" + } + if status := coderws.CloseStatus(baseErr); status != -1 { + return normalizeOpenAIWSLogValue(fmt.Sprintf("ws_close_%d", int(status))) + } + message := strings.ToLower(strings.TrimSpace(baseErr.Error())) + switch { + case strings.Contains(message, "handshake not finished"): + return "handshake_not_finished" + case strings.Contains(message, "bad handshake"): + return "bad_handshake" + case strings.Contains(message, "connection refused"): + return "connection_refused" + case strings.Contains(message, "no such host"): + return "dns_not_found" + case strings.Contains(message, "tls"): + return "tls_error" + case strings.Contains(message, "i/o timeout"): + return "io_timeout" + case strings.Contains(message, "context deadline exceeded"): + return "ctx_deadline_exceeded" + default: + return "dial_error" + } +} + +func summarizeOpenAIWSDialError(err error) ( + statusCode int, + dialClass string, + closeStatus string, + closeReason string, + respServer string, + respVia string, + respCFRay string, + respRequestID string, +) { + dialClass = "-" + closeStatus = "-" + closeReason = "-" + respServer = "-" + respVia = "-" + respCFRay = "-" + respRequestID = "-" + if err == nil { + return + } + var dialErr *openAIWSDialError + if errors.As(err, &dialErr) && dialErr != nil { + statusCode = dialErr.StatusCode + respServer = openAIWSDialRespHeaderForLog(err, "server") + respVia = openAIWSDialRespHeaderForLog(err, "via") + respCFRay = openAIWSDialRespHeaderForLog(err, "cf-ray") + respRequestID = openAIWSDialRespHeaderForLog(err, "x-request-id") + } + dialClass = normalizeOpenAIWSLogValue(classifyOpenAIWSDialError(err)) + closeStatus, closeReason = summarizeOpenAIWSReadCloseError(unwrapOpenAIWSDialBaseError(err)) + return +} + +func isOpenAIWSClientDisconnectError(err error) bool { + if err == nil { + return false + } + if errors.Is(err, io.EOF) || errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled) { + return true + } + switch coderws.CloseStatus(err) { + case coderws.StatusNormalClosure, coderws.StatusGoingAway, coderws.StatusNoStatusRcvd, coderws.StatusAbnormalClosure: + return true + } + message := strings.ToLower(strings.TrimSpace(err.Error())) + if message == "" { + return false + } + return strings.Contains(message, "failed to read frame header: eof") || + strings.Contains(message, "unexpected eof") || + strings.Contains(message, "use of closed network connection") || + strings.Contains(message, "connection reset by peer") || + strings.Contains(message, "broken pipe") +} + +func classifyOpenAIWSReadFallbackReason(err error) string { + if err == nil { + return "read_event" + } + switch coderws.CloseStatus(err) { + case coderws.StatusPolicyViolation: + return "policy_violation" + case coderws.StatusMessageTooBig: + return "message_too_big" + default: + return "read_event" + } +} + +func sortedKeys(m map[string]any) []string { + if len(m) == 0 { + return nil + } + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + sort.Strings(keys) + return keys +} + +func (s *OpenAIGatewayService) getOpenAIWSConnPool() *openAIWSConnPool { + if s == nil { + return nil + } + s.openaiWSPoolOnce.Do(func() { + if s.openaiWSPool == nil { + s.openaiWSPool = newOpenAIWSConnPool(s.cfg) + } + }) + return s.openaiWSPool +} + +func (s *OpenAIGatewayService) SnapshotOpenAIWSPoolMetrics() OpenAIWSPoolMetricsSnapshot { + pool := s.getOpenAIWSConnPool() + if pool == nil { + return OpenAIWSPoolMetricsSnapshot{} + } + return pool.SnapshotMetrics() +} + +type OpenAIWSPerformanceMetricsSnapshot struct { + Pool OpenAIWSPoolMetricsSnapshot `json:"pool"` + Retry OpenAIWSRetryMetricsSnapshot `json:"retry"` + Transport OpenAIWSTransportMetricsSnapshot `json:"transport"` +} + +func (s *OpenAIGatewayService) SnapshotOpenAIWSPerformanceMetrics() OpenAIWSPerformanceMetricsSnapshot { + pool := s.getOpenAIWSConnPool() + snapshot := OpenAIWSPerformanceMetricsSnapshot{ + Retry: s.SnapshotOpenAIWSRetryMetrics(), + } + if pool == nil { + return snapshot + } + snapshot.Pool = pool.SnapshotMetrics() + snapshot.Transport = pool.SnapshotTransportMetrics() + return snapshot +} + +func (s *OpenAIGatewayService) getOpenAIWSStateStore() OpenAIWSStateStore { + if s == nil { + return nil + } + s.openaiWSStateStoreOnce.Do(func() { + if s.openaiWSStateStore == nil { + s.openaiWSStateStore = NewOpenAIWSStateStore(s.cache) + } + }) + return s.openaiWSStateStore +} + +func (s *OpenAIGatewayService) openAIWSResponseStickyTTL() time.Duration { + if s != nil && s.cfg != nil { + seconds := s.cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds + if seconds > 0 { + return time.Duration(seconds) * time.Second + } + } + return time.Hour +} + +func (s *OpenAIGatewayService) openAIWSIngressPreviousResponseRecoveryEnabled() bool { + if s != nil && s.cfg != nil { + return s.cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled + } + return true +} + +func (s *OpenAIGatewayService) openAIWSReadTimeout() time.Duration { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ReadTimeoutSeconds > 0 { + return time.Duration(s.cfg.Gateway.OpenAIWS.ReadTimeoutSeconds) * time.Second + } + return 15 * time.Minute +} + +func (s *OpenAIGatewayService) openAIWSWriteTimeout() time.Duration { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds > 0 { + return time.Duration(s.cfg.Gateway.OpenAIWS.WriteTimeoutSeconds) * time.Second + } + return 2 * time.Minute +} + +func (s *OpenAIGatewayService) openAIWSEventFlushBatchSize() int { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.EventFlushBatchSize > 0 { + return s.cfg.Gateway.OpenAIWS.EventFlushBatchSize + } + return openAIWSEventFlushBatchSizeDefault +} + +func (s *OpenAIGatewayService) openAIWSEventFlushInterval() time.Duration { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.EventFlushIntervalMS >= 0 { + if s.cfg.Gateway.OpenAIWS.EventFlushIntervalMS == 0 { + return 0 + } + return time.Duration(s.cfg.Gateway.OpenAIWS.EventFlushIntervalMS) * time.Millisecond + } + return openAIWSEventFlushIntervalDefault +} + +func (s *OpenAIGatewayService) openAIWSPayloadLogSampleRate() float64 { + if s != nil && s.cfg != nil { + rate := s.cfg.Gateway.OpenAIWS.PayloadLogSampleRate + if rate < 0 { + return 0 + } + if rate > 1 { + return 1 + } + return rate + } + return openAIWSPayloadLogSampleDefault +} + +func (s *OpenAIGatewayService) shouldLogOpenAIWSPayloadSchema(attempt int) bool { + // 首次尝试保留一条完整 payload_schema 便于排障。 + if attempt <= 1 { + return true + } + rate := s.openAIWSPayloadLogSampleRate() + if rate <= 0 { + return false + } + if rate >= 1 { + return true + } + return rand.Float64() < rate +} + +func (s *OpenAIGatewayService) shouldEmitOpenAIWSPayloadSchema(attempt int) bool { + if !s.shouldLogOpenAIWSPayloadSchema(attempt) { + return false + } + return logger.L().Core().Enabled(zap.DebugLevel) +} + +func (s *OpenAIGatewayService) openAIWSDialTimeout() time.Duration { + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.DialTimeoutSeconds > 0 { + return time.Duration(s.cfg.Gateway.OpenAIWS.DialTimeoutSeconds) * time.Second + } + return 10 * time.Second +} + +func (s *OpenAIGatewayService) openAIWSAcquireTimeout() time.Duration { + // Acquire 覆盖“连接复用命中/排队/新建连接”三个阶段。 + // 这里不再叠加 write_timeout,避免高并发排队时把 TTFT 长尾拉到分钟级。 + dial := s.openAIWSDialTimeout() + if dial <= 0 { + dial = 10 * time.Second + } + return dial + 2*time.Second +} + +func (s *OpenAIGatewayService) buildOpenAIResponsesWSURL(account *Account) (string, error) { + if account == nil { + return "", errors.New("account is nil") + } + var targetURL string + switch account.Type { + case AccountTypeOAuth: + targetURL = chatgptCodexURL + case AccountTypeAPIKey: + baseURL := account.GetOpenAIBaseURL() + if baseURL == "" { + targetURL = openaiPlatformAPIURL + } else { + validatedURL, err := s.validateUpstreamBaseURL(baseURL) + if err != nil { + return "", err + } + targetURL = buildOpenAIResponsesURL(validatedURL) + } + default: + targetURL = openaiPlatformAPIURL + } + + parsed, err := url.Parse(strings.TrimSpace(targetURL)) + if err != nil { + return "", fmt.Errorf("invalid target url: %w", err) + } + switch strings.ToLower(parsed.Scheme) { + case "https": + parsed.Scheme = "wss" + case "http": + parsed.Scheme = "ws" + case "wss", "ws": + // 保持不变 + default: + return "", fmt.Errorf("unsupported scheme for ws: %s", parsed.Scheme) + } + return parsed.String(), nil +} + +func (s *OpenAIGatewayService) buildOpenAIWSHeaders( + c *gin.Context, + account *Account, + token string, + decision OpenAIWSProtocolDecision, + isCodexCLI bool, + turnState string, + turnMetadata string, + promptCacheKey string, +) (http.Header, openAIWSSessionHeaderResolution) { + headers := make(http.Header) + headers.Set("authorization", "Bearer "+token) + + sessionResolution := resolveOpenAIWSSessionHeaders(c, promptCacheKey) + if c != nil && c.Request != nil { + if v := strings.TrimSpace(c.Request.Header.Get("accept-language")); v != "" { + headers.Set("accept-language", v) + } + } + if sessionResolution.SessionID != "" { + headers.Set("session_id", sessionResolution.SessionID) + } + if sessionResolution.ConversationID != "" { + headers.Set("conversation_id", sessionResolution.ConversationID) + } + if state := strings.TrimSpace(turnState); state != "" { + headers.Set(openAIWSTurnStateHeader, state) + } + if metadata := strings.TrimSpace(turnMetadata); metadata != "" { + headers.Set(openAIWSTurnMetadataHeader, metadata) + } + + if account != nil && account.Type == AccountTypeOAuth { + if chatgptAccountID := account.GetChatGPTAccountID(); chatgptAccountID != "" { + headers.Set("chatgpt-account-id", chatgptAccountID) + } + if isCodexCLI { + headers.Set("originator", "codex_cli_rs") + } else { + headers.Set("originator", "opencode") + } + } + + betaValue := openAIWSBetaV2Value + if decision.Transport == OpenAIUpstreamTransportResponsesWebsocket { + betaValue = openAIWSBetaV1Value + } + headers.Set("OpenAI-Beta", betaValue) + + customUA := "" + if account != nil { + customUA = account.GetOpenAIUserAgent() + } + if strings.TrimSpace(customUA) != "" { + headers.Set("user-agent", customUA) + } else if c != nil { + if ua := strings.TrimSpace(c.GetHeader("User-Agent")); ua != "" { + headers.Set("user-agent", ua) + } + } + if s != nil && s.cfg != nil && s.cfg.Gateway.ForceCodexCLI { + headers.Set("user-agent", codexCLIUserAgent) + } + if account != nil && account.Type == AccountTypeOAuth && !openai.IsCodexCLIRequest(headers.Get("user-agent")) { + headers.Set("user-agent", codexCLIUserAgent) + } + + return headers, sessionResolution +} + +func (s *OpenAIGatewayService) buildOpenAIWSCreatePayload(reqBody map[string]any, account *Account) map[string]any { + // OpenAI WS Mode 协议:response.create 字段与 HTTP /responses 基本一致。 + // 保留 stream 字段(与 Codex CLI 一致),仅移除 background。 + payload := make(map[string]any, len(reqBody)+1) + for k, v := range reqBody { + payload[k] = v + } + + delete(payload, "background") + if _, exists := payload["stream"]; !exists { + payload["stream"] = true + } + payload["type"] = "response.create" + + // OAuth 默认保持 store=false,避免误依赖服务端历史。 + if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) { + payload["store"] = false + } + return payload +} + +func setOpenAIWSTurnMetadata(payload map[string]any, turnMetadata string) { + if len(payload) == 0 { + return + } + metadata := strings.TrimSpace(turnMetadata) + if metadata == "" { + return + } + + switch existing := payload["client_metadata"].(type) { + case map[string]any: + existing[openAIWSTurnMetadataHeader] = metadata + payload["client_metadata"] = existing + case map[string]string: + next := make(map[string]any, len(existing)+1) + for k, v := range existing { + next[k] = v + } + next[openAIWSTurnMetadataHeader] = metadata + payload["client_metadata"] = next + default: + payload["client_metadata"] = map[string]any{ + openAIWSTurnMetadataHeader: metadata, + } + } +} + +func (s *OpenAIGatewayService) isOpenAIWSStoreRecoveryAllowed(account *Account) bool { + if account != nil && account.IsOpenAIWSAllowStoreRecoveryEnabled() { + return true + } + if s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.AllowStoreRecovery { + return true + } + return false +} + +func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequest(reqBody map[string]any, account *Account) bool { + if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) { + return true + } + if len(reqBody) == 0 { + return false + } + rawStore, ok := reqBody["store"] + if !ok { + return false + } + storeEnabled, ok := rawStore.(bool) + if !ok { + return false + } + return !storeEnabled +} + +func (s *OpenAIGatewayService) isOpenAIWSStoreDisabledInRequestRaw(reqBody []byte, account *Account) bool { + if account != nil && account.Type == AccountTypeOAuth && !s.isOpenAIWSStoreRecoveryAllowed(account) { + return true + } + if len(reqBody) == 0 { + return false + } + storeValue := gjson.GetBytes(reqBody, "store") + if !storeValue.Exists() { + return false + } + if storeValue.Type != gjson.True && storeValue.Type != gjson.False { + return false + } + return !storeValue.Bool() +} + +func (s *OpenAIGatewayService) openAIWSStoreDisabledConnMode() string { + if s == nil || s.cfg == nil { + return openAIWSStoreDisabledConnModeStrict + } + mode := strings.ToLower(strings.TrimSpace(s.cfg.Gateway.OpenAIWS.StoreDisabledConnMode)) + switch mode { + case openAIWSStoreDisabledConnModeStrict, openAIWSStoreDisabledConnModeAdaptive, openAIWSStoreDisabledConnModeOff: + return mode + case "": + // 兼容旧配置:仅配置了布尔开关时按旧语义推导。 + if s.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn { + return openAIWSStoreDisabledConnModeStrict + } + return openAIWSStoreDisabledConnModeOff + default: + return openAIWSStoreDisabledConnModeStrict + } +} + +func shouldForceNewConnOnStoreDisabled(mode, lastFailureReason string) bool { + switch mode { + case openAIWSStoreDisabledConnModeOff: + return false + case openAIWSStoreDisabledConnModeAdaptive: + reason := strings.TrimPrefix(strings.TrimSpace(lastFailureReason), "prewarm_") + switch reason { + case "policy_violation", "message_too_big", "auth_failed", "write_request", "write": + return true + default: + return false + } + default: + return true + } +} + +func dropPreviousResponseIDFromRawPayload(payload []byte) ([]byte, bool, error) { + return dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, sjson.DeleteBytes) +} + +func dropPreviousResponseIDFromRawPayloadWithDeleteFn( + payload []byte, + deleteFn func([]byte, string) ([]byte, error), +) ([]byte, bool, error) { + if len(payload) == 0 { + return payload, false, nil + } + if !gjson.GetBytes(payload, "previous_response_id").Exists() { + return payload, false, nil + } + if deleteFn == nil { + deleteFn = sjson.DeleteBytes + } + + updated := payload + for i := 0; i < openAIWSMaxPrevResponseIDDeletePasses && + gjson.GetBytes(updated, "previous_response_id").Exists(); i++ { + next, err := deleteFn(updated, "previous_response_id") + if err != nil { + return payload, false, err + } + updated = next + } + return updated, !gjson.GetBytes(updated, "previous_response_id").Exists(), nil +} + +func setPreviousResponseIDToRawPayload(payload []byte, previousResponseID string) ([]byte, error) { + normalizedPrevID := strings.TrimSpace(previousResponseID) + if len(payload) == 0 || normalizedPrevID == "" { + return payload, nil + } + updated, err := sjson.SetBytes(payload, "previous_response_id", normalizedPrevID) + if err == nil { + return updated, nil + } + + var reqBody map[string]any + if unmarshalErr := json.Unmarshal(payload, &reqBody); unmarshalErr != nil { + return nil, err + } + reqBody["previous_response_id"] = normalizedPrevID + rebuilt, marshalErr := json.Marshal(reqBody) + if marshalErr != nil { + return nil, marshalErr + } + return rebuilt, nil +} + +func shouldInferIngressFunctionCallOutputPreviousResponseID( + storeDisabled bool, + turn int, + hasFunctionCallOutput bool, + currentPreviousResponseID string, + expectedPreviousResponseID string, +) bool { + if !storeDisabled || turn <= 1 || !hasFunctionCallOutput { + return false + } + if strings.TrimSpace(currentPreviousResponseID) != "" { + return false + } + return strings.TrimSpace(expectedPreviousResponseID) != "" +} + +func alignStoreDisabledPreviousResponseID( + payload []byte, + expectedPreviousResponseID string, +) ([]byte, bool, error) { + if len(payload) == 0 { + return payload, false, nil + } + expected := strings.TrimSpace(expectedPreviousResponseID) + if expected == "" { + return payload, false, nil + } + current := openAIWSPayloadStringFromRaw(payload, "previous_response_id") + if current == "" || current == expected { + return payload, false, nil + } + + withoutPrev, removed, dropErr := dropPreviousResponseIDFromRawPayload(payload) + if dropErr != nil { + return payload, false, dropErr + } + if !removed { + return payload, false, nil + } + updated, setErr := setPreviousResponseIDToRawPayload(withoutPrev, expected) + if setErr != nil { + return payload, false, setErr + } + return updated, true, nil +} + +func cloneOpenAIWSPayloadBytes(payload []byte) []byte { + if len(payload) == 0 { + return nil + } + cloned := make([]byte, len(payload)) + copy(cloned, payload) + return cloned +} + +func cloneOpenAIWSRawMessages(items []json.RawMessage) []json.RawMessage { + if items == nil { + return nil + } + cloned := make([]json.RawMessage, 0, len(items)) + for idx := range items { + cloned = append(cloned, json.RawMessage(cloneOpenAIWSPayloadBytes(items[idx]))) + } + return cloned +} + +func normalizeOpenAIWSJSONForCompare(raw []byte) ([]byte, error) { + trimmed := bytes.TrimSpace(raw) + if len(trimmed) == 0 { + return nil, errors.New("json is empty") + } + var decoded any + if err := json.Unmarshal(trimmed, &decoded); err != nil { + return nil, err + } + return json.Marshal(decoded) +} + +func normalizeOpenAIWSJSONForCompareOrRaw(raw []byte) []byte { + normalized, err := normalizeOpenAIWSJSONForCompare(raw) + if err != nil { + return bytes.TrimSpace(raw) + } + return normalized +} + +func normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload []byte) ([]byte, error) { + if len(payload) == 0 { + return nil, errors.New("payload is empty") + } + var decoded map[string]any + if err := json.Unmarshal(payload, &decoded); err != nil { + return nil, err + } + delete(decoded, "input") + delete(decoded, "previous_response_id") + return json.Marshal(decoded) +} + +func openAIWSExtractNormalizedInputSequence(payload []byte) ([]json.RawMessage, bool, error) { + if len(payload) == 0 { + return nil, false, nil + } + inputValue := gjson.GetBytes(payload, "input") + if !inputValue.Exists() { + return nil, false, nil + } + if inputValue.Type == gjson.JSON { + raw := strings.TrimSpace(inputValue.Raw) + if strings.HasPrefix(raw, "[") { + var items []json.RawMessage + if err := json.Unmarshal([]byte(raw), &items); err != nil { + return nil, true, err + } + return items, true, nil + } + return []json.RawMessage{json.RawMessage(raw)}, true, nil + } + if inputValue.Type == gjson.String { + encoded, _ := json.Marshal(inputValue.String()) + return []json.RawMessage{encoded}, true, nil + } + return []json.RawMessage{json.RawMessage(inputValue.Raw)}, true, nil +} + +func openAIWSInputIsPrefixExtended(previousPayload, currentPayload []byte) (bool, error) { + previousItems, previousExists, prevErr := openAIWSExtractNormalizedInputSequence(previousPayload) + if prevErr != nil { + return false, prevErr + } + currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload) + if currentErr != nil { + return false, currentErr + } + if !previousExists && !currentExists { + return true, nil + } + if !previousExists { + return len(currentItems) == 0, nil + } + if !currentExists { + return len(previousItems) == 0, nil + } + if len(currentItems) < len(previousItems) { + return false, nil + } + + for idx := range previousItems { + previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(previousItems[idx]) + currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(currentItems[idx]) + if !bytes.Equal(previousNormalized, currentNormalized) { + return false, nil + } + } + return true, nil +} + +func openAIWSRawItemsHasPrefix(items []json.RawMessage, prefix []json.RawMessage) bool { + if len(prefix) == 0 { + return true + } + if len(items) < len(prefix) { + return false + } + for idx := range prefix { + previousNormalized := normalizeOpenAIWSJSONForCompareOrRaw(prefix[idx]) + currentNormalized := normalizeOpenAIWSJSONForCompareOrRaw(items[idx]) + if !bytes.Equal(previousNormalized, currentNormalized) { + return false + } + } + return true +} + +func buildOpenAIWSReplayInputSequence( + previousFullInput []json.RawMessage, + previousFullInputExists bool, + currentPayload []byte, + hasPreviousResponseID bool, +) ([]json.RawMessage, bool, error) { + currentItems, currentExists, currentErr := openAIWSExtractNormalizedInputSequence(currentPayload) + if currentErr != nil { + return nil, false, currentErr + } + if !hasPreviousResponseID { + return cloneOpenAIWSRawMessages(currentItems), currentExists, nil + } + if !previousFullInputExists { + return cloneOpenAIWSRawMessages(currentItems), currentExists, nil + } + if !currentExists || len(currentItems) == 0 { + return cloneOpenAIWSRawMessages(previousFullInput), true, nil + } + if openAIWSRawItemsHasPrefix(currentItems, previousFullInput) { + return cloneOpenAIWSRawMessages(currentItems), true, nil + } + merged := make([]json.RawMessage, 0, len(previousFullInput)+len(currentItems)) + merged = append(merged, cloneOpenAIWSRawMessages(previousFullInput)...) + merged = append(merged, cloneOpenAIWSRawMessages(currentItems)...) + return merged, true, nil +} + +func setOpenAIWSPayloadInputSequence( + payload []byte, + fullInput []json.RawMessage, + fullInputExists bool, +) ([]byte, error) { + if !fullInputExists { + return payload, nil + } + // Preserve [] vs null semantics when input exists but is empty. + inputForMarshal := fullInput + if inputForMarshal == nil { + inputForMarshal = []json.RawMessage{} + } + inputRaw, marshalErr := json.Marshal(inputForMarshal) + if marshalErr != nil { + return nil, marshalErr + } + return sjson.SetRawBytes(payload, "input", inputRaw) +} + +func shouldKeepIngressPreviousResponseID( + previousPayload []byte, + currentPayload []byte, + lastTurnResponseID string, + hasFunctionCallOutput bool, +) (bool, string, error) { + if hasFunctionCallOutput { + return true, "has_function_call_output", nil + } + currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")) + if currentPreviousResponseID == "" { + return false, "missing_previous_response_id", nil + } + expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID) + if expectedPreviousResponseID == "" { + return false, "missing_last_turn_response_id", nil + } + if currentPreviousResponseID != expectedPreviousResponseID { + return false, "previous_response_id_mismatch", nil + } + if len(previousPayload) == 0 { + return false, "missing_previous_turn_payload", nil + } + + previousComparable, previousComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(previousPayload) + if previousComparableErr != nil { + return false, "non_input_compare_error", previousComparableErr + } + currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload) + if currentComparableErr != nil { + return false, "non_input_compare_error", currentComparableErr + } + if !bytes.Equal(previousComparable, currentComparable) { + return false, "non_input_changed", nil + } + return true, "strict_incremental_ok", nil +} + +type openAIWSIngressPreviousTurnStrictState struct { + nonInputComparable []byte +} + +func buildOpenAIWSIngressPreviousTurnStrictState(payload []byte) (*openAIWSIngressPreviousTurnStrictState, error) { + if len(payload) == 0 { + return nil, nil + } + nonInputComparable, nonInputErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(payload) + if nonInputErr != nil { + return nil, nonInputErr + } + return &openAIWSIngressPreviousTurnStrictState{ + nonInputComparable: nonInputComparable, + }, nil +} + +func shouldKeepIngressPreviousResponseIDWithStrictState( + previousState *openAIWSIngressPreviousTurnStrictState, + currentPayload []byte, + lastTurnResponseID string, + hasFunctionCallOutput bool, +) (bool, string, error) { + if hasFunctionCallOutput { + return true, "has_function_call_output", nil + } + currentPreviousResponseID := strings.TrimSpace(openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id")) + if currentPreviousResponseID == "" { + return false, "missing_previous_response_id", nil + } + expectedPreviousResponseID := strings.TrimSpace(lastTurnResponseID) + if expectedPreviousResponseID == "" { + return false, "missing_last_turn_response_id", nil + } + if currentPreviousResponseID != expectedPreviousResponseID { + return false, "previous_response_id_mismatch", nil + } + if previousState == nil { + return false, "missing_previous_turn_payload", nil + } + + currentComparable, currentComparableErr := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(currentPayload) + if currentComparableErr != nil { + return false, "non_input_compare_error", currentComparableErr + } + if !bytes.Equal(previousState.nonInputComparable, currentComparable) { + return false, "non_input_changed", nil + } + return true, "strict_incremental_ok", nil +} + +func (s *OpenAIGatewayService) forwardOpenAIWSV2( + ctx context.Context, + c *gin.Context, + account *Account, + reqBody map[string]any, + token string, + decision OpenAIWSProtocolDecision, + isCodexCLI bool, + reqStream bool, + originalModel string, + mappedModel string, + startTime time.Time, + attempt int, + lastFailureReason string, +) (*OpenAIForwardResult, error) { + if s == nil || account == nil { + return nil, wrapOpenAIWSFallback("invalid_state", errors.New("service or account is nil")) + } + + wsURL, err := s.buildOpenAIResponsesWSURL(account) + if err != nil { + return nil, wrapOpenAIWSFallback("build_ws_url", err) + } + wsHost := "-" + wsPath := "-" + if parsed, parseErr := url.Parse(wsURL); parseErr == nil && parsed != nil { + if h := strings.TrimSpace(parsed.Host); h != "" { + wsHost = normalizeOpenAIWSLogValue(h) + } + if p := strings.TrimSpace(parsed.Path); p != "" { + wsPath = normalizeOpenAIWSLogValue(p) + } + } + logOpenAIWSModeDebug( + "dial_target account_id=%d account_type=%s ws_host=%s ws_path=%s", + account.ID, + account.Type, + wsHost, + wsPath, + ) + + payload := s.buildOpenAIWSCreatePayload(reqBody, account) + payloadStrategy, removedKeys := applyOpenAIWSRetryPayloadStrategy(payload, attempt) + previousResponseID := openAIWSPayloadString(payload, "previous_response_id") + previousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(previousResponseID) + promptCacheKey := openAIWSPayloadString(payload, "prompt_cache_key") + _, hasTools := payload["tools"] + debugEnabled := isOpenAIWSModeDebugEnabled() + payloadBytes := -1 + resolvePayloadBytes := func() int { + if payloadBytes >= 0 { + return payloadBytes + } + payloadBytes = len(payloadAsJSONBytes(payload)) + return payloadBytes + } + streamValue := "-" + if raw, ok := payload["stream"]; ok { + streamValue = normalizeOpenAIWSLogValue(strings.TrimSpace(fmt.Sprintf("%v", raw))) + } + turnState := "" + turnMetadata := "" + if c != nil && c.Request != nil { + turnState = strings.TrimSpace(c.GetHeader(openAIWSTurnStateHeader)) + turnMetadata = strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)) + } + setOpenAIWSTurnMetadata(payload, turnMetadata) + payloadEventType := openAIWSPayloadString(payload, "type") + if payloadEventType == "" { + payloadEventType = "response.create" + } + if s.shouldEmitOpenAIWSPayloadSchema(attempt) { + logOpenAIWSModeInfo( + "[debug] payload_schema account_id=%d attempt=%d event=%s payload_keys=%s payload_bytes=%d payload_key_sizes=%s input_summary=%s stream=%s payload_strategy=%s removed_keys=%s has_previous_response_id=%v has_prompt_cache_key=%v has_tools=%v", + account.ID, + attempt, + payloadEventType, + normalizeOpenAIWSLogValue(strings.Join(sortedKeys(payload), ",")), + resolvePayloadBytes(), + normalizeOpenAIWSLogValue(summarizeOpenAIWSPayloadKeySizes(payload, openAIWSPayloadKeySizeTopN)), + normalizeOpenAIWSLogValue(summarizeOpenAIWSInput(payload["input"])), + streamValue, + normalizeOpenAIWSLogValue(payloadStrategy), + normalizeOpenAIWSLogValue(strings.Join(removedKeys, ",")), + previousResponseID != "", + promptCacheKey != "", + hasTools, + ) + } + + stateStore := s.getOpenAIWSStateStore() + groupID := getOpenAIGroupIDFromContext(c) + sessionHash := s.GenerateSessionHash(c, nil) + if sessionHash == "" { + var legacySessionHash string + sessionHash, legacySessionHash = openAIWSSessionHashesFromID(promptCacheKey) + attachOpenAILegacySessionHashToGin(c, legacySessionHash) + } + if turnState == "" && stateStore != nil && sessionHash != "" { + if savedTurnState, ok := stateStore.GetSessionTurnState(groupID, sessionHash); ok { + turnState = savedTurnState + } + } + preferredConnID := "" + if stateStore != nil && previousResponseID != "" { + if connID, ok := stateStore.GetResponseConn(previousResponseID); ok { + preferredConnID = connID + } + } + storeDisabled := s.isOpenAIWSStoreDisabledInRequest(reqBody, account) + if stateStore != nil && storeDisabled && previousResponseID == "" && sessionHash != "" { + if connID, ok := stateStore.GetSessionConn(groupID, sessionHash); ok { + preferredConnID = connID + } + } + storeDisabledConnMode := s.openAIWSStoreDisabledConnMode() + forceNewConnByPolicy := shouldForceNewConnOnStoreDisabled(storeDisabledConnMode, lastFailureReason) + forceNewConn := forceNewConnByPolicy && storeDisabled && previousResponseID == "" && sessionHash != "" && preferredConnID == "" + wsHeaders, sessionResolution := s.buildOpenAIWSHeaders(c, account, token, decision, isCodexCLI, turnState, turnMetadata, promptCacheKey) + logOpenAIWSModeDebug( + "acquire_start account_id=%d account_type=%s transport=%s preferred_conn_id=%s has_previous_response_id=%v session_hash=%s has_turn_state=%v turn_state_len=%d has_turn_metadata=%v turn_metadata_len=%d store_disabled=%v store_disabled_conn_mode=%s retry_last_reason=%s force_new_conn=%v header_user_agent=%s header_openai_beta=%s header_originator=%s header_accept_language=%s header_session_id=%s header_conversation_id=%s session_id_source=%s conversation_id_source=%s has_prompt_cache_key=%v has_chatgpt_account_id=%v has_authorization=%v has_session_id=%v has_conversation_id=%v proxy_enabled=%v", + account.ID, + account.Type, + normalizeOpenAIWSLogValue(string(decision.Transport)), + truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), + previousResponseID != "", + truncateOpenAIWSLogValue(sessionHash, 12), + turnState != "", + len(turnState), + turnMetadata != "", + len(turnMetadata), + storeDisabled, + normalizeOpenAIWSLogValue(storeDisabledConnMode), + truncateOpenAIWSLogValue(lastFailureReason, openAIWSLogValueMaxLen), + forceNewConn, + openAIWSHeaderValueForLog(wsHeaders, "user-agent"), + openAIWSHeaderValueForLog(wsHeaders, "openai-beta"), + openAIWSHeaderValueForLog(wsHeaders, "originator"), + openAIWSHeaderValueForLog(wsHeaders, "accept-language"), + openAIWSHeaderValueForLog(wsHeaders, "session_id"), + openAIWSHeaderValueForLog(wsHeaders, "conversation_id"), + normalizeOpenAIWSLogValue(sessionResolution.SessionSource), + normalizeOpenAIWSLogValue(sessionResolution.ConversationSource), + promptCacheKey != "", + hasOpenAIWSHeader(wsHeaders, "chatgpt-account-id"), + hasOpenAIWSHeader(wsHeaders, "authorization"), + hasOpenAIWSHeader(wsHeaders, "session_id"), + hasOpenAIWSHeader(wsHeaders, "conversation_id"), + account.ProxyID != nil && account.Proxy != nil, + ) + + acquireCtx, acquireCancel := context.WithTimeout(ctx, s.openAIWSAcquireTimeout()) + defer acquireCancel() + + lease, err := s.getOpenAIWSConnPool().Acquire(acquireCtx, openAIWSAcquireRequest{ + Account: account, + WSURL: wsURL, + Headers: wsHeaders, + PreferredConnID: preferredConnID, + ForceNewConn: forceNewConn, + ProxyURL: func() string { + if account.ProxyID != nil && account.Proxy != nil { + return account.Proxy.URL() + } + return "" + }(), + }) + if err != nil { + dialStatus, dialClass, dialCloseStatus, dialCloseReason, dialRespServer, dialRespVia, dialRespCFRay, dialRespReqID := summarizeOpenAIWSDialError(err) + logOpenAIWSModeInfo( + "acquire_fail account_id=%d account_type=%s transport=%s reason=%s dial_status=%d dial_class=%s dial_close_status=%s dial_close_reason=%s dial_resp_server=%s dial_resp_via=%s dial_resp_cf_ray=%s dial_resp_x_request_id=%s cause=%s preferred_conn_id=%s force_new_conn=%v ws_host=%s ws_path=%s proxy_enabled=%v", + account.ID, + account.Type, + normalizeOpenAIWSLogValue(string(decision.Transport)), + normalizeOpenAIWSLogValue(classifyOpenAIWSAcquireError(err)), + dialStatus, + dialClass, + dialCloseStatus, + truncateOpenAIWSLogValue(dialCloseReason, openAIWSHeaderValueMaxLen), + dialRespServer, + dialRespVia, + dialRespCFRay, + dialRespReqID, + truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), + forceNewConn, + wsHost, + wsPath, + account.ProxyID != nil && account.Proxy != nil, + ) + return nil, wrapOpenAIWSFallback(classifyOpenAIWSAcquireError(err), err) + } + defer lease.Release() + connID := strings.TrimSpace(lease.ConnID()) + logOpenAIWSModeDebug( + "connected account_id=%d account_type=%s transport=%s conn_id=%s conn_reused=%v conn_pick_ms=%d queue_wait_ms=%d has_previous_response_id=%v", + account.ID, + account.Type, + normalizeOpenAIWSLogValue(string(decision.Transport)), + connID, + lease.Reused(), + lease.ConnPickDuration().Milliseconds(), + lease.QueueWaitDuration().Milliseconds(), + previousResponseID != "", + ) + if previousResponseID != "" { + logOpenAIWSModeInfo( + "continuation_probe account_id=%d account_type=%s conn_id=%s previous_response_id=%s previous_response_id_kind=%s preferred_conn_id=%s conn_reused=%v store_disabled=%v session_hash=%s header_session_id=%s header_conversation_id=%s session_id_source=%s conversation_id_source=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v", + account.ID, + account.Type, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(previousResponseIDKind), + truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), + lease.Reused(), + storeDisabled, + truncateOpenAIWSLogValue(sessionHash, 12), + openAIWSHeaderValueForLog(wsHeaders, "session_id"), + openAIWSHeaderValueForLog(wsHeaders, "conversation_id"), + normalizeOpenAIWSLogValue(sessionResolution.SessionSource), + normalizeOpenAIWSLogValue(sessionResolution.ConversationSource), + turnState != "", + len(turnState), + promptCacheKey != "", + ) + } + if c != nil { + SetOpsLatencyMs(c, OpsOpenAIWSConnPickMsKey, lease.ConnPickDuration().Milliseconds()) + SetOpsLatencyMs(c, OpsOpenAIWSQueueWaitMsKey, lease.QueueWaitDuration().Milliseconds()) + c.Set(OpsOpenAIWSConnReusedKey, lease.Reused()) + if connID != "" { + c.Set(OpsOpenAIWSConnIDKey, connID) + } + } + + handshakeTurnState := strings.TrimSpace(lease.HandshakeHeader(openAIWSTurnStateHeader)) + logOpenAIWSModeDebug( + "handshake account_id=%d conn_id=%s has_turn_state=%v turn_state_len=%d", + account.ID, + connID, + handshakeTurnState != "", + len(handshakeTurnState), + ) + if handshakeTurnState != "" { + if stateStore != nil && sessionHash != "" { + stateStore.BindSessionTurnState(groupID, sessionHash, handshakeTurnState, s.openAIWSSessionStickyTTL()) + } + if c != nil { + c.Header(http.CanonicalHeaderKey(openAIWSTurnStateHeader), handshakeTurnState) + } + } + + if err := s.performOpenAIWSGeneratePrewarm( + ctx, + lease, + decision, + payload, + previousResponseID, + reqBody, + account, + stateStore, + groupID, + ); err != nil { + return nil, err + } + + if err := lease.WriteJSONWithContextTimeout(ctx, payload, s.openAIWSWriteTimeout()); err != nil { + lease.MarkBroken() + logOpenAIWSModeInfo( + "write_request_fail account_id=%d conn_id=%s cause=%s payload_bytes=%d", + account.ID, + connID, + truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), + resolvePayloadBytes(), + ) + return nil, wrapOpenAIWSFallback("write_request", err) + } + if debugEnabled { + logOpenAIWSModeDebug( + "write_request_sent account_id=%d conn_id=%s stream=%v payload_bytes=%d previous_response_id=%s", + account.ID, + connID, + reqStream, + resolvePayloadBytes(), + truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), + ) + } + + usage := &OpenAIUsage{} + var firstTokenMs *int + responseID := "" + var finalResponse []byte + wroteDownstream := false + needModelReplace := originalModel != mappedModel + var mappedModelBytes []byte + if needModelReplace && mappedModel != "" { + mappedModelBytes = []byte(mappedModel) + } + bufferedStreamEvents := make([][]byte, 0, 4) + eventCount := 0 + tokenEventCount := 0 + terminalEventCount := 0 + bufferedEventCount := 0 + flushedBufferedEventCount := 0 + firstEventType := "" + lastEventType := "" + + var flusher http.Flusher + if reqStream { + if s.responseHeaderFilter != nil { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), http.Header{}, s.responseHeaderFilter) + } + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Header("X-Accel-Buffering", "no") + f, ok := c.Writer.(http.Flusher) + if !ok { + lease.MarkBroken() + return nil, wrapOpenAIWSFallback("streaming_not_supported", errors.New("streaming not supported")) + } + flusher = f + } + + clientDisconnected := false + flushBatchSize := s.openAIWSEventFlushBatchSize() + flushInterval := s.openAIWSEventFlushInterval() + pendingFlushEvents := 0 + lastFlushAt := time.Now() + flushStreamWriter := func(force bool) { + if clientDisconnected || flusher == nil || pendingFlushEvents <= 0 { + return + } + if !force && flushBatchSize > 1 && pendingFlushEvents < flushBatchSize { + if flushInterval <= 0 || time.Since(lastFlushAt) < flushInterval { + return + } + } + flusher.Flush() + pendingFlushEvents = 0 + lastFlushAt = time.Now() + } + emitStreamMessage := func(message []byte, forceFlush bool) { + if clientDisconnected { + return + } + frame := make([]byte, 0, len(message)+8) + frame = append(frame, "data: "...) + frame = append(frame, message...) + frame = append(frame, '\n', '\n') + _, wErr := c.Writer.Write(frame) + if wErr == nil { + wroteDownstream = true + pendingFlushEvents++ + flushStreamWriter(forceFlush) + return + } + clientDisconnected = true + logger.LegacyPrintf("service.openai_gateway", "[OpenAI WS Mode] client disconnected, continue draining upstream: account=%d", account.ID) + } + flushBufferedStreamEvents := func(reason string) { + if len(bufferedStreamEvents) == 0 { + return + } + flushed := len(bufferedStreamEvents) + for _, buffered := range bufferedStreamEvents { + emitStreamMessage(buffered, false) + } + bufferedStreamEvents = bufferedStreamEvents[:0] + flushStreamWriter(true) + flushedBufferedEventCount += flushed + if debugEnabled { + logOpenAIWSModeDebug( + "buffer_flush account_id=%d conn_id=%s reason=%s flushed=%d total_flushed=%d client_disconnected=%v", + account.ID, + connID, + truncateOpenAIWSLogValue(reason, openAIWSLogValueMaxLen), + flushed, + flushedBufferedEventCount, + clientDisconnected, + ) + } + } + + readTimeout := s.openAIWSReadTimeout() + + for { + message, readErr := lease.ReadMessageWithContextTimeout(ctx, readTimeout) + if readErr != nil { + lease.MarkBroken() + closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr) + logOpenAIWSModeInfo( + "read_fail account_id=%d conn_id=%s wrote_downstream=%v close_status=%s close_reason=%s cause=%s events=%d token_events=%d terminal_events=%d buffered_pending=%d buffered_flushed=%d first_event=%s last_event=%s", + account.ID, + connID, + wroteDownstream, + closeStatus, + closeReason, + truncateOpenAIWSLogValue(readErr.Error(), openAIWSLogValueMaxLen), + eventCount, + tokenEventCount, + terminalEventCount, + len(bufferedStreamEvents), + flushedBufferedEventCount, + truncateOpenAIWSLogValue(firstEventType, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(lastEventType, openAIWSLogValueMaxLen), + ) + if !wroteDownstream { + return nil, wrapOpenAIWSFallback(classifyOpenAIWSReadFallbackReason(readErr), readErr) + } + if clientDisconnected { + break + } + setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(readErr.Error()), "") + return nil, fmt.Errorf("openai ws read event: %w", readErr) + } + + eventType, eventResponseID, responseField := parseOpenAIWSEventEnvelope(message) + if eventType == "" { + continue + } + eventCount++ + if firstEventType == "" { + firstEventType = eventType + } + lastEventType = eventType + + if responseID == "" && eventResponseID != "" { + responseID = eventResponseID + } + + isTokenEvent := isOpenAIWSTokenEvent(eventType) + if isTokenEvent { + tokenEventCount++ + } + isTerminalEvent := isOpenAIWSTerminalEvent(eventType) + if isTerminalEvent { + terminalEventCount++ + } + if firstTokenMs == nil && isTokenEvent { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + if debugEnabled && shouldLogOpenAIWSEvent(eventCount, eventType) { + logOpenAIWSModeDebug( + "event_received account_id=%d conn_id=%s idx=%d type=%s bytes=%d token=%v terminal=%v buffered_pending=%d", + account.ID, + connID, + eventCount, + truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen), + len(message), + isTokenEvent, + isTerminalEvent, + len(bufferedStreamEvents), + ) + } + + if !clientDisconnected { + if needModelReplace && len(mappedModelBytes) > 0 && openAIWSEventMayContainModel(eventType) && bytes.Contains(message, mappedModelBytes) { + message = replaceOpenAIWSMessageModel(message, mappedModel, originalModel) + } + if openAIWSEventMayContainToolCalls(eventType) && openAIWSMessageLikelyContainsToolCalls(message) { + if corrected, changed := s.toolCorrector.CorrectToolCallsInSSEBytes(message); changed { + message = corrected + } + } + } + if openAIWSEventShouldParseUsage(eventType) { + parseOpenAIWSResponseUsageFromCompletedEvent(message, usage) + } + + if eventType == "error" { + errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message) + errMsg := strings.TrimSpace(errMsgRaw) + if errMsg == "" { + errMsg = "Upstream websocket error" + } + fallbackReason, canFallback := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) + errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) + logOpenAIWSModeInfo( + "error_event account_id=%d conn_id=%s idx=%d fallback_reason=%s can_fallback=%v err_code=%s err_type=%s err_message=%s", + account.ID, + connID, + eventCount, + truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen), + canFallback, + errCode, + errType, + errMessage, + ) + if fallbackReason == "previous_response_not_found" { + logOpenAIWSModeInfo( + "previous_response_not_found_diag account_id=%d account_type=%s conn_id=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s event_idx=%d req_stream=%v store_disabled=%v conn_reused=%v session_hash=%s header_session_id=%s header_conversation_id=%s session_id_source=%s conversation_id_source=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v err_code=%s err_type=%s err_message=%s", + account.ID, + account.Type, + connID, + truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(previousResponseIDKind), + truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen), + eventCount, + reqStream, + storeDisabled, + lease.Reused(), + truncateOpenAIWSLogValue(sessionHash, 12), + openAIWSHeaderValueForLog(wsHeaders, "session_id"), + openAIWSHeaderValueForLog(wsHeaders, "conversation_id"), + normalizeOpenAIWSLogValue(sessionResolution.SessionSource), + normalizeOpenAIWSLogValue(sessionResolution.ConversationSource), + turnState != "", + len(turnState), + promptCacheKey != "", + errCode, + errType, + errMessage, + ) + } + // error 事件后连接不再可复用,避免回池后污染下一请求。 + lease.MarkBroken() + if !wroteDownstream && canFallback { + return nil, wrapOpenAIWSFallback(fallbackReason, errors.New(errMsg)) + } + statusCode := openAIWSErrorHTTPStatusFromRaw(errCodeRaw, errTypeRaw) + setOpsUpstreamError(c, statusCode, errMsg, "") + if reqStream && !clientDisconnected { + flushBufferedStreamEvents("error_event") + emitStreamMessage(message, true) + } + if !reqStream { + c.JSON(statusCode, gin.H{ + "error": gin.H{ + "type": "upstream_error", + "message": errMsg, + }, + }) + } + return nil, fmt.Errorf("openai ws error event: %s", errMsg) + } + + if reqStream { + // 在首个 token 前先缓冲事件(如 response.created), + // 以便上游早期断连时仍可安全回退到 HTTP,不给下游发送半截流。 + shouldBuffer := firstTokenMs == nil && !isTokenEvent && !isTerminalEvent + if shouldBuffer { + buffered := make([]byte, len(message)) + copy(buffered, message) + bufferedStreamEvents = append(bufferedStreamEvents, buffered) + bufferedEventCount++ + if debugEnabled && shouldLogOpenAIWSBufferedEvent(bufferedEventCount) { + logOpenAIWSModeDebug( + "buffer_enqueue account_id=%d conn_id=%s idx=%d event_idx=%d event_type=%s buffer_size=%d", + account.ID, + connID, + bufferedEventCount, + eventCount, + truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen), + len(bufferedStreamEvents), + ) + } + } else { + flushBufferedStreamEvents(eventType) + emitStreamMessage(message, isTerminalEvent) + } + } else { + if responseField.Exists() && responseField.Type == gjson.JSON { + finalResponse = []byte(responseField.Raw) + } + } + + if isTerminalEvent { + break + } + } + + if !reqStream { + if len(finalResponse) == 0 { + logOpenAIWSModeInfo( + "missing_final_response account_id=%d conn_id=%s events=%d token_events=%d terminal_events=%d wrote_downstream=%v", + account.ID, + connID, + eventCount, + tokenEventCount, + terminalEventCount, + wroteDownstream, + ) + if !wroteDownstream { + return nil, wrapOpenAIWSFallback("missing_final_response", errors.New("no terminal response payload")) + } + return nil, errors.New("ws finished without final response") + } + + if needModelReplace { + finalResponse = s.replaceModelInResponseBody(finalResponse, mappedModel, originalModel) + } + finalResponse = s.correctToolCallsInResponseBody(finalResponse) + populateOpenAIUsageFromResponseJSON(finalResponse, usage) + if responseID == "" { + responseID = strings.TrimSpace(gjson.GetBytes(finalResponse, "id").String()) + } + + c.Data(http.StatusOK, "application/json", finalResponse) + } else { + flushStreamWriter(true) + } + + if responseID != "" && stateStore != nil { + ttl := s.openAIWSResponseStickyTTL() + logOpenAIWSBindResponseAccountWarn(groupID, account.ID, responseID, stateStore.BindResponseAccount(ctx, groupID, responseID, account.ID, ttl)) + stateStore.BindResponseConn(responseID, lease.ConnID(), ttl) + } + if stateStore != nil && storeDisabled && sessionHash != "" { + stateStore.BindSessionConn(groupID, sessionHash, lease.ConnID(), s.openAIWSSessionStickyTTL()) + } + firstTokenMsValue := -1 + if firstTokenMs != nil { + firstTokenMsValue = *firstTokenMs + } + logOpenAIWSModeDebug( + "completed account_id=%d conn_id=%s response_id=%s stream=%v duration_ms=%d events=%d token_events=%d terminal_events=%d buffered_events=%d buffered_flushed=%d first_event=%s last_event=%s first_token_ms=%d wrote_downstream=%v client_disconnected=%v", + account.ID, + connID, + truncateOpenAIWSLogValue(strings.TrimSpace(responseID), openAIWSIDValueMaxLen), + reqStream, + time.Since(startTime).Milliseconds(), + eventCount, + tokenEventCount, + terminalEventCount, + bufferedEventCount, + flushedBufferedEventCount, + truncateOpenAIWSLogValue(firstEventType, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(lastEventType, openAIWSLogValueMaxLen), + firstTokenMsValue, + wroteDownstream, + clientDisconnected, + ) + + return &OpenAIForwardResult{ + RequestID: responseID, + Usage: *usage, + Model: originalModel, + ReasoningEffort: extractOpenAIReasoningEffort(reqBody, originalModel), + Stream: reqStream, + OpenAIWSMode: true, + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + }, nil +} + +// ProxyResponsesWebSocketFromClient 处理客户端入站 WebSocket(OpenAI Responses WS Mode)并转发到上游。 +// 当前实现按“单请求 -> 终止事件 -> 下一请求”的顺序代理,适配 Codex CLI 的 turn 模式。 +func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient( + ctx context.Context, + c *gin.Context, + clientConn *coderws.Conn, + account *Account, + token string, + firstClientMessage []byte, + hooks *OpenAIWSIngressHooks, +) error { + if s == nil { + return errors.New("service is nil") + } + if c == nil { + return errors.New("gin context is nil") + } + if clientConn == nil { + return errors.New("client websocket is nil") + } + if account == nil { + return errors.New("account is nil") + } + if strings.TrimSpace(token) == "" { + return errors.New("token is empty") + } + + wsDecision := s.getOpenAIWSProtocolResolver().Resolve(account) + modeRouterV2Enabled := s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled + ingressMode := OpenAIWSIngressModeShared + if modeRouterV2Enabled { + ingressMode = account.ResolveOpenAIResponsesWebSocketV2Mode(s.cfg.Gateway.OpenAIWS.IngressModeDefault) + if ingressMode == OpenAIWSIngressModeOff { + return NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "websocket mode is disabled for this account", + nil, + ) + } + } + if wsDecision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { + return fmt.Errorf("websocket ingress requires ws_v2 transport, got=%s", wsDecision.Transport) + } + dedicatedMode := modeRouterV2Enabled && ingressMode == OpenAIWSIngressModeDedicated + + wsURL, err := s.buildOpenAIResponsesWSURL(account) + if err != nil { + return fmt.Errorf("build ws url: %w", err) + } + wsHost := "-" + wsPath := "-" + if parsedURL, parseErr := url.Parse(wsURL); parseErr == nil && parsedURL != nil { + wsHost = normalizeOpenAIWSLogValue(parsedURL.Host) + wsPath = normalizeOpenAIWSLogValue(parsedURL.Path) + } + debugEnabled := isOpenAIWSModeDebugEnabled() + + type openAIWSClientPayload struct { + payloadRaw []byte + rawForHash []byte + promptCacheKey string + previousResponseID string + originalModel string + payloadBytes int + } + + applyPayloadMutation := func(current []byte, path string, value any) ([]byte, error) { + next, err := sjson.SetBytes(current, path, value) + if err == nil { + return next, nil + } + + // 仅在确实需要修改 payload 且 sjson 失败时,退回 map 路径确保兼容性。 + payload := make(map[string]any) + if unmarshalErr := json.Unmarshal(current, &payload); unmarshalErr != nil { + return nil, err + } + switch path { + case "type", "model": + payload[path] = value + case "client_metadata." + openAIWSTurnMetadataHeader: + setOpenAIWSTurnMetadata(payload, fmt.Sprintf("%v", value)) + default: + return nil, err + } + rebuilt, marshalErr := json.Marshal(payload) + if marshalErr != nil { + return nil, marshalErr + } + return rebuilt, nil + } + + parseClientPayload := func(raw []byte) (openAIWSClientPayload, error) { + trimmed := bytes.TrimSpace(raw) + if len(trimmed) == 0 { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "empty websocket request payload", nil) + } + if !gjson.ValidBytes(trimmed) { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", errors.New("invalid json")) + } + + values := gjson.GetManyBytes(trimmed, "type", "model", "prompt_cache_key", "previous_response_id") + eventType := strings.TrimSpace(values[0].String()) + normalized := trimmed + switch eventType { + case "": + eventType = "response.create" + next, setErr := applyPayloadMutation(normalized, "type", eventType) + if setErr != nil { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr) + } + normalized = next + case "response.create": + case "response.append": + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "response.append is not supported in ws v2; use response.create with previous_response_id", + nil, + ) + default: + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + fmt.Sprintf("unsupported websocket request type: %s", eventType), + nil, + ) + } + + originalModel := strings.TrimSpace(values[1].String()) + if originalModel == "" { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "model is required in response.create payload", + nil, + ) + } + promptCacheKey := strings.TrimSpace(values[2].String()) + previousResponseID := strings.TrimSpace(values[3].String()) + previousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(previousResponseID) + if previousResponseID != "" && previousResponseIDKind == OpenAIPreviousResponseIDKindMessageID { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "previous_response_id must be a response.id (resp_*), not a message id", + nil, + ) + } + if turnMetadata := strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)); turnMetadata != "" { + next, setErr := applyPayloadMutation(normalized, "client_metadata."+openAIWSTurnMetadataHeader, turnMetadata) + if setErr != nil { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr) + } + normalized = next + } + mappedModel := account.GetMappedModel(originalModel) + if normalizedModel := normalizeCodexModel(mappedModel); normalizedModel != "" { + mappedModel = normalizedModel + } + if mappedModel != originalModel { + next, setErr := applyPayloadMutation(normalized, "model", mappedModel) + if setErr != nil { + return openAIWSClientPayload{}, NewOpenAIWSClientCloseError(coderws.StatusPolicyViolation, "invalid websocket request payload", setErr) + } + normalized = next + } + + return openAIWSClientPayload{ + payloadRaw: normalized, + rawForHash: trimmed, + promptCacheKey: promptCacheKey, + previousResponseID: previousResponseID, + originalModel: originalModel, + payloadBytes: len(normalized), + }, nil + } + + firstPayload, err := parseClientPayload(firstClientMessage) + if err != nil { + return err + } + + turnState := strings.TrimSpace(c.GetHeader(openAIWSTurnStateHeader)) + stateStore := s.getOpenAIWSStateStore() + groupID := getOpenAIGroupIDFromContext(c) + sessionHash := s.GenerateSessionHash(c, firstPayload.rawForHash) + if turnState == "" && stateStore != nil && sessionHash != "" { + if savedTurnState, ok := stateStore.GetSessionTurnState(groupID, sessionHash); ok { + turnState = savedTurnState + } + } + + preferredConnID := "" + if stateStore != nil && firstPayload.previousResponseID != "" { + if connID, ok := stateStore.GetResponseConn(firstPayload.previousResponseID); ok { + preferredConnID = connID + } + } + + storeDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(firstPayload.payloadRaw, account) + storeDisabledConnMode := s.openAIWSStoreDisabledConnMode() + if stateStore != nil && storeDisabled && firstPayload.previousResponseID == "" && sessionHash != "" { + if connID, ok := stateStore.GetSessionConn(groupID, sessionHash); ok { + preferredConnID = connID + } + } + + isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) + wsHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), firstPayload.promptCacheKey) + baseAcquireReq := openAIWSAcquireRequest{ + Account: account, + WSURL: wsURL, + Headers: wsHeaders, + ProxyURL: func() string { + if account.ProxyID != nil && account.Proxy != nil { + return account.Proxy.URL() + } + return "" + }(), + ForceNewConn: false, + } + pool := s.getOpenAIWSConnPool() + if pool == nil { + return errors.New("openai ws conn pool is nil") + } + + logOpenAIWSModeInfo( + "ingress_ws_protocol_confirm account_id=%d account_type=%s transport=%s ws_host=%s ws_path=%s ws_mode=%s store_disabled=%v has_session_hash=%v has_previous_response_id=%v", + account.ID, + account.Type, + normalizeOpenAIWSLogValue(string(wsDecision.Transport)), + wsHost, + wsPath, + normalizeOpenAIWSLogValue(ingressMode), + storeDisabled, + sessionHash != "", + firstPayload.previousResponseID != "", + ) + + if debugEnabled { + logOpenAIWSModeDebug( + "ingress_ws_start account_id=%d account_type=%s transport=%s ws_host=%s preferred_conn_id=%s has_session_hash=%v has_previous_response_id=%v store_disabled=%v", + account.ID, + account.Type, + normalizeOpenAIWSLogValue(string(wsDecision.Transport)), + wsHost, + truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), + sessionHash != "", + firstPayload.previousResponseID != "", + storeDisabled, + ) + } + if firstPayload.previousResponseID != "" { + firstPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(firstPayload.previousResponseID) + logOpenAIWSModeInfo( + "ingress_ws_continuation_probe account_id=%d turn=%d previous_response_id=%s previous_response_id_kind=%s preferred_conn_id=%s session_hash=%s header_session_id=%s header_conversation_id=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v store_disabled=%v", + account.ID, + 1, + truncateOpenAIWSLogValue(firstPayload.previousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(firstPreviousResponseIDKind), + truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(sessionHash, 12), + openAIWSHeaderValueForLog(baseAcquireReq.Headers, "session_id"), + openAIWSHeaderValueForLog(baseAcquireReq.Headers, "conversation_id"), + turnState != "", + len(turnState), + firstPayload.promptCacheKey != "", + storeDisabled, + ) + } + + acquireTimeout := s.openAIWSAcquireTimeout() + if acquireTimeout <= 0 { + acquireTimeout = 30 * time.Second + } + + acquireTurnLease := func(turn int, preferred string, forcePreferredConn bool) (*openAIWSConnLease, error) { + req := cloneOpenAIWSAcquireRequest(baseAcquireReq) + req.PreferredConnID = strings.TrimSpace(preferred) + req.ForcePreferredConn = forcePreferredConn + // dedicated 模式下每次获取均新建连接,避免跨会话复用残留上下文。 + req.ForceNewConn = dedicatedMode + acquireCtx, acquireCancel := context.WithTimeout(ctx, acquireTimeout) + lease, acquireErr := pool.Acquire(acquireCtx, req) + acquireCancel() + if acquireErr != nil { + dialStatus, dialClass, dialCloseStatus, dialCloseReason, dialRespServer, dialRespVia, dialRespCFRay, dialRespReqID := summarizeOpenAIWSDialError(acquireErr) + logOpenAIWSModeInfo( + "ingress_ws_upstream_acquire_fail account_id=%d turn=%d reason=%s dial_status=%d dial_class=%s dial_close_status=%s dial_close_reason=%s dial_resp_server=%s dial_resp_via=%s dial_resp_cf_ray=%s dial_resp_x_request_id=%s cause=%s preferred_conn_id=%s force_preferred_conn=%v ws_host=%s ws_path=%s proxy_enabled=%v", + account.ID, + turn, + normalizeOpenAIWSLogValue(classifyOpenAIWSAcquireError(acquireErr)), + dialStatus, + dialClass, + dialCloseStatus, + truncateOpenAIWSLogValue(dialCloseReason, openAIWSHeaderValueMaxLen), + dialRespServer, + dialRespVia, + dialRespCFRay, + dialRespReqID, + truncateOpenAIWSLogValue(acquireErr.Error(), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(preferred, openAIWSIDValueMaxLen), + forcePreferredConn, + wsHost, + wsPath, + account.ProxyID != nil && account.Proxy != nil, + ) + if errors.Is(acquireErr, errOpenAIWSPreferredConnUnavailable) { + return nil, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "upstream continuation connection is unavailable; please restart the conversation", + acquireErr, + ) + } + if errors.Is(acquireErr, context.DeadlineExceeded) || errors.Is(acquireErr, errOpenAIWSConnQueueFull) { + return nil, NewOpenAIWSClientCloseError( + coderws.StatusTryAgainLater, + "upstream websocket is busy, please retry later", + acquireErr, + ) + } + return nil, acquireErr + } + connID := strings.TrimSpace(lease.ConnID()) + if handshakeTurnState := strings.TrimSpace(lease.HandshakeHeader(openAIWSTurnStateHeader)); handshakeTurnState != "" { + turnState = handshakeTurnState + if stateStore != nil && sessionHash != "" { + stateStore.BindSessionTurnState(groupID, sessionHash, handshakeTurnState, s.openAIWSSessionStickyTTL()) + } + updatedHeaders := cloneHeader(baseAcquireReq.Headers) + if updatedHeaders == nil { + updatedHeaders = make(http.Header) + } + updatedHeaders.Set(openAIWSTurnStateHeader, handshakeTurnState) + baseAcquireReq.Headers = updatedHeaders + } + logOpenAIWSModeInfo( + "ingress_ws_upstream_connected account_id=%d turn=%d conn_id=%s conn_reused=%v conn_pick_ms=%d queue_wait_ms=%d preferred_conn_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + lease.Reused(), + lease.ConnPickDuration().Milliseconds(), + lease.QueueWaitDuration().Milliseconds(), + truncateOpenAIWSLogValue(preferred, openAIWSIDValueMaxLen), + ) + return lease, nil + } + + writeClientMessage := func(message []byte) error { + writeCtx, cancel := context.WithTimeout(ctx, s.openAIWSWriteTimeout()) + defer cancel() + return clientConn.Write(writeCtx, coderws.MessageText, message) + } + + readClientMessage := func() ([]byte, error) { + msgType, payload, readErr := clientConn.Read(ctx) + if readErr != nil { + return nil, readErr + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + return nil, NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + fmt.Sprintf("unsupported websocket client message type: %s", msgType.String()), + nil, + ) + } + return payload, nil + } + + sendAndRelay := func(turn int, lease *openAIWSConnLease, payload []byte, payloadBytes int, originalModel string) (*OpenAIForwardResult, error) { + if lease == nil { + return nil, errors.New("upstream websocket lease is nil") + } + turnStart := time.Now() + wroteDownstream := false + if err := lease.WriteJSONWithContextTimeout(ctx, json.RawMessage(payload), s.openAIWSWriteTimeout()); err != nil { + return nil, wrapOpenAIWSIngressTurnError( + "write_upstream", + fmt.Errorf("write upstream websocket request: %w", err), + false, + ) + } + if debugEnabled { + logOpenAIWSModeDebug( + "ingress_ws_turn_request_sent account_id=%d turn=%d conn_id=%s payload_bytes=%d", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + payloadBytes, + ) + } + + responseID := "" + usage := OpenAIUsage{} + var firstTokenMs *int + reqStream := openAIWSPayloadBoolFromRaw(payload, "stream", true) + turnPreviousResponseID := openAIWSPayloadStringFromRaw(payload, "previous_response_id") + turnPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(turnPreviousResponseID) + turnPromptCacheKey := openAIWSPayloadStringFromRaw(payload, "prompt_cache_key") + turnStoreDisabled := s.isOpenAIWSStoreDisabledInRequestRaw(payload, account) + turnHasFunctionCallOutput := gjson.GetBytes(payload, `input.#(type=="function_call_output")`).Exists() + eventCount := 0 + tokenEventCount := 0 + terminalEventCount := 0 + firstEventType := "" + lastEventType := "" + needModelReplace := false + clientDisconnected := false + mappedModel := "" + var mappedModelBytes []byte + if originalModel != "" { + mappedModel = account.GetMappedModel(originalModel) + if normalizedModel := normalizeCodexModel(mappedModel); normalizedModel != "" { + mappedModel = normalizedModel + } + needModelReplace = mappedModel != "" && mappedModel != originalModel + if needModelReplace { + mappedModelBytes = []byte(mappedModel) + } + } + for { + upstreamMessage, readErr := lease.ReadMessageWithContextTimeout(ctx, s.openAIWSReadTimeout()) + if readErr != nil { + lease.MarkBroken() + return nil, wrapOpenAIWSIngressTurnError( + "read_upstream", + fmt.Errorf("read upstream websocket event: %w", readErr), + wroteDownstream, + ) + } + + eventType, eventResponseID, _ := parseOpenAIWSEventEnvelope(upstreamMessage) + if responseID == "" && eventResponseID != "" { + responseID = eventResponseID + } + if eventType != "" { + eventCount++ + if firstEventType == "" { + firstEventType = eventType + } + lastEventType = eventType + } + if eventType == "error" { + errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(upstreamMessage) + fallbackReason, _ := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) + errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) + recoverablePrevNotFound := fallbackReason == openAIWSIngressStagePreviousResponseNotFound && + turnPreviousResponseID != "" && + !turnHasFunctionCallOutput && + s.openAIWSIngressPreviousResponseRecoveryEnabled() && + !wroteDownstream + if recoverablePrevNotFound { + // 可恢复场景使用非 error 关键字日志,避免被 LegacyPrintf 误判为 ERROR 级别。 + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recoverable account_id=%d turn=%d conn_id=%s idx=%d reason=%s code=%s type=%s message=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s store_disabled=%v has_prompt_cache_key=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + eventCount, + truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen), + errCode, + errType, + errMessage, + truncateOpenAIWSLogValue(turnPreviousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(turnPreviousResponseIDKind), + truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen), + turnStoreDisabled, + turnPromptCacheKey != "", + ) + } else { + logOpenAIWSModeInfo( + "ingress_ws_error_event account_id=%d turn=%d conn_id=%s idx=%d fallback_reason=%s err_code=%s err_type=%s err_message=%s previous_response_id=%s previous_response_id_kind=%s response_id=%s store_disabled=%v has_prompt_cache_key=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + eventCount, + truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen), + errCode, + errType, + errMessage, + truncateOpenAIWSLogValue(turnPreviousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(turnPreviousResponseIDKind), + truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen), + turnStoreDisabled, + turnPromptCacheKey != "", + ) + } + // previous_response_not_found 在 ingress 模式支持单次恢复重试: + // 不把该 error 直接下发客户端,而是由上层去掉 previous_response_id 后重放当前 turn。 + if recoverablePrevNotFound { + lease.MarkBroken() + errMsg := strings.TrimSpace(errMsgRaw) + if errMsg == "" { + errMsg = "previous response not found" + } + return nil, wrapOpenAIWSIngressTurnError( + openAIWSIngressStagePreviousResponseNotFound, + errors.New(errMsg), + false, + ) + } + } + isTokenEvent := isOpenAIWSTokenEvent(eventType) + if isTokenEvent { + tokenEventCount++ + } + isTerminalEvent := isOpenAIWSTerminalEvent(eventType) + if isTerminalEvent { + terminalEventCount++ + } + if firstTokenMs == nil && isTokenEvent { + ms := int(time.Since(turnStart).Milliseconds()) + firstTokenMs = &ms + } + if openAIWSEventShouldParseUsage(eventType) { + parseOpenAIWSResponseUsageFromCompletedEvent(upstreamMessage, &usage) + } + + if !clientDisconnected { + if needModelReplace && len(mappedModelBytes) > 0 && openAIWSEventMayContainModel(eventType) && bytes.Contains(upstreamMessage, mappedModelBytes) { + upstreamMessage = replaceOpenAIWSMessageModel(upstreamMessage, mappedModel, originalModel) + } + if openAIWSEventMayContainToolCalls(eventType) && openAIWSMessageLikelyContainsToolCalls(upstreamMessage) { + if corrected, changed := s.toolCorrector.CorrectToolCallsInSSEBytes(upstreamMessage); changed { + upstreamMessage = corrected + } + } + if err := writeClientMessage(upstreamMessage); err != nil { + if isOpenAIWSClientDisconnectError(err) { + clientDisconnected = true + closeStatus, closeReason := summarizeOpenAIWSReadCloseError(err) + logOpenAIWSModeInfo( + "ingress_ws_client_disconnected_drain account_id=%d turn=%d conn_id=%s close_status=%s close_reason=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + closeStatus, + truncateOpenAIWSLogValue(closeReason, openAIWSHeaderValueMaxLen), + ) + } else { + return nil, wrapOpenAIWSIngressTurnError( + "write_client", + fmt.Errorf("write client websocket event: %w", err), + wroteDownstream, + ) + } + } else { + wroteDownstream = true + } + } + if isTerminalEvent { + // 客户端已断连时,上游连接的 session 状态不可信,标记 broken 避免回池复用。 + if clientDisconnected { + lease.MarkBroken() + } + firstTokenMsValue := -1 + if firstTokenMs != nil { + firstTokenMsValue = *firstTokenMs + } + if debugEnabled { + logOpenAIWSModeDebug( + "ingress_ws_turn_completed account_id=%d turn=%d conn_id=%s response_id=%s duration_ms=%d events=%d token_events=%d terminal_events=%d first_event=%s last_event=%s first_token_ms=%d client_disconnected=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(lease.ConnID(), openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(responseID, openAIWSIDValueMaxLen), + time.Since(turnStart).Milliseconds(), + eventCount, + tokenEventCount, + terminalEventCount, + truncateOpenAIWSLogValue(firstEventType, openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(lastEventType, openAIWSLogValueMaxLen), + firstTokenMsValue, + clientDisconnected, + ) + } + return &OpenAIForwardResult{ + RequestID: responseID, + Usage: usage, + Model: originalModel, + ReasoningEffort: extractOpenAIReasoningEffortFromBody(payload, originalModel), + Stream: reqStream, + OpenAIWSMode: true, + Duration: time.Since(turnStart), + FirstTokenMs: firstTokenMs, + }, nil + } + } + } + + currentPayload := firstPayload.payloadRaw + currentOriginalModel := firstPayload.originalModel + currentPayloadBytes := firstPayload.payloadBytes + isStrictAffinityTurn := func(payload []byte) bool { + if !storeDisabled { + return false + } + return strings.TrimSpace(openAIWSPayloadStringFromRaw(payload, "previous_response_id")) != "" + } + var sessionLease *openAIWSConnLease + sessionConnID := "" + pinnedSessionConnID := "" + unpinSessionConn := func(connID string) { + connID = strings.TrimSpace(connID) + if connID == "" || pinnedSessionConnID != connID { + return + } + pool.UnpinConn(account.ID, connID) + pinnedSessionConnID = "" + } + pinSessionConn := func(connID string) { + if !storeDisabled { + return + } + connID = strings.TrimSpace(connID) + if connID == "" || pinnedSessionConnID == connID { + return + } + if pinnedSessionConnID != "" { + pool.UnpinConn(account.ID, pinnedSessionConnID) + pinnedSessionConnID = "" + } + if pool.PinConn(account.ID, connID) { + pinnedSessionConnID = connID + } + } + releaseSessionLease := func() { + if sessionLease == nil { + return + } + if dedicatedMode { + // dedicated 会话结束后主动标记损坏,确保连接不会跨会话复用。 + sessionLease.MarkBroken() + } + unpinSessionConn(sessionConnID) + sessionLease.Release() + if debugEnabled { + logOpenAIWSModeDebug( + "ingress_ws_upstream_released account_id=%d conn_id=%s", + account.ID, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + ) + } + } + defer releaseSessionLease() + + turn := 1 + turnRetry := 0 + turnPrevRecoveryTried := false + lastTurnFinishedAt := time.Time{} + lastTurnResponseID := "" + lastTurnPayload := []byte(nil) + var lastTurnStrictState *openAIWSIngressPreviousTurnStrictState + lastTurnReplayInput := []json.RawMessage(nil) + lastTurnReplayInputExists := false + currentTurnReplayInput := []json.RawMessage(nil) + currentTurnReplayInputExists := false + skipBeforeTurn := false + resetSessionLease := func(markBroken bool) { + if sessionLease == nil { + return + } + if markBroken { + sessionLease.MarkBroken() + } + releaseSessionLease() + sessionLease = nil + sessionConnID = "" + preferredConnID = "" + } + recoverIngressPrevResponseNotFound := func(relayErr error, turn int, connID string) bool { + if !isOpenAIWSIngressPreviousResponseNotFound(relayErr) { + return false + } + if turnPrevRecoveryTried || !s.openAIWSIngressPreviousResponseRecoveryEnabled() { + return false + } + if isStrictAffinityTurn(currentPayload) { + // Layer 2:严格亲和链路命中 previous_response_not_found 时,降级为“去掉 previous_response_id 后重放一次”。 + // 该错误说明续链锚点已失效,继续 strict fail-close 只会直接中断本轮请求。 + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery_layer2 account_id=%d turn=%d conn_id=%s store_disabled_conn_mode=%s action=drop_previous_response_id_retry", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(storeDisabledConnMode), + ) + } + turnPrevRecoveryTried = true + updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload) + if dropErr != nil || !removed { + reason := "not_removed" + if dropErr != nil { + reason = "drop_error" + } + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(reason), + ) + return false + } + updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence( + updatedPayload, + currentTurnReplayInput, + currentTurnReplayInputExists, + ) + if setInputErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_full_input_error cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen), + ) + return false + } + logOpenAIWSModeInfo( + "ingress_ws_prev_response_recovery account_id=%d turn=%d conn_id=%s action=drop_previous_response_id retry=1", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + ) + currentPayload = updatedWithInput + currentPayloadBytes = len(updatedWithInput) + resetSessionLease(true) + skipBeforeTurn = true + return true + } + retryIngressTurn := func(relayErr error, turn int, connID string) bool { + if !isOpenAIWSIngressTurnRetryable(relayErr) || turnRetry >= 1 { + return false + } + if isStrictAffinityTurn(currentPayload) { + logOpenAIWSModeInfo( + "ingress_ws_turn_retry_skip account_id=%d turn=%d conn_id=%s reason=strict_affinity", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + ) + return false + } + turnRetry++ + logOpenAIWSModeInfo( + "ingress_ws_turn_retry account_id=%d turn=%d retry=%d reason=%s conn_id=%s", + account.ID, + turn, + turnRetry, + truncateOpenAIWSLogValue(openAIWSIngressTurnRetryReason(relayErr), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + ) + resetSessionLease(true) + skipBeforeTurn = true + return true + } + for { + if !skipBeforeTurn && hooks != nil && hooks.BeforeTurn != nil { + if err := hooks.BeforeTurn(turn); err != nil { + return err + } + } + skipBeforeTurn = false + currentPreviousResponseID := openAIWSPayloadStringFromRaw(currentPayload, "previous_response_id") + expectedPrev := strings.TrimSpace(lastTurnResponseID) + hasFunctionCallOutput := gjson.GetBytes(currentPayload, `input.#(type=="function_call_output")`).Exists() + // store=false + function_call_output 场景必须有续链锚点。 + // 若客户端未传 previous_response_id,优先回填上一轮响应 ID,避免上游报 call_id 无法关联。 + if shouldInferIngressFunctionCallOutputPreviousResponseID( + storeDisabled, + turn, + hasFunctionCallOutput, + currentPreviousResponseID, + expectedPrev, + ) { + updatedPayload, setPrevErr := setPreviousResponseIDToRawPayload(currentPayload, expectedPrev) + if setPrevErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_function_call_output_prev_infer_skip account_id=%d turn=%d conn_id=%s reason=set_previous_response_id_error cause=%s expected_previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(setPrevErr.Error(), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + ) + } else { + currentPayload = updatedPayload + currentPayloadBytes = len(updatedPayload) + currentPreviousResponseID = expectedPrev + logOpenAIWSModeInfo( + "ingress_ws_function_call_output_prev_infer account_id=%d turn=%d conn_id=%s action=set_previous_response_id previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + ) + } + } + nextReplayInput, nextReplayInputExists, replayInputErr := buildOpenAIWSReplayInputSequence( + lastTurnReplayInput, + lastTurnReplayInputExists, + currentPayload, + currentPreviousResponseID != "", + ) + if replayInputErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_replay_input_skip account_id=%d turn=%d conn_id=%s reason=build_error cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(replayInputErr.Error(), openAIWSLogValueMaxLen), + ) + currentTurnReplayInput = nil + currentTurnReplayInputExists = false + } else { + currentTurnReplayInput = nextReplayInput + currentTurnReplayInputExists = nextReplayInputExists + } + if storeDisabled && turn > 1 && currentPreviousResponseID != "" { + shouldKeepPreviousResponseID := false + strictReason := "" + var strictErr error + if lastTurnStrictState != nil { + shouldKeepPreviousResponseID, strictReason, strictErr = shouldKeepIngressPreviousResponseIDWithStrictState( + lastTurnStrictState, + currentPayload, + lastTurnResponseID, + hasFunctionCallOutput, + ) + } else { + shouldKeepPreviousResponseID, strictReason, strictErr = shouldKeepIngressPreviousResponseID( + lastTurnPayload, + currentPayload, + lastTurnResponseID, + hasFunctionCallOutput, + ) + } + if strictErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=keep_previous_response_id reason=%s cause=%s previous_response_id=%s expected_previous_response_id=%s has_function_call_output=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(strictReason), + truncateOpenAIWSLogValue(strictErr.Error(), openAIWSLogValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + hasFunctionCallOutput, + ) + } else if !shouldKeepPreviousResponseID { + updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload) + if dropErr != nil || !removed { + dropReason := "not_removed" + if dropErr != nil { + dropReason = "drop_error" + } + logOpenAIWSModeInfo( + "ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=keep_previous_response_id reason=%s drop_reason=%s previous_response_id=%s expected_previous_response_id=%s has_function_call_output=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(strictReason), + normalizeOpenAIWSLogValue(dropReason), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + hasFunctionCallOutput, + ) + } else { + updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence( + updatedPayload, + currentTurnReplayInput, + currentTurnReplayInputExists, + ) + if setInputErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=keep_previous_response_id reason=%s drop_reason=set_full_input_error previous_response_id=%s expected_previous_response_id=%s cause=%s has_function_call_output=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(strictReason), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen), + hasFunctionCallOutput, + ) + } else { + currentPayload = updatedWithInput + currentPayloadBytes = len(updatedWithInput) + logOpenAIWSModeInfo( + "ingress_ws_prev_response_strict_eval account_id=%d turn=%d conn_id=%s action=drop_previous_response_id_full_create reason=%s previous_response_id=%s expected_previous_response_id=%s has_function_call_output=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(strictReason), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + hasFunctionCallOutput, + ) + currentPreviousResponseID = "" + } + } + } + } + forcePreferredConn := isStrictAffinityTurn(currentPayload) + if sessionLease == nil { + acquiredLease, acquireErr := acquireTurnLease(turn, preferredConnID, forcePreferredConn) + if acquireErr != nil { + return fmt.Errorf("acquire upstream websocket: %w", acquireErr) + } + sessionLease = acquiredLease + sessionConnID = strings.TrimSpace(sessionLease.ConnID()) + if storeDisabled { + pinSessionConn(sessionConnID) + } else { + unpinSessionConn(sessionConnID) + } + } + shouldPreflightPing := turn > 1 && sessionLease != nil && turnRetry == 0 + if shouldPreflightPing && openAIWSIngressPreflightPingIdle > 0 && !lastTurnFinishedAt.IsZero() { + if time.Since(lastTurnFinishedAt) < openAIWSIngressPreflightPingIdle { + shouldPreflightPing = false + } + } + if shouldPreflightPing { + if pingErr := sessionLease.PingWithTimeout(openAIWSConnHealthCheckTO); pingErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_upstream_preflight_ping_fail account_id=%d turn=%d conn_id=%s cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(pingErr.Error(), openAIWSLogValueMaxLen), + ) + if forcePreferredConn { + if !turnPrevRecoveryTried && currentPreviousResponseID != "" { + updatedPayload, removed, dropErr := dropPreviousResponseIDFromRawPayload(currentPayload) + if dropErr != nil || !removed { + reason := "not_removed" + if dropErr != nil { + reason = "drop_error" + } + logOpenAIWSModeInfo( + "ingress_ws_preflight_ping_recovery_skip account_id=%d turn=%d conn_id=%s reason=%s previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(reason), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + ) + } else { + updatedWithInput, setInputErr := setOpenAIWSPayloadInputSequence( + updatedPayload, + currentTurnReplayInput, + currentTurnReplayInputExists, + ) + if setInputErr != nil { + logOpenAIWSModeInfo( + "ingress_ws_preflight_ping_recovery_skip account_id=%d turn=%d conn_id=%s reason=set_full_input_error previous_response_id=%s cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(setInputErr.Error(), openAIWSLogValueMaxLen), + ) + } else { + logOpenAIWSModeInfo( + "ingress_ws_preflight_ping_recovery account_id=%d turn=%d conn_id=%s action=drop_previous_response_id_retry previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + ) + turnPrevRecoveryTried = true + currentPayload = updatedWithInput + currentPayloadBytes = len(updatedWithInput) + resetSessionLease(true) + skipBeforeTurn = true + continue + } + } + } + resetSessionLease(true) + return NewOpenAIWSClientCloseError( + coderws.StatusPolicyViolation, + "upstream continuation connection is unavailable; please restart the conversation", + pingErr, + ) + } + resetSessionLease(true) + + acquiredLease, acquireErr := acquireTurnLease(turn, preferredConnID, forcePreferredConn) + if acquireErr != nil { + return fmt.Errorf("acquire upstream websocket after preflight ping fail: %w", acquireErr) + } + sessionLease = acquiredLease + sessionConnID = strings.TrimSpace(sessionLease.ConnID()) + if storeDisabled { + pinSessionConn(sessionConnID) + } + } + } + connID := sessionConnID + if currentPreviousResponseID != "" { + chainedFromLast := expectedPrev != "" && currentPreviousResponseID == expectedPrev + currentPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(currentPreviousResponseID) + logOpenAIWSModeInfo( + "ingress_ws_turn_chain account_id=%d turn=%d conn_id=%s previous_response_id=%s previous_response_id_kind=%s last_turn_response_id=%s chained_from_last=%v preferred_conn_id=%s header_session_id=%s header_conversation_id=%s has_turn_state=%v turn_state_len=%d has_prompt_cache_key=%v store_disabled=%v", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(currentPreviousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(currentPreviousResponseIDKind), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + chainedFromLast, + truncateOpenAIWSLogValue(preferredConnID, openAIWSIDValueMaxLen), + openAIWSHeaderValueForLog(baseAcquireReq.Headers, "session_id"), + openAIWSHeaderValueForLog(baseAcquireReq.Headers, "conversation_id"), + turnState != "", + len(turnState), + openAIWSPayloadStringFromRaw(currentPayload, "prompt_cache_key") != "", + storeDisabled, + ) + } + + result, relayErr := sendAndRelay(turn, sessionLease, currentPayload, currentPayloadBytes, currentOriginalModel) + if relayErr != nil { + if recoverIngressPrevResponseNotFound(relayErr, turn, connID) { + continue + } + if retryIngressTurn(relayErr, turn, connID) { + continue + } + finalErr := relayErr + if unwrapped := errors.Unwrap(relayErr); unwrapped != nil { + finalErr = unwrapped + } + if hooks != nil && hooks.AfterTurn != nil { + hooks.AfterTurn(turn, nil, finalErr) + } + sessionLease.MarkBroken() + return finalErr + } + turnRetry = 0 + turnPrevRecoveryTried = false + lastTurnFinishedAt = time.Now() + if hooks != nil && hooks.AfterTurn != nil { + hooks.AfterTurn(turn, result, nil) + } + if result == nil { + return errors.New("websocket turn result is nil") + } + responseID := strings.TrimSpace(result.RequestID) + lastTurnResponseID = responseID + lastTurnPayload = cloneOpenAIWSPayloadBytes(currentPayload) + lastTurnReplayInput = cloneOpenAIWSRawMessages(currentTurnReplayInput) + lastTurnReplayInputExists = currentTurnReplayInputExists + nextStrictState, strictStateErr := buildOpenAIWSIngressPreviousTurnStrictState(currentPayload) + if strictStateErr != nil { + lastTurnStrictState = nil + logOpenAIWSModeInfo( + "ingress_ws_prev_response_strict_state_skip account_id=%d turn=%d conn_id=%s reason=build_error cause=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(strictStateErr.Error(), openAIWSLogValueMaxLen), + ) + } else { + lastTurnStrictState = nextStrictState + } + + if responseID != "" && stateStore != nil { + ttl := s.openAIWSResponseStickyTTL() + logOpenAIWSBindResponseAccountWarn(groupID, account.ID, responseID, stateStore.BindResponseAccount(ctx, groupID, responseID, account.ID, ttl)) + stateStore.BindResponseConn(responseID, connID, ttl) + } + if stateStore != nil && storeDisabled && sessionHash != "" { + stateStore.BindSessionConn(groupID, sessionHash, connID, s.openAIWSSessionStickyTTL()) + } + if connID != "" { + preferredConnID = connID + } + + nextClientMessage, readErr := readClientMessage() + if readErr != nil { + if isOpenAIWSClientDisconnectError(readErr) { + closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr) + logOpenAIWSModeInfo( + "ingress_ws_client_closed account_id=%d conn_id=%s close_status=%s close_reason=%s", + account.ID, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + closeStatus, + truncateOpenAIWSLogValue(closeReason, openAIWSHeaderValueMaxLen), + ) + return nil + } + return fmt.Errorf("read client websocket request: %w", readErr) + } + + nextPayload, parseErr := parseClientPayload(nextClientMessage) + if parseErr != nil { + return parseErr + } + if nextPayload.promptCacheKey != "" { + // ingress 会话在整个客户端 WS 生命周期内复用同一上游连接; + // prompt_cache_key 对握手头的更新仅在未来需要重新建连时生效。 + updatedHeaders, _ := s.buildOpenAIWSHeaders(c, account, token, wsDecision, isCodexCLI, turnState, strings.TrimSpace(c.GetHeader(openAIWSTurnMetadataHeader)), nextPayload.promptCacheKey) + baseAcquireReq.Headers = updatedHeaders + } + if nextPayload.previousResponseID != "" { + expectedPrev := strings.TrimSpace(lastTurnResponseID) + chainedFromLast := expectedPrev != "" && nextPayload.previousResponseID == expectedPrev + nextPreviousResponseIDKind := ClassifyOpenAIPreviousResponseIDKind(nextPayload.previousResponseID) + logOpenAIWSModeInfo( + "ingress_ws_next_turn_chain account_id=%d turn=%d next_turn=%d conn_id=%s previous_response_id=%s previous_response_id_kind=%s last_turn_response_id=%s chained_from_last=%v has_prompt_cache_key=%v store_disabled=%v", + account.ID, + turn, + turn+1, + truncateOpenAIWSLogValue(connID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(nextPayload.previousResponseID, openAIWSIDValueMaxLen), + normalizeOpenAIWSLogValue(nextPreviousResponseIDKind), + truncateOpenAIWSLogValue(expectedPrev, openAIWSIDValueMaxLen), + chainedFromLast, + nextPayload.promptCacheKey != "", + storeDisabled, + ) + } + if stateStore != nil && nextPayload.previousResponseID != "" { + if stickyConnID, ok := stateStore.GetResponseConn(nextPayload.previousResponseID); ok { + if sessionConnID != "" && stickyConnID != "" && stickyConnID != sessionConnID { + logOpenAIWSModeInfo( + "ingress_ws_keep_session_conn account_id=%d turn=%d conn_id=%s sticky_conn_id=%s previous_response_id=%s", + account.ID, + turn, + truncateOpenAIWSLogValue(sessionConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(stickyConnID, openAIWSIDValueMaxLen), + truncateOpenAIWSLogValue(nextPayload.previousResponseID, openAIWSIDValueMaxLen), + ) + } else { + preferredConnID = stickyConnID + } + } + } + currentPayload = nextPayload.payloadRaw + currentOriginalModel = nextPayload.originalModel + currentPayloadBytes = nextPayload.payloadBytes + storeDisabled = s.isOpenAIWSStoreDisabledInRequestRaw(currentPayload, account) + if !storeDisabled { + unpinSessionConn(sessionConnID) + } + turn++ + } +} + +func (s *OpenAIGatewayService) isOpenAIWSGeneratePrewarmEnabled() bool { + return s != nil && s.cfg != nil && s.cfg.Gateway.OpenAIWS.PrewarmGenerateEnabled +} + +// performOpenAIWSGeneratePrewarm 在 WSv2 下执行可选的 generate=false 预热。 +// 预热默认关闭,仅在配置开启后生效;失败时按可恢复错误回退到 HTTP。 +func (s *OpenAIGatewayService) performOpenAIWSGeneratePrewarm( + ctx context.Context, + lease *openAIWSConnLease, + decision OpenAIWSProtocolDecision, + payload map[string]any, + previousResponseID string, + reqBody map[string]any, + account *Account, + stateStore OpenAIWSStateStore, + groupID int64, +) error { + if s == nil { + return nil + } + if lease == nil || account == nil { + logOpenAIWSModeInfo("prewarm_skip reason=invalid_state has_lease=%v has_account=%v", lease != nil, account != nil) + return nil + } + connID := strings.TrimSpace(lease.ConnID()) + if !s.isOpenAIWSGeneratePrewarmEnabled() { + return nil + } + if decision.Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { + logOpenAIWSModeInfo( + "prewarm_skip account_id=%d conn_id=%s reason=transport_not_v2 transport=%s", + account.ID, + connID, + normalizeOpenAIWSLogValue(string(decision.Transport)), + ) + return nil + } + if strings.TrimSpace(previousResponseID) != "" { + logOpenAIWSModeInfo( + "prewarm_skip account_id=%d conn_id=%s reason=has_previous_response_id previous_response_id=%s", + account.ID, + connID, + truncateOpenAIWSLogValue(previousResponseID, openAIWSIDValueMaxLen), + ) + return nil + } + if lease.IsPrewarmed() { + logOpenAIWSModeInfo("prewarm_skip account_id=%d conn_id=%s reason=already_prewarmed", account.ID, connID) + return nil + } + if NeedsToolContinuation(reqBody) { + logOpenAIWSModeInfo("prewarm_skip account_id=%d conn_id=%s reason=tool_continuation", account.ID, connID) + return nil + } + prewarmStart := time.Now() + logOpenAIWSModeInfo("prewarm_start account_id=%d conn_id=%s", account.ID, connID) + + prewarmPayload := make(map[string]any, len(payload)+1) + for k, v := range payload { + prewarmPayload[k] = v + } + prewarmPayload["generate"] = false + prewarmPayloadJSON := payloadAsJSONBytes(prewarmPayload) + + if err := lease.WriteJSONWithContextTimeout(ctx, prewarmPayload, s.openAIWSWriteTimeout()); err != nil { + lease.MarkBroken() + logOpenAIWSModeInfo( + "prewarm_write_fail account_id=%d conn_id=%s cause=%s", + account.ID, + connID, + truncateOpenAIWSLogValue(err.Error(), openAIWSLogValueMaxLen), + ) + return wrapOpenAIWSFallback("prewarm_write", err) + } + logOpenAIWSModeInfo("prewarm_write_sent account_id=%d conn_id=%s payload_bytes=%d", account.ID, connID, len(prewarmPayloadJSON)) + + prewarmResponseID := "" + prewarmEventCount := 0 + prewarmTerminalCount := 0 + for { + message, readErr := lease.ReadMessageWithContextTimeout(ctx, s.openAIWSReadTimeout()) + if readErr != nil { + lease.MarkBroken() + closeStatus, closeReason := summarizeOpenAIWSReadCloseError(readErr) + logOpenAIWSModeInfo( + "prewarm_read_fail account_id=%d conn_id=%s close_status=%s close_reason=%s cause=%s events=%d", + account.ID, + connID, + closeStatus, + closeReason, + truncateOpenAIWSLogValue(readErr.Error(), openAIWSLogValueMaxLen), + prewarmEventCount, + ) + return wrapOpenAIWSFallback("prewarm_"+classifyOpenAIWSReadFallbackReason(readErr), readErr) + } + + eventType, eventResponseID, _ := parseOpenAIWSEventEnvelope(message) + if eventType == "" { + continue + } + prewarmEventCount++ + if prewarmResponseID == "" && eventResponseID != "" { + prewarmResponseID = eventResponseID + } + if prewarmEventCount <= openAIWSPrewarmEventLogHead || eventType == "error" || isOpenAIWSTerminalEvent(eventType) { + logOpenAIWSModeInfo( + "prewarm_event account_id=%d conn_id=%s idx=%d type=%s bytes=%d", + account.ID, + connID, + prewarmEventCount, + truncateOpenAIWSLogValue(eventType, openAIWSLogValueMaxLen), + len(message), + ) + } + + if eventType == "error" { + errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message) + errMsg := strings.TrimSpace(errMsgRaw) + if errMsg == "" { + errMsg = "OpenAI websocket prewarm error" + } + fallbackReason, canFallback := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) + errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw) + logOpenAIWSModeInfo( + "prewarm_error_event account_id=%d conn_id=%s idx=%d fallback_reason=%s can_fallback=%v err_code=%s err_type=%s err_message=%s", + account.ID, + connID, + prewarmEventCount, + truncateOpenAIWSLogValue(fallbackReason, openAIWSLogValueMaxLen), + canFallback, + errCode, + errType, + errMessage, + ) + lease.MarkBroken() + if canFallback { + return wrapOpenAIWSFallback("prewarm_"+fallbackReason, errors.New(errMsg)) + } + return wrapOpenAIWSFallback("prewarm_error_event", errors.New(errMsg)) + } + + if isOpenAIWSTerminalEvent(eventType) { + prewarmTerminalCount++ + break + } + } + + lease.MarkPrewarmed() + if prewarmResponseID != "" && stateStore != nil { + ttl := s.openAIWSResponseStickyTTL() + logOpenAIWSBindResponseAccountWarn(groupID, account.ID, prewarmResponseID, stateStore.BindResponseAccount(ctx, groupID, prewarmResponseID, account.ID, ttl)) + stateStore.BindResponseConn(prewarmResponseID, lease.ConnID(), ttl) + } + logOpenAIWSModeInfo( + "prewarm_done account_id=%d conn_id=%s response_id=%s events=%d terminal_events=%d duration_ms=%d", + account.ID, + connID, + truncateOpenAIWSLogValue(prewarmResponseID, openAIWSIDValueMaxLen), + prewarmEventCount, + prewarmTerminalCount, + time.Since(prewarmStart).Milliseconds(), + ) + return nil +} + +func payloadAsJSON(payload map[string]any) string { + return string(payloadAsJSONBytes(payload)) +} + +func payloadAsJSONBytes(payload map[string]any) []byte { + if len(payload) == 0 { + return []byte("{}") + } + body, err := json.Marshal(payload) + if err != nil { + return []byte("{}") + } + return body +} + +func isOpenAIWSTerminalEvent(eventType string) bool { + switch strings.TrimSpace(eventType) { + case "response.completed", "response.done", "response.failed", "response.incomplete", "response.cancelled", "response.canceled": + return true + default: + return false + } +} + +func isOpenAIWSTokenEvent(eventType string) bool { + eventType = strings.TrimSpace(eventType) + if eventType == "" { + return false + } + switch eventType { + case "response.created", "response.in_progress", "response.output_item.added", "response.output_item.done": + return false + } + if strings.Contains(eventType, ".delta") { + return true + } + if strings.HasPrefix(eventType, "response.output_text") { + return true + } + if strings.HasPrefix(eventType, "response.output") { + return true + } + return eventType == "response.completed" || eventType == "response.done" +} + +func replaceOpenAIWSMessageModel(message []byte, fromModel, toModel string) []byte { + if len(message) == 0 { + return message + } + if strings.TrimSpace(fromModel) == "" || strings.TrimSpace(toModel) == "" || fromModel == toModel { + return message + } + if !bytes.Contains(message, []byte(`"model"`)) || !bytes.Contains(message, []byte(fromModel)) { + return message + } + modelValues := gjson.GetManyBytes(message, "model", "response.model") + replaceModel := modelValues[0].Exists() && modelValues[0].Str == fromModel + replaceResponseModel := modelValues[1].Exists() && modelValues[1].Str == fromModel + if !replaceModel && !replaceResponseModel { + return message + } + updated := message + if replaceModel { + if next, err := sjson.SetBytes(updated, "model", toModel); err == nil { + updated = next + } + } + if replaceResponseModel { + if next, err := sjson.SetBytes(updated, "response.model", toModel); err == nil { + updated = next + } + } + return updated +} + +func populateOpenAIUsageFromResponseJSON(body []byte, usage *OpenAIUsage) { + if usage == nil || len(body) == 0 { + return + } + values := gjson.GetManyBytes( + body, + "usage.input_tokens", + "usage.output_tokens", + "usage.input_tokens_details.cached_tokens", + ) + usage.InputTokens = int(values[0].Int()) + usage.OutputTokens = int(values[1].Int()) + usage.CacheReadInputTokens = int(values[2].Int()) +} + +func getOpenAIGroupIDFromContext(c *gin.Context) int64 { + if c == nil { + return 0 + } + value, exists := c.Get("api_key") + if !exists { + return 0 + } + apiKey, ok := value.(*APIKey) + if !ok || apiKey == nil || apiKey.GroupID == nil { + return 0 + } + return *apiKey.GroupID +} + +// SelectAccountByPreviousResponseID 按 previous_response_id 命中账号粘连。 +// 未命中或账号不可用时返回 (nil, nil),由调用方继续走常规调度。 +func (s *OpenAIGatewayService) SelectAccountByPreviousResponseID( + ctx context.Context, + groupID *int64, + previousResponseID string, + requestedModel string, + excludedIDs map[int64]struct{}, +) (*AccountSelectionResult, error) { + if s == nil { + return nil, nil + } + responseID := strings.TrimSpace(previousResponseID) + if responseID == "" { + return nil, nil + } + store := s.getOpenAIWSStateStore() + if store == nil { + return nil, nil + } + + accountID, err := store.GetResponseAccount(ctx, derefGroupID(groupID), responseID) + if err != nil || accountID <= 0 { + return nil, nil + } + if excludedIDs != nil { + if _, excluded := excludedIDs[accountID]; excluded { + return nil, nil + } + } + + account, err := s.getSchedulableAccount(ctx, accountID) + if err != nil || account == nil { + _ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID) + return nil, nil + } + // 非 WSv2 场景(如 force_http/全局关闭)不应使用 previous_response_id 粘连, + // 以保持“回滚到 HTTP”后的历史行为一致性。 + if s.getOpenAIWSProtocolResolver().Resolve(account).Transport != OpenAIUpstreamTransportResponsesWebsocketV2 { + return nil, nil + } + if shouldClearStickySession(account, requestedModel) || !account.IsOpenAI() { + _ = store.DeleteResponseAccount(ctx, derefGroupID(groupID), responseID) + return nil, nil + } + if requestedModel != "" && !account.IsModelSupported(requestedModel) { + return nil, nil + } + + result, acquireErr := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) + if acquireErr == nil && result.Acquired { + logOpenAIWSBindResponseAccountWarn( + derefGroupID(groupID), + accountID, + responseID, + store.BindResponseAccount(ctx, derefGroupID(groupID), responseID, accountID, s.openAIWSResponseStickyTTL()), + ) + return &AccountSelectionResult{ + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, + }, nil + } + + cfg := s.schedulingConfig() + if s.concurrencyService != nil { + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: accountID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } + return nil, nil +} + +func classifyOpenAIWSAcquireError(err error) string { + if err == nil { + return "acquire_conn" + } + var dialErr *openAIWSDialError + if errors.As(err, &dialErr) { + switch dialErr.StatusCode { + case 426: + return "upgrade_required" + case 401, 403: + return "auth_failed" + case 429: + return "upstream_rate_limited" + } + if dialErr.StatusCode >= 500 { + return "upstream_5xx" + } + return "dial_failed" + } + if errors.Is(err, errOpenAIWSConnQueueFull) { + return "conn_queue_full" + } + if errors.Is(err, errOpenAIWSPreferredConnUnavailable) { + return "preferred_conn_unavailable" + } + if errors.Is(err, context.DeadlineExceeded) { + return "acquire_timeout" + } + return "acquire_conn" +} + +func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) { + code := strings.ToLower(strings.TrimSpace(codeRaw)) + errType := strings.ToLower(strings.TrimSpace(errTypeRaw)) + msg := strings.ToLower(strings.TrimSpace(msgRaw)) + + switch code { + case "upgrade_required": + return "upgrade_required", true + case "websocket_not_supported", "websocket_unsupported": + return "ws_unsupported", true + case "websocket_connection_limit_reached": + return "ws_connection_limit_reached", true + case "previous_response_not_found": + return "previous_response_not_found", true + } + if strings.Contains(msg, "upgrade required") || strings.Contains(msg, "status 426") { + return "upgrade_required", true + } + if strings.Contains(errType, "upgrade") { + return "upgrade_required", true + } + if strings.Contains(msg, "websocket") && strings.Contains(msg, "unsupported") { + return "ws_unsupported", true + } + if strings.Contains(msg, "connection limit") && strings.Contains(msg, "websocket") { + return "ws_connection_limit_reached", true + } + if strings.Contains(msg, "previous_response_not_found") || + (strings.Contains(msg, "previous response") && strings.Contains(msg, "not found")) { + return "previous_response_not_found", true + } + if strings.Contains(errType, "server_error") || strings.Contains(code, "server_error") { + return "upstream_error_event", true + } + return "event_error", false +} + +func classifyOpenAIWSErrorEvent(message []byte) (string, bool) { + if len(message) == 0 { + return "event_error", false + } + return classifyOpenAIWSErrorEventFromRaw(parseOpenAIWSErrorEventFields(message)) +} + +func openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw string) int { + code := strings.ToLower(strings.TrimSpace(codeRaw)) + errType := strings.ToLower(strings.TrimSpace(errTypeRaw)) + switch { + case strings.Contains(errType, "invalid_request"), + strings.Contains(code, "invalid_request"), + strings.Contains(code, "bad_request"), + code == "previous_response_not_found": + return http.StatusBadRequest + case strings.Contains(errType, "authentication"), + strings.Contains(code, "invalid_api_key"), + strings.Contains(code, "unauthorized"): + return http.StatusUnauthorized + case strings.Contains(errType, "permission"), + strings.Contains(code, "forbidden"): + return http.StatusForbidden + case strings.Contains(errType, "rate_limit"), + strings.Contains(code, "rate_limit"), + strings.Contains(code, "insufficient_quota"): + return http.StatusTooManyRequests + default: + return http.StatusBadGateway + } +} + +func openAIWSErrorHTTPStatus(message []byte) int { + if len(message) == 0 { + return http.StatusBadGateway + } + codeRaw, errTypeRaw, _ := parseOpenAIWSErrorEventFields(message) + return openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw) +} + +func (s *OpenAIGatewayService) openAIWSFallbackCooldown() time.Duration { + if s == nil || s.cfg == nil { + return 30 * time.Second + } + seconds := s.cfg.Gateway.OpenAIWS.FallbackCooldownSeconds + if seconds <= 0 { + return 0 + } + return time.Duration(seconds) * time.Second +} + +func (s *OpenAIGatewayService) isOpenAIWSFallbackCooling(accountID int64) bool { + if s == nil || accountID <= 0 { + return false + } + cooldown := s.openAIWSFallbackCooldown() + if cooldown <= 0 { + return false + } + rawUntil, ok := s.openaiWSFallbackUntil.Load(accountID) + if !ok || rawUntil == nil { + return false + } + until, ok := rawUntil.(time.Time) + if !ok || until.IsZero() { + s.openaiWSFallbackUntil.Delete(accountID) + return false + } + if time.Now().Before(until) { + return true + } + s.openaiWSFallbackUntil.Delete(accountID) + return false +} + +func (s *OpenAIGatewayService) markOpenAIWSFallbackCooling(accountID int64, _ string) { + if s == nil || accountID <= 0 { + return + } + cooldown := s.openAIWSFallbackCooldown() + if cooldown <= 0 { + return + } + s.openaiWSFallbackUntil.Store(accountID, time.Now().Add(cooldown)) +} + +func (s *OpenAIGatewayService) clearOpenAIWSFallbackCooling(accountID int64) { + if s == nil || accountID <= 0 { + return + } + s.openaiWSFallbackUntil.Delete(accountID) +} diff --git a/backend/internal/service/openai_ws_forwarder_benchmark_test.go b/backend/internal/service/openai_ws_forwarder_benchmark_test.go new file mode 100644 index 00000000..bd03ab5a --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder_benchmark_test.go @@ -0,0 +1,127 @@ +package service + +import ( + "fmt" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +var ( + benchmarkOpenAIWSPayloadJSONSink string + benchmarkOpenAIWSStringSink string + benchmarkOpenAIWSBoolSink bool + benchmarkOpenAIWSBytesSink []byte +) + +func BenchmarkOpenAIWSForwarderHotPath(b *testing.B) { + cfg := &config.Config{} + svc := &OpenAIGatewayService{cfg: cfg} + account := &Account{ID: 1, Platform: PlatformOpenAI, Type: AccountTypeOAuth} + reqBody := benchmarkOpenAIWSHotPathRequest() + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + payload := svc.buildOpenAIWSCreatePayload(reqBody, account) + _, _ = applyOpenAIWSRetryPayloadStrategy(payload, 2) + setOpenAIWSTurnMetadata(payload, `{"trace":"bench","turn":"1"}`) + + benchmarkOpenAIWSStringSink = openAIWSPayloadString(payload, "previous_response_id") + benchmarkOpenAIWSBoolSink = payload["tools"] != nil + benchmarkOpenAIWSStringSink = summarizeOpenAIWSPayloadKeySizes(payload, openAIWSPayloadKeySizeTopN) + benchmarkOpenAIWSStringSink = summarizeOpenAIWSInput(payload["input"]) + benchmarkOpenAIWSPayloadJSONSink = payloadAsJSON(payload) + } +} + +func benchmarkOpenAIWSHotPathRequest() map[string]any { + tools := make([]map[string]any, 0, 24) + for i := 0; i < 24; i++ { + tools = append(tools, map[string]any{ + "type": "function", + "name": fmt.Sprintf("tool_%02d", i), + "description": "benchmark tool schema", + "parameters": map[string]any{ + "type": "object", + "properties": map[string]any{ + "query": map[string]any{"type": "string"}, + "limit": map[string]any{"type": "number"}, + }, + "required": []string{"query"}, + }, + }) + } + + input := make([]map[string]any, 0, 16) + for i := 0; i < 16; i++ { + input = append(input, map[string]any{ + "role": "user", + "type": "message", + "content": fmt.Sprintf("benchmark message %d", i), + }) + } + + return map[string]any{ + "type": "response.create", + "model": "gpt-5.3-codex", + "input": input, + "tools": tools, + "parallel_tool_calls": true, + "previous_response_id": "resp_benchmark_prev", + "prompt_cache_key": "bench-cache-key", + "reasoning": map[string]any{"effort": "medium"}, + "instructions": "benchmark instructions", + "store": false, + } +} + +func BenchmarkOpenAIWSEventEnvelopeParse(b *testing.B) { + event := []byte(`{"type":"response.completed","response":{"id":"resp_bench_1","model":"gpt-5.1","usage":{"input_tokens":12,"output_tokens":8}}}`) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + eventType, responseID, response := parseOpenAIWSEventEnvelope(event) + benchmarkOpenAIWSStringSink = eventType + benchmarkOpenAIWSStringSink = responseID + benchmarkOpenAIWSBoolSink = response.Exists() + } +} + +func BenchmarkOpenAIWSErrorEventFieldReuse(b *testing.B) { + event := []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + codeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(event) + benchmarkOpenAIWSStringSink, benchmarkOpenAIWSBoolSink = classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, errMsgRaw) + code, errType, errMsg := summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMsgRaw) + benchmarkOpenAIWSStringSink = code + benchmarkOpenAIWSStringSink = errType + benchmarkOpenAIWSStringSink = errMsg + benchmarkOpenAIWSBoolSink = openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw) > 0 + } +} + +func BenchmarkReplaceOpenAIWSMessageModel_NoMatchFastPath(b *testing.B) { + event := []byte(`{"type":"response.output_text.delta","delta":"hello world"}`) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchmarkOpenAIWSBytesSink = replaceOpenAIWSMessageModel(event, "gpt-5.1", "custom-model") + } +} + +func BenchmarkReplaceOpenAIWSMessageModel_DualReplace(b *testing.B) { + event := []byte(`{"type":"response.completed","model":"gpt-5.1","response":{"id":"resp_1","model":"gpt-5.1"}}`) + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + benchmarkOpenAIWSBytesSink = replaceOpenAIWSMessageModel(event, "gpt-5.1", "custom-model") + } +} diff --git a/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go b/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go new file mode 100644 index 00000000..76167603 --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder_hotpath_optimization_test.go @@ -0,0 +1,73 @@ +package service + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseOpenAIWSEventEnvelope(t *testing.T) { + eventType, responseID, response := parseOpenAIWSEventEnvelope([]byte(`{"type":"response.completed","response":{"id":"resp_1","model":"gpt-5.1"}}`)) + require.Equal(t, "response.completed", eventType) + require.Equal(t, "resp_1", responseID) + require.True(t, response.Exists()) + require.Equal(t, `{"id":"resp_1","model":"gpt-5.1"}`, response.Raw) + + eventType, responseID, response = parseOpenAIWSEventEnvelope([]byte(`{"type":"response.delta","id":"evt_1"}`)) + require.Equal(t, "response.delta", eventType) + require.Equal(t, "evt_1", responseID) + require.False(t, response.Exists()) +} + +func TestParseOpenAIWSResponseUsageFromCompletedEvent(t *testing.T) { + usage := &OpenAIUsage{} + parseOpenAIWSResponseUsageFromCompletedEvent( + []byte(`{"type":"response.completed","response":{"usage":{"input_tokens":11,"output_tokens":7,"input_tokens_details":{"cached_tokens":3}}}}`), + usage, + ) + require.Equal(t, 11, usage.InputTokens) + require.Equal(t, 7, usage.OutputTokens) + require.Equal(t, 3, usage.CacheReadInputTokens) +} + +func TestOpenAIWSErrorEventHelpers_ConsistentWithWrapper(t *testing.T) { + message := []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`) + codeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message) + + wrappedReason, wrappedRecoverable := classifyOpenAIWSErrorEvent(message) + rawReason, rawRecoverable := classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, errMsgRaw) + require.Equal(t, wrappedReason, rawReason) + require.Equal(t, wrappedRecoverable, rawRecoverable) + + wrappedStatus := openAIWSErrorHTTPStatus(message) + rawStatus := openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw) + require.Equal(t, wrappedStatus, rawStatus) + require.Equal(t, http.StatusBadRequest, rawStatus) + + wrappedCode, wrappedType, wrappedMsg := summarizeOpenAIWSErrorEventFields(message) + rawCode, rawType, rawMsg := summarizeOpenAIWSErrorEventFieldsFromRaw(codeRaw, errTypeRaw, errMsgRaw) + require.Equal(t, wrappedCode, rawCode) + require.Equal(t, wrappedType, rawType) + require.Equal(t, wrappedMsg, rawMsg) +} + +func TestOpenAIWSMessageLikelyContainsToolCalls(t *testing.T) { + require.False(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_text.delta","delta":"hello"}`))) + require.True(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_item.added","item":{"tool_calls":[{"id":"tc1"}]}}`))) + require.True(t, openAIWSMessageLikelyContainsToolCalls([]byte(`{"type":"response.output_item.added","item":{"type":"function_call"}}`))) +} + +func TestReplaceOpenAIWSMessageModel_OptimizedStillCorrect(t *testing.T) { + noModel := []byte(`{"type":"response.output_text.delta","delta":"hello"}`) + require.Equal(t, string(noModel), string(replaceOpenAIWSMessageModel(noModel, "gpt-5.1", "custom-model"))) + + rootOnly := []byte(`{"type":"response.created","model":"gpt-5.1"}`) + require.Equal(t, `{"type":"response.created","model":"custom-model"}`, string(replaceOpenAIWSMessageModel(rootOnly, "gpt-5.1", "custom-model"))) + + responseOnly := []byte(`{"type":"response.completed","response":{"model":"gpt-5.1"}}`) + require.Equal(t, `{"type":"response.completed","response":{"model":"custom-model"}}`, string(replaceOpenAIWSMessageModel(responseOnly, "gpt-5.1", "custom-model"))) + + both := []byte(`{"model":"gpt-5.1","response":{"model":"gpt-5.1"}}`) + require.Equal(t, `{"model":"custom-model","response":{"model":"custom-model"}}`, string(replaceOpenAIWSMessageModel(both, "gpt-5.1", "custom-model"))) +} diff --git a/backend/internal/service/openai_ws_forwarder_ingress_session_test.go b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go new file mode 100644 index 00000000..5a3c12c3 --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder_ingress_session_test.go @@ -0,0 +1,2483 @@ +package service + +import ( + "context" + "errors" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + coderws "github.com/coder/websocket" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_KeepLeaseAcrossTurns(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_ingress_turn_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_ingress_turn_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 114, + Name: "openai-ingress-session-lease", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + turnWSModeCh := make(chan bool, 2) + hooks := &OpenAIWSIngressHooks{ + AfterTurn: func(_ int, result *OpenAIForwardResult, turnErr error) { + if turnErr == nil && result != nil { + turnWSModeCh <- result.OpenAIWSMode + } + }, + } + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, hooks) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false}`) + firstTurnEvent := readMessage() + require.Equal(t, "response.completed", gjson.GetBytes(firstTurnEvent, "type").String()) + require.Equal(t, "resp_ingress_turn_1", gjson.GetBytes(firstTurnEvent, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_ingress_turn_1"}`) + secondTurnEvent := readMessage() + require.Equal(t, "response.completed", gjson.GetBytes(secondTurnEvent, "type").String()) + require.Equal(t, "resp_ingress_turn_2", gjson.GetBytes(secondTurnEvent, "response.id").String()) + require.True(t, <-turnWSModeCh, "首轮 turn 应标记为 WS 模式") + require.True(t, <-turnWSModeCh, "第二轮 turn 应标记为 WS 模式") + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + metrics := svc.SnapshotOpenAIWSPoolMetrics() + require.Equal(t, int64(1), metrics.AcquireTotal, "同一 ingress 会话多 turn 应只获取一次上游 lease") + require.Equal(t, 1, captureDialer.DialCount(), "同一 ingress 会话应保持同一上游连接") + require.Len(t, captureConn.writes, 2, "应向同一上游连接发送两轮 response.create") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_DedicatedModeDoesNotReuseConnAcrossSessions(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + upstreamConn1 := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_dedicated_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + upstreamConn2 := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_dedicated_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{upstreamConn1, upstreamConn2}, + } + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(dialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 441, + Name: "openai-ingress-dedicated", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated, + }, + } + + serverErrCh := make(chan error, 2) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + runSingleTurnSession := func(expectedResponseID string) { + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`)) + cancelWrite() + require.NoError(t, err) + + readCtx, cancelRead := context.WithTimeout(context.Background(), 3*time.Second) + msgType, event, readErr := clientConn.Read(readCtx) + cancelRead() + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + require.Equal(t, expectedResponseID, gjson.GetBytes(event, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + } + + runSingleTurnSession("resp_dedicated_1") + runSingleTurnSession("resp_dedicated_2") + + require.Equal(t, 2, dialer.DialCount(), "dedicated 模式下跨客户端会话不应复用上游连接") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ModeOffReturnsPolicyViolation(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: newOpenAIWSConnPool(cfg), + } + + account := &Account{ + ID: 442, + Name: "openai-ingress-off", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_mode": OpenAIWSIngressModeOff, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false}`)) + cancelWrite() + require.NoError(t, err) + + select { + case serverErr := <-serverErrCh: + var closeErr *OpenAIWSClientCloseError + require.ErrorAs(t, serverErr, &closeErr) + require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode()) + require.Equal(t, "websocket mode is disabled for this account", closeErr.Reason()) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledPrevResponseStrictDropToFullCreate(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_preflight_rewrite_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_preflight_rewrite_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 140, + Name: "openai-ingress-prev-preflight-rewrite", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) + firstTurn := readMessage() + require.Equal(t, "resp_preflight_rewrite_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_stale_external","input":[{"type":"input_text","text":"world"}]}`) + secondTurn := readMessage() + require.Equal(t, "resp_preflight_rewrite_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + require.Equal(t, 1, captureDialer.DialCount(), "严格增量不成立时应在同一连接内降级为 full create") + require.Len(t, captureConn.writes, 2) + secondWrite := requestToJSONString(captureConn.writes[1]) + require.False(t, gjson.Get(secondWrite, "previous_response_id").Exists(), "严格增量不成立时应移除 previous_response_id,改为 full create") + require.Equal(t, 2, len(gjson.Get(secondWrite, "input").Array()), "严格降级为 full create 时应重放完整 input 上下文") + require.Equal(t, "hello", gjson.Get(secondWrite, "input.0.text").String()) + require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String()) +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledPrevResponseStrictDropBeforePreflightPingFailReconnects(t *testing.T) { + gin.SetMode(gin.TestMode) + prevPreflightPingIdle := openAIWSIngressPreflightPingIdle + openAIWSIngressPreflightPingIdle = 0 + defer func() { + openAIWSIngressPreflightPingIdle = prevPreflightPingIdle + }() + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + firstConn := &openAIWSPreflightFailConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_drop_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + secondConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_drop_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{firstConn, secondConn}, + } + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(dialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 142, + Name: "openai-ingress-prev-strict-drop-before-ping", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) + firstTurn := readMessage() + require.Equal(t, "resp_turn_ping_drop_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_stale_external","input":[{"type":"input_text","text":"world"}]}`) + secondTurn := readMessage() + require.Equal(t, "resp_turn_ping_drop_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 严格降级后预检换连超时") + } + + require.Equal(t, 2, dialer.DialCount(), "严格降级为 full create 后,预检 ping 失败应允许换连") + require.Equal(t, 1, firstConn.WriteCount(), "首连接在预检失败后不应继续发送第二轮") + require.GreaterOrEqual(t, firstConn.PingCount(), 1, "第二轮前应执行 preflight ping") + secondConn.mu.Lock() + secondWrites := append([]map[string]any(nil), secondConn.writes...) + secondConn.mu.Unlock() + require.Len(t, secondWrites, 1) + secondWrite := requestToJSONString(secondWrites[0]) + require.False(t, gjson.Get(secondWrite, "previous_response_id").Exists(), "严格降级后重试应移除 previous_response_id") + require.Equal(t, 2, len(gjson.Get(secondWrite, "input").Array())) + require.Equal(t, "hello", gjson.Get(secondWrite, "input.0.text").String()) + require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String()) +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreEnabledSkipsStrictPrevResponseEval(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_store_enabled_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_store_enabled_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 143, + Name: "openai-ingress-store-enabled-skip-strict", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":true}`) + firstTurn := readMessage() + require.Equal(t, "resp_store_enabled_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":true,"previous_response_id":"resp_stale_external"}`) + secondTurn := readMessage() + require.Equal(t, "resp_store_enabled_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 store=true 场景 websocket 结束超时") + } + + require.Equal(t, 1, captureDialer.DialCount()) + require.Len(t, captureConn.writes, 2) + require.Equal(t, "resp_stale_external", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String(), "store=true 场景不应触发 store-disabled strict 规则") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledPrevResponsePreflightSkipForFunctionCallOutput(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_preflight_skip_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_preflight_skip_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 141, + Name: "openai-ingress-prev-preflight-skip-fco", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false}`) + firstTurn := readMessage() + require.Equal(t, "resp_preflight_skip_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_stale_external","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`) + secondTurn := readMessage() + require.Equal(t, "resp_preflight_skip_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + require.Equal(t, 1, captureDialer.DialCount()) + require.Len(t, captureConn.writes, 2) + require.Equal(t, "resp_stale_external", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String(), "function_call_output 场景不应预改写 previous_response_id") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputAutoAttachPreviousResponseID(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 143, + Name: "openai-ingress-fco-auto-prev", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) + firstTurn := readMessage() + require.Equal(t, "resp_auto_prev_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call_output","call_id":"call_auto_1","output":"ok"}]}`) + secondTurn := readMessage() + require.Equal(t, "resp_auto_prev_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + require.Equal(t, 1, captureDialer.DialCount()) + require.Len(t, captureConn.writes, 2) + require.Equal(t, "resp_auto_prev_1", gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").String(), "function_call_output 缺失 previous_response_id 时应回填上一轮响应 ID") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledFunctionCallOutputSkipsAutoAttachWhenLastResponseIDMissing(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_auto_prev_skip_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 144, + Name: "openai-ingress-fco-auto-prev-skip", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) + firstTurn := readMessage() + require.Equal(t, "response.completed", gjson.GetBytes(firstTurn, "type").String()) + require.Empty(t, gjson.GetBytes(firstTurn, "response.id").String(), "首轮响应不返回 response.id,模拟无法推导续链锚点") + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"function_call_output","call_id":"call_auto_skip_1","output":"ok"}]}`) + secondTurn := readMessage() + require.Equal(t, "resp_auto_prev_skip_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + require.Equal(t, 1, captureDialer.DialCount()) + require.Len(t, captureConn.writes, 2) + require.False(t, gjson.Get(requestToJSONString(captureConn.writes[1]), "previous_response_id").Exists(), "上一轮缺失 response.id 时不应自动补齐 previous_response_id") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreflightPingFailReconnectsBeforeTurn(t *testing.T) { + gin.SetMode(gin.TestMode) + prevPreflightPingIdle := openAIWSIngressPreflightPingIdle + openAIWSIngressPreflightPingIdle = 0 + defer func() { + openAIWSIngressPreflightPingIdle = prevPreflightPingIdle + }() + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + firstConn := &openAIWSPreflightFailConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + secondConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{firstConn, secondConn}, + } + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(dialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 116, + Name: "openai-ingress-preflight-ping", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false}`) + firstTurn := readMessage() + require.Equal(t, "resp_turn_ping_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_turn_ping_1"}`) + secondTurn := readMessage() + require.Equal(t, "resp_turn_ping_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + require.Equal(t, 2, dialer.DialCount(), "第二轮 turn 前 ping 失败应触发换连") + require.Equal(t, 1, firstConn.WriteCount(), "preflight ping 失败后不应继续向旧连接发送第二轮 turn") + require.GreaterOrEqual(t, firstConn.PingCount(), 1, "第二轮前应对旧连接执行 preflight ping") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledStrictAffinityPreflightPingFailAutoRecoveryReconnects(t *testing.T) { + gin.SetMode(gin.TestMode) + prevPreflightPingIdle := openAIWSIngressPreflightPingIdle + openAIWSIngressPreflightPingIdle = 0 + defer func() { + openAIWSIngressPreflightPingIdle = prevPreflightPingIdle + }() + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + firstConn := &openAIWSPreflightFailConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_strict_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + secondConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_ping_strict_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{firstConn, secondConn}, + } + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(dialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 121, + Name: "openai-ingress-preflight-ping-strict-affinity", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) + firstTurn := readMessage() + require.Equal(t, "resp_turn_ping_strict_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"previous_response_id":"resp_turn_ping_strict_1","input":[{"type":"input_text","text":"world"}]}`) + secondTurn := readMessage() + require.Equal(t, "resp_turn_ping_strict_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 严格亲和自动恢复后结束超时") + } + + require.Equal(t, 2, dialer.DialCount(), "严格亲和 preflight ping 失败后应自动降级并换连重放") + require.Equal(t, 1, firstConn.WriteCount(), "preflight ping 失败后不应继续在旧连接写第二轮") + require.GreaterOrEqual(t, firstConn.PingCount(), 1, "第二轮前应执行 preflight ping") + secondConn.mu.Lock() + secondWrites := append([]map[string]any(nil), secondConn.writes...) + secondConn.mu.Unlock() + require.Len(t, secondWrites, 1) + secondWrite := requestToJSONString(secondWrites[0]) + require.False(t, gjson.Get(secondWrite, "previous_response_id").Exists(), "自动恢复重放应移除 previous_response_id") + require.Equal(t, 2, len(gjson.Get(secondWrite, "input").Array()), "自动恢复重放应使用完整 input 上下文") + require.Equal(t, "hello", gjson.Get(secondWrite, "input.0.text").String()) + require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String()) +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_WriteFailBeforeDownstreamRetriesOnce(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + firstConn := &openAIWSWriteFailAfterFirstTurnConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_write_retry_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + secondConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_write_retry_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{firstConn, secondConn}, + } + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(dialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 117, + Name: "openai-ingress-write-retry", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + var hooksMu sync.Mutex + beforeTurnCalls := make(map[int]int) + afterTurnCalls := make(map[int]int) + hooks := &OpenAIWSIngressHooks{ + BeforeTurn: func(turn int) error { + hooksMu.Lock() + beforeTurnCalls[turn]++ + hooksMu.Unlock() + return nil + }, + AfterTurn: func(turn int, _ *OpenAIForwardResult, _ error) { + hooksMu.Lock() + afterTurnCalls[turn]++ + hooksMu.Unlock() + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, hooks) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false}`) + firstTurn := readMessage() + require.Equal(t, "resp_turn_write_retry_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_turn_write_retry_1"}`) + secondTurn := readMessage() + require.Equal(t, "resp_turn_write_retry_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + require.Equal(t, 2, dialer.DialCount(), "第二轮 turn 上游写失败且未写下游时应自动重试并换连") + hooksMu.Lock() + beforeTurn1 := beforeTurnCalls[1] + beforeTurn2 := beforeTurnCalls[2] + afterTurn1 := afterTurnCalls[1] + afterTurn2 := afterTurnCalls[2] + hooksMu.Unlock() + require.Equal(t, 1, beforeTurn1, "首轮 turn BeforeTurn 应执行一次") + require.Equal(t, 1, beforeTurn2, "同一 turn 重试不应重复触发 BeforeTurn") + require.Equal(t, 1, afterTurn1, "首轮 turn AfterTurn 应执行一次") + require.Equal(t, 1, afterTurn2, "第二轮 turn AfterTurn 应执行一次") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreviousResponseNotFoundRecoversByDroppingPrevID(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + firstConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_recover_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"previous_response_not_found","message":""}}`), + }, + } + secondConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_recover_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{firstConn, secondConn}, + } + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(dialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 118, + Name: "openai-ingress-prev-recovery", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_seed_anchor"}`) + firstTurn := readMessage() + require.Equal(t, "resp_turn_prev_recover_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_turn_prev_recover_1"}`) + secondTurn := readMessage() + require.Equal(t, "response.completed", gjson.GetBytes(secondTurn, "type").String()) + require.Equal(t, "resp_turn_prev_recover_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + require.Equal(t, 2, dialer.DialCount(), "previous_response_not_found 恢复应触发换连重试") + + firstConn.mu.Lock() + firstWrites := append([]map[string]any(nil), firstConn.writes...) + firstConn.mu.Unlock() + require.Len(t, firstWrites, 2, "首个连接应处理首轮与失败的第二轮请求") + require.True(t, gjson.Get(requestToJSONString(firstWrites[1]), "previous_response_id").Exists(), "失败轮次首发请求应包含 previous_response_id") + + secondConn.mu.Lock() + secondWrites := append([]map[string]any(nil), secondConn.writes...) + secondConn.mu.Unlock() + require.Len(t, secondWrites, 1, "恢复重试应在第二个连接发送一次请求") + require.False(t, gjson.Get(requestToJSONString(secondWrites[0]), "previous_response_id").Exists(), "恢复重试应移除 previous_response_id") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_StoreDisabledStrictAffinityPreviousResponseNotFoundLayer2Recovery(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + firstConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_strict_recover_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"previous_response_not_found","message":"missing strict anchor"}}`), + }, + } + secondConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_strict_recover_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{firstConn, secondConn}, + } + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(dialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 122, + Name: "openai-ingress-prev-strict-layer2", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"prompt_cache_key":"pk_strict_layer2","input":[{"type":"input_text","text":"hello"}]}`) + firstTurn := readMessage() + require.Equal(t, "resp_turn_prev_strict_recover_1", gjson.GetBytes(firstTurn, "response.id").String()) + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"store":false,"prompt_cache_key":"pk_strict_layer2","previous_response_id":"resp_turn_prev_strict_recover_1","input":[{"type":"input_text","text":"world"}]}`) + secondTurn := readMessage() + require.Equal(t, "resp_turn_prev_strict_recover_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 严格亲和 Layer2 恢复结束超时") + } + + require.Equal(t, 2, dialer.DialCount(), "严格亲和链路命中 previous_response_not_found 应触发 Layer2 恢复重试") + + firstConn.mu.Lock() + firstWrites := append([]map[string]any(nil), firstConn.writes...) + firstConn.mu.Unlock() + require.Len(t, firstWrites, 2, "首连接应收到首轮请求和失败的续链请求") + require.True(t, gjson.Get(requestToJSONString(firstWrites[1]), "previous_response_id").Exists()) + + secondConn.mu.Lock() + secondWrites := append([]map[string]any(nil), secondConn.writes...) + secondConn.mu.Unlock() + require.Len(t, secondWrites, 1, "Layer2 恢复应仅重放一次") + secondWrite := requestToJSONString(secondWrites[0]) + require.False(t, gjson.Get(secondWrite, "previous_response_id").Exists(), "Layer2 恢复重放应移除 previous_response_id") + require.True(t, gjson.Get(secondWrite, "store").Exists(), "Layer2 恢复不应改变 store 标志") + require.False(t, gjson.Get(secondWrite, "store").Bool()) + require.Equal(t, 2, len(gjson.Get(secondWrite, "input").Array()), "Layer2 恢复应重放完整 input 上下文") + require.Equal(t, "hello", gjson.Get(secondWrite, "input.0.text").String()) + require.Equal(t, "world", gjson.Get(secondWrite, "input.1.text").String()) +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_PreviousResponseNotFoundRecoveryRemovesDuplicatePrevID(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + firstConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_once_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"error","error":{"type":"invalid_request_error","code":"previous_response_not_found","message":"first missing"}}`), + }, + } + secondConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_turn_prev_once_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + dialer := &openAIWSQueueDialer{ + conns: []openAIWSClientConn{firstConn, secondConn}, + } + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(dialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 120, + Name: "openai-ingress-prev-recovery-once", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeMessage := func(payload string) { + writeCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + require.NoError(t, clientConn.Write(writeCtx, coderws.MessageText, []byte(payload))) + } + readMessage := func() []byte { + readCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + msgType, message, readErr := clientConn.Read(readCtx) + require.NoError(t, readErr) + require.Equal(t, coderws.MessageText, msgType) + return message + } + + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false}`) + firstTurn := readMessage() + require.Equal(t, "resp_turn_prev_once_1", gjson.GetBytes(firstTurn, "response.id").String()) + + // duplicate previous_response_id: 恢复重试时应删除所有重复键,避免再次 previous_response_not_found。 + writeMessage(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"resp_turn_prev_once_1","input":[],"previous_response_id":"resp_turn_prev_duplicate"}`) + secondTurn := readMessage() + require.Equal(t, "resp_turn_prev_once_2", gjson.GetBytes(secondTurn, "response.id").String()) + + require.NoError(t, clientConn.Close(coderws.StatusNormalClosure, "done")) + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr) + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + require.Equal(t, 2, dialer.DialCount(), "previous_response_not_found 恢复应只重试一次") + + firstConn.mu.Lock() + firstWrites := append([]map[string]any(nil), firstConn.writes...) + firstConn.mu.Unlock() + require.Len(t, firstWrites, 2) + require.True(t, gjson.Get(requestToJSONString(firstWrites[1]), "previous_response_id").Exists()) + + secondConn.mu.Lock() + secondWrites := append([]map[string]any(nil), secondConn.writes...) + secondConn.mu.Unlock() + require.Len(t, secondWrites, 1) + require.False(t, gjson.Get(requestToJSONString(secondWrites[0]), "previous_response_id").Exists(), "重复键场景恢复重试后不应保留 previous_response_id") +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_RejectsMessageIDAsPreviousResponseID(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 119, + Name: "openai-ingress-prev-validation", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, nil) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + defer func() { + _ = clientConn.CloseNow() + }() + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"gpt-5.1","stream":false,"previous_response_id":"msg_123456"}`)) + cancelWrite() + require.NoError(t, err) + + select { + case serverErr := <-serverErrCh: + require.Error(t, serverErr) + var closeErr *OpenAIWSClientCloseError + require.ErrorAs(t, serverErr, &closeErr) + require.Equal(t, coderws.StatusPolicyViolation, closeErr.StatusCode()) + require.Contains(t, closeErr.Reason(), "previous_response_id must be a response.id") + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } +} + +type openAIWSQueueDialer struct { + mu sync.Mutex + conns []openAIWSClientConn + dialCount int +} + +func (d *openAIWSQueueDialer) Dial( + ctx context.Context, + wsURL string, + headers http.Header, + proxyURL string, +) (openAIWSClientConn, int, http.Header, error) { + _ = ctx + _ = wsURL + _ = headers + _ = proxyURL + d.mu.Lock() + defer d.mu.Unlock() + d.dialCount++ + if len(d.conns) == 0 { + return nil, 503, nil, errors.New("no test conn") + } + conn := d.conns[0] + if len(d.conns) > 1 { + d.conns = d.conns[1:] + } + return conn, 0, nil, nil +} + +func (d *openAIWSQueueDialer) DialCount() int { + d.mu.Lock() + defer d.mu.Unlock() + return d.dialCount +} + +type openAIWSPreflightFailConn struct { + mu sync.Mutex + events [][]byte + pingFails bool + writeCount int + pingCount int +} + +func (c *openAIWSPreflightFailConn) WriteJSON(context.Context, any) error { + c.mu.Lock() + c.writeCount++ + c.mu.Unlock() + return nil +} + +func (c *openAIWSPreflightFailConn) ReadMessage(context.Context) ([]byte, error) { + c.mu.Lock() + defer c.mu.Unlock() + if len(c.events) == 0 { + return nil, io.EOF + } + event := c.events[0] + c.events = c.events[1:] + if len(c.events) == 0 { + c.pingFails = true + } + return event, nil +} + +func (c *openAIWSPreflightFailConn) Ping(context.Context) error { + c.mu.Lock() + defer c.mu.Unlock() + c.pingCount++ + if c.pingFails { + return errors.New("preflight ping failed") + } + return nil +} + +func (c *openAIWSPreflightFailConn) Close() error { + return nil +} + +func (c *openAIWSPreflightFailConn) WriteCount() int { + c.mu.Lock() + defer c.mu.Unlock() + return c.writeCount +} + +func (c *openAIWSPreflightFailConn) PingCount() int { + c.mu.Lock() + defer c.mu.Unlock() + return c.pingCount +} + +type openAIWSWriteFailAfterFirstTurnConn struct { + mu sync.Mutex + events [][]byte + failOnWrite bool +} + +func (c *openAIWSWriteFailAfterFirstTurnConn) WriteJSON(context.Context, any) error { + c.mu.Lock() + defer c.mu.Unlock() + if c.failOnWrite { + return errors.New("write failed on stale conn") + } + return nil +} + +func (c *openAIWSWriteFailAfterFirstTurnConn) ReadMessage(context.Context) ([]byte, error) { + c.mu.Lock() + defer c.mu.Unlock() + if len(c.events) == 0 { + return nil, io.EOF + } + event := c.events[0] + c.events = c.events[1:] + if len(c.events) == 0 { + c.failOnWrite = true + } + return event, nil +} + +func (c *openAIWSWriteFailAfterFirstTurnConn) Ping(context.Context) error { + return nil +} + +func (c *openAIWSWriteFailAfterFirstTurnConn) Close() error { + return nil +} + +func TestOpenAIGatewayService_ProxyResponsesWebSocketFromClient_ClientDisconnectStillDrainsUpstream(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + // 多个上游事件:前几个为非 terminal 事件,最后一个为 terminal。 + // 第一个事件延迟 250ms 让客户端 RST 有时间传播,使 writeClientMessage 可靠失败。 + captureConn := &openAIWSCaptureConn{ + readDelays: []time.Duration{250 * time.Millisecond, 0, 0}, + events: [][]byte{ + []byte(`{"type":"response.created","response":{"id":"resp_ingress_disconnect","model":"gpt-5.1"}}`), + []byte(`{"type":"response.output_item.added","response":{"id":"resp_ingress_disconnect"}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_ingress_disconnect","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 115, + Name: "openai-ingress-client-disconnect", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "model_mapping": map[string]any{ + "custom-original-model": "gpt-5.1", + }, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + serverErrCh := make(chan error, 1) + resultCh := make(chan *OpenAIForwardResult, 1) + hooks := &OpenAIWSIngressHooks{ + AfterTurn: func(_ int, result *OpenAIForwardResult, turnErr error) { + if turnErr == nil && result != nil { + resultCh <- result + } + }, + } + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := coderws.Accept(w, r, &coderws.AcceptOptions{ + CompressionMode: coderws.CompressionContextTakeover, + }) + if err != nil { + serverErrCh <- err + return + } + defer func() { + _ = conn.CloseNow() + }() + + rec := httptest.NewRecorder() + ginCtx, _ := gin.CreateTestContext(rec) + req := r.Clone(r.Context()) + req.Header = req.Header.Clone() + req.Header.Set("User-Agent", "unit-test-agent/1.0") + ginCtx.Request = req + + readCtx, cancel := context.WithTimeout(r.Context(), 3*time.Second) + msgType, firstMessage, readErr := conn.Read(readCtx) + cancel() + if readErr != nil { + serverErrCh <- readErr + return + } + if msgType != coderws.MessageText && msgType != coderws.MessageBinary { + serverErrCh <- errors.New("unsupported websocket client message type") + return + } + + serverErrCh <- svc.ProxyResponsesWebSocketFromClient(r.Context(), ginCtx, conn, account, "sk-test", firstMessage, hooks) + })) + defer wsServer.Close() + + dialCtx, cancelDial := context.WithTimeout(context.Background(), 3*time.Second) + clientConn, _, err := coderws.Dial(dialCtx, "ws"+strings.TrimPrefix(wsServer.URL, "http"), nil) + cancelDial() + require.NoError(t, err) + + writeCtx, cancelWrite := context.WithTimeout(context.Background(), 3*time.Second) + err = clientConn.Write(writeCtx, coderws.MessageText, []byte(`{"type":"response.create","model":"custom-original-model","stream":false}`)) + cancelWrite() + require.NoError(t, err) + // 立即关闭客户端,模拟客户端在 relay 期间断连。 + require.NoError(t, clientConn.CloseNow(), "模拟 ingress 客户端提前断连") + + select { + case serverErr := <-serverErrCh: + require.NoError(t, serverErr, "客户端断连后应继续 drain 上游直到 terminal 或正常结束") + case <-time.After(5 * time.Second): + t.Fatal("等待 ingress websocket 结束超时") + } + + select { + case result := <-resultCh: + require.Equal(t, "resp_ingress_disconnect", result.RequestID) + require.Equal(t, 2, result.Usage.InputTokens) + require.Equal(t, 1, result.Usage.OutputTokens) + case <-time.After(2 * time.Second): + t.Fatal("未收到断连后的 turn 结果回调") + } +} diff --git a/backend/internal/service/openai_ws_forwarder_ingress_test.go b/backend/internal/service/openai_ws_forwarder_ingress_test.go new file mode 100644 index 00000000..ff35cb01 --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder_ingress_test.go @@ -0,0 +1,714 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "io" + "net" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + coderws "github.com/coder/websocket" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestIsOpenAIWSClientDisconnectError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + want bool + }{ + {name: "nil", err: nil, want: false}, + {name: "io_eof", err: io.EOF, want: true}, + {name: "net_closed", err: net.ErrClosed, want: true}, + {name: "context_canceled", err: context.Canceled, want: true}, + {name: "ws_normal_closure", err: coderws.CloseError{Code: coderws.StatusNormalClosure}, want: true}, + {name: "ws_going_away", err: coderws.CloseError{Code: coderws.StatusGoingAway}, want: true}, + {name: "ws_no_status", err: coderws.CloseError{Code: coderws.StatusNoStatusRcvd}, want: true}, + {name: "ws_abnormal_1006", err: coderws.CloseError{Code: coderws.StatusAbnormalClosure}, want: true}, + {name: "ws_policy_violation", err: coderws.CloseError{Code: coderws.StatusPolicyViolation}, want: false}, + {name: "wrapped_eof_message", err: errors.New("failed to get reader: failed to read frame header: EOF"), want: true}, + {name: "connection_reset_by_peer", err: errors.New("failed to read frame header: read tcp 127.0.0.1:1234->127.0.0.1:5678: read: connection reset by peer"), want: true}, + {name: "broken_pipe", err: errors.New("write tcp 127.0.0.1:1234->127.0.0.1:5678: write: broken pipe"), want: true}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + require.Equal(t, tt.want, isOpenAIWSClientDisconnectError(tt.err)) + }) + } +} + +func TestIsOpenAIWSIngressPreviousResponseNotFound(t *testing.T) { + t.Parallel() + + require.False(t, isOpenAIWSIngressPreviousResponseNotFound(nil)) + require.False(t, isOpenAIWSIngressPreviousResponseNotFound(errors.New("plain error"))) + require.False(t, isOpenAIWSIngressPreviousResponseNotFound( + wrapOpenAIWSIngressTurnError("read_upstream", errors.New("upstream read failed"), false), + )) + require.False(t, isOpenAIWSIngressPreviousResponseNotFound( + wrapOpenAIWSIngressTurnError(openAIWSIngressStagePreviousResponseNotFound, errors.New("previous response not found"), true), + )) + require.True(t, isOpenAIWSIngressPreviousResponseNotFound( + wrapOpenAIWSIngressTurnError(openAIWSIngressStagePreviousResponseNotFound, errors.New("previous response not found"), false), + )) +} + +func TestOpenAIWSIngressPreviousResponseRecoveryEnabled(t *testing.T) { + t.Parallel() + + var nilService *OpenAIGatewayService + require.True(t, nilService.openAIWSIngressPreviousResponseRecoveryEnabled(), "nil service should default to enabled") + + svcWithNilCfg := &OpenAIGatewayService{} + require.True(t, svcWithNilCfg.openAIWSIngressPreviousResponseRecoveryEnabled(), "nil config should default to enabled") + + svc := &OpenAIGatewayService{ + cfg: &config.Config{}, + } + require.False(t, svc.openAIWSIngressPreviousResponseRecoveryEnabled(), "explicit config default should be false") + + svc.cfg.Gateway.OpenAIWS.IngressPreviousResponseRecoveryEnabled = true + require.True(t, svc.openAIWSIngressPreviousResponseRecoveryEnabled()) +} + +func TestDropPreviousResponseIDFromRawPayload(t *testing.T) { + t.Parallel() + + t.Run("empty_payload", func(t *testing.T) { + updated, removed, err := dropPreviousResponseIDFromRawPayload(nil) + require.NoError(t, err) + require.False(t, removed) + require.Empty(t, updated) + }) + + t.Run("payload_without_previous_response_id", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`) + updated, removed, err := dropPreviousResponseIDFromRawPayload(payload) + require.NoError(t, err) + require.False(t, removed) + require.Equal(t, string(payload), string(updated)) + }) + + t.Run("normal_delete_success", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_abc"}`) + updated, removed, err := dropPreviousResponseIDFromRawPayload(payload) + require.NoError(t, err) + require.True(t, removed) + require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists()) + }) + + t.Run("duplicate_keys_are_removed", func(t *testing.T) { + payload := []byte(`{"type":"response.create","previous_response_id":"resp_a","input":[],"previous_response_id":"resp_b"}`) + updated, removed, err := dropPreviousResponseIDFromRawPayload(payload) + require.NoError(t, err) + require.True(t, removed) + require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists()) + }) + + t.Run("nil_delete_fn_uses_default_delete_logic", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_abc"}`) + updated, removed, err := dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, nil) + require.NoError(t, err) + require.True(t, removed) + require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists()) + }) + + t.Run("delete_error", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_abc"}`) + updated, removed, err := dropPreviousResponseIDFromRawPayloadWithDeleteFn(payload, func(_ []byte, _ string) ([]byte, error) { + return nil, errors.New("delete failed") + }) + require.Error(t, err) + require.False(t, removed) + require.Equal(t, string(payload), string(updated)) + }) + + t.Run("malformed_json_is_still_best_effort_deleted", func(t *testing.T) { + payload := []byte(`{"type":"response.create","previous_response_id":"resp_abc"`) + require.True(t, gjson.GetBytes(payload, "previous_response_id").Exists()) + + updated, removed, err := dropPreviousResponseIDFromRawPayload(payload) + require.NoError(t, err) + require.True(t, removed) + require.False(t, gjson.GetBytes(updated, "previous_response_id").Exists()) + }) +} + +func TestAlignStoreDisabledPreviousResponseID(t *testing.T) { + t.Parallel() + + t.Run("empty_payload", func(t *testing.T) { + updated, changed, err := alignStoreDisabledPreviousResponseID(nil, "resp_target") + require.NoError(t, err) + require.False(t, changed) + require.Empty(t, updated) + }) + + t.Run("empty_expected", func(t *testing.T) { + payload := []byte(`{"type":"response.create","previous_response_id":"resp_old"}`) + updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "") + require.NoError(t, err) + require.False(t, changed) + require.Equal(t, string(payload), string(updated)) + }) + + t.Run("missing_previous_response_id", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`) + updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target") + require.NoError(t, err) + require.False(t, changed) + require.Equal(t, string(payload), string(updated)) + }) + + t.Run("already_aligned", func(t *testing.T) { + payload := []byte(`{"type":"response.create","previous_response_id":"resp_target"}`) + updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target") + require.NoError(t, err) + require.False(t, changed) + require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String()) + }) + + t.Run("mismatch_rewrites_to_expected", func(t *testing.T) { + payload := []byte(`{"type":"response.create","previous_response_id":"resp_old","input":[]}`) + updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target") + require.NoError(t, err) + require.True(t, changed) + require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String()) + }) + + t.Run("duplicate_keys_rewrites_to_single_expected", func(t *testing.T) { + payload := []byte(`{"type":"response.create","previous_response_id":"resp_old_1","input":[],"previous_response_id":"resp_old_2"}`) + updated, changed, err := alignStoreDisabledPreviousResponseID(payload, "resp_target") + require.NoError(t, err) + require.True(t, changed) + require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String()) + }) +} + +func TestSetPreviousResponseIDToRawPayload(t *testing.T) { + t.Parallel() + + t.Run("empty_payload", func(t *testing.T) { + updated, err := setPreviousResponseIDToRawPayload(nil, "resp_target") + require.NoError(t, err) + require.Empty(t, updated) + }) + + t.Run("empty_previous_response_id", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`) + updated, err := setPreviousResponseIDToRawPayload(payload, "") + require.NoError(t, err) + require.Equal(t, string(payload), string(updated)) + }) + + t.Run("set_previous_response_id_when_missing", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1"}`) + updated, err := setPreviousResponseIDToRawPayload(payload, "resp_target") + require.NoError(t, err) + require.Equal(t, "resp_target", gjson.GetBytes(updated, "previous_response_id").String()) + require.Equal(t, "gpt-5.1", gjson.GetBytes(updated, "model").String()) + }) + + t.Run("overwrite_existing_previous_response_id", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_old"}`) + updated, err := setPreviousResponseIDToRawPayload(payload, "resp_new") + require.NoError(t, err) + require.Equal(t, "resp_new", gjson.GetBytes(updated, "previous_response_id").String()) + }) +} + +func TestShouldInferIngressFunctionCallOutputPreviousResponseID(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + storeDisabled bool + turn int + hasFunctionCallOutput bool + currentPreviousResponse string + expectedPrevious string + want bool + }{ + { + name: "infer_when_all_conditions_match", + storeDisabled: true, + turn: 2, + hasFunctionCallOutput: true, + expectedPrevious: "resp_1", + want: true, + }, + { + name: "skip_when_store_enabled", + storeDisabled: false, + turn: 2, + hasFunctionCallOutput: true, + expectedPrevious: "resp_1", + want: false, + }, + { + name: "skip_on_first_turn", + storeDisabled: true, + turn: 1, + hasFunctionCallOutput: true, + expectedPrevious: "resp_1", + want: false, + }, + { + name: "skip_without_function_call_output", + storeDisabled: true, + turn: 2, + hasFunctionCallOutput: false, + expectedPrevious: "resp_1", + want: false, + }, + { + name: "skip_when_request_already_has_previous_response_id", + storeDisabled: true, + turn: 2, + hasFunctionCallOutput: true, + currentPreviousResponse: "resp_client", + expectedPrevious: "resp_1", + want: false, + }, + { + name: "skip_when_last_turn_response_id_missing", + storeDisabled: true, + turn: 2, + hasFunctionCallOutput: true, + expectedPrevious: "", + want: false, + }, + { + name: "trim_whitespace_before_judgement", + storeDisabled: true, + turn: 2, + hasFunctionCallOutput: true, + expectedPrevious: " resp_2 ", + want: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := shouldInferIngressFunctionCallOutputPreviousResponseID( + tt.storeDisabled, + tt.turn, + tt.hasFunctionCallOutput, + tt.currentPreviousResponse, + tt.expectedPrevious, + ) + require.Equal(t, tt.want, got) + }) + } +} + +func TestOpenAIWSInputIsPrefixExtended(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + previous []byte + current []byte + want bool + expectErr bool + }{ + { + name: "both_missing_input", + previous: []byte(`{"type":"response.create","model":"gpt-5.1"}`), + current: []byte(`{"type":"response.create","model":"gpt-5.1","previous_response_id":"resp_1"}`), + want: true, + }, + { + name: "previous_missing_current_empty_array", + previous: []byte(`{"type":"response.create","model":"gpt-5.1"}`), + current: []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`), + want: true, + }, + { + name: "previous_missing_current_non_empty_array", + previous: []byte(`{"type":"response.create","model":"gpt-5.1"}`), + current: []byte(`{"type":"response.create","model":"gpt-5.1","input":[{"type":"input_text","text":"hello"}]}`), + want: false, + }, + { + name: "array_prefix_match", + previous: []byte(`{"input":[{"type":"input_text","text":"hello"}]}`), + current: []byte(`{"input":[{"text":"hello","type":"input_text"},{"type":"input_text","text":"world"}]}`), + want: true, + }, + { + name: "array_prefix_mismatch", + previous: []byte(`{"input":[{"type":"input_text","text":"hello"}]}`), + current: []byte(`{"input":[{"type":"input_text","text":"different"}]}`), + want: false, + }, + { + name: "current_shorter_than_previous", + previous: []byte(`{"input":[{"type":"input_text","text":"a"},{"type":"input_text","text":"b"}]}`), + current: []byte(`{"input":[{"type":"input_text","text":"a"}]}`), + want: false, + }, + { + name: "previous_has_input_current_missing", + previous: []byte(`{"input":[{"type":"input_text","text":"a"}]}`), + current: []byte(`{"model":"gpt-5.1"}`), + want: false, + }, + { + name: "input_string_treated_as_single_item", + previous: []byte(`{"input":"hello"}`), + current: []byte(`{"input":"hello"}`), + want: true, + }, + { + name: "current_invalid_input_json", + previous: []byte(`{"input":[]}`), + current: []byte(`{"input":[}`), + expectErr: true, + }, + { + name: "invalid_input_json", + previous: []byte(`{"input":[}`), + current: []byte(`{"input":[]}`), + expectErr: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, err := openAIWSInputIsPrefixExtended(tt.previous, tt.current) + if tt.expectErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tt.want, got) + }) + } +} + +func TestNormalizeOpenAIWSJSONForCompare(t *testing.T) { + t.Parallel() + + normalized, err := normalizeOpenAIWSJSONForCompare([]byte(`{"b":2,"a":1}`)) + require.NoError(t, err) + require.Equal(t, `{"a":1,"b":2}`, string(normalized)) + + _, err = normalizeOpenAIWSJSONForCompare([]byte(" ")) + require.Error(t, err) + + _, err = normalizeOpenAIWSJSONForCompare([]byte(`{"a":`)) + require.Error(t, err) +} + +func TestNormalizeOpenAIWSJSONForCompareOrRaw(t *testing.T) { + t.Parallel() + + require.Equal(t, `{"a":1,"b":2}`, string(normalizeOpenAIWSJSONForCompareOrRaw([]byte(`{"b":2,"a":1}`)))) + require.Equal(t, `{"a":`, string(normalizeOpenAIWSJSONForCompareOrRaw([]byte(`{"a":`)))) +} + +func TestNormalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(t *testing.T) { + t.Parallel() + + normalized, err := normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID( + []byte(`{"model":"gpt-5.1","input":[1],"previous_response_id":"resp_x","metadata":{"b":2,"a":1}}`), + ) + require.NoError(t, err) + require.False(t, gjson.GetBytes(normalized, "input").Exists()) + require.False(t, gjson.GetBytes(normalized, "previous_response_id").Exists()) + require.Equal(t, float64(1), gjson.GetBytes(normalized, "metadata.a").Float()) + + _, err = normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID(nil) + require.Error(t, err) + + _, err = normalizeOpenAIWSPayloadWithoutInputAndPreviousResponseID([]byte(`[]`)) + require.Error(t, err) +} + +func TestOpenAIWSExtractNormalizedInputSequence(t *testing.T) { + t.Parallel() + + t.Run("empty_payload", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence(nil) + require.NoError(t, err) + require.False(t, exists) + require.Nil(t, items) + }) + + t.Run("input_missing", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"type":"response.create"}`)) + require.NoError(t, err) + require.False(t, exists) + require.Nil(t, items) + }) + + t.Run("input_array", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":[{"type":"input_text","text":"hello"}]}`)) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 1) + }) + + t.Run("input_object", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":{"type":"input_text","text":"hello"}}`)) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 1) + }) + + t.Run("input_string", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":"hello"}`)) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 1) + require.Equal(t, `"hello"`, string(items[0])) + }) + + t.Run("input_number", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":42}`)) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 1) + require.Equal(t, "42", string(items[0])) + }) + + t.Run("input_bool", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":true}`)) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 1) + require.Equal(t, "true", string(items[0])) + }) + + t.Run("input_null", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":null}`)) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 1) + require.Equal(t, "null", string(items[0])) + }) + + t.Run("input_invalid_array_json", func(t *testing.T) { + items, exists, err := openAIWSExtractNormalizedInputSequence([]byte(`{"input":[}`)) + require.Error(t, err) + require.True(t, exists) + require.Nil(t, items) + }) +} + +func TestShouldKeepIngressPreviousResponseID(t *testing.T) { + t.Parallel() + + previousPayload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "store":false, + "tools":[{"type":"function","name":"tool_a"}], + "input":[{"type":"input_text","text":"hello"}] + }`) + currentStrictPayload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "store":false, + "tools":[{"name":"tool_a","type":"function"}], + "previous_response_id":"resp_turn_1", + "input":[{"text":"hello","type":"input_text"},{"type":"input_text","text":"world"}] + }`) + + t.Run("strict_incremental_keep", func(t *testing.T) { + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "resp_turn_1", false) + require.NoError(t, err) + require.True(t, keep) + require.Equal(t, "strict_incremental_ok", reason) + }) + + t.Run("missing_previous_response_id", func(t *testing.T) { + payload := []byte(`{"type":"response.create","model":"gpt-5.1","input":[]}`) + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false) + require.NoError(t, err) + require.False(t, keep) + require.Equal(t, "missing_previous_response_id", reason) + }) + + t.Run("missing_last_turn_response_id", func(t *testing.T) { + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "", false) + require.NoError(t, err) + require.False(t, keep) + require.Equal(t, "missing_last_turn_response_id", reason) + }) + + t.Run("previous_response_id_mismatch", func(t *testing.T) { + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, currentStrictPayload, "resp_turn_other", false) + require.NoError(t, err) + require.False(t, keep) + require.Equal(t, "previous_response_id_mismatch", reason) + }) + + t.Run("missing_previous_turn_payload", func(t *testing.T) { + keep, reason, err := shouldKeepIngressPreviousResponseID(nil, currentStrictPayload, "resp_turn_1", false) + require.NoError(t, err) + require.False(t, keep) + require.Equal(t, "missing_previous_turn_payload", reason) + }) + + t.Run("non_input_changed", func(t *testing.T) { + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1-mini", + "store":false, + "tools":[{"type":"function","name":"tool_a"}], + "previous_response_id":"resp_turn_1", + "input":[{"type":"input_text","text":"hello"},{"type":"input_text","text":"world"}] + }`) + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false) + require.NoError(t, err) + require.False(t, keep) + require.Equal(t, "non_input_changed", reason) + }) + + t.Run("delta_input_keeps_previous_response_id", func(t *testing.T) { + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "store":false, + "tools":[{"type":"function","name":"tool_a"}], + "previous_response_id":"resp_turn_1", + "input":[{"type":"input_text","text":"different"}] + }`) + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", false) + require.NoError(t, err) + require.True(t, keep) + require.Equal(t, "strict_incremental_ok", reason) + }) + + t.Run("function_call_output_keeps_previous_response_id", func(t *testing.T) { + payload := []byte(`{ + "type":"response.create", + "model":"gpt-5.1", + "store":false, + "previous_response_id":"resp_external", + "input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}] + }`) + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, payload, "resp_turn_1", true) + require.NoError(t, err) + require.True(t, keep) + require.Equal(t, "has_function_call_output", reason) + }) + + t.Run("non_input_compare_error", func(t *testing.T) { + keep, reason, err := shouldKeepIngressPreviousResponseID([]byte(`[]`), currentStrictPayload, "resp_turn_1", false) + require.Error(t, err) + require.False(t, keep) + require.Equal(t, "non_input_compare_error", reason) + }) + + t.Run("current_payload_compare_error", func(t *testing.T) { + keep, reason, err := shouldKeepIngressPreviousResponseID(previousPayload, []byte(`{"previous_response_id":"resp_turn_1","input":[}`), "resp_turn_1", false) + require.Error(t, err) + require.False(t, keep) + require.Equal(t, "non_input_compare_error", reason) + }) +} + +func TestBuildOpenAIWSReplayInputSequence(t *testing.T) { + t.Parallel() + + lastFull := []json.RawMessage{ + json.RawMessage(`{"type":"input_text","text":"hello"}`), + } + + t.Run("no_previous_response_id_use_current", func(t *testing.T) { + items, exists, err := buildOpenAIWSReplayInputSequence( + lastFull, + true, + []byte(`{"input":[{"type":"input_text","text":"new"}]}`), + false, + ) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 1) + require.Equal(t, "new", gjson.GetBytes(items[0], "text").String()) + }) + + t.Run("previous_response_id_delta_append", func(t *testing.T) { + items, exists, err := buildOpenAIWSReplayInputSequence( + lastFull, + true, + []byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"world"}]}`), + true, + ) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 2) + require.Equal(t, "hello", gjson.GetBytes(items[0], "text").String()) + require.Equal(t, "world", gjson.GetBytes(items[1], "text").String()) + }) + + t.Run("previous_response_id_full_input_replace", func(t *testing.T) { + items, exists, err := buildOpenAIWSReplayInputSequence( + lastFull, + true, + []byte(`{"previous_response_id":"resp_1","input":[{"type":"input_text","text":"hello"},{"type":"input_text","text":"world"}]}`), + true, + ) + require.NoError(t, err) + require.True(t, exists) + require.Len(t, items, 2) + require.Equal(t, "hello", gjson.GetBytes(items[0], "text").String()) + require.Equal(t, "world", gjson.GetBytes(items[1], "text").String()) + }) +} + +func TestSetOpenAIWSPayloadInputSequence(t *testing.T) { + t.Parallel() + + t.Run("set_items", func(t *testing.T) { + original := []byte(`{"type":"response.create","previous_response_id":"resp_1"}`) + items := []json.RawMessage{ + json.RawMessage(`{"type":"input_text","text":"hello"}`), + json.RawMessage(`{"type":"input_text","text":"world"}`), + } + updated, err := setOpenAIWSPayloadInputSequence(original, items, true) + require.NoError(t, err) + require.Equal(t, "hello", gjson.GetBytes(updated, "input.0.text").String()) + require.Equal(t, "world", gjson.GetBytes(updated, "input.1.text").String()) + }) + + t.Run("preserve_empty_array_not_null", func(t *testing.T) { + original := []byte(`{"type":"response.create","previous_response_id":"resp_1"}`) + updated, err := setOpenAIWSPayloadInputSequence(original, nil, true) + require.NoError(t, err) + require.True(t, gjson.GetBytes(updated, "input").IsArray()) + require.Len(t, gjson.GetBytes(updated, "input").Array(), 0) + require.False(t, gjson.GetBytes(updated, "input").Type == gjson.Null) + }) +} + +func TestCloneOpenAIWSRawMessages(t *testing.T) { + t.Parallel() + + t.Run("nil_slice", func(t *testing.T) { + cloned := cloneOpenAIWSRawMessages(nil) + require.Nil(t, cloned) + }) + + t.Run("empty_slice", func(t *testing.T) { + items := make([]json.RawMessage, 0) + cloned := cloneOpenAIWSRawMessages(items) + require.NotNil(t, cloned) + require.Len(t, cloned, 0) + }) +} diff --git a/backend/internal/service/openai_ws_forwarder_retry_payload_test.go b/backend/internal/service/openai_ws_forwarder_retry_payload_test.go new file mode 100644 index 00000000..0ea7e1c7 --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder_retry_payload_test.go @@ -0,0 +1,50 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestApplyOpenAIWSRetryPayloadStrategy_KeepPromptCacheKey(t *testing.T) { + payload := map[string]any{ + "model": "gpt-5.3-codex", + "prompt_cache_key": "pcache_123", + "include": []any{"reasoning.encrypted_content"}, + "text": map[string]any{ + "verbosity": "low", + }, + "tools": []any{map[string]any{"type": "function"}}, + } + + strategy, removed := applyOpenAIWSRetryPayloadStrategy(payload, 3) + require.Equal(t, "trim_optional_fields", strategy) + require.Contains(t, removed, "include") + require.NotContains(t, removed, "prompt_cache_key") + require.Equal(t, "pcache_123", payload["prompt_cache_key"]) + require.NotContains(t, payload, "include") + require.Contains(t, payload, "text") +} + +func TestApplyOpenAIWSRetryPayloadStrategy_AttemptSixKeepsSemanticFields(t *testing.T) { + payload := map[string]any{ + "prompt_cache_key": "pcache_456", + "instructions": "long instructions", + "tools": []any{map[string]any{"type": "function"}}, + "parallel_tool_calls": true, + "tool_choice": "auto", + "include": []any{"reasoning.encrypted_content"}, + "text": map[string]any{"verbosity": "high"}, + } + + strategy, removed := applyOpenAIWSRetryPayloadStrategy(payload, 6) + require.Equal(t, "trim_optional_fields", strategy) + require.Contains(t, removed, "include") + require.NotContains(t, removed, "prompt_cache_key") + require.Equal(t, "pcache_456", payload["prompt_cache_key"]) + require.Contains(t, payload, "instructions") + require.Contains(t, payload, "tools") + require.Contains(t, payload, "tool_choice") + require.Contains(t, payload, "parallel_tool_calls") + require.Contains(t, payload, "text") +} diff --git a/backend/internal/service/openai_ws_forwarder_success_test.go b/backend/internal/service/openai_ws_forwarder_success_test.go new file mode 100644 index 00000000..592801f6 --- /dev/null +++ b/backend/internal/service/openai_ws_forwarder_success_test.go @@ -0,0 +1,1306 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "io" + "net/http" + "net/http/httptest" + "strconv" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestOpenAIGatewayService_Forward_WSv2_SuccessAndBindSticky(t *testing.T) { + gin.SetMode(gin.TestMode) + + type receivedPayload struct { + Type string + PreviousResponseID string + StreamExists bool + Stream bool + } + receivedCh := make(chan receivedPayload, 1) + + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var request map[string]any + if err := conn.ReadJSON(&request); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + requestJSON := requestToJSONString(request) + receivedCh <- receivedPayload{ + Type: strings.TrimSpace(gjson.Get(requestJSON, "type").String()), + PreviousResponseID: strings.TrimSpace(gjson.Get(requestJSON, "previous_response_id").String()), + StreamExists: gjson.Get(requestJSON, "stream").Exists(), + Stream: gjson.Get(requestJSON, "stream").Bool(), + } + + if err := conn.WriteJSON(map[string]any{ + "type": "response.created", + "response": map[string]any{ + "id": "resp_new_1", + "model": "gpt-5.1", + }, + }); err != nil { + t.Errorf("write response.created failed: %v", err) + return + } + if err := conn.WriteJSON(map[string]any{ + "type": "response.completed", + "response": map[string]any{ + "id": "resp_new_1", + "model": "gpt-5.1", + "usage": map[string]any{ + "input_tokens": 12, + "output_tokens": 7, + "input_tokens_details": map[string]any{ + "cached_tokens": 3, + }, + }, + }, + }); err != nil { + t.Errorf("write response.completed failed: %v", err) + return + } + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "unit-test-agent/1.0") + groupID := int64(1001) + c.Set("api_key", &APIKey{GroupID: &groupID}) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 30 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 10 + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = 3600 + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cache := &stubGatewayCache{} + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + cache: cache, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 9, + Name: "openai-ws", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 2, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_1","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 12, result.Usage.InputTokens) + require.Equal(t, 7, result.Usage.OutputTokens) + require.Equal(t, 3, result.Usage.CacheReadInputTokens) + require.Equal(t, "resp_new_1", result.RequestID) + require.True(t, result.OpenAIWSMode) + require.False(t, gjson.GetBytes(upstream.lastBody, "model").Exists(), "WSv2 成功时不应回落 HTTP 上游") + + received := <-receivedCh + require.Equal(t, "response.create", received.Type) + require.Equal(t, "resp_prev_1", received.PreviousResponseID) + require.True(t, received.StreamExists, "WS 请求应携带 stream 字段") + require.False(t, received.Stream, "应保持客户端 stream=false 的原始语义") + + store := svc.getOpenAIWSStateStore() + mappedAccountID, getErr := store.GetResponseAccount(context.Background(), groupID, "resp_new_1") + require.NoError(t, getErr) + require.Equal(t, account.ID, mappedAccountID) + connID, ok := store.GetResponseConn("resp_new_1") + require.True(t, ok) + require.NotEmpty(t, connID) + + responseBody := rec.Body.Bytes() + require.Equal(t, "resp_new_1", gjson.GetBytes(responseBody, "id").String()) +} + +func requestToJSONString(payload map[string]any) string { + if len(payload) == 0 { + return "{}" + } + b, err := json.Marshal(payload) + if err != nil { + return "{}" + } + return string(b) +} + +func TestLogOpenAIWSBindResponseAccountWarn(t *testing.T) { + require.NotPanics(t, func() { + logOpenAIWSBindResponseAccountWarn(1, 2, "resp_ok", nil) + }) + require.NotPanics(t, func() { + logOpenAIWSBindResponseAccountWarn(1, 2, "resp_err", errors.New("bind failed")) + }) +} + +func TestOpenAIGatewayService_Forward_WSv2_RewriteModelAndToolCallsOnCompletedEvent(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0") + groupID := int64(3001) + c.Set("api_key", &APIKey{GroupID: &groupID}) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 5 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_model_tool_1","model":"gpt-5.1","tool_calls":[{"function":{"name":"apply_patch","arguments":"{\"file_path\":\"/tmp/a.txt\",\"old_string\":\"a\",\"new_string\":\"b\"}"}}],"usage":{"input_tokens":2,"output_tokens":1}},"tool_calls":[{"function":{"name":"apply_patch","arguments":"{\"file_path\":\"/tmp/a.txt\",\"old_string\":\"a\",\"new_string\":\"b\"}"}}]}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 1301, + Name: "openai-rewrite", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "model_mapping": map[string]any{ + "custom-original-model": "gpt-5.1", + }, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"custom-original-model","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "resp_model_tool_1", result.RequestID) + require.Equal(t, "custom-original-model", gjson.GetBytes(rec.Body.Bytes(), "model").String(), "响应模型应回写为原始请求模型") + require.Equal(t, "edit", gjson.GetBytes(rec.Body.Bytes(), "tool_calls.0.function.name").String(), "工具名称应被修正为 OpenCode 规范") +} + +func TestOpenAIWSPayloadString_OnlyAcceptsStringValues(t *testing.T) { + payload := map[string]any{ + "type": nil, + "model": 123, + "prompt_cache_key": " cache-key ", + "previous_response_id": []byte(" resp_1 "), + } + + require.Equal(t, "", openAIWSPayloadString(payload, "type")) + require.Equal(t, "", openAIWSPayloadString(payload, "model")) + require.Equal(t, "cache-key", openAIWSPayloadString(payload, "prompt_cache_key")) + require.Equal(t, "resp_1", openAIWSPayloadString(payload, "previous_response_id")) +} + +func TestOpenAIGatewayService_Forward_WSv2_PoolReuseNotOneToOne(t *testing.T) { + gin.SetMode(gin.TestMode) + + var upgradeCount atomic.Int64 + var sequence atomic.Int64 + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgradeCount.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + for { + var request map[string]any + if err := conn.ReadJSON(&request); err != nil { + return + } + idx := sequence.Add(1) + responseID := "resp_reuse_" + strconv.FormatInt(idx, 10) + if err := conn.WriteJSON(map[string]any{ + "type": "response.created", + "response": map[string]any{ + "id": responseID, + "model": "gpt-5.1", + }, + }); err != nil { + return + } + if err := conn.WriteJSON(map[string]any{ + "type": "response.completed", + "response": map[string]any{ + "id": responseID, + "model": "gpt-5.1", + "usage": map[string]any{ + "input_tokens": 2, + "output_tokens": 1, + }, + }, + }); err != nil { + return + } + } + })) + defer wsServer.Close() + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 30 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 10 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + account := &Account{ + ID: 19, + Name: "openai-ws", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 2, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + for i := 0; i < 2; i++ { + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0") + groupID := int64(2001) + c.Set("api_key", &APIKey{GroupID: &groupID}) + + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_reuse","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, strings.HasPrefix(result.RequestID, "resp_reuse_")) + } + + require.Equal(t, int64(1), upgradeCount.Load(), "多个客户端请求应复用账号连接池而不是 1:1 对等建链") + metrics := svc.SnapshotOpenAIWSPoolMetrics() + require.GreaterOrEqual(t, metrics.AcquireReuseTotal, int64(1)) + require.GreaterOrEqual(t, metrics.ConnPickTotal, int64(1)) +} + +func TestOpenAIGatewayService_Forward_WSv2_OAuthStoreFalseByDefault(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0") + c.Request.Header.Set("session_id", "sess-oauth-1") + c.Request.Header.Set("conversation_id", "conv-oauth-1") + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.AllowStoreRecovery = false + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_oauth_1","model":"gpt-5.1","usage":{"input_tokens":3,"output_tokens":2}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + account := &Account{ + ID: 29, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token-1", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"store":true,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "resp_oauth_1", result.RequestID) + + require.NotNil(t, captureConn.lastWrite) + requestJSON := requestToJSONString(captureConn.lastWrite) + require.True(t, gjson.Get(requestJSON, "store").Exists(), "OAuth WSv2 应显式写入 store 字段") + require.False(t, gjson.Get(requestJSON, "store").Bool(), "默认策略应将 OAuth store 置为 false") + require.True(t, gjson.Get(requestJSON, "stream").Exists(), "WSv2 payload 应保留 stream 字段") + require.True(t, gjson.Get(requestJSON, "stream").Bool(), "OAuth Codex 规范化后应强制 stream=true") + require.Equal(t, openAIWSBetaV2Value, captureDialer.lastHeaders.Get("OpenAI-Beta")) + require.Equal(t, "sess-oauth-1", captureDialer.lastHeaders.Get("session_id")) + require.Equal(t, "conv-oauth-1", captureDialer.lastHeaders.Get("conversation_id")) +} + +func TestOpenAIGatewayService_Forward_WSv2_HeaderSessionFallbackFromPromptCacheKey(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0") + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_prompt_cache_key","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + account := &Account{ + ID: 31, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "access_token": "oauth-token-1", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":true,"prompt_cache_key":"pcache_123","input":[{"type":"input_text","text":"hi"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "resp_prompt_cache_key", result.RequestID) + + require.Equal(t, "pcache_123", captureDialer.lastHeaders.Get("session_id")) + require.Empty(t, captureDialer.lastHeaders.Get("conversation_id")) + require.NotNil(t, captureConn.lastWrite) + require.True(t, gjson.Get(requestToJSONString(captureConn.lastWrite), "stream").Exists()) +} + +func TestOpenAIGatewayService_Forward_WSv1_Unsupported(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0") + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsockets = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = false + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 39, + Name: "openai-ws-v1", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": "https://api.openai.com/v1/responses", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_prev_v1","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Contains(t, err.Error(), "ws v1") + require.Equal(t, http.StatusBadRequest, rec.Code) + require.Contains(t, rec.Body.String(), "WSv1") + require.Nil(t, upstream.lastReq, "WSv1 不支持时不应触发 HTTP 上游请求") +} + +func TestOpenAIGatewayService_Forward_WSv2_TurnStateAndMetadataReplayOnReconnect(t *testing.T) { + gin.SetMode(gin.TestMode) + + var connIndex atomic.Int64 + headersCh := make(chan http.Header, 4) + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + idx := connIndex.Add(1) + headersCh <- cloneHeader(r.Header) + + respHeader := http.Header{} + if idx == 1 { + respHeader.Set("x-codex-turn-state", "turn_state_first") + } + conn, err := upgrader.Upgrade(w, r, respHeader) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var request map[string]any + if err := conn.ReadJSON(&request); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + responseID := "resp_turn_" + strconv.FormatInt(idx, 10) + if err := conn.WriteJSON(map[string]any{ + "type": "response.completed", + "response": map[string]any{ + "id": responseID, + "model": "gpt-5.1", + "usage": map[string]any{ + "input_tokens": 2, + "output_tokens": 1, + }, + }, + }); err != nil { + t.Errorf("write response.completed failed: %v", err) + return + } + })) + defer wsServer.Close() + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 0 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 49, + Name: "openai-turn-state", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + reqBody := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + rec1 := httptest.NewRecorder() + c1, _ := gin.CreateTestContext(rec1) + c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c1.Request.Header.Set("session_id", "session_turn_state") + c1.Request.Header.Set("x-codex-turn-metadata", "turn_meta_1") + result1, err := svc.Forward(context.Background(), c1, account, reqBody) + require.NoError(t, err) + require.NotNil(t, result1) + + sessionHash := svc.GenerateSessionHash(c1, reqBody) + store := svc.getOpenAIWSStateStore() + turnState, ok := store.GetSessionTurnState(0, sessionHash) + require.True(t, ok) + require.Equal(t, "turn_state_first", turnState) + + // 主动淘汰连接,模拟下一次请求发生重连。 + connID, hasConn := store.GetResponseConn(result1.RequestID) + require.True(t, hasConn) + svc.getOpenAIWSConnPool().evictConn(account.ID, connID) + + rec2 := httptest.NewRecorder() + c2, _ := gin.CreateTestContext(rec2) + c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c2.Request.Header.Set("session_id", "session_turn_state") + c2.Request.Header.Set("x-codex-turn-metadata", "turn_meta_2") + result2, err := svc.Forward(context.Background(), c2, account, reqBody) + require.NoError(t, err) + require.NotNil(t, result2) + + firstHandshakeHeaders := <-headersCh + secondHandshakeHeaders := <-headersCh + require.Equal(t, "turn_meta_1", firstHandshakeHeaders.Get("X-Codex-Turn-Metadata")) + require.Equal(t, "turn_meta_2", secondHandshakeHeaders.Get("X-Codex-Turn-Metadata")) + require.Equal(t, "turn_state_first", secondHandshakeHeaders.Get("X-Codex-Turn-State")) +} + +func TestOpenAIGatewayService_Forward_WSv2_GeneratePrewarm(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("session_id", "session-prewarm") + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.PrewarmGenerateEnabled = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_prewarm_1","model":"gpt-5.1","usage":{"input_tokens":0,"output_tokens":0}}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_main_1","model":"gpt-5.1","usage":{"input_tokens":4,"output_tokens":2}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 59, + Name: "openai-prewarm", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "resp_main_1", result.RequestID) + + require.Len(t, captureConn.writes, 2, "开启 generate=false 预热后应发送两次 WS 请求") + firstWrite := requestToJSONString(captureConn.writes[0]) + secondWrite := requestToJSONString(captureConn.writes[1]) + require.True(t, gjson.Get(firstWrite, "generate").Exists()) + require.False(t, gjson.Get(firstWrite, "generate").Bool()) + require.False(t, gjson.Get(secondWrite, "generate").Exists()) +} + +func TestOpenAIGatewayService_PrewarmReadHonorsParentContext(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.PrewarmGenerateEnabled = true + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 5 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + svc := &OpenAIGatewayService{ + cfg: cfg, + toolCorrector: NewCodexToolCorrector(), + } + account := &Account{ + ID: 601, + Name: "openai-prewarm-timeout", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + } + conn := newOpenAIWSConn("prewarm_ctx_conn", account.ID, &openAIWSBlockingConn{ + readDelay: 200 * time.Millisecond, + }, nil) + lease := &openAIWSConnLease{ + accountID: account.ID, + conn: conn, + } + payload := map[string]any{ + "type": "response.create", + "model": "gpt-5.1", + } + + ctx, cancel := context.WithTimeout(context.Background(), 40*time.Millisecond) + defer cancel() + start := time.Now() + err := svc.performOpenAIWSGeneratePrewarm( + ctx, + lease, + OpenAIWSProtocolDecision{Transport: OpenAIUpstreamTransportResponsesWebsocketV2}, + payload, + "", + map[string]any{"model": "gpt-5.1"}, + account, + nil, + 0, + ) + elapsed := time.Since(start) + require.Error(t, err) + require.Contains(t, err.Error(), "prewarm_read_event") + require.Less(t, elapsed, 180*time.Millisecond, "预热读取应受父 context 取消控制,不应阻塞到 read_timeout") +} + +func TestOpenAIGatewayService_Forward_WSv2_TurnMetadataInPayloadOnConnReuse(t *testing.T) { + gin.SetMode(gin.TestMode) + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + + captureConn := &openAIWSCaptureConn{ + events: [][]byte{ + []byte(`{"type":"response.completed","response":{"id":"resp_meta_1","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_meta_2","model":"gpt-5.1","usage":{"input_tokens":1,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 69, + Name: "openai-turn-metadata", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + + rec1 := httptest.NewRecorder() + c1, _ := gin.CreateTestContext(rec1) + c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c1.Request.Header.Set("session_id", "session-metadata-reuse") + c1.Request.Header.Set("x-codex-turn-metadata", "turn_meta_payload_1") + result1, err := svc.Forward(context.Background(), c1, account, body) + require.NoError(t, err) + require.NotNil(t, result1) + require.Equal(t, "resp_meta_1", result1.RequestID) + + rec2 := httptest.NewRecorder() + c2, _ := gin.CreateTestContext(rec2) + c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c2.Request.Header.Set("session_id", "session-metadata-reuse") + c2.Request.Header.Set("x-codex-turn-metadata", "turn_meta_payload_2") + result2, err := svc.Forward(context.Background(), c2, account, body) + require.NoError(t, err) + require.NotNil(t, result2) + require.Equal(t, "resp_meta_2", result2.RequestID) + + require.Equal(t, 1, captureDialer.DialCount(), "同一账号两轮请求应复用同一 WS 连接") + require.Len(t, captureConn.writes, 2) + + firstWrite := requestToJSONString(captureConn.writes[0]) + secondWrite := requestToJSONString(captureConn.writes[1]) + require.Equal(t, "turn_meta_payload_1", gjson.Get(firstWrite, "client_metadata.x-codex-turn-metadata").String()) + require.Equal(t, "turn_meta_payload_2", gjson.Get(secondWrite, "client_metadata.x-codex-turn-metadata").String()) +} + +func TestOpenAIGatewayService_Forward_WSv2StoreFalseSessionConnIsolation(t *testing.T) { + gin.SetMode(gin.TestMode) + + var upgradeCount atomic.Int64 + var sequence atomic.Int64 + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgradeCount.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + for { + var request map[string]any + if err := conn.ReadJSON(&request); err != nil { + return + } + responseID := "resp_store_false_" + strconv.FormatInt(sequence.Add(1), 10) + if err := conn.WriteJSON(map[string]any{ + "type": "response.completed", + "response": map[string]any{ + "id": responseID, + "model": "gpt-5.1", + "usage": map[string]any{ + "input_tokens": 1, + "output_tokens": 1, + }, + }, + }); err != nil { + return + } + } + })) + defer wsServer.Close() + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 4 + cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = true + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 79, + Name: "openai-store-false", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 2, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) + + rec1 := httptest.NewRecorder() + c1, _ := gin.CreateTestContext(rec1) + c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c1.Request.Header.Set("session_id", "session_store_false_a") + result1, err := svc.Forward(context.Background(), c1, account, body) + require.NoError(t, err) + require.NotNil(t, result1) + require.Equal(t, int64(1), upgradeCount.Load()) + + rec2 := httptest.NewRecorder() + c2, _ := gin.CreateTestContext(rec2) + c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c2.Request.Header.Set("session_id", "session_store_false_a") + result2, err := svc.Forward(context.Background(), c2, account, body) + require.NoError(t, err) + require.NotNil(t, result2) + require.Equal(t, int64(1), upgradeCount.Load(), "同一 session(store=false) 应复用同一 WS 连接") + + rec3 := httptest.NewRecorder() + c3, _ := gin.CreateTestContext(rec3) + c3.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c3.Request.Header.Set("session_id", "session_store_false_b") + result3, err := svc.Forward(context.Background(), c3, account, body) + require.NoError(t, err) + require.NotNil(t, result3) + require.Equal(t, int64(2), upgradeCount.Load(), "不同 session(store=false) 应隔离连接,避免续链状态互相覆盖") +} + +func TestOpenAIGatewayService_Forward_WSv2StoreFalseDisableForceNewConnAllowsReuse(t *testing.T) { + gin.SetMode(gin.TestMode) + + var upgradeCount atomic.Int64 + var sequence atomic.Int64 + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + upgradeCount.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + for { + var request map[string]any + if err := conn.ReadJSON(&request); err != nil { + return + } + responseID := "resp_store_false_reuse_" + strconv.FormatInt(sequence.Add(1), 10) + if err := conn.WriteJSON(map[string]any{ + "type": "response.completed", + "response": map[string]any{ + "id": responseID, + "model": "gpt-5.1", + "usage": map[string]any{ + "input_tokens": 1, + "output_tokens": 1, + }, + }, + }); err != nil { + return + } + } + })) + defer wsServer.Close() + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = false + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: &httpUpstreamRecorder{}, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 80, + Name: "openai-store-false-reuse", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 2, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"store":false,"input":[{"type":"input_text","text":"hello"}]}`) + + rec1 := httptest.NewRecorder() + c1, _ := gin.CreateTestContext(rec1) + c1.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c1.Request.Header.Set("session_id", "session_store_false_reuse_a") + result1, err := svc.Forward(context.Background(), c1, account, body) + require.NoError(t, err) + require.NotNil(t, result1) + require.Equal(t, int64(1), upgradeCount.Load()) + + rec2 := httptest.NewRecorder() + c2, _ := gin.CreateTestContext(rec2) + c2.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c2.Request.Header.Set("session_id", "session_store_false_reuse_b") + result2, err := svc.Forward(context.Background(), c2, account, body) + require.NoError(t, err) + require.NotNil(t, result2) + require.Equal(t, int64(1), upgradeCount.Load(), "关闭强制新连后,不同 session(store=false) 可复用连接") +} + +func TestOpenAIGatewayService_Forward_WSv2ReadTimeoutAppliesPerRead(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "codex_cli_rs/0.98.0") + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 3 + cfg.Gateway.OpenAIWS.ReadTimeoutSeconds = 1 + cfg.Gateway.OpenAIWS.WriteTimeoutSeconds = 3 + + captureConn := &openAIWSCaptureConn{ + readDelays: []time.Duration{ + 700 * time.Millisecond, + 700 * time.Millisecond, + }, + events: [][]byte{ + []byte(`{"type":"response.created","response":{"id":"resp_timeout_ok","model":"gpt-5.1"}}`), + []byte(`{"type":"response.completed","response":{"id":"resp_timeout_ok","model":"gpt-5.1","usage":{"input_tokens":2,"output_tokens":1}}}`), + }, + } + captureDialer := &openAIWSCaptureDialer{conn: captureConn} + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(captureDialer) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_fallback","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + cache: &stubGatewayCache{}, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + openaiWSPool: pool, + } + + account := &Account{ + ID: 81, + Name: "openai-read-timeout", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Status: StatusActive, + Schedulable: true, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "resp_timeout_ok", result.RequestID) + require.Nil(t, upstream.lastReq, "每次 Read 都应独立应用超时;总时长超过 read_timeout 不应误回退 HTTP") +} + +type openAIWSCaptureDialer struct { + mu sync.Mutex + conn *openAIWSCaptureConn + lastHeaders http.Header + handshake http.Header + dialCount int +} + +func (d *openAIWSCaptureDialer) Dial( + ctx context.Context, + wsURL string, + headers http.Header, + proxyURL string, +) (openAIWSClientConn, int, http.Header, error) { + _ = ctx + _ = wsURL + _ = proxyURL + d.mu.Lock() + d.lastHeaders = cloneHeader(headers) + d.dialCount++ + respHeaders := cloneHeader(d.handshake) + d.mu.Unlock() + return d.conn, 0, respHeaders, nil +} + +func (d *openAIWSCaptureDialer) DialCount() int { + d.mu.Lock() + defer d.mu.Unlock() + return d.dialCount +} + +type openAIWSCaptureConn struct { + mu sync.Mutex + readDelays []time.Duration + events [][]byte + lastWrite map[string]any + writes []map[string]any + closed bool +} + +func (c *openAIWSCaptureConn) WriteJSON(ctx context.Context, value any) error { + _ = ctx + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return errOpenAIWSConnClosed + } + switch payload := value.(type) { + case map[string]any: + c.lastWrite = cloneMapStringAny(payload) + c.writes = append(c.writes, cloneMapStringAny(payload)) + case json.RawMessage: + var parsed map[string]any + if err := json.Unmarshal(payload, &parsed); err == nil { + c.lastWrite = cloneMapStringAny(parsed) + c.writes = append(c.writes, cloneMapStringAny(parsed)) + } + case []byte: + var parsed map[string]any + if err := json.Unmarshal(payload, &parsed); err == nil { + c.lastWrite = cloneMapStringAny(parsed) + c.writes = append(c.writes, cloneMapStringAny(parsed)) + } + } + return nil +} + +func (c *openAIWSCaptureConn) ReadMessage(ctx context.Context) ([]byte, error) { + if ctx == nil { + ctx = context.Background() + } + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return nil, errOpenAIWSConnClosed + } + if len(c.events) == 0 { + c.mu.Unlock() + return nil, io.EOF + } + delay := time.Duration(0) + if len(c.readDelays) > 0 { + delay = c.readDelays[0] + c.readDelays = c.readDelays[1:] + } + event := c.events[0] + c.events = c.events[1:] + c.mu.Unlock() + if delay > 0 { + timer := time.NewTimer(delay) + defer timer.Stop() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-timer.C: + } + } + return event, nil +} + +func (c *openAIWSCaptureConn) Ping(ctx context.Context) error { + _ = ctx + return nil +} + +func (c *openAIWSCaptureConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + c.closed = true + return nil +} + +func cloneMapStringAny(src map[string]any) map[string]any { + if src == nil { + return nil + } + dst := make(map[string]any, len(src)) + for k, v := range src { + dst[k] = v + } + return dst +} diff --git a/backend/internal/service/openai_ws_pool.go b/backend/internal/service/openai_ws_pool.go new file mode 100644 index 00000000..db6a96a7 --- /dev/null +++ b/backend/internal/service/openai_ws_pool.go @@ -0,0 +1,1706 @@ +package service + +import ( + "context" + "errors" + "fmt" + "math" + "net/http" + "sort" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "golang.org/x/sync/errgroup" +) + +const ( + openAIWSConnMaxAge = 60 * time.Minute + openAIWSConnHealthCheckIdle = 90 * time.Second + openAIWSConnHealthCheckTO = 2 * time.Second + openAIWSConnPrewarmExtraDelay = 2 * time.Second + openAIWSAcquireCleanupInterval = 3 * time.Second + openAIWSBackgroundPingInterval = 30 * time.Second + openAIWSBackgroundSweepTicker = 30 * time.Second + + openAIWSPrewarmFailureWindow = 30 * time.Second + openAIWSPrewarmFailureSuppress = 2 +) + +var ( + errOpenAIWSConnClosed = errors.New("openai ws connection closed") + errOpenAIWSConnQueueFull = errors.New("openai ws connection queue full") + errOpenAIWSPreferredConnUnavailable = errors.New("openai ws preferred connection unavailable") +) + +type openAIWSDialError struct { + StatusCode int + ResponseHeaders http.Header + Err error +} + +func (e *openAIWSDialError) Error() string { + if e == nil { + return "" + } + if e.StatusCode > 0 { + return fmt.Sprintf("openai ws dial failed: status=%d err=%v", e.StatusCode, e.Err) + } + return fmt.Sprintf("openai ws dial failed: %v", e.Err) +} + +func (e *openAIWSDialError) Unwrap() error { + if e == nil { + return nil + } + return e.Err +} + +type openAIWSAcquireRequest struct { + Account *Account + WSURL string + Headers http.Header + ProxyURL string + PreferredConnID string + // ForceNewConn: 强制本次获取新连接(避免复用导致连接内续链状态互相污染)。 + ForceNewConn bool + // ForcePreferredConn: 强制本次只使用 PreferredConnID,禁止漂移到其它连接。 + ForcePreferredConn bool +} + +type openAIWSConnLease struct { + pool *openAIWSConnPool + accountID int64 + conn *openAIWSConn + queueWait time.Duration + connPick time.Duration + reused bool + released atomic.Bool +} + +func (l *openAIWSConnLease) activeConn() (*openAIWSConn, error) { + if l == nil || l.conn == nil { + return nil, errOpenAIWSConnClosed + } + if l.released.Load() { + return nil, errOpenAIWSConnClosed + } + return l.conn, nil +} + +func (l *openAIWSConnLease) ConnID() string { + if l == nil || l.conn == nil { + return "" + } + return l.conn.id +} + +func (l *openAIWSConnLease) QueueWaitDuration() time.Duration { + if l == nil { + return 0 + } + return l.queueWait +} + +func (l *openAIWSConnLease) ConnPickDuration() time.Duration { + if l == nil { + return 0 + } + return l.connPick +} + +func (l *openAIWSConnLease) Reused() bool { + if l == nil { + return false + } + return l.reused +} + +func (l *openAIWSConnLease) HandshakeHeader(name string) string { + if l == nil || l.conn == nil { + return "" + } + return l.conn.handshakeHeader(name) +} + +func (l *openAIWSConnLease) IsPrewarmed() bool { + if l == nil || l.conn == nil { + return false + } + return l.conn.isPrewarmed() +} + +func (l *openAIWSConnLease) MarkPrewarmed() { + if l == nil || l.conn == nil { + return + } + l.conn.markPrewarmed() +} + +func (l *openAIWSConnLease) WriteJSON(value any, timeout time.Duration) error { + conn, err := l.activeConn() + if err != nil { + return err + } + return conn.writeJSONWithTimeout(context.Background(), value, timeout) +} + +func (l *openAIWSConnLease) WriteJSONWithContextTimeout(ctx context.Context, value any, timeout time.Duration) error { + conn, err := l.activeConn() + if err != nil { + return err + } + return conn.writeJSONWithTimeout(ctx, value, timeout) +} + +func (l *openAIWSConnLease) WriteJSONContext(ctx context.Context, value any) error { + conn, err := l.activeConn() + if err != nil { + return err + } + return conn.writeJSON(value, ctx) +} + +func (l *openAIWSConnLease) ReadMessage(timeout time.Duration) ([]byte, error) { + conn, err := l.activeConn() + if err != nil { + return nil, err + } + return conn.readMessageWithTimeout(timeout) +} + +func (l *openAIWSConnLease) ReadMessageContext(ctx context.Context) ([]byte, error) { + conn, err := l.activeConn() + if err != nil { + return nil, err + } + return conn.readMessage(ctx) +} + +func (l *openAIWSConnLease) ReadMessageWithContextTimeout(ctx context.Context, timeout time.Duration) ([]byte, error) { + conn, err := l.activeConn() + if err != nil { + return nil, err + } + return conn.readMessageWithContextTimeout(ctx, timeout) +} + +func (l *openAIWSConnLease) PingWithTimeout(timeout time.Duration) error { + conn, err := l.activeConn() + if err != nil { + return err + } + return conn.pingWithTimeout(timeout) +} + +func (l *openAIWSConnLease) MarkBroken() { + if l == nil || l.pool == nil || l.conn == nil || l.released.Load() { + return + } + l.pool.evictConn(l.accountID, l.conn.id) +} + +func (l *openAIWSConnLease) Release() { + if l == nil || l.conn == nil { + return + } + if !l.released.CompareAndSwap(false, true) { + return + } + l.conn.release() +} + +type openAIWSConn struct { + id string + ws openAIWSClientConn + + handshakeHeaders http.Header + + leaseCh chan struct{} + closedCh chan struct{} + closeOnce sync.Once + + readMu sync.Mutex + writeMu sync.Mutex + + waiters atomic.Int32 + createdAtNano atomic.Int64 + lastUsedNano atomic.Int64 + prewarmed atomic.Bool +} + +func newOpenAIWSConn(id string, _ int64, ws openAIWSClientConn, handshakeHeaders http.Header) *openAIWSConn { + now := time.Now() + conn := &openAIWSConn{ + id: id, + ws: ws, + handshakeHeaders: cloneHeader(handshakeHeaders), + leaseCh: make(chan struct{}, 1), + closedCh: make(chan struct{}), + } + conn.leaseCh <- struct{}{} + conn.createdAtNano.Store(now.UnixNano()) + conn.lastUsedNano.Store(now.UnixNano()) + return conn +} + +func (c *openAIWSConn) tryAcquire() bool { + if c == nil { + return false + } + select { + case <-c.closedCh: + return false + default: + } + select { + case <-c.leaseCh: + select { + case <-c.closedCh: + c.release() + return false + default: + } + return true + default: + return false + } +} + +func (c *openAIWSConn) acquire(ctx context.Context) error { + if c == nil { + return errOpenAIWSConnClosed + } + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-c.closedCh: + return errOpenAIWSConnClosed + case <-c.leaseCh: + select { + case <-c.closedCh: + c.release() + return errOpenAIWSConnClosed + default: + } + return nil + } + } +} + +func (c *openAIWSConn) release() { + if c == nil { + return + } + select { + case c.leaseCh <- struct{}{}: + default: + } + c.touch() +} + +func (c *openAIWSConn) close() { + if c == nil { + return + } + c.closeOnce.Do(func() { + close(c.closedCh) + if c.ws != nil { + _ = c.ws.Close() + } + select { + case c.leaseCh <- struct{}{}: + default: + } + }) +} + +func (c *openAIWSConn) writeJSONWithTimeout(parent context.Context, value any, timeout time.Duration) error { + if c == nil { + return errOpenAIWSConnClosed + } + select { + case <-c.closedCh: + return errOpenAIWSConnClosed + default: + } + + writeCtx := parent + if writeCtx == nil { + writeCtx = context.Background() + } + if timeout <= 0 { + return c.writeJSON(value, writeCtx) + } + var cancel context.CancelFunc + writeCtx, cancel = context.WithTimeout(writeCtx, timeout) + defer cancel() + return c.writeJSON(value, writeCtx) +} + +func (c *openAIWSConn) writeJSON(value any, writeCtx context.Context) error { + c.writeMu.Lock() + defer c.writeMu.Unlock() + if c.ws == nil { + return errOpenAIWSConnClosed + } + if writeCtx == nil { + writeCtx = context.Background() + } + if err := c.ws.WriteJSON(writeCtx, value); err != nil { + return err + } + c.touch() + return nil +} + +func (c *openAIWSConn) readMessageWithTimeout(timeout time.Duration) ([]byte, error) { + return c.readMessageWithContextTimeout(context.Background(), timeout) +} + +func (c *openAIWSConn) readMessageWithContextTimeout(parent context.Context, timeout time.Duration) ([]byte, error) { + if c == nil { + return nil, errOpenAIWSConnClosed + } + select { + case <-c.closedCh: + return nil, errOpenAIWSConnClosed + default: + } + + if parent == nil { + parent = context.Background() + } + if timeout <= 0 { + return c.readMessage(parent) + } + readCtx, cancel := context.WithTimeout(parent, timeout) + defer cancel() + return c.readMessage(readCtx) +} + +func (c *openAIWSConn) readMessage(readCtx context.Context) ([]byte, error) { + c.readMu.Lock() + defer c.readMu.Unlock() + if c.ws == nil { + return nil, errOpenAIWSConnClosed + } + if readCtx == nil { + readCtx = context.Background() + } + payload, err := c.ws.ReadMessage(readCtx) + if err != nil { + return nil, err + } + c.touch() + return payload, nil +} + +func (c *openAIWSConn) pingWithTimeout(timeout time.Duration) error { + if c == nil { + return errOpenAIWSConnClosed + } + select { + case <-c.closedCh: + return errOpenAIWSConnClosed + default: + } + + c.writeMu.Lock() + defer c.writeMu.Unlock() + if c.ws == nil { + return errOpenAIWSConnClosed + } + if timeout <= 0 { + timeout = openAIWSConnHealthCheckTO + } + pingCtx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + if err := c.ws.Ping(pingCtx); err != nil { + return err + } + return nil +} + +func (c *openAIWSConn) touch() { + if c == nil { + return + } + c.lastUsedNano.Store(time.Now().UnixNano()) +} + +func (c *openAIWSConn) createdAt() time.Time { + if c == nil { + return time.Time{} + } + nano := c.createdAtNano.Load() + if nano <= 0 { + return time.Time{} + } + return time.Unix(0, nano) +} + +func (c *openAIWSConn) lastUsedAt() time.Time { + if c == nil { + return time.Time{} + } + nano := c.lastUsedNano.Load() + if nano <= 0 { + return time.Time{} + } + return time.Unix(0, nano) +} + +func (c *openAIWSConn) idleDuration(now time.Time) time.Duration { + if c == nil { + return 0 + } + last := c.lastUsedAt() + if last.IsZero() { + return 0 + } + return now.Sub(last) +} + +func (c *openAIWSConn) age(now time.Time) time.Duration { + if c == nil { + return 0 + } + created := c.createdAt() + if created.IsZero() { + return 0 + } + return now.Sub(created) +} + +func (c *openAIWSConn) isLeased() bool { + if c == nil { + return false + } + return len(c.leaseCh) == 0 +} + +func (c *openAIWSConn) handshakeHeader(name string) string { + if c == nil || c.handshakeHeaders == nil { + return "" + } + return strings.TrimSpace(c.handshakeHeaders.Get(strings.TrimSpace(name))) +} + +func (c *openAIWSConn) isPrewarmed() bool { + if c == nil { + return false + } + return c.prewarmed.Load() +} + +func (c *openAIWSConn) markPrewarmed() { + if c == nil { + return + } + c.prewarmed.Store(true) +} + +type openAIWSAccountPool struct { + mu sync.Mutex + conns map[string]*openAIWSConn + pinnedConns map[string]int + creating int + lastCleanupAt time.Time + lastAcquire *openAIWSAcquireRequest + prewarmActive bool + prewarmUntil time.Time + prewarmFails int + prewarmFailAt time.Time +} + +type OpenAIWSPoolMetricsSnapshot struct { + AcquireTotal int64 + AcquireReuseTotal int64 + AcquireCreateTotal int64 + AcquireQueueWaitTotal int64 + AcquireQueueWaitMsTotal int64 + ConnPickTotal int64 + ConnPickMsTotal int64 + ScaleUpTotal int64 + ScaleDownTotal int64 +} + +type openAIWSPoolMetrics struct { + acquireTotal atomic.Int64 + acquireReuseTotal atomic.Int64 + acquireCreateTotal atomic.Int64 + acquireQueueWaitTotal atomic.Int64 + acquireQueueWaitMs atomic.Int64 + connPickTotal atomic.Int64 + connPickMs atomic.Int64 + scaleUpTotal atomic.Int64 + scaleDownTotal atomic.Int64 +} + +type openAIWSConnPool struct { + cfg *config.Config + // 通过接口解耦底层 WS 客户端实现,默认使用 coder/websocket。 + clientDialer openAIWSClientDialer + + accounts sync.Map // key: int64(accountID), value: *openAIWSAccountPool + seq atomic.Uint64 + + metrics openAIWSPoolMetrics + + workerStopCh chan struct{} + workerWg sync.WaitGroup + closeOnce sync.Once +} + +func newOpenAIWSConnPool(cfg *config.Config) *openAIWSConnPool { + pool := &openAIWSConnPool{ + cfg: cfg, + clientDialer: newDefaultOpenAIWSClientDialer(), + workerStopCh: make(chan struct{}), + } + pool.startBackgroundWorkers() + return pool +} + +func (p *openAIWSConnPool) SnapshotMetrics() OpenAIWSPoolMetricsSnapshot { + if p == nil { + return OpenAIWSPoolMetricsSnapshot{} + } + return OpenAIWSPoolMetricsSnapshot{ + AcquireTotal: p.metrics.acquireTotal.Load(), + AcquireReuseTotal: p.metrics.acquireReuseTotal.Load(), + AcquireCreateTotal: p.metrics.acquireCreateTotal.Load(), + AcquireQueueWaitTotal: p.metrics.acquireQueueWaitTotal.Load(), + AcquireQueueWaitMsTotal: p.metrics.acquireQueueWaitMs.Load(), + ConnPickTotal: p.metrics.connPickTotal.Load(), + ConnPickMsTotal: p.metrics.connPickMs.Load(), + ScaleUpTotal: p.metrics.scaleUpTotal.Load(), + ScaleDownTotal: p.metrics.scaleDownTotal.Load(), + } +} + +func (p *openAIWSConnPool) SnapshotTransportMetrics() OpenAIWSTransportMetricsSnapshot { + if p == nil { + return OpenAIWSTransportMetricsSnapshot{} + } + if dialer, ok := p.clientDialer.(openAIWSTransportMetricsDialer); ok { + return dialer.SnapshotTransportMetrics() + } + return OpenAIWSTransportMetricsSnapshot{} +} + +func (p *openAIWSConnPool) setClientDialerForTest(dialer openAIWSClientDialer) { + if p == nil || dialer == nil { + return + } + p.clientDialer = dialer +} + +// Close 停止后台 worker 并关闭所有空闲连接,应在优雅关闭时调用。 +func (p *openAIWSConnPool) Close() { + if p == nil { + return + } + p.closeOnce.Do(func() { + if p.workerStopCh != nil { + close(p.workerStopCh) + } + p.workerWg.Wait() + // 遍历所有账户池,关闭全部空闲连接。 + p.accounts.Range(func(key, value any) bool { + ap, ok := value.(*openAIWSAccountPool) + if !ok || ap == nil { + return true + } + ap.mu.Lock() + for _, conn := range ap.conns { + if conn != nil && !conn.isLeased() { + conn.close() + } + } + ap.mu.Unlock() + return true + }) + }) +} + +func (p *openAIWSConnPool) startBackgroundWorkers() { + if p == nil || p.workerStopCh == nil { + return + } + p.workerWg.Add(2) + go func() { + defer p.workerWg.Done() + p.runBackgroundPingWorker() + }() + go func() { + defer p.workerWg.Done() + p.runBackgroundCleanupWorker() + }() +} + +type openAIWSIdlePingCandidate struct { + accountID int64 + conn *openAIWSConn +} + +func (p *openAIWSConnPool) runBackgroundPingWorker() { + if p == nil { + return + } + ticker := time.NewTicker(openAIWSBackgroundPingInterval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + p.runBackgroundPingSweep() + case <-p.workerStopCh: + return + } + } +} + +func (p *openAIWSConnPool) runBackgroundPingSweep() { + if p == nil { + return + } + candidates := p.snapshotIdleConnsForPing() + var g errgroup.Group + g.SetLimit(10) + for _, item := range candidates { + item := item + if item.conn == nil || item.conn.isLeased() || item.conn.waiters.Load() > 0 { + continue + } + g.Go(func() error { + if err := item.conn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { + p.evictConn(item.accountID, item.conn.id) + } + return nil + }) + } + _ = g.Wait() +} + +func (p *openAIWSConnPool) snapshotIdleConnsForPing() []openAIWSIdlePingCandidate { + if p == nil { + return nil + } + candidates := make([]openAIWSIdlePingCandidate, 0) + p.accounts.Range(func(key, value any) bool { + accountID, ok := key.(int64) + if !ok || accountID <= 0 { + return true + } + ap, ok := value.(*openAIWSAccountPool) + if !ok || ap == nil { + return true + } + ap.mu.Lock() + for _, conn := range ap.conns { + if conn == nil || conn.isLeased() || conn.waiters.Load() > 0 { + continue + } + candidates = append(candidates, openAIWSIdlePingCandidate{ + accountID: accountID, + conn: conn, + }) + } + ap.mu.Unlock() + return true + }) + return candidates +} + +func (p *openAIWSConnPool) runBackgroundCleanupWorker() { + if p == nil { + return + } + ticker := time.NewTicker(openAIWSBackgroundSweepTicker) + defer ticker.Stop() + for { + select { + case <-ticker.C: + p.runBackgroundCleanupSweep(time.Now()) + case <-p.workerStopCh: + return + } + } +} + +func (p *openAIWSConnPool) runBackgroundCleanupSweep(now time.Time) { + if p == nil { + return + } + type cleanupResult struct { + evicted []*openAIWSConn + } + results := make([]cleanupResult, 0) + p.accounts.Range(func(_ any, value any) bool { + ap, ok := value.(*openAIWSAccountPool) + if !ok || ap == nil { + return true + } + maxConns := p.maxConnsHardCap() + ap.mu.Lock() + if ap.lastAcquire != nil && ap.lastAcquire.Account != nil { + maxConns = p.effectiveMaxConnsByAccount(ap.lastAcquire.Account) + } + evicted := p.cleanupAccountLocked(ap, now, maxConns) + ap.lastCleanupAt = now + ap.mu.Unlock() + if len(evicted) > 0 { + results = append(results, cleanupResult{evicted: evicted}) + } + return true + }) + for _, result := range results { + closeOpenAIWSConns(result.evicted) + } +} + +func (p *openAIWSConnPool) Acquire(ctx context.Context, req openAIWSAcquireRequest) (*openAIWSConnLease, error) { + if p != nil { + p.metrics.acquireTotal.Add(1) + } + return p.acquire(ctx, cloneOpenAIWSAcquireRequest(req), 0) +} + +func (p *openAIWSConnPool) acquire(ctx context.Context, req openAIWSAcquireRequest, retry int) (*openAIWSConnLease, error) { + if p == nil || req.Account == nil || req.Account.ID <= 0 { + return nil, errors.New("invalid ws acquire request") + } + if stringsTrim(req.WSURL) == "" { + return nil, errors.New("ws url is empty") + } + + accountID := req.Account.ID + effectiveMaxConns := p.effectiveMaxConnsByAccount(req.Account) + if effectiveMaxConns <= 0 { + return nil, errOpenAIWSConnQueueFull + } + var evicted []*openAIWSConn + ap := p.getOrCreateAccountPool(accountID) + ap.mu.Lock() + ap.lastAcquire = cloneOpenAIWSAcquireRequestPtr(&req) + now := time.Now() + if ap.lastCleanupAt.IsZero() || now.Sub(ap.lastCleanupAt) >= openAIWSAcquireCleanupInterval { + evicted = p.cleanupAccountLocked(ap, now, effectiveMaxConns) + ap.lastCleanupAt = now + } + pickStartedAt := time.Now() + allowReuse := !req.ForceNewConn + preferredConnID := stringsTrim(req.PreferredConnID) + forcePreferredConn := allowReuse && req.ForcePreferredConn + + if allowReuse { + if forcePreferredConn { + if preferredConnID == "" { + p.recordConnPickDuration(time.Since(pickStartedAt)) + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + return nil, errOpenAIWSPreferredConnUnavailable + } + preferredConn, ok := ap.conns[preferredConnID] + if !ok || preferredConn == nil { + p.recordConnPickDuration(time.Since(pickStartedAt)) + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + return nil, errOpenAIWSPreferredConnUnavailable + } + if preferredConn.tryAcquire() { + connPick := time.Since(pickStartedAt) + p.recordConnPickDuration(connPick) + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + if p.shouldHealthCheckConn(preferredConn) { + if err := preferredConn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { + preferredConn.close() + p.evictConn(accountID, preferredConn.id) + if retry < 1 { + return p.acquire(ctx, req, retry+1) + } + return nil, err + } + } + lease := &openAIWSConnLease{ + pool: p, + accountID: accountID, + conn: preferredConn, + connPick: connPick, + reused: true, + } + p.metrics.acquireReuseTotal.Add(1) + p.ensureTargetIdleAsync(accountID) + return lease, nil + } + + connPick := time.Since(pickStartedAt) + p.recordConnPickDuration(connPick) + if int(preferredConn.waiters.Load()) >= p.queueLimitPerConn() { + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + return nil, errOpenAIWSConnQueueFull + } + preferredConn.waiters.Add(1) + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + defer preferredConn.waiters.Add(-1) + waitStart := time.Now() + p.metrics.acquireQueueWaitTotal.Add(1) + + if err := preferredConn.acquire(ctx); err != nil { + if errors.Is(err, errOpenAIWSConnClosed) && retry < 1 { + return p.acquire(ctx, req, retry+1) + } + return nil, err + } + if p.shouldHealthCheckConn(preferredConn) { + if err := preferredConn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { + preferredConn.release() + preferredConn.close() + p.evictConn(accountID, preferredConn.id) + if retry < 1 { + return p.acquire(ctx, req, retry+1) + } + return nil, err + } + } + + queueWait := time.Since(waitStart) + p.metrics.acquireQueueWaitMs.Add(queueWait.Milliseconds()) + lease := &openAIWSConnLease{ + pool: p, + accountID: accountID, + conn: preferredConn, + queueWait: queueWait, + connPick: connPick, + reused: true, + } + p.metrics.acquireReuseTotal.Add(1) + p.ensureTargetIdleAsync(accountID) + return lease, nil + } + + if preferredConnID != "" { + if conn, ok := ap.conns[preferredConnID]; ok && conn.tryAcquire() { + connPick := time.Since(pickStartedAt) + p.recordConnPickDuration(connPick) + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + if p.shouldHealthCheckConn(conn) { + if err := conn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { + conn.close() + p.evictConn(accountID, conn.id) + if retry < 1 { + return p.acquire(ctx, req, retry+1) + } + return nil, err + } + } + lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: conn, connPick: connPick, reused: true} + p.metrics.acquireReuseTotal.Add(1) + p.ensureTargetIdleAsync(accountID) + return lease, nil + } + } + + best := p.pickLeastBusyConnLocked(ap, "") + if best != nil && best.tryAcquire() { + connPick := time.Since(pickStartedAt) + p.recordConnPickDuration(connPick) + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + if p.shouldHealthCheckConn(best) { + if err := best.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { + best.close() + p.evictConn(accountID, best.id) + if retry < 1 { + return p.acquire(ctx, req, retry+1) + } + return nil, err + } + } + lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: best, connPick: connPick, reused: true} + p.metrics.acquireReuseTotal.Add(1) + p.ensureTargetIdleAsync(accountID) + return lease, nil + } + for _, conn := range ap.conns { + if conn == nil || conn == best { + continue + } + if conn.tryAcquire() { + connPick := time.Since(pickStartedAt) + p.recordConnPickDuration(connPick) + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + if p.shouldHealthCheckConn(conn) { + if err := conn.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { + conn.close() + p.evictConn(accountID, conn.id) + if retry < 1 { + return p.acquire(ctx, req, retry+1) + } + return nil, err + } + } + lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: conn, connPick: connPick, reused: true} + p.metrics.acquireReuseTotal.Add(1) + p.ensureTargetIdleAsync(accountID) + return lease, nil + } + } + } + + if req.ForceNewConn && len(ap.conns)+ap.creating >= effectiveMaxConns { + if idle := p.pickOldestIdleConnLocked(ap); idle != nil { + delete(ap.conns, idle.id) + evicted = append(evicted, idle) + p.metrics.scaleDownTotal.Add(1) + } + } + + if len(ap.conns)+ap.creating < effectiveMaxConns { + connPick := time.Since(pickStartedAt) + p.recordConnPickDuration(connPick) + ap.creating++ + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + + conn, dialErr := p.dialConn(ctx, req) + + ap = p.getOrCreateAccountPool(accountID) + ap.mu.Lock() + ap.creating-- + if dialErr != nil { + ap.prewarmFails++ + ap.prewarmFailAt = time.Now() + ap.mu.Unlock() + return nil, dialErr + } + ap.conns[conn.id] = conn + ap.prewarmFails = 0 + ap.prewarmFailAt = time.Time{} + ap.mu.Unlock() + p.metrics.acquireCreateTotal.Add(1) + + if !conn.tryAcquire() { + if err := conn.acquire(ctx); err != nil { + conn.close() + p.evictConn(accountID, conn.id) + return nil, err + } + } + lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: conn, connPick: connPick} + p.ensureTargetIdleAsync(accountID) + return lease, nil + } + + if req.ForceNewConn { + p.recordConnPickDuration(time.Since(pickStartedAt)) + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + return nil, errOpenAIWSConnQueueFull + } + + target := p.pickLeastBusyConnLocked(ap, req.PreferredConnID) + connPick := time.Since(pickStartedAt) + p.recordConnPickDuration(connPick) + if target == nil { + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + return nil, errOpenAIWSConnClosed + } + if int(target.waiters.Load()) >= p.queueLimitPerConn() { + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + return nil, errOpenAIWSConnQueueFull + } + target.waiters.Add(1) + ap.mu.Unlock() + closeOpenAIWSConns(evicted) + defer target.waiters.Add(-1) + waitStart := time.Now() + p.metrics.acquireQueueWaitTotal.Add(1) + + if err := target.acquire(ctx); err != nil { + if errors.Is(err, errOpenAIWSConnClosed) && retry < 1 { + return p.acquire(ctx, req, retry+1) + } + return nil, err + } + if p.shouldHealthCheckConn(target) { + if err := target.pingWithTimeout(openAIWSConnHealthCheckTO); err != nil { + target.release() + target.close() + p.evictConn(accountID, target.id) + if retry < 1 { + return p.acquire(ctx, req, retry+1) + } + return nil, err + } + } + + queueWait := time.Since(waitStart) + p.metrics.acquireQueueWaitMs.Add(queueWait.Milliseconds()) + lease := &openAIWSConnLease{pool: p, accountID: accountID, conn: target, queueWait: queueWait, connPick: connPick, reused: true} + p.metrics.acquireReuseTotal.Add(1) + p.ensureTargetIdleAsync(accountID) + return lease, nil +} + +func (p *openAIWSConnPool) recordConnPickDuration(duration time.Duration) { + if p == nil { + return + } + if duration < 0 { + duration = 0 + } + p.metrics.connPickTotal.Add(1) + p.metrics.connPickMs.Add(duration.Milliseconds()) +} + +func (p *openAIWSConnPool) pickOldestIdleConnLocked(ap *openAIWSAccountPool) *openAIWSConn { + if ap == nil || len(ap.conns) == 0 { + return nil + } + var oldest *openAIWSConn + for _, conn := range ap.conns { + if conn == nil || conn.isLeased() || conn.waiters.Load() > 0 || p.isConnPinnedLocked(ap, conn.id) { + continue + } + if oldest == nil || conn.lastUsedAt().Before(oldest.lastUsedAt()) { + oldest = conn + } + } + return oldest +} + +func (p *openAIWSConnPool) getOrCreateAccountPool(accountID int64) *openAIWSAccountPool { + if p == nil || accountID <= 0 { + return nil + } + if existing, ok := p.accounts.Load(accountID); ok { + if ap, typed := existing.(*openAIWSAccountPool); typed && ap != nil { + return ap + } + } + ap := &openAIWSAccountPool{ + conns: make(map[string]*openAIWSConn), + pinnedConns: make(map[string]int), + } + actual, _ := p.accounts.LoadOrStore(accountID, ap) + if typed, ok := actual.(*openAIWSAccountPool); ok && typed != nil { + return typed + } + return ap +} + +// ensureAccountPoolLocked 兼容旧调用。 +func (p *openAIWSConnPool) ensureAccountPoolLocked(accountID int64) *openAIWSAccountPool { + return p.getOrCreateAccountPool(accountID) +} + +func (p *openAIWSConnPool) getAccountPool(accountID int64) (*openAIWSAccountPool, bool) { + if p == nil || accountID <= 0 { + return nil, false + } + value, ok := p.accounts.Load(accountID) + if !ok || value == nil { + return nil, false + } + ap, typed := value.(*openAIWSAccountPool) + return ap, typed && ap != nil +} + +func (p *openAIWSConnPool) isConnPinnedLocked(ap *openAIWSAccountPool, connID string) bool { + if ap == nil || connID == "" || len(ap.pinnedConns) == 0 { + return false + } + return ap.pinnedConns[connID] > 0 +} + +func (p *openAIWSConnPool) cleanupAccountLocked(ap *openAIWSAccountPool, now time.Time, maxConns int) []*openAIWSConn { + if ap == nil { + return nil + } + maxAge := p.maxConnAge() + + evicted := make([]*openAIWSConn, 0) + for id, conn := range ap.conns { + if conn == nil { + delete(ap.conns, id) + if len(ap.pinnedConns) > 0 { + delete(ap.pinnedConns, id) + } + continue + } + select { + case <-conn.closedCh: + delete(ap.conns, id) + if len(ap.pinnedConns) > 0 { + delete(ap.pinnedConns, id) + } + evicted = append(evicted, conn) + continue + default: + } + if p.isConnPinnedLocked(ap, id) { + continue + } + if maxAge > 0 && !conn.isLeased() && conn.age(now) > maxAge { + delete(ap.conns, id) + if len(ap.pinnedConns) > 0 { + delete(ap.pinnedConns, id) + } + evicted = append(evicted, conn) + } + } + + if maxConns <= 0 { + maxConns = p.maxConnsHardCap() + } + maxIdle := p.maxIdlePerAccount() + if maxIdle < 0 || maxIdle > maxConns { + maxIdle = maxConns + } + if maxIdle >= 0 && len(ap.conns) > maxIdle { + idleConns := make([]*openAIWSConn, 0, len(ap.conns)) + for id, conn := range ap.conns { + if conn == nil { + delete(ap.conns, id) + if len(ap.pinnedConns) > 0 { + delete(ap.pinnedConns, id) + } + continue + } + // 有等待者的连接不能在清理阶段被淘汰,否则等待中的 acquire 会收到 closed 错误。 + if conn.isLeased() || conn.waiters.Load() > 0 || p.isConnPinnedLocked(ap, conn.id) { + continue + } + idleConns = append(idleConns, conn) + } + sort.SliceStable(idleConns, func(i, j int) bool { + return idleConns[i].lastUsedAt().Before(idleConns[j].lastUsedAt()) + }) + redundant := len(ap.conns) - maxIdle + if redundant > len(idleConns) { + redundant = len(idleConns) + } + for i := 0; i < redundant; i++ { + conn := idleConns[i] + delete(ap.conns, conn.id) + if len(ap.pinnedConns) > 0 { + delete(ap.pinnedConns, conn.id) + } + evicted = append(evicted, conn) + } + if redundant > 0 { + p.metrics.scaleDownTotal.Add(int64(redundant)) + } + } + + return evicted +} + +func (p *openAIWSConnPool) pickLeastBusyConnLocked(ap *openAIWSAccountPool, preferredConnID string) *openAIWSConn { + if ap == nil || len(ap.conns) == 0 { + return nil + } + preferredConnID = stringsTrim(preferredConnID) + if preferredConnID != "" { + if conn, ok := ap.conns[preferredConnID]; ok { + return conn + } + } + var best *openAIWSConn + var bestWaiters int32 + var bestLastUsed time.Time + for _, conn := range ap.conns { + if conn == nil { + continue + } + waiters := conn.waiters.Load() + lastUsed := conn.lastUsedAt() + if best == nil || + waiters < bestWaiters || + (waiters == bestWaiters && lastUsed.Before(bestLastUsed)) { + best = conn + bestWaiters = waiters + bestLastUsed = lastUsed + } + } + return best +} + +func accountPoolLoadLocked(ap *openAIWSAccountPool) (inflight int, waiters int) { + if ap == nil { + return 0, 0 + } + for _, conn := range ap.conns { + if conn == nil { + continue + } + if conn.isLeased() { + inflight++ + } + waiters += int(conn.waiters.Load()) + } + return inflight, waiters +} + +// AccountPoolLoad 返回指定账号连接池的并发与排队快照。 +func (p *openAIWSConnPool) AccountPoolLoad(accountID int64) (inflight int, waiters int, conns int) { + if p == nil || accountID <= 0 { + return 0, 0, 0 + } + ap, ok := p.getAccountPool(accountID) + if !ok || ap == nil { + return 0, 0, 0 + } + ap.mu.Lock() + defer ap.mu.Unlock() + inflight, waiters = accountPoolLoadLocked(ap) + return inflight, waiters, len(ap.conns) +} + +func (p *openAIWSConnPool) ensureTargetIdleAsync(accountID int64) { + if p == nil || accountID <= 0 { + return + } + + var req openAIWSAcquireRequest + need := 0 + ap, ok := p.getAccountPool(accountID) + if !ok || ap == nil { + return + } + ap.mu.Lock() + defer ap.mu.Unlock() + if ap.lastAcquire == nil { + return + } + if ap.prewarmActive { + return + } + now := time.Now() + if !ap.prewarmUntil.IsZero() && now.Before(ap.prewarmUntil) { + return + } + if p.shouldSuppressPrewarmLocked(ap, now) { + return + } + effectiveMaxConns := p.maxConnsHardCap() + if ap.lastAcquire != nil && ap.lastAcquire.Account != nil { + effectiveMaxConns = p.effectiveMaxConnsByAccount(ap.lastAcquire.Account) + } + target := p.targetConnCountLocked(ap, effectiveMaxConns) + current := len(ap.conns) + ap.creating + if current >= target { + return + } + need = target - current + if need <= 0 { + return + } + req = cloneOpenAIWSAcquireRequest(*ap.lastAcquire) + ap.prewarmActive = true + if cooldown := p.prewarmCooldown(); cooldown > 0 { + ap.prewarmUntil = now.Add(cooldown) + } + ap.creating += need + p.metrics.scaleUpTotal.Add(int64(need)) + + go p.prewarmConns(accountID, req, need) +} + +func (p *openAIWSConnPool) targetConnCountLocked(ap *openAIWSAccountPool, maxConns int) int { + if ap == nil { + return 0 + } + + if maxConns <= 0 { + return 0 + } + + minIdle := p.minIdlePerAccount() + if minIdle < 0 { + minIdle = 0 + } + if minIdle > maxConns { + minIdle = maxConns + } + + inflight, waiters := accountPoolLoadLocked(ap) + utilization := p.targetUtilization() + demand := inflight + waiters + if demand <= 0 { + return minIdle + } + + target := 1 + if demand > 1 { + target = int(math.Ceil(float64(demand) / utilization)) + } + if waiters > 0 && target < len(ap.conns)+1 { + target = len(ap.conns) + 1 + } + if target < minIdle { + target = minIdle + } + if target > maxConns { + target = maxConns + } + return target +} + +func (p *openAIWSConnPool) prewarmConns(accountID int64, req openAIWSAcquireRequest, total int) { + defer func() { + if ap, ok := p.getAccountPool(accountID); ok && ap != nil { + ap.mu.Lock() + ap.prewarmActive = false + ap.mu.Unlock() + } + }() + + for i := 0; i < total; i++ { + ctx, cancel := context.WithTimeout(context.Background(), p.dialTimeout()+openAIWSConnPrewarmExtraDelay) + conn, err := p.dialConn(ctx, req) + cancel() + + ap, ok := p.getAccountPool(accountID) + if !ok || ap == nil { + if conn != nil { + conn.close() + } + return + } + ap.mu.Lock() + if ap.creating > 0 { + ap.creating-- + } + if err != nil { + ap.prewarmFails++ + ap.prewarmFailAt = time.Now() + ap.mu.Unlock() + continue + } + if len(ap.conns) >= p.effectiveMaxConnsByAccount(req.Account) { + ap.mu.Unlock() + conn.close() + continue + } + ap.conns[conn.id] = conn + ap.prewarmFails = 0 + ap.prewarmFailAt = time.Time{} + ap.mu.Unlock() + } +} + +func (p *openAIWSConnPool) evictConn(accountID int64, connID string) { + if p == nil || accountID <= 0 || stringsTrim(connID) == "" { + return + } + var conn *openAIWSConn + ap, ok := p.getAccountPool(accountID) + if ok && ap != nil { + ap.mu.Lock() + if c, exists := ap.conns[connID]; exists { + conn = c + delete(ap.conns, connID) + if len(ap.pinnedConns) > 0 { + delete(ap.pinnedConns, connID) + } + } + ap.mu.Unlock() + } + if conn != nil { + conn.close() + } +} + +func (p *openAIWSConnPool) PinConn(accountID int64, connID string) bool { + if p == nil || accountID <= 0 { + return false + } + connID = stringsTrim(connID) + if connID == "" { + return false + } + ap, ok := p.getAccountPool(accountID) + if !ok || ap == nil { + return false + } + ap.mu.Lock() + defer ap.mu.Unlock() + if _, exists := ap.conns[connID]; !exists { + return false + } + if ap.pinnedConns == nil { + ap.pinnedConns = make(map[string]int) + } + ap.pinnedConns[connID]++ + return true +} + +func (p *openAIWSConnPool) UnpinConn(accountID int64, connID string) { + if p == nil || accountID <= 0 { + return + } + connID = stringsTrim(connID) + if connID == "" { + return + } + ap, ok := p.getAccountPool(accountID) + if !ok || ap == nil { + return + } + ap.mu.Lock() + defer ap.mu.Unlock() + if len(ap.pinnedConns) == 0 { + return + } + count := ap.pinnedConns[connID] + if count <= 1 { + delete(ap.pinnedConns, connID) + return + } + ap.pinnedConns[connID] = count - 1 +} + +func (p *openAIWSConnPool) dialConn(ctx context.Context, req openAIWSAcquireRequest) (*openAIWSConn, error) { + if p == nil || p.clientDialer == nil { + return nil, errors.New("openai ws client dialer is nil") + } + conn, status, handshakeHeaders, err := p.clientDialer.Dial(ctx, req.WSURL, req.Headers, req.ProxyURL) + if err != nil { + return nil, &openAIWSDialError{ + StatusCode: status, + ResponseHeaders: cloneHeader(handshakeHeaders), + Err: err, + } + } + if conn == nil { + return nil, &openAIWSDialError{ + StatusCode: status, + ResponseHeaders: cloneHeader(handshakeHeaders), + Err: errors.New("openai ws dialer returned nil connection"), + } + } + id := p.nextConnID(req.Account.ID) + return newOpenAIWSConn(id, req.Account.ID, conn, handshakeHeaders), nil +} + +func (p *openAIWSConnPool) nextConnID(accountID int64) string { + seq := p.seq.Add(1) + buf := make([]byte, 0, 32) + buf = append(buf, "oa_ws_"...) + buf = strconv.AppendInt(buf, accountID, 10) + buf = append(buf, '_') + buf = strconv.AppendUint(buf, seq, 10) + return string(buf) +} + +func (p *openAIWSConnPool) shouldHealthCheckConn(conn *openAIWSConn) bool { + if conn == nil { + return false + } + return conn.idleDuration(time.Now()) >= openAIWSConnHealthCheckIdle +} + +func (p *openAIWSConnPool) maxConnsHardCap() int { + if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MaxConnsPerAccount > 0 { + return p.cfg.Gateway.OpenAIWS.MaxConnsPerAccount + } + return 8 +} + +func (p *openAIWSConnPool) dynamicMaxConnsEnabled() bool { + if p != nil && p.cfg != nil { + return p.cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled + } + return false +} + +func (p *openAIWSConnPool) modeRouterV2Enabled() bool { + if p != nil && p.cfg != nil { + return p.cfg.Gateway.OpenAIWS.ModeRouterV2Enabled + } + return false +} + +func (p *openAIWSConnPool) maxConnsFactorByAccount(account *Account) float64 { + if p == nil || p.cfg == nil || account == nil { + return 1.0 + } + switch account.Type { + case AccountTypeOAuth: + if p.cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor > 0 { + return p.cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor + } + case AccountTypeAPIKey: + if p.cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor > 0 { + return p.cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor + } + } + return 1.0 +} + +func (p *openAIWSConnPool) effectiveMaxConnsByAccount(account *Account) int { + hardCap := p.maxConnsHardCap() + if hardCap <= 0 { + return 0 + } + if p.modeRouterV2Enabled() { + if account == nil { + return hardCap + } + if account.Concurrency <= 0 { + return 0 + } + return account.Concurrency + } + if account == nil || !p.dynamicMaxConnsEnabled() { + return hardCap + } + if account.Concurrency <= 0 { + // 0/-1 等“无限制”并发场景下,仍由全局硬上限兜底。 + return hardCap + } + factor := p.maxConnsFactorByAccount(account) + if factor <= 0 { + factor = 1.0 + } + effective := int(math.Ceil(float64(account.Concurrency) * factor)) + if effective < 1 { + effective = 1 + } + if effective > hardCap { + effective = hardCap + } + return effective +} + +func (p *openAIWSConnPool) minIdlePerAccount() int { + if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MinIdlePerAccount >= 0 { + return p.cfg.Gateway.OpenAIWS.MinIdlePerAccount + } + return 0 +} + +func (p *openAIWSConnPool) maxIdlePerAccount() int { + if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.MaxIdlePerAccount >= 0 { + return p.cfg.Gateway.OpenAIWS.MaxIdlePerAccount + } + return 4 +} + +func (p *openAIWSConnPool) maxConnAge() time.Duration { + return openAIWSConnMaxAge +} + +func (p *openAIWSConnPool) queueLimitPerConn() int { + if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.QueueLimitPerConn > 0 { + return p.cfg.Gateway.OpenAIWS.QueueLimitPerConn + } + return 256 +} + +func (p *openAIWSConnPool) targetUtilization() float64 { + if p != nil && p.cfg != nil { + ratio := p.cfg.Gateway.OpenAIWS.PoolTargetUtilization + if ratio > 0 && ratio <= 1 { + return ratio + } + } + return 0.7 +} + +func (p *openAIWSConnPool) prewarmCooldown() time.Duration { + if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.PrewarmCooldownMS > 0 { + return time.Duration(p.cfg.Gateway.OpenAIWS.PrewarmCooldownMS) * time.Millisecond + } + return 0 +} + +func (p *openAIWSConnPool) shouldSuppressPrewarmLocked(ap *openAIWSAccountPool, now time.Time) bool { + if ap == nil { + return true + } + if ap.prewarmFails <= 0 { + return false + } + if ap.prewarmFailAt.IsZero() { + ap.prewarmFails = 0 + return false + } + if now.Sub(ap.prewarmFailAt) > openAIWSPrewarmFailureWindow { + ap.prewarmFails = 0 + ap.prewarmFailAt = time.Time{} + return false + } + return ap.prewarmFails >= openAIWSPrewarmFailureSuppress +} + +func (p *openAIWSConnPool) dialTimeout() time.Duration { + if p != nil && p.cfg != nil && p.cfg.Gateway.OpenAIWS.DialTimeoutSeconds > 0 { + return time.Duration(p.cfg.Gateway.OpenAIWS.DialTimeoutSeconds) * time.Second + } + return 10 * time.Second +} + +func cloneOpenAIWSAcquireRequest(req openAIWSAcquireRequest) openAIWSAcquireRequest { + copied := req + copied.Headers = cloneHeader(req.Headers) + copied.WSURL = stringsTrim(req.WSURL) + copied.ProxyURL = stringsTrim(req.ProxyURL) + copied.PreferredConnID = stringsTrim(req.PreferredConnID) + return copied +} + +func cloneOpenAIWSAcquireRequestPtr(req *openAIWSAcquireRequest) *openAIWSAcquireRequest { + if req == nil { + return nil + } + copied := cloneOpenAIWSAcquireRequest(*req) + return &copied +} + +func cloneHeader(src http.Header) http.Header { + if src == nil { + return nil + } + dst := make(http.Header, len(src)) + for k, vals := range src { + if len(vals) == 0 { + dst[k] = nil + continue + } + copied := make([]string, len(vals)) + copy(copied, vals) + dst[k] = copied + } + return dst +} + +func closeOpenAIWSConns(conns []*openAIWSConn) { + if len(conns) == 0 { + return + } + for _, conn := range conns { + if conn == nil { + continue + } + conn.close() + } +} + +func stringsTrim(value string) string { + return strings.TrimSpace(value) +} diff --git a/backend/internal/service/openai_ws_pool_benchmark_test.go b/backend/internal/service/openai_ws_pool_benchmark_test.go new file mode 100644 index 00000000..bff74b62 --- /dev/null +++ b/backend/internal/service/openai_ws_pool_benchmark_test.go @@ -0,0 +1,58 @@ +package service + +import ( + "context" + "errors" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +func BenchmarkOpenAIWSPoolAcquire(b *testing.B) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 4 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 256 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1 + + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(&openAIWSCountingDialer{}) + + account := &Account{ID: 1001, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + req := openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + } + ctx := context.Background() + + lease, err := pool.Acquire(ctx, req) + if err != nil { + b.Fatalf("warm acquire failed: %v", err) + } + lease.Release() + + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + var ( + got *openAIWSConnLease + acquireErr error + ) + for retry := 0; retry < 3; retry++ { + got, acquireErr = pool.Acquire(ctx, req) + if acquireErr == nil { + break + } + if !errors.Is(acquireErr, errOpenAIWSConnClosed) { + break + } + } + if acquireErr != nil { + b.Fatalf("acquire failed: %v", acquireErr) + } + got.Release() + } + }) +} diff --git a/backend/internal/service/openai_ws_pool_test.go b/backend/internal/service/openai_ws_pool_test.go new file mode 100644 index 00000000..b2683ee0 --- /dev/null +++ b/backend/internal/service/openai_ws_pool_test.go @@ -0,0 +1,1709 @@ +package service + +import ( + "context" + "errors" + "net/http" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestOpenAIWSConnPool_CleanupStaleAndTrimIdle(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 1 + pool := newOpenAIWSConnPool(cfg) + + accountID := int64(10) + ap := pool.getOrCreateAccountPool(accountID) + + stale := newOpenAIWSConn("stale", accountID, nil, nil) + stale.createdAtNano.Store(time.Now().Add(-2 * time.Hour).UnixNano()) + stale.lastUsedNano.Store(time.Now().Add(-2 * time.Hour).UnixNano()) + + idleOld := newOpenAIWSConn("idle_old", accountID, nil, nil) + idleOld.lastUsedNano.Store(time.Now().Add(-10 * time.Minute).UnixNano()) + + idleNew := newOpenAIWSConn("idle_new", accountID, nil, nil) + idleNew.lastUsedNano.Store(time.Now().Add(-1 * time.Minute).UnixNano()) + + ap.conns[stale.id] = stale + ap.conns[idleOld.id] = idleOld + ap.conns[idleNew.id] = idleNew + + evicted := pool.cleanupAccountLocked(ap, time.Now(), pool.maxConnsHardCap()) + closeOpenAIWSConns(evicted) + + require.Nil(t, ap.conns["stale"], "stale connection should be rotated") + require.Nil(t, ap.conns["idle_old"], "old idle should be trimmed by max_idle") + require.NotNil(t, ap.conns["idle_new"], "newer idle should be kept") +} + +func TestOpenAIWSConnPool_NextConnIDFormat(t *testing.T) { + pool := newOpenAIWSConnPool(&config.Config{}) + id1 := pool.nextConnID(42) + id2 := pool.nextConnID(42) + + require.True(t, strings.HasPrefix(id1, "oa_ws_42_")) + require.True(t, strings.HasPrefix(id2, "oa_ws_42_")) + require.NotEqual(t, id1, id2) + require.Equal(t, "oa_ws_42_1", id1) + require.Equal(t, "oa_ws_42_2", id2) +} + +func TestOpenAIWSConnPool_AcquireCleanupInterval(t *testing.T) { + require.Equal(t, 3*time.Second, openAIWSAcquireCleanupInterval) + require.Less(t, openAIWSAcquireCleanupInterval, openAIWSBackgroundSweepTicker) +} + +func TestOpenAIWSConnLease_WriteJSONAndGuards(t *testing.T) { + conn := newOpenAIWSConn("lease_write", 1, &openAIWSFakeConn{}, nil) + lease := &openAIWSConnLease{conn: conn} + require.NoError(t, lease.WriteJSON(map[string]any{"type": "response.create"}, 0)) + + var nilLease *openAIWSConnLease + err := nilLease.WriteJSONWithContextTimeout(context.Background(), map[string]any{"type": "response.create"}, time.Second) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + + err = (&openAIWSConnLease{}).WriteJSONWithContextTimeout(context.Background(), map[string]any{"type": "response.create"}, time.Second) + require.ErrorIs(t, err, errOpenAIWSConnClosed) +} + +func TestOpenAIWSConn_WriteJSONWithTimeout_NilParentContextUsesBackground(t *testing.T) { + probe := &openAIWSContextProbeConn{} + conn := newOpenAIWSConn("ctx_probe", 1, probe, nil) + require.NoError(t, conn.writeJSONWithTimeout(context.Background(), map[string]any{"type": "response.create"}, 0)) + require.NotNil(t, probe.lastWriteCtx) +} + +func TestOpenAIWSConnPool_TargetConnCountAdaptive(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 6 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.5 + + pool := newOpenAIWSConnPool(cfg) + ap := pool.getOrCreateAccountPool(88) + + conn1 := newOpenAIWSConn("c1", 88, nil, nil) + conn2 := newOpenAIWSConn("c2", 88, nil, nil) + require.True(t, conn1.tryAcquire()) + require.True(t, conn2.tryAcquire()) + conn1.waiters.Store(1) + conn2.waiters.Store(1) + + ap.conns[conn1.id] = conn1 + ap.conns[conn2.id] = conn2 + + target := pool.targetConnCountLocked(ap, pool.maxConnsHardCap()) + require.Equal(t, 6, target, "应按 inflight+waiters 与 target_utilization 自适应扩容到上限") + + conn1.release() + conn2.release() + conn1.waiters.Store(0) + conn2.waiters.Store(0) + target = pool.targetConnCountLocked(ap, pool.maxConnsHardCap()) + require.Equal(t, 1, target, "低负载时应缩回到最小空闲连接") +} + +func TestOpenAIWSConnPool_TargetConnCountMinIdleZero(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8 + + pool := newOpenAIWSConnPool(cfg) + ap := pool.getOrCreateAccountPool(66) + + target := pool.targetConnCountLocked(ap, pool.maxConnsHardCap()) + require.Equal(t, 0, target, "min_idle=0 且无负载时应允许缩容到 0") +} + +func TestOpenAIWSConnPool_EnsureTargetIdleAsync(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 2 + cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1 + + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(&openAIWSFakeDialer{}) + + accountID := int64(77) + account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + ap := pool.getOrCreateAccountPool(accountID) + ap.mu.Lock() + ap.lastAcquire = &openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + } + ap.mu.Unlock() + + pool.ensureTargetIdleAsync(accountID) + + require.Eventually(t, func() bool { + ap, ok := pool.getAccountPool(accountID) + if !ok || ap == nil { + return false + } + ap.mu.Lock() + defer ap.mu.Unlock() + return len(ap.conns) >= 2 + }, 2*time.Second, 20*time.Millisecond) + + metrics := pool.SnapshotMetrics() + require.GreaterOrEqual(t, metrics.ScaleUpTotal, int64(2)) +} + +func TestOpenAIWSConnPool_EnsureTargetIdleAsyncCooldown(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 2 + cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1 + cfg.Gateway.OpenAIWS.PrewarmCooldownMS = 500 + + pool := newOpenAIWSConnPool(cfg) + dialer := &openAIWSCountingDialer{} + pool.setClientDialerForTest(dialer) + + accountID := int64(178) + account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + ap := pool.getOrCreateAccountPool(accountID) + ap.mu.Lock() + ap.lastAcquire = &openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + } + ap.mu.Unlock() + + pool.ensureTargetIdleAsync(accountID) + require.Eventually(t, func() bool { + ap, ok := pool.getAccountPool(accountID) + if !ok || ap == nil { + return false + } + ap.mu.Lock() + defer ap.mu.Unlock() + return len(ap.conns) >= 2 && !ap.prewarmActive + }, 2*time.Second, 20*time.Millisecond) + firstDialCount := dialer.DialCount() + require.GreaterOrEqual(t, firstDialCount, 2) + + // 人工制造缺口触发新一轮预热需求。 + ap, ok := pool.getAccountPool(accountID) + require.True(t, ok) + require.NotNil(t, ap) + ap.mu.Lock() + for id := range ap.conns { + delete(ap.conns, id) + break + } + ap.mu.Unlock() + + pool.ensureTargetIdleAsync(accountID) + time.Sleep(120 * time.Millisecond) + require.Equal(t, firstDialCount, dialer.DialCount(), "cooldown 窗口内不应再次触发预热") + + time.Sleep(450 * time.Millisecond) + pool.ensureTargetIdleAsync(accountID) + require.Eventually(t, func() bool { + return dialer.DialCount() > firstDialCount + }, 2*time.Second, 20*time.Millisecond) +} + +func TestOpenAIWSConnPool_EnsureTargetIdleAsyncFailureSuppress(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 1 + cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.8 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1 + cfg.Gateway.OpenAIWS.PrewarmCooldownMS = 0 + + pool := newOpenAIWSConnPool(cfg) + dialer := &openAIWSAlwaysFailDialer{} + pool.setClientDialerForTest(dialer) + + accountID := int64(279) + account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + ap := pool.getOrCreateAccountPool(accountID) + ap.mu.Lock() + ap.lastAcquire = &openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + } + ap.mu.Unlock() + + pool.ensureTargetIdleAsync(accountID) + require.Eventually(t, func() bool { + ap, ok := pool.getAccountPool(accountID) + if !ok || ap == nil { + return false + } + ap.mu.Lock() + defer ap.mu.Unlock() + return !ap.prewarmActive + }, 2*time.Second, 20*time.Millisecond) + + pool.ensureTargetIdleAsync(accountID) + require.Eventually(t, func() bool { + ap, ok := pool.getAccountPool(accountID) + if !ok || ap == nil { + return false + } + ap.mu.Lock() + defer ap.mu.Unlock() + return !ap.prewarmActive + }, 2*time.Second, 20*time.Millisecond) + require.Equal(t, 2, dialer.DialCount()) + + // 连续失败达到阈值后,新的预热触发应被抑制,不再继续拨号。 + pool.ensureTargetIdleAsync(accountID) + time.Sleep(120 * time.Millisecond) + require.Equal(t, 2, dialer.DialCount()) +} + +func TestOpenAIWSConnPool_AcquireQueueWaitMetrics(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 4 + + pool := newOpenAIWSConnPool(cfg) + accountID := int64(99) + account := &Account{ID: accountID, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + conn := newOpenAIWSConn("busy", accountID, &openAIWSFakeConn{}, nil) + require.True(t, conn.tryAcquire()) // 占用连接,触发后续排队 + + ap := pool.ensureAccountPoolLocked(accountID) + ap.mu.Lock() + ap.conns[conn.id] = conn + ap.lastAcquire = &openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + } + ap.mu.Unlock() + + go func() { + time.Sleep(60 * time.Millisecond) + conn.release() + }() + + lease, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + }) + require.NoError(t, err) + require.NotNil(t, lease) + require.True(t, lease.Reused()) + require.GreaterOrEqual(t, lease.QueueWaitDuration(), 50*time.Millisecond) + lease.Release() + + metrics := pool.SnapshotMetrics() + require.GreaterOrEqual(t, metrics.AcquireQueueWaitTotal, int64(1)) + require.Greater(t, metrics.AcquireQueueWaitMsTotal, int64(0)) + require.GreaterOrEqual(t, metrics.ConnPickTotal, int64(1)) +} + +func TestOpenAIWSConnPool_ForceNewConnSkipsReuse(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2 + + pool := newOpenAIWSConnPool(cfg) + dialer := &openAIWSCountingDialer{} + pool.setClientDialerForTest(dialer) + + account := &Account{ID: 123, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + lease1, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + }) + require.NoError(t, err) + require.NotNil(t, lease1) + lease1.Release() + + lease2, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + ForceNewConn: true, + }) + require.NoError(t, err) + require.NotNil(t, lease2) + lease2.Release() + + require.Equal(t, 2, dialer.DialCount(), "ForceNewConn=true 时应跳过空闲连接复用并新建连接") +} + +func TestOpenAIWSConnPool_AcquireForcePreferredConnUnavailable(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2 + + pool := newOpenAIWSConnPool(cfg) + account := &Account{ID: 124, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + ap := pool.getOrCreateAccountPool(account.ID) + otherConn := newOpenAIWSConn("other_conn", account.ID, &openAIWSFakeConn{}, nil) + ap.mu.Lock() + ap.conns[otherConn.id] = otherConn + ap.mu.Unlock() + + _, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + ForcePreferredConn: true, + }) + require.ErrorIs(t, err, errOpenAIWSPreferredConnUnavailable) + + _, err = pool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + PreferredConnID: "missing_conn", + ForcePreferredConn: true, + }) + require.ErrorIs(t, err, errOpenAIWSPreferredConnUnavailable) +} + +func TestOpenAIWSConnPool_AcquireForcePreferredConnQueuesOnPreferredOnly(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 4 + + pool := newOpenAIWSConnPool(cfg) + account := &Account{ID: 125, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + ap := pool.getOrCreateAccountPool(account.ID) + preferredConn := newOpenAIWSConn("preferred_conn", account.ID, &openAIWSFakeConn{}, nil) + otherConn := newOpenAIWSConn("other_conn_idle", account.ID, &openAIWSFakeConn{}, nil) + require.True(t, preferredConn.tryAcquire(), "先占用 preferred 连接,触发排队获取") + ap.mu.Lock() + ap.conns[preferredConn.id] = preferredConn + ap.conns[otherConn.id] = otherConn + ap.lastCleanupAt = time.Now() + ap.mu.Unlock() + + go func() { + time.Sleep(60 * time.Millisecond) + preferredConn.release() + }() + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + lease, err := pool.Acquire(ctx, openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + PreferredConnID: preferredConn.id, + ForcePreferredConn: true, + }) + require.NoError(t, err) + require.NotNil(t, lease) + require.Equal(t, preferredConn.id, lease.ConnID(), "严格模式应只等待并复用 preferred 连接,不可漂移") + require.GreaterOrEqual(t, lease.QueueWaitDuration(), 40*time.Millisecond) + lease.Release() + require.True(t, otherConn.tryAcquire(), "other 连接不应被严格模式抢占") + otherConn.release() +} + +func TestOpenAIWSConnPool_AcquireForcePreferredConnDirectAndQueueFull(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 1 + + pool := newOpenAIWSConnPool(cfg) + account := &Account{ID: 127, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + ap := pool.getOrCreateAccountPool(account.ID) + preferredConn := newOpenAIWSConn("preferred_conn_direct", account.ID, &openAIWSFakeConn{}, nil) + otherConn := newOpenAIWSConn("other_conn_direct", account.ID, &openAIWSFakeConn{}, nil) + ap.mu.Lock() + ap.conns[preferredConn.id] = preferredConn + ap.conns[otherConn.id] = otherConn + ap.lastCleanupAt = time.Now() + ap.mu.Unlock() + + lease, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + PreferredConnID: preferredConn.id, + ForcePreferredConn: true, + }) + require.NoError(t, err) + require.Equal(t, preferredConn.id, lease.ConnID(), "preferred 空闲时应直接命中") + lease.Release() + + require.True(t, preferredConn.tryAcquire()) + preferredConn.waiters.Store(1) + _, err = pool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + PreferredConnID: preferredConn.id, + ForcePreferredConn: true, + }) + require.ErrorIs(t, err, errOpenAIWSConnQueueFull, "严格模式下队列满应直接失败,不得漂移") + preferredConn.waiters.Store(0) + preferredConn.release() +} + +func TestOpenAIWSConnPool_CleanupSkipsPinnedConn(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 0 + + pool := newOpenAIWSConnPool(cfg) + accountID := int64(126) + ap := pool.getOrCreateAccountPool(accountID) + pinnedConn := newOpenAIWSConn("pinned_conn", accountID, &openAIWSFakeConn{}, nil) + idleConn := newOpenAIWSConn("idle_conn", accountID, &openAIWSFakeConn{}, nil) + ap.mu.Lock() + ap.conns[pinnedConn.id] = pinnedConn + ap.conns[idleConn.id] = idleConn + ap.mu.Unlock() + + require.True(t, pool.PinConn(accountID, pinnedConn.id)) + evicted := pool.cleanupAccountLocked(ap, time.Now(), pool.maxConnsHardCap()) + closeOpenAIWSConns(evicted) + + ap.mu.Lock() + _, pinnedExists := ap.conns[pinnedConn.id] + _, idleExists := ap.conns[idleConn.id] + ap.mu.Unlock() + require.True(t, pinnedExists, "被 active ingress 绑定的连接不应被 cleanup 回收") + require.False(t, idleExists, "非绑定的空闲连接应被回收") + + pool.UnpinConn(accountID, pinnedConn.id) + evicted = pool.cleanupAccountLocked(ap, time.Now(), pool.maxConnsHardCap()) + closeOpenAIWSConns(evicted) + ap.mu.Lock() + _, pinnedExists = ap.conns[pinnedConn.id] + ap.mu.Unlock() + require.False(t, pinnedExists, "解绑后连接应可被正常回收") +} + +func TestOpenAIWSConnPool_PinUnpinConnBranches(t *testing.T) { + var nilPool *openAIWSConnPool + require.False(t, nilPool.PinConn(1, "x")) + nilPool.UnpinConn(1, "x") + + cfg := &config.Config{} + pool := newOpenAIWSConnPool(cfg) + accountID := int64(128) + ap := &openAIWSAccountPool{ + conns: map[string]*openAIWSConn{}, + } + pool.accounts.Store(accountID, ap) + + require.False(t, pool.PinConn(0, "x")) + require.False(t, pool.PinConn(999, "x")) + require.False(t, pool.PinConn(accountID, "")) + require.False(t, pool.PinConn(accountID, "missing")) + + conn := newOpenAIWSConn("pin_refcount", accountID, &openAIWSFakeConn{}, nil) + ap.mu.Lock() + ap.conns[conn.id] = conn + ap.mu.Unlock() + require.True(t, pool.PinConn(accountID, conn.id)) + require.True(t, pool.PinConn(accountID, conn.id)) + + ap.mu.Lock() + require.Equal(t, 2, ap.pinnedConns[conn.id]) + ap.mu.Unlock() + + pool.UnpinConn(accountID, conn.id) + ap.mu.Lock() + require.Equal(t, 1, ap.pinnedConns[conn.id]) + ap.mu.Unlock() + + pool.UnpinConn(accountID, conn.id) + ap.mu.Lock() + _, exists := ap.pinnedConns[conn.id] + ap.mu.Unlock() + require.False(t, exists) + + pool.UnpinConn(accountID, conn.id) + pool.UnpinConn(accountID, "") + pool.UnpinConn(0, conn.id) + pool.UnpinConn(999, conn.id) +} + +func TestOpenAIWSConnPool_EffectiveMaxConnsByAccount(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8 + cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = true + cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor = 1.0 + cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 0.6 + + pool := newOpenAIWSConnPool(cfg) + + oauthHigh := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 10} + require.Equal(t, 8, pool.effectiveMaxConnsByAccount(oauthHigh), "应受全局硬上限约束") + + oauthLow := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 3} + require.Equal(t, 3, pool.effectiveMaxConnsByAccount(oauthLow)) + + apiKeyHigh := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 10} + require.Equal(t, 6, pool.effectiveMaxConnsByAccount(apiKeyHigh), "API Key 应按系数缩放") + + apiKeyLow := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 1} + require.Equal(t, 1, pool.effectiveMaxConnsByAccount(apiKeyLow), "最小值应保持为 1") + + unlimited := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 0} + require.Equal(t, 8, pool.effectiveMaxConnsByAccount(unlimited), "无限并发应回退到全局硬上限") + + require.Equal(t, 8, pool.effectiveMaxConnsByAccount(nil), "缺少账号上下文应回退到全局硬上限") +} + +func TestOpenAIWSConnPool_EffectiveMaxConnsDisabledFallbackHardCap(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8 + cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = false + cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor = 1.0 + cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 1.0 + + pool := newOpenAIWSConnPool(cfg) + account := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 2} + require.Equal(t, 8, pool.effectiveMaxConnsByAccount(account), "关闭动态模式后应保持旧行为") +} + +func TestOpenAIWSConnPool_EffectiveMaxConnsByAccount_ModeRouterV2UsesAccountConcurrency(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8 + cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = true + cfg.Gateway.OpenAIWS.OAuthMaxConnsFactor = 0.3 + cfg.Gateway.OpenAIWS.APIKeyMaxConnsFactor = 0.6 + + pool := newOpenAIWSConnPool(cfg) + + high := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 20} + require.Equal(t, 20, pool.effectiveMaxConnsByAccount(high), "v2 路径应直接使用账号并发数作为池上限") + + nonPositive := &Account{Platform: PlatformOpenAI, Type: AccountTypeAPIKey, Concurrency: 0} + require.Equal(t, 0, pool.effectiveMaxConnsByAccount(nonPositive), "并发数<=0 时应不可调度") +} + +func TestOpenAIWSConnPool_AcquireRejectsWhenEffectiveMaxConnsIsZero(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 8 + pool := newOpenAIWSConnPool(cfg) + + account := &Account{ID: 901, Platform: PlatformOpenAI, Type: AccountTypeOAuth, Concurrency: 0} + _, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + }) + require.ErrorIs(t, err, errOpenAIWSConnQueueFull) +} + +func TestOpenAIWSConnLease_ReadMessageWithContextTimeout_PerRead(t *testing.T) { + conn := newOpenAIWSConn("timeout", 1, &openAIWSBlockingConn{readDelay: 80 * time.Millisecond}, nil) + lease := &openAIWSConnLease{conn: conn} + + _, err := lease.ReadMessageWithContextTimeout(context.Background(), 20*time.Millisecond) + require.Error(t, err) + require.ErrorIs(t, err, context.DeadlineExceeded) + + payload, err := lease.ReadMessageWithContextTimeout(context.Background(), 150*time.Millisecond) + require.NoError(t, err) + require.Contains(t, string(payload), "response.completed") + + parentCtx, cancel := context.WithCancel(context.Background()) + cancel() + _, err = lease.ReadMessageWithContextTimeout(parentCtx, 150*time.Millisecond) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) +} + +func TestOpenAIWSConnLease_WriteJSONWithContextTimeout_RespectsParentContext(t *testing.T) { + conn := newOpenAIWSConn("write_timeout_ctx", 1, &openAIWSWriteBlockingConn{}, nil) + lease := &openAIWSConnLease{conn: conn} + + parentCtx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(20 * time.Millisecond) + cancel() + }() + + start := time.Now() + err := lease.WriteJSONWithContextTimeout(parentCtx, map[string]any{"type": "response.create"}, 2*time.Minute) + elapsed := time.Since(start) + + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + require.Less(t, elapsed, 200*time.Millisecond) +} + +func TestOpenAIWSConnLease_PingWithTimeout(t *testing.T) { + conn := newOpenAIWSConn("ping_ok", 1, &openAIWSFakeConn{}, nil) + lease := &openAIWSConnLease{conn: conn} + require.NoError(t, lease.PingWithTimeout(50*time.Millisecond)) + + var nilLease *openAIWSConnLease + err := nilLease.PingWithTimeout(50 * time.Millisecond) + require.ErrorIs(t, err, errOpenAIWSConnClosed) +} + +func TestOpenAIWSConn_ReadAndWriteCanProceedConcurrently(t *testing.T) { + conn := newOpenAIWSConn("full_duplex", 1, &openAIWSBlockingConn{readDelay: 120 * time.Millisecond}, nil) + + readDone := make(chan error, 1) + go func() { + _, err := conn.readMessageWithContextTimeout(context.Background(), 200*time.Millisecond) + readDone <- err + }() + + // 让读取先占用 readMu。 + time.Sleep(20 * time.Millisecond) + + start := time.Now() + err := conn.pingWithTimeout(50 * time.Millisecond) + elapsed := time.Since(start) + + require.NoError(t, err) + require.Less(t, elapsed, 80*time.Millisecond, "写路径不应被读锁长期阻塞") + require.NoError(t, <-readDone) +} + +func TestOpenAIWSConnPool_BackgroundPingSweep_EvictsDeadIdleConn(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + pool := newOpenAIWSConnPool(cfg) + + accountID := int64(301) + ap := pool.getOrCreateAccountPool(accountID) + conn := newOpenAIWSConn("dead_idle", accountID, &openAIWSPingFailConn{}, nil) + ap.mu.Lock() + ap.conns[conn.id] = conn + ap.mu.Unlock() + + pool.runBackgroundPingSweep() + + ap.mu.Lock() + _, exists := ap.conns[conn.id] + ap.mu.Unlock() + require.False(t, exists, "后台 ping 失败的空闲连接应被回收") +} + +func TestOpenAIWSConnPool_BackgroundCleanupSweep_WithoutAcquire(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.MaxIdlePerAccount = 2 + pool := newOpenAIWSConnPool(cfg) + + accountID := int64(302) + ap := pool.getOrCreateAccountPool(accountID) + stale := newOpenAIWSConn("stale_bg", accountID, &openAIWSFakeConn{}, nil) + stale.createdAtNano.Store(time.Now().Add(-2 * time.Hour).UnixNano()) + stale.lastUsedNano.Store(time.Now().Add(-2 * time.Hour).UnixNano()) + ap.mu.Lock() + ap.conns[stale.id] = stale + ap.mu.Unlock() + + pool.runBackgroundCleanupSweep(time.Now()) + + ap.mu.Lock() + _, exists := ap.conns[stale.id] + ap.mu.Unlock() + require.False(t, exists, "后台清理应在无新 acquire 时也回收过期连接") +} + +func TestOpenAIWSConnPool_BackgroundWorkerGuardBranches(t *testing.T) { + var nilPool *openAIWSConnPool + require.NotPanics(t, func() { + nilPool.startBackgroundWorkers() + nilPool.runBackgroundPingWorker() + nilPool.runBackgroundPingSweep() + _ = nilPool.snapshotIdleConnsForPing() + nilPool.runBackgroundCleanupWorker() + nilPool.runBackgroundCleanupSweep(time.Now()) + }) + + poolNoStop := &openAIWSConnPool{} + require.NotPanics(t, func() { + poolNoStop.startBackgroundWorkers() + }) + + poolStopPing := &openAIWSConnPool{workerStopCh: make(chan struct{})} + pingDone := make(chan struct{}) + go func() { + poolStopPing.runBackgroundPingWorker() + close(pingDone) + }() + close(poolStopPing.workerStopCh) + select { + case <-pingDone: + case <-time.After(500 * time.Millisecond): + t.Fatal("runBackgroundPingWorker 未在 stop 信号后退出") + } + + poolStopCleanup := &openAIWSConnPool{workerStopCh: make(chan struct{})} + cleanupDone := make(chan struct{}) + go func() { + poolStopCleanup.runBackgroundCleanupWorker() + close(cleanupDone) + }() + close(poolStopCleanup.workerStopCh) + select { + case <-cleanupDone: + case <-time.After(500 * time.Millisecond): + t.Fatal("runBackgroundCleanupWorker 未在 stop 信号后退出") + } +} + +func TestOpenAIWSConnPool_SnapshotIdleConnsForPing_SkipsInvalidEntries(t *testing.T) { + pool := &openAIWSConnPool{} + pool.accounts.Store("invalid-key", &openAIWSAccountPool{}) + pool.accounts.Store(int64(123), "invalid-value") + + accountID := int64(123) + ap := &openAIWSAccountPool{ + conns: make(map[string]*openAIWSConn), + } + ap.conns["nil_conn"] = nil + + leased := newOpenAIWSConn("leased", accountID, &openAIWSFakeConn{}, nil) + require.True(t, leased.tryAcquire()) + ap.conns[leased.id] = leased + + waiting := newOpenAIWSConn("waiting", accountID, &openAIWSFakeConn{}, nil) + waiting.waiters.Store(1) + ap.conns[waiting.id] = waiting + + idle := newOpenAIWSConn("idle", accountID, &openAIWSFakeConn{}, nil) + ap.conns[idle.id] = idle + + pool.accounts.Store(accountID, ap) + candidates := pool.snapshotIdleConnsForPing() + require.Len(t, candidates, 1) + require.Equal(t, idle.id, candidates[0].conn.id) +} + +func TestOpenAIWSConnPool_RunBackgroundCleanupSweep_SkipsInvalidAndUsesAccountCap(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 4 + cfg.Gateway.OpenAIWS.DynamicMaxConnsByAccountConcurrencyEnabled = true + + pool := &openAIWSConnPool{cfg: cfg} + pool.accounts.Store("bad-key", "bad-value") + + accountID := int64(2026) + ap := &openAIWSAccountPool{ + conns: make(map[string]*openAIWSConn), + } + ap.conns["nil_conn"] = nil + stale := newOpenAIWSConn("stale_bg_cleanup", accountID, &openAIWSFakeConn{}, nil) + stale.createdAtNano.Store(time.Now().Add(-2 * time.Hour).UnixNano()) + stale.lastUsedNano.Store(time.Now().Add(-2 * time.Hour).UnixNano()) + ap.conns[stale.id] = stale + ap.lastAcquire = &openAIWSAcquireRequest{ + Account: &Account{ + ID: accountID, + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + }, + } + pool.accounts.Store(accountID, ap) + + now := time.Now() + require.NotPanics(t, func() { + pool.runBackgroundCleanupSweep(now) + }) + + ap.mu.Lock() + _, nilConnExists := ap.conns["nil_conn"] + _, exists := ap.conns[stale.id] + lastCleanupAt := ap.lastCleanupAt + ap.mu.Unlock() + + require.False(t, nilConnExists, "后台清理应移除无效 nil 连接条目") + require.False(t, exists, "后台清理应清理过期连接") + require.Equal(t, now, lastCleanupAt) +} + +func TestOpenAIWSConnPool_QueueLimitPerConn_DefaultAndConfigured(t *testing.T) { + var nilPool *openAIWSConnPool + require.Equal(t, 256, nilPool.queueLimitPerConn()) + + pool := &openAIWSConnPool{cfg: &config.Config{}} + require.Equal(t, 256, pool.queueLimitPerConn()) + + pool.cfg.Gateway.OpenAIWS.QueueLimitPerConn = 9 + require.Equal(t, 9, pool.queueLimitPerConn()) +} + +func TestOpenAIWSConnPool_Close(t *testing.T) { + cfg := &config.Config{} + pool := newOpenAIWSConnPool(cfg) + + // Close 应该可以安全调用 + pool.Close() + + // workerStopCh 应已关闭 + select { + case <-pool.workerStopCh: + // 预期:channel 已关闭 + default: + t.Fatal("Close 后 workerStopCh 应已关闭") + } + + // 多次调用 Close 不应 panic + pool.Close() + + // nil pool 调用 Close 不应 panic + var nilPool *openAIWSConnPool + nilPool.Close() +} + +func TestOpenAIWSDialError_ErrorAndUnwrap(t *testing.T) { + baseErr := errors.New("boom") + dialErr := &openAIWSDialError{StatusCode: 502, Err: baseErr} + require.Contains(t, dialErr.Error(), "status=502") + require.ErrorIs(t, dialErr.Unwrap(), baseErr) + + noStatus := &openAIWSDialError{Err: baseErr} + require.Contains(t, noStatus.Error(), "boom") + + var nilDialErr *openAIWSDialError + require.Equal(t, "", nilDialErr.Error()) + require.NoError(t, nilDialErr.Unwrap()) +} + +func TestOpenAIWSConnLease_ReadWriteHelpersAndConnStats(t *testing.T) { + conn := newOpenAIWSConn("helper_conn", 1, &openAIWSFakeConn{}, http.Header{ + "X-Test": []string{" value "}, + }) + lease := &openAIWSConnLease{conn: conn} + + require.NoError(t, lease.WriteJSONContext(context.Background(), map[string]any{"type": "response.create"})) + payload, err := lease.ReadMessage(100 * time.Millisecond) + require.NoError(t, err) + require.Contains(t, string(payload), "response.completed") + + payload, err = lease.ReadMessageContext(context.Background()) + require.NoError(t, err) + require.Contains(t, string(payload), "response.completed") + + payload, err = conn.readMessageWithTimeout(100 * time.Millisecond) + require.NoError(t, err) + require.Contains(t, string(payload), "response.completed") + + require.Equal(t, "value", conn.handshakeHeader(" X-Test ")) + require.NotZero(t, conn.createdAt()) + require.NotZero(t, conn.lastUsedAt()) + require.GreaterOrEqual(t, conn.age(time.Now()), time.Duration(0)) + require.GreaterOrEqual(t, conn.idleDuration(time.Now()), time.Duration(0)) + require.False(t, conn.isLeased()) + + // 覆盖空上下文路径 + _, err = conn.readMessage(context.Background()) + require.NoError(t, err) + + // 覆盖 nil 保护分支 + var nilConn *openAIWSConn + require.ErrorIs(t, nilConn.writeJSONWithTimeout(context.Background(), map[string]any{}, time.Second), errOpenAIWSConnClosed) + _, err = nilConn.readMessageWithTimeout(10 * time.Millisecond) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + _, err = nilConn.readMessageWithContextTimeout(context.Background(), 10*time.Millisecond) + require.ErrorIs(t, err, errOpenAIWSConnClosed) +} + +func TestOpenAIWSConnPool_PickOldestIdleAndAccountPoolLoad(t *testing.T) { + pool := &openAIWSConnPool{} + accountID := int64(404) + ap := &openAIWSAccountPool{conns: map[string]*openAIWSConn{}} + + idleOld := newOpenAIWSConn("idle_old", accountID, &openAIWSFakeConn{}, nil) + idleOld.lastUsedNano.Store(time.Now().Add(-10 * time.Minute).UnixNano()) + idleNew := newOpenAIWSConn("idle_new", accountID, &openAIWSFakeConn{}, nil) + idleNew.lastUsedNano.Store(time.Now().Add(-1 * time.Minute).UnixNano()) + leased := newOpenAIWSConn("leased", accountID, &openAIWSFakeConn{}, nil) + require.True(t, leased.tryAcquire()) + leased.waiters.Store(2) + + ap.conns[idleOld.id] = idleOld + ap.conns[idleNew.id] = idleNew + ap.conns[leased.id] = leased + + oldest := pool.pickOldestIdleConnLocked(ap) + require.NotNil(t, oldest) + require.Equal(t, idleOld.id, oldest.id) + + inflight, waiters := accountPoolLoadLocked(ap) + require.Equal(t, 1, inflight) + require.Equal(t, 2, waiters) + + pool.accounts.Store(accountID, ap) + loadInflight, loadWaiters, conns := pool.AccountPoolLoad(accountID) + require.Equal(t, 1, loadInflight) + require.Equal(t, 2, loadWaiters) + require.Equal(t, 3, conns) + + zeroInflight, zeroWaiters, zeroConns := pool.AccountPoolLoad(0) + require.Equal(t, 0, zeroInflight) + require.Equal(t, 0, zeroWaiters) + require.Equal(t, 0, zeroConns) +} + +func TestOpenAIWSConnPool_Close_WaitsWorkerGroupAndNilStopChannel(t *testing.T) { + pool := &openAIWSConnPool{} + release := make(chan struct{}) + pool.workerWg.Add(1) + go func() { + defer pool.workerWg.Done() + <-release + }() + + closed := make(chan struct{}) + go func() { + pool.Close() + close(closed) + }() + + select { + case <-closed: + t.Fatal("Close 不应在 WaitGroup 未完成时提前返回") + case <-time.After(30 * time.Millisecond): + } + + close(release) + select { + case <-closed: + case <-time.After(time.Second): + t.Fatal("Close 未等待 workerWg 完成") + } +} + +func TestOpenAIWSConnPool_Close_ClosesOnlyIdleConnections(t *testing.T) { + pool := &openAIWSConnPool{ + workerStopCh: make(chan struct{}), + } + + accountID := int64(606) + ap := &openAIWSAccountPool{ + conns: map[string]*openAIWSConn{}, + } + idle := newOpenAIWSConn("idle_conn", accountID, &openAIWSFakeConn{}, nil) + leased := newOpenAIWSConn("leased_conn", accountID, &openAIWSFakeConn{}, nil) + require.True(t, leased.tryAcquire()) + + ap.conns[idle.id] = idle + ap.conns[leased.id] = leased + pool.accounts.Store(accountID, ap) + pool.accounts.Store("invalid-key", "invalid-value") + + pool.Close() + + select { + case <-idle.closedCh: + // idle should be closed + default: + t.Fatal("空闲连接应在 Close 时被关闭") + } + + select { + case <-leased.closedCh: + t.Fatal("已租赁连接不应在 Close 时被关闭") + default: + } + + leased.release() + pool.Close() +} + +func TestOpenAIWSConnPool_RunBackgroundPingSweep_ConcurrencyLimit(t *testing.T) { + cfg := &config.Config{} + pool := newOpenAIWSConnPool(cfg) + accountID := int64(505) + ap := pool.getOrCreateAccountPool(accountID) + + var current atomic.Int32 + var maxConcurrent atomic.Int32 + release := make(chan struct{}) + for i := 0; i < 25; i++ { + conn := newOpenAIWSConn(pool.nextConnID(accountID), accountID, &openAIWSPingBlockingConn{ + current: ¤t, + maxConcurrent: &maxConcurrent, + release: release, + }, nil) + ap.mu.Lock() + ap.conns[conn.id] = conn + ap.mu.Unlock() + } + + done := make(chan struct{}) + go func() { + pool.runBackgroundPingSweep() + close(done) + }() + + require.Eventually(t, func() bool { + return maxConcurrent.Load() >= 10 + }, time.Second, 10*time.Millisecond) + + close(release) + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("runBackgroundPingSweep 未在释放后完成") + } + + require.LessOrEqual(t, maxConcurrent.Load(), int32(10)) +} + +func TestOpenAIWSConnLease_BasicGetterBranches(t *testing.T) { + var nilLease *openAIWSConnLease + require.Equal(t, "", nilLease.ConnID()) + require.Equal(t, time.Duration(0), nilLease.QueueWaitDuration()) + require.Equal(t, time.Duration(0), nilLease.ConnPickDuration()) + require.False(t, nilLease.Reused()) + require.Equal(t, "", nilLease.HandshakeHeader("x-test")) + require.False(t, nilLease.IsPrewarmed()) + nilLease.MarkPrewarmed() + nilLease.Release() + + conn := newOpenAIWSConn("getter_conn", 1, &openAIWSFakeConn{}, http.Header{"X-Test": []string{"ok"}}) + lease := &openAIWSConnLease{ + conn: conn, + queueWait: 3 * time.Millisecond, + connPick: 4 * time.Millisecond, + reused: true, + } + require.Equal(t, "getter_conn", lease.ConnID()) + require.Equal(t, 3*time.Millisecond, lease.QueueWaitDuration()) + require.Equal(t, 4*time.Millisecond, lease.ConnPickDuration()) + require.True(t, lease.Reused()) + require.Equal(t, "ok", lease.HandshakeHeader("x-test")) + require.False(t, lease.IsPrewarmed()) + lease.MarkPrewarmed() + require.True(t, lease.IsPrewarmed()) + lease.Release() +} + +func TestOpenAIWSConnPool_UtilityBranches(t *testing.T) { + var nilPool *openAIWSConnPool + require.Equal(t, OpenAIWSPoolMetricsSnapshot{}, nilPool.SnapshotMetrics()) + require.Equal(t, OpenAIWSTransportMetricsSnapshot{}, nilPool.SnapshotTransportMetrics()) + + pool := &openAIWSConnPool{cfg: &config.Config{}} + pool.metrics.acquireTotal.Store(7) + pool.metrics.acquireReuseTotal.Store(3) + metrics := pool.SnapshotMetrics() + require.Equal(t, int64(7), metrics.AcquireTotal) + require.Equal(t, int64(3), metrics.AcquireReuseTotal) + + // 非 transport metrics dialer 路径 + pool.clientDialer = &openAIWSFakeDialer{} + require.Equal(t, OpenAIWSTransportMetricsSnapshot{}, pool.SnapshotTransportMetrics()) + pool.setClientDialerForTest(nil) + require.NotNil(t, pool.clientDialer) + + require.Equal(t, 8, nilPool.maxConnsHardCap()) + require.False(t, nilPool.dynamicMaxConnsEnabled()) + require.Equal(t, 1.0, nilPool.maxConnsFactorByAccount(nil)) + require.Equal(t, 0, nilPool.minIdlePerAccount()) + require.Equal(t, 4, nilPool.maxIdlePerAccount()) + require.Equal(t, 256, nilPool.queueLimitPerConn()) + require.Equal(t, 0.7, nilPool.targetUtilization()) + require.Equal(t, time.Duration(0), nilPool.prewarmCooldown()) + require.Equal(t, 10*time.Second, nilPool.dialTimeout()) + + // shouldSuppressPrewarmLocked 覆盖 3 条分支 + now := time.Now() + apNilFail := &openAIWSAccountPool{prewarmFails: 1} + require.False(t, pool.shouldSuppressPrewarmLocked(apNilFail, now)) + apZeroTime := &openAIWSAccountPool{prewarmFails: 2} + require.False(t, pool.shouldSuppressPrewarmLocked(apZeroTime, now)) + require.Equal(t, 0, apZeroTime.prewarmFails) + apOldFail := &openAIWSAccountPool{prewarmFails: 2, prewarmFailAt: now.Add(-openAIWSPrewarmFailureWindow - time.Second)} + require.False(t, pool.shouldSuppressPrewarmLocked(apOldFail, now)) + apRecentFail := &openAIWSAccountPool{prewarmFails: openAIWSPrewarmFailureSuppress, prewarmFailAt: now} + require.True(t, pool.shouldSuppressPrewarmLocked(apRecentFail, now)) + + // recordConnPickDuration 的保护分支 + nilPool.recordConnPickDuration(10 * time.Millisecond) + pool.recordConnPickDuration(-10 * time.Millisecond) + require.Equal(t, int64(1), pool.metrics.connPickTotal.Load()) + + // account pool 读写分支 + require.Nil(t, nilPool.getOrCreateAccountPool(1)) + require.Nil(t, pool.getOrCreateAccountPool(0)) + pool.accounts.Store(int64(7), "invalid") + ap := pool.getOrCreateAccountPool(7) + require.NotNil(t, ap) + _, ok := pool.getAccountPool(0) + require.False(t, ok) + _, ok = pool.getAccountPool(12345) + require.False(t, ok) + pool.accounts.Store(int64(8), "bad-type") + _, ok = pool.getAccountPool(8) + require.False(t, ok) + + // health check 条件 + require.False(t, pool.shouldHealthCheckConn(nil)) + conn := newOpenAIWSConn("health", 1, &openAIWSFakeConn{}, nil) + conn.lastUsedNano.Store(time.Now().Add(-openAIWSConnHealthCheckIdle - time.Second).UnixNano()) + require.True(t, pool.shouldHealthCheckConn(conn)) +} + +func TestOpenAIWSConn_LeaseAndTimeHelpers_NilAndClosedBranches(t *testing.T) { + var nilConn *openAIWSConn + nilConn.touch() + require.Equal(t, time.Time{}, nilConn.createdAt()) + require.Equal(t, time.Time{}, nilConn.lastUsedAt()) + require.Equal(t, time.Duration(0), nilConn.idleDuration(time.Now())) + require.Equal(t, time.Duration(0), nilConn.age(time.Now())) + require.False(t, nilConn.isLeased()) + require.False(t, nilConn.isPrewarmed()) + nilConn.markPrewarmed() + + conn := newOpenAIWSConn("lease_state", 1, &openAIWSFakeConn{}, nil) + require.True(t, conn.tryAcquire()) + require.True(t, conn.isLeased()) + conn.release() + require.False(t, conn.isLeased()) + conn.close() + require.False(t, conn.tryAcquire()) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + err := conn.acquire(ctx) + require.Error(t, err) +} + +func TestOpenAIWSConnLease_ReadWriteNilConnBranches(t *testing.T) { + lease := &openAIWSConnLease{} + require.ErrorIs(t, lease.WriteJSON(map[string]any{"k": "v"}, time.Second), errOpenAIWSConnClosed) + require.ErrorIs(t, lease.WriteJSONContext(context.Background(), map[string]any{"k": "v"}), errOpenAIWSConnClosed) + _, err := lease.ReadMessage(10 * time.Millisecond) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + _, err = lease.ReadMessageContext(context.Background()) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + _, err = lease.ReadMessageWithContextTimeout(context.Background(), 10*time.Millisecond) + require.ErrorIs(t, err, errOpenAIWSConnClosed) +} + +func TestOpenAIWSConnLease_ReleasedLeaseGuards(t *testing.T) { + conn := newOpenAIWSConn("released_guard", 1, &openAIWSFakeConn{}, nil) + lease := &openAIWSConnLease{conn: conn} + + require.NoError(t, lease.PingWithTimeout(50*time.Millisecond)) + + lease.Release() + lease.Release() // idempotent + + require.ErrorIs(t, lease.WriteJSON(map[string]any{"k": "v"}, time.Second), errOpenAIWSConnClosed) + require.ErrorIs(t, lease.WriteJSONContext(context.Background(), map[string]any{"k": "v"}), errOpenAIWSConnClosed) + require.ErrorIs(t, lease.WriteJSONWithContextTimeout(context.Background(), map[string]any{"k": "v"}, time.Second), errOpenAIWSConnClosed) + + _, err := lease.ReadMessage(10 * time.Millisecond) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + _, err = lease.ReadMessageContext(context.Background()) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + _, err = lease.ReadMessageWithContextTimeout(context.Background(), 10*time.Millisecond) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + + require.ErrorIs(t, lease.PingWithTimeout(50*time.Millisecond), errOpenAIWSConnClosed) +} + +func TestOpenAIWSConnLease_MarkBrokenAfterRelease_NoEviction(t *testing.T) { + conn := newOpenAIWSConn("released_markbroken", 7, &openAIWSFakeConn{}, nil) + ap := &openAIWSAccountPool{ + conns: map[string]*openAIWSConn{ + conn.id: conn, + }, + } + pool := &openAIWSConnPool{} + pool.accounts.Store(int64(7), ap) + + lease := &openAIWSConnLease{ + pool: pool, + accountID: 7, + conn: conn, + } + + lease.Release() + lease.MarkBroken() + + ap.mu.Lock() + _, exists := ap.conns[conn.id] + ap.mu.Unlock() + require.True(t, exists, "released lease should not evict active pool connection") +} + +func TestOpenAIWSConn_AdditionalGuardBranches(t *testing.T) { + var nilConn *openAIWSConn + require.False(t, nilConn.tryAcquire()) + require.ErrorIs(t, nilConn.acquire(context.Background()), errOpenAIWSConnClosed) + nilConn.release() + nilConn.close() + require.Equal(t, "", nilConn.handshakeHeader("x-test")) + + connBusy := newOpenAIWSConn("busy_ctx", 1, &openAIWSFakeConn{}, nil) + require.True(t, connBusy.tryAcquire()) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + require.ErrorIs(t, connBusy.acquire(ctx), context.Canceled) + connBusy.release() + + connClosed := newOpenAIWSConn("closed_guard", 1, &openAIWSFakeConn{}, nil) + connClosed.close() + require.ErrorIs( + t, + connClosed.writeJSONWithTimeout(context.Background(), map[string]any{"k": "v"}, time.Second), + errOpenAIWSConnClosed, + ) + _, err := connClosed.readMessageWithContextTimeout(context.Background(), time.Second) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + require.ErrorIs(t, connClosed.pingWithTimeout(time.Second), errOpenAIWSConnClosed) + + connNoWS := newOpenAIWSConn("no_ws", 1, nil, nil) + require.ErrorIs(t, connNoWS.writeJSON(map[string]any{"k": "v"}, context.Background()), errOpenAIWSConnClosed) + _, err = connNoWS.readMessage(context.Background()) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + require.ErrorIs(t, connNoWS.pingWithTimeout(time.Second), errOpenAIWSConnClosed) + require.Equal(t, "", connNoWS.handshakeHeader("x-test")) + + connOK := newOpenAIWSConn("ok", 1, &openAIWSFakeConn{}, nil) + require.NoError(t, connOK.writeJSON(map[string]any{"k": "v"}, nil)) + _, err = connOK.readMessageWithContextTimeout(context.Background(), 0) + require.NoError(t, err) + require.NoError(t, connOK.pingWithTimeout(0)) + + connZero := newOpenAIWSConn("zero_ts", 1, &openAIWSFakeConn{}, nil) + connZero.createdAtNano.Store(0) + connZero.lastUsedNano.Store(0) + require.True(t, connZero.createdAt().IsZero()) + require.True(t, connZero.lastUsedAt().IsZero()) + require.Equal(t, time.Duration(0), connZero.idleDuration(time.Now())) + require.Equal(t, time.Duration(0), connZero.age(time.Now())) + + require.Nil(t, cloneOpenAIWSAcquireRequestPtr(nil)) + copied := cloneHeader(http.Header{ + "X-Empty": []string{}, + "X-Test": []string{"v1"}, + }) + require.Contains(t, copied, "X-Empty") + require.Nil(t, copied["X-Empty"]) + require.Equal(t, "v1", copied.Get("X-Test")) + + closeOpenAIWSConns([]*openAIWSConn{nil, connOK}) +} + +func TestOpenAIWSConnLease_MarkBrokenEvictsConn(t *testing.T) { + pool := newOpenAIWSConnPool(&config.Config{}) + accountID := int64(5001) + conn := newOpenAIWSConn("broken_me", accountID, &openAIWSFakeConn{}, nil) + ap := pool.getOrCreateAccountPool(accountID) + ap.mu.Lock() + ap.conns[conn.id] = conn + ap.mu.Unlock() + + lease := &openAIWSConnLease{ + pool: pool, + accountID: accountID, + conn: conn, + } + lease.MarkBroken() + + ap.mu.Lock() + _, exists := ap.conns[conn.id] + ap.mu.Unlock() + require.False(t, exists) + require.False(t, conn.tryAcquire(), "被标记为 broken 的连接应被关闭") +} + +func TestOpenAIWSConnPool_TargetConnCountAndPrewarmBranches(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + pool := newOpenAIWSConnPool(cfg) + + require.Equal(t, 0, pool.targetConnCountLocked(nil, 1)) + ap := &openAIWSAccountPool{conns: map[string]*openAIWSConn{}} + require.Equal(t, 0, pool.targetConnCountLocked(ap, 0)) + + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 3 + require.Equal(t, 1, pool.targetConnCountLocked(ap, 1), "minIdle 应被 maxConns 截断") + + // 覆盖 waiters>0 且 target 需要至少 len(conns)+1 的分支 + cfg.Gateway.OpenAIWS.MinIdlePerAccount = 0 + cfg.Gateway.OpenAIWS.PoolTargetUtilization = 0.9 + busy := newOpenAIWSConn("busy_target", 2, &openAIWSFakeConn{}, nil) + require.True(t, busy.tryAcquire()) + busy.waiters.Store(1) + ap.conns[busy.id] = busy + target := pool.targetConnCountLocked(ap, 4) + require.GreaterOrEqual(t, target, len(ap.conns)+1) + + // prewarm: account pool 缺失时,拨号后的连接应被关闭并提前返回 + req := openAIWSAcquireRequest{ + Account: &Account{ID: 999, Platform: PlatformOpenAI, Type: AccountTypeAPIKey}, + WSURL: "wss://example.com/v1/responses", + } + pool.prewarmConns(999, req, 1) + + // prewarm: 拨号失败分支(prewarmFails 累加) + accountID := int64(1000) + failPool := newOpenAIWSConnPool(cfg) + failPool.setClientDialerForTest(&openAIWSAlwaysFailDialer{}) + apFail := failPool.getOrCreateAccountPool(accountID) + apFail.mu.Lock() + apFail.creating = 1 + apFail.mu.Unlock() + req.Account.ID = accountID + failPool.prewarmConns(accountID, req, 1) + apFail.mu.Lock() + require.GreaterOrEqual(t, apFail.prewarmFails, 1) + apFail.mu.Unlock() +} + +func TestOpenAIWSConnPool_Acquire_ErrorBranches(t *testing.T) { + var nilPool *openAIWSConnPool + _, err := nilPool.Acquire(context.Background(), openAIWSAcquireRequest{}) + require.Error(t, err) + + pool := newOpenAIWSConnPool(&config.Config{}) + _, err = pool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: &Account{ID: 1}, + WSURL: " ", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "ws url is empty") + + // target=nil 分支:池满且仅有 nil 连接 + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 1 + cfg.Gateway.OpenAIWS.QueueLimitPerConn = 1 + fullPool := newOpenAIWSConnPool(cfg) + account := &Account{ID: 2001, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + ap := fullPool.getOrCreateAccountPool(account.ID) + ap.mu.Lock() + ap.conns["nil"] = nil + ap.lastCleanupAt = time.Now() + ap.mu.Unlock() + _, err = fullPool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + }) + require.ErrorIs(t, err, errOpenAIWSConnClosed) + + // queue full 分支:waiters 达上限 + account2 := &Account{ID: 2002, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + ap2 := fullPool.getOrCreateAccountPool(account2.ID) + conn := newOpenAIWSConn("queue_full", account2.ID, &openAIWSFakeConn{}, nil) + require.True(t, conn.tryAcquire()) + conn.waiters.Store(1) + ap2.mu.Lock() + ap2.conns[conn.id] = conn + ap2.lastCleanupAt = time.Now() + ap2.mu.Unlock() + _, err = fullPool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: account2, + WSURL: "wss://example.com/v1/responses", + }) + require.ErrorIs(t, err, errOpenAIWSConnQueueFull) +} + +type openAIWSFakeDialer struct{} + +func (d *openAIWSFakeDialer) Dial( + ctx context.Context, + wsURL string, + headers http.Header, + proxyURL string, +) (openAIWSClientConn, int, http.Header, error) { + _ = ctx + _ = wsURL + _ = headers + _ = proxyURL + return &openAIWSFakeConn{}, 0, nil, nil +} + +type openAIWSCountingDialer struct { + mu sync.Mutex + dialCount int +} + +type openAIWSAlwaysFailDialer struct { + mu sync.Mutex + dialCount int +} + +type openAIWSPingBlockingConn struct { + current *atomic.Int32 + maxConcurrent *atomic.Int32 + release <-chan struct{} +} + +func (c *openAIWSPingBlockingConn) WriteJSON(context.Context, any) error { + return nil +} + +func (c *openAIWSPingBlockingConn) ReadMessage(context.Context) ([]byte, error) { + return []byte(`{"type":"response.completed","response":{"id":"resp_blocking_ping"}}`), nil +} + +func (c *openAIWSPingBlockingConn) Ping(ctx context.Context) error { + if c.current == nil || c.maxConcurrent == nil { + return nil + } + + now := c.current.Add(1) + for { + prev := c.maxConcurrent.Load() + if now <= prev || c.maxConcurrent.CompareAndSwap(prev, now) { + break + } + } + defer c.current.Add(-1) + + select { + case <-ctx.Done(): + return ctx.Err() + case <-c.release: + return nil + } +} + +func (c *openAIWSPingBlockingConn) Close() error { + return nil +} + +func (d *openAIWSCountingDialer) Dial( + ctx context.Context, + wsURL string, + headers http.Header, + proxyURL string, +) (openAIWSClientConn, int, http.Header, error) { + _ = ctx + _ = wsURL + _ = headers + _ = proxyURL + d.mu.Lock() + d.dialCount++ + d.mu.Unlock() + return &openAIWSFakeConn{}, 0, nil, nil +} + +func (d *openAIWSCountingDialer) DialCount() int { + d.mu.Lock() + defer d.mu.Unlock() + return d.dialCount +} + +func (d *openAIWSAlwaysFailDialer) Dial( + ctx context.Context, + wsURL string, + headers http.Header, + proxyURL string, +) (openAIWSClientConn, int, http.Header, error) { + _ = ctx + _ = wsURL + _ = headers + _ = proxyURL + d.mu.Lock() + d.dialCount++ + d.mu.Unlock() + return nil, 503, nil, errors.New("dial failed") +} + +func (d *openAIWSAlwaysFailDialer) DialCount() int { + d.mu.Lock() + defer d.mu.Unlock() + return d.dialCount +} + +type openAIWSFakeConn struct { + mu sync.Mutex + closed bool + payload [][]byte +} + +func (c *openAIWSFakeConn) WriteJSON(ctx context.Context, value any) error { + _ = ctx + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return errors.New("closed") + } + c.payload = append(c.payload, []byte("ok")) + _ = value + return nil +} + +func (c *openAIWSFakeConn) ReadMessage(ctx context.Context) ([]byte, error) { + _ = ctx + c.mu.Lock() + defer c.mu.Unlock() + if c.closed { + return nil, errors.New("closed") + } + return []byte(`{"type":"response.completed","response":{"id":"resp_fake"}}`), nil +} + +func (c *openAIWSFakeConn) Ping(ctx context.Context) error { + _ = ctx + return nil +} + +func (c *openAIWSFakeConn) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + c.closed = true + return nil +} + +type openAIWSBlockingConn struct { + readDelay time.Duration +} + +func (c *openAIWSBlockingConn) WriteJSON(ctx context.Context, value any) error { + _ = ctx + _ = value + return nil +} + +func (c *openAIWSBlockingConn) ReadMessage(ctx context.Context) ([]byte, error) { + delay := c.readDelay + if delay <= 0 { + delay = 10 * time.Millisecond + } + timer := time.NewTimer(delay) + defer timer.Stop() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-timer.C: + return []byte(`{"type":"response.completed","response":{"id":"resp_blocking"}}`), nil + } +} + +func (c *openAIWSBlockingConn) Ping(ctx context.Context) error { + _ = ctx + return nil +} + +func (c *openAIWSBlockingConn) Close() error { + return nil +} + +type openAIWSWriteBlockingConn struct{} + +func (c *openAIWSWriteBlockingConn) WriteJSON(ctx context.Context, _ any) error { + <-ctx.Done() + return ctx.Err() +} + +func (c *openAIWSWriteBlockingConn) ReadMessage(context.Context) ([]byte, error) { + return []byte(`{"type":"response.completed","response":{"id":"resp_write_block"}}`), nil +} + +func (c *openAIWSWriteBlockingConn) Ping(context.Context) error { + return nil +} + +func (c *openAIWSWriteBlockingConn) Close() error { + return nil +} + +type openAIWSPingFailConn struct{} + +func (c *openAIWSPingFailConn) WriteJSON(context.Context, any) error { + return nil +} + +func (c *openAIWSPingFailConn) ReadMessage(context.Context) ([]byte, error) { + return []byte(`{"type":"response.completed","response":{"id":"resp_ping_fail"}}`), nil +} + +func (c *openAIWSPingFailConn) Ping(context.Context) error { + return errors.New("ping failed") +} + +func (c *openAIWSPingFailConn) Close() error { + return nil +} + +type openAIWSContextProbeConn struct { + lastWriteCtx context.Context +} + +func (c *openAIWSContextProbeConn) WriteJSON(ctx context.Context, _ any) error { + c.lastWriteCtx = ctx + return nil +} + +func (c *openAIWSContextProbeConn) ReadMessage(context.Context) ([]byte, error) { + return []byte(`{"type":"response.completed","response":{"id":"resp_ctx_probe"}}`), nil +} + +func (c *openAIWSContextProbeConn) Ping(context.Context) error { + return nil +} + +func (c *openAIWSContextProbeConn) Close() error { + return nil +} + +type openAIWSNilConnDialer struct{} + +func (d *openAIWSNilConnDialer) Dial( + ctx context.Context, + wsURL string, + headers http.Header, + proxyURL string, +) (openAIWSClientConn, int, http.Header, error) { + _ = ctx + _ = wsURL + _ = headers + _ = proxyURL + return nil, 200, nil, nil +} + +func TestOpenAIWSConnPool_DialConnNilConnection(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.MaxConnsPerAccount = 2 + cfg.Gateway.OpenAIWS.DialTimeoutSeconds = 1 + + pool := newOpenAIWSConnPool(cfg) + pool.setClientDialerForTest(&openAIWSNilConnDialer{}) + account := &Account{ID: 91, Platform: PlatformOpenAI, Type: AccountTypeAPIKey} + + _, err := pool.Acquire(context.Background(), openAIWSAcquireRequest{ + Account: account, + WSURL: "wss://example.com/v1/responses", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "nil connection") +} + +func TestOpenAIWSConnPool_SnapshotTransportMetrics(t *testing.T) { + cfg := &config.Config{} + pool := newOpenAIWSConnPool(cfg) + + dialer, ok := pool.clientDialer.(*coderOpenAIWSClientDialer) + require.True(t, ok) + + _, err := dialer.proxyHTTPClient("http://127.0.0.1:28080") + require.NoError(t, err) + _, err = dialer.proxyHTTPClient("http://127.0.0.1:28080") + require.NoError(t, err) + _, err = dialer.proxyHTTPClient("http://127.0.0.1:28081") + require.NoError(t, err) + + snapshot := pool.SnapshotTransportMetrics() + require.Equal(t, int64(1), snapshot.ProxyClientCacheHits) + require.Equal(t, int64(2), snapshot.ProxyClientCacheMisses) + require.InDelta(t, 1.0/3.0, snapshot.TransportReuseRatio, 0.0001) +} diff --git a/backend/internal/service/openai_ws_protocol_forward_test.go b/backend/internal/service/openai_ws_protocol_forward_test.go new file mode 100644 index 00000000..df4d4871 --- /dev/null +++ b/backend/internal/service/openai_ws_protocol_forward_test.go @@ -0,0 +1,1218 @@ +package service + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" +) + +func TestOpenAIGatewayService_Forward_PreservePreviousResponseIDWhenWSEnabled(t *testing.T) { + gin.SetMode(gin.TestMode) + wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + })) + defer wsFallbackServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader( + `{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`, + )), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + } + + account := &Account{ + ID: 1, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsFallbackServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_123","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "WS 模式下失败时不应回退 HTTP") +} + +func TestOpenAIGatewayService_Forward_HTTPIngressStaysHTTPWhenWSEnabled(t *testing.T) { + gin.SetMode(gin.TestMode) + wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + })) + defer wsFallbackServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + SetOpenAIClientTransport(c, OpenAIClientTransportHTTP) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader( + `{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`, + )), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + } + + account := &Account{ + ID: 101, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsFallbackServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_http_keep","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, result.OpenAIWSMode, "HTTP 入站应保持 HTTP 转发") + require.NotNil(t, upstream.lastReq, "HTTP 入站应命中 HTTP 上游") + require.False(t, gjson.GetBytes(upstream.lastBody, "previous_response_id").Exists(), "HTTP 路径应沿用原逻辑移除 previous_response_id") + + decision, _ := c.Get("openai_ws_transport_decision") + reason, _ := c.Get("openai_ws_transport_reason") + require.Equal(t, string(OpenAIUpstreamTransportHTTPSSE), decision) + require.Equal(t, "client_protocol_http", reason) +} + +func TestOpenAIGatewayService_Forward_RemovePreviousResponseIDWhenWSDisabled(t *testing.T) { + gin.SetMode(gin.TestMode) + wsFallbackServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + })) + defer wsFallbackServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader( + `{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`, + )), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = false + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + } + + account := &Account{ + ID: 1, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsFallbackServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_123","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.False(t, gjson.GetBytes(upstream.lastBody, "previous_response_id").Exists()) +} + +func TestOpenAIGatewayService_Forward_WSv2Dial426FallbackHTTP(t *testing.T) { + gin.SetMode(gin.TestMode) + ws426Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUpgradeRequired) + _, _ = w.Write([]byte(`upgrade required`)) + })) + defer ws426Server.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader( + `{"usage":{"input_tokens":8,"output_tokens":9,"input_tokens_details":{"cached_tokens":1}}}`, + )), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + } + + account := &Account{ + ID: 12, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": ws426Server.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_426","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Contains(t, err.Error(), "upgrade_required") + require.Nil(t, upstream.lastReq, "WS 模式下不应再回退 HTTP") + require.Equal(t, http.StatusUpgradeRequired, rec.Code) + require.Contains(t, rec.Body.String(), "426") +} + +func TestOpenAIGatewayService_Forward_WSv2FallbackCoolingSkipWS(t *testing.T) { + gin.SetMode(gin.TestMode) + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.NotFound(w, r) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader( + `{"usage":{"input_tokens":2,"output_tokens":3,"input_tokens_details":{"cached_tokens":0}}}`, + )), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 30 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + } + + account := &Account{ + ID: 21, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + svc.markOpenAIWSFallbackCooling(account.ID, "upgrade_required") + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_cooling","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "WS 模式下不应再回退 HTTP") + + _, ok := c.Get("openai_ws_fallback_cooling") + require.False(t, ok, "已移除 fallback cooling 快捷回退路径") +} + +func TestOpenAIGatewayService_Forward_ReturnErrorWhenOnlyWSv1Enabled(t *testing.T) { + gin.SetMode(gin.TestMode) + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader( + `{"usage":{"input_tokens":1,"output_tokens":2,"input_tokens_details":{"cached_tokens":0}}}`, + )), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsockets = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = false + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + } + + account := &Account{ + ID: 31, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": "https://api.openai.com/v1/responses", + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"previous_response_id":"resp_v1","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Contains(t, err.Error(), "ws v1") + require.Equal(t, http.StatusBadRequest, rec.Code) + require.Contains(t, rec.Body.String(), "WSv1") + require.Nil(t, upstream.lastReq, "WSv1 不支持时不应触发 HTTP 上游请求") +} + +func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) { + cfg := &config.Config{} + svc := NewOpenAIGatewayService( + nil, + nil, + nil, + nil, + nil, + cfg, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + nil, + ) + + decision := svc.getOpenAIWSProtocolResolver().Resolve(nil) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "account_missing", decision.Reason) +} + +func TestOpenAIGatewayService_Forward_WSv2FallbackWhenResponseAlreadyWrittenReturnsWSError(t *testing.T) { + gin.SetMode(gin.TestMode) + ws426Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUpgradeRequired) + _, _ = w.Write([]byte(`upgrade required`)) + })) + defer ws426Server.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + c.String(http.StatusAccepted, "already-written") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + } + + account := &Account{ + ID: 41, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": ws426Server.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.1","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Contains(t, err.Error(), "ws fallback") + require.Nil(t, upstream.lastReq, "已写下游响应时,不应再回退 HTTP") +} + +func TestOpenAIGatewayService_Forward_WSv2StreamEarlyCloseFallbackHTTP(t *testing.T) { + gin.SetMode(gin.TestMode) + + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + + // 仅发送 response.created(非 token 事件)后立即关闭, + // 模拟线上“上游早期内部错误断连”的场景。 + if err := conn.WriteJSON(map[string]any{ + "type": "response.created", + "response": map[string]any{ + "id": "resp_ws_created_only", + "model": "gpt-5.3-codex", + }, + }); err != nil { + t.Errorf("write response.created failed: %v", err) + return + } + closePayload := websocket.FormatCloseMessage(websocket.CloseInternalServerErr, "") + _ = conn.WriteControl(websocket.CloseMessage, closePayload, time.Now().Add(time.Second)) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.output_text.delta\",\"delta\":\"ok\"}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_http_fallback\",\"usage\":{\"input_tokens\":2,\"output_tokens\":1}}}\n\n" + + "data: [DONE]\n\n", + )), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 88, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":true,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "WS 早期断连后不应再回退 HTTP") + require.Empty(t, rec.Body.String(), "未产出 token 前上游断连时不应写入下游半截流") +} + +func TestOpenAIGatewayService_Forward_WSv2RetryFiveTimesThenFallbackHTTP(t *testing.T) { + gin.SetMode(gin.TestMode) + + var wsAttempts atomic.Int32 + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wsAttempts.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + closePayload := websocket.FormatCloseMessage(websocket.CloseInternalServerErr, "") + _ = conn.WriteControl(websocket.CloseMessage, closePayload, time.Now().Add(time.Second)) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.output_text.delta\",\"delta\":\"ok\"}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_retry_http_fallback\",\"usage\":{\"input_tokens\":2,\"output_tokens\":1}}}\n\n" + + "data: [DONE]\n\n", + )), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 89, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":true,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "WS 重连耗尽后不应再回退 HTTP") + require.Equal(t, int32(openAIWSReconnectRetryLimit+1), wsAttempts.Load()) +} + +func TestOpenAIGatewayService_Forward_WSv2PolicyViolationFastFallbackHTTP(t *testing.T) { + gin.SetMode(gin.TestMode) + + var wsAttempts atomic.Int32 + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wsAttempts.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + closePayload := websocket.FormatCloseMessage(websocket.ClosePolicyViolation, "") + _ = conn.WriteControl(websocket.CloseMessage, closePayload, time.Now().Add(time.Second)) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_policy_fallback","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + cfg.Gateway.OpenAIWS.RetryBackoffInitialMS = 1 + cfg.Gateway.OpenAIWS.RetryBackoffMaxMS = 2 + cfg.Gateway.OpenAIWS.RetryJitterRatio = 0 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 8901, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "策略违规关闭后不应回退 HTTP") + require.Equal(t, int32(1), wsAttempts.Load(), "策略违规不应进行 WS 重试") +} + +func TestOpenAIGatewayService_Forward_WSv2ConnectionLimitReachedRetryThenFallbackHTTP(t *testing.T) { + gin.SetMode(gin.TestMode) + + var wsAttempts atomic.Int32 + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wsAttempts.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + _ = conn.WriteJSON(map[string]any{ + "type": "error", + "error": map[string]any{ + "code": "websocket_connection_limit_reached", + "type": "server_error", + "message": "websocket connection limit reached", + }, + }) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_retry_limit","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 90, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "触发 websocket_connection_limit_reached 后不应回退 HTTP") + require.Equal(t, int32(openAIWSReconnectRetryLimit+1), wsAttempts.Load()) +} + +func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundRecoversByDroppingPreviousResponseID(t *testing.T) { + gin.SetMode(gin.TestMode) + + var wsAttempts atomic.Int32 + var wsRequestPayloads [][]byte + var wsRequestMu sync.Mutex + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + attempt := wsAttempts.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + reqRaw, _ := json.Marshal(req) + wsRequestMu.Lock() + wsRequestPayloads = append(wsRequestPayloads, reqRaw) + wsRequestMu.Unlock() + if attempt == 1 { + _ = conn.WriteJSON(map[string]any{ + "type": "error", + "error": map[string]any{ + "code": "previous_response_not_found", + "type": "invalid_request_error", + "message": "previous response not found", + }, + }) + return + } + _ = conn.WriteJSON(map[string]any{ + "type": "response.completed", + "response": map[string]any{ + "id": "resp_ws_prev_recover_ok", + "model": "gpt-5.3-codex", + "usage": map[string]any{ + "input_tokens": 1, + "output_tokens": 1, + "input_tokens_details": map[string]any{ + "cached_tokens": 0, + }, + }, + }, + }) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 91, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_missing","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "resp_ws_prev_recover_ok", result.RequestID) + require.Nil(t, upstream.lastReq, "previous_response_not_found 不应回退 HTTP") + require.Equal(t, int32(2), wsAttempts.Load(), "previous_response_not_found 应触发一次去掉 previous_response_id 的恢复重试") + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "resp_ws_prev_recover_ok", gjson.Get(rec.Body.String(), "id").String()) + + wsRequestMu.Lock() + requests := append([][]byte(nil), wsRequestPayloads...) + wsRequestMu.Unlock() + require.Len(t, requests, 2) + require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists(), "首轮请求应保留 previous_response_id") + require.False(t, gjson.GetBytes(requests[1], "previous_response_id").Exists(), "恢复重试应移除 previous_response_id") +} + +func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundSkipsRecoveryForFunctionCallOutput(t *testing.T) { + gin.SetMode(gin.TestMode) + + var wsAttempts atomic.Int32 + var wsRequestPayloads [][]byte + var wsRequestMu sync.Mutex + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wsAttempts.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + reqRaw, _ := json.Marshal(req) + wsRequestMu.Lock() + wsRequestPayloads = append(wsRequestPayloads, reqRaw) + wsRequestMu.Unlock() + _ = conn.WriteJSON(map[string]any{ + "type": "error", + "error": map[string]any{ + "code": "previous_response_not_found", + "type": "invalid_request_error", + "message": "previous response not found", + }, + }) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 92, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_missing","input":[{"type":"function_call_output","call_id":"call_1","output":"ok"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "previous_response_not_found 不应回退 HTTP") + require.Equal(t, int32(1), wsAttempts.Load(), "function_call_output 场景应跳过 previous_response_not_found 自动恢复") + require.Equal(t, http.StatusBadRequest, rec.Code) + require.Contains(t, strings.ToLower(rec.Body.String()), "previous response not found") + + wsRequestMu.Lock() + requests := append([][]byte(nil), wsRequestPayloads...) + wsRequestMu.Unlock() + require.Len(t, requests, 1) + require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists()) +} + +func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundSkipsRecoveryWithoutPreviousResponseID(t *testing.T) { + gin.SetMode(gin.TestMode) + + var wsAttempts atomic.Int32 + var wsRequestPayloads [][]byte + var wsRequestMu sync.Mutex + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wsAttempts.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + reqRaw, _ := json.Marshal(req) + wsRequestMu.Lock() + wsRequestPayloads = append(wsRequestPayloads, reqRaw) + wsRequestMu.Unlock() + _ = conn.WriteJSON(map[string]any{ + "type": "error", + "error": map[string]any{ + "code": "previous_response_not_found", + "type": "invalid_request_error", + "message": "previous response not found", + }, + }) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 93, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":false,"input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "WS 模式下 previous_response_not_found 不应回退 HTTP") + require.Equal(t, int32(1), wsAttempts.Load(), "缺少 previous_response_id 时应跳过自动恢复重试") + require.Equal(t, http.StatusBadRequest, rec.Code) + + wsRequestMu.Lock() + requests := append([][]byte(nil), wsRequestPayloads...) + wsRequestMu.Unlock() + require.Len(t, requests, 1) + require.False(t, gjson.GetBytes(requests[0], "previous_response_id").Exists()) +} + +func TestOpenAIGatewayService_Forward_WSv2PreviousResponseNotFoundOnlyRecoversOnce(t *testing.T) { + gin.SetMode(gin.TestMode) + + var wsAttempts atomic.Int32 + var wsRequestPayloads [][]byte + var wsRequestMu sync.Mutex + upgrader := websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }} + wsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + wsAttempts.Add(1) + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + t.Errorf("upgrade websocket failed: %v", err) + return + } + defer func() { + _ = conn.Close() + }() + + var req map[string]any + if err := conn.ReadJSON(&req); err != nil { + t.Errorf("read ws request failed: %v", err) + return + } + reqRaw, _ := json.Marshal(req) + wsRequestMu.Lock() + wsRequestPayloads = append(wsRequestPayloads, reqRaw) + wsRequestMu.Unlock() + _ = conn.WriteJSON(map[string]any{ + "type": "error", + "error": map[string]any{ + "code": "previous_response_not_found", + "type": "invalid_request_error", + "message": "previous response not found", + }, + }) + })) + defer wsServer.Close() + + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + c.Request.Header.Set("User-Agent", "custom-client/1.0") + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(`{"id":"resp_http_drop_prev","usage":{"input_tokens":1,"output_tokens":1}}`)), + }, + } + + cfg := &config.Config{} + cfg.Security.URLAllowlist.Enabled = false + cfg.Security.URLAllowlist.AllowInsecureHTTP = true + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1 + + svc := &OpenAIGatewayService{ + cfg: cfg, + httpUpstream: upstream, + openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), + toolCorrector: NewCodexToolCorrector(), + } + + account := &Account{ + ID: 94, + Name: "openai-apikey", + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Credentials: map[string]any{ + "api_key": "sk-test", + "base_url": wsServer.URL, + }, + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + + body := []byte(`{"model":"gpt-5.3-codex","stream":false,"previous_response_id":"resp_prev_missing","input":[{"type":"input_text","text":"hello"}]}`) + result, err := svc.Forward(context.Background(), c, account, body) + require.Error(t, err) + require.Nil(t, result) + require.Nil(t, upstream.lastReq, "WS 模式下 previous_response_not_found 不应回退 HTTP") + require.Equal(t, int32(2), wsAttempts.Load(), "应只允许一次自动恢复重试") + require.Equal(t, http.StatusBadRequest, rec.Code) + + wsRequestMu.Lock() + requests := append([][]byte(nil), wsRequestPayloads...) + wsRequestMu.Unlock() + require.Len(t, requests, 2) + require.True(t, gjson.GetBytes(requests[0], "previous_response_id").Exists(), "首轮请求应包含 previous_response_id") + require.False(t, gjson.GetBytes(requests[1], "previous_response_id").Exists(), "恢复重试应移除 previous_response_id") +} diff --git a/backend/internal/service/openai_ws_protocol_resolver.go b/backend/internal/service/openai_ws_protocol_resolver.go new file mode 100644 index 00000000..368643be --- /dev/null +++ b/backend/internal/service/openai_ws_protocol_resolver.go @@ -0,0 +1,117 @@ +package service + +import "github.com/Wei-Shaw/sub2api/internal/config" + +// OpenAIUpstreamTransport 表示 OpenAI 上游传输协议。 +type OpenAIUpstreamTransport string + +const ( + OpenAIUpstreamTransportAny OpenAIUpstreamTransport = "" + OpenAIUpstreamTransportHTTPSSE OpenAIUpstreamTransport = "http_sse" + OpenAIUpstreamTransportResponsesWebsocket OpenAIUpstreamTransport = "responses_websockets" + OpenAIUpstreamTransportResponsesWebsocketV2 OpenAIUpstreamTransport = "responses_websockets_v2" +) + +// OpenAIWSProtocolDecision 表示协议决策结果。 +type OpenAIWSProtocolDecision struct { + Transport OpenAIUpstreamTransport + Reason string +} + +// OpenAIWSProtocolResolver 定义 OpenAI 上游协议决策。 +type OpenAIWSProtocolResolver interface { + Resolve(account *Account) OpenAIWSProtocolDecision +} + +type defaultOpenAIWSProtocolResolver struct { + cfg *config.Config +} + +// NewOpenAIWSProtocolResolver 创建默认协议决策器。 +func NewOpenAIWSProtocolResolver(cfg *config.Config) OpenAIWSProtocolResolver { + return &defaultOpenAIWSProtocolResolver{cfg: cfg} +} + +func (r *defaultOpenAIWSProtocolResolver) Resolve(account *Account) OpenAIWSProtocolDecision { + if account == nil { + return openAIWSHTTPDecision("account_missing") + } + if !account.IsOpenAI() { + return openAIWSHTTPDecision("platform_not_openai") + } + if account.IsOpenAIWSForceHTTPEnabled() { + return openAIWSHTTPDecision("account_force_http") + } + if r == nil || r.cfg == nil { + return openAIWSHTTPDecision("config_missing") + } + + wsCfg := r.cfg.Gateway.OpenAIWS + if wsCfg.ForceHTTP { + return openAIWSHTTPDecision("global_force_http") + } + if !wsCfg.Enabled { + return openAIWSHTTPDecision("global_disabled") + } + if account.IsOpenAIOAuth() { + if !wsCfg.OAuthEnabled { + return openAIWSHTTPDecision("oauth_disabled") + } + } else if account.IsOpenAIApiKey() { + if !wsCfg.APIKeyEnabled { + return openAIWSHTTPDecision("apikey_disabled") + } + } else { + return openAIWSHTTPDecision("unknown_auth_type") + } + if wsCfg.ModeRouterV2Enabled { + mode := account.ResolveOpenAIResponsesWebSocketV2Mode(wsCfg.IngressModeDefault) + switch mode { + case OpenAIWSIngressModeOff: + return openAIWSHTTPDecision("account_mode_off") + case OpenAIWSIngressModeShared, OpenAIWSIngressModeDedicated: + // continue + default: + return openAIWSHTTPDecision("account_mode_off") + } + if account.Concurrency <= 0 { + return openAIWSHTTPDecision("account_concurrency_invalid") + } + if wsCfg.ResponsesWebsocketsV2 { + return OpenAIWSProtocolDecision{ + Transport: OpenAIUpstreamTransportResponsesWebsocketV2, + Reason: "ws_v2_mode_" + mode, + } + } + if wsCfg.ResponsesWebsockets { + return OpenAIWSProtocolDecision{ + Transport: OpenAIUpstreamTransportResponsesWebsocket, + Reason: "ws_v1_mode_" + mode, + } + } + return openAIWSHTTPDecision("feature_disabled") + } + if !account.IsOpenAIResponsesWebSocketV2Enabled() { + return openAIWSHTTPDecision("account_disabled") + } + if wsCfg.ResponsesWebsocketsV2 { + return OpenAIWSProtocolDecision{ + Transport: OpenAIUpstreamTransportResponsesWebsocketV2, + Reason: "ws_v2_enabled", + } + } + if wsCfg.ResponsesWebsockets { + return OpenAIWSProtocolDecision{ + Transport: OpenAIUpstreamTransportResponsesWebsocket, + Reason: "ws_v1_enabled", + } + } + return openAIWSHTTPDecision("feature_disabled") +} + +func openAIWSHTTPDecision(reason string) OpenAIWSProtocolDecision { + return OpenAIWSProtocolDecision{ + Transport: OpenAIUpstreamTransportHTTPSSE, + Reason: reason, + } +} diff --git a/backend/internal/service/openai_ws_protocol_resolver_test.go b/backend/internal/service/openai_ws_protocol_resolver_test.go new file mode 100644 index 00000000..5be76e28 --- /dev/null +++ b/backend/internal/service/openai_ws_protocol_resolver_test.go @@ -0,0 +1,203 @@ +package service + +import ( + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestOpenAIWSProtocolResolver_Resolve(t *testing.T) { + baseCfg := &config.Config{} + baseCfg.Gateway.OpenAIWS.Enabled = true + baseCfg.Gateway.OpenAIWS.OAuthEnabled = true + baseCfg.Gateway.OpenAIWS.APIKeyEnabled = true + baseCfg.Gateway.OpenAIWS.ResponsesWebsockets = false + baseCfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + + openAIOAuthEnabled := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_enabled": true, + }, + } + + t.Run("v2优先", func(t *testing.T) { + decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(openAIOAuthEnabled) + require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) + require.Equal(t, "ws_v2_enabled", decision.Reason) + }) + + t.Run("v2关闭时回退v1", func(t *testing.T) { + cfg := *baseCfg + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = false + cfg.Gateway.OpenAIWS.ResponsesWebsockets = true + + decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(openAIOAuthEnabled) + require.Equal(t, OpenAIUpstreamTransportResponsesWebsocket, decision.Transport) + require.Equal(t, "ws_v1_enabled", decision.Reason) + }) + + t.Run("透传开关不影响WS协议判定", func(t *testing.T) { + account := *openAIOAuthEnabled + account.Extra = map[string]any{ + "openai_oauth_responses_websockets_v2_enabled": true, + "openai_passthrough": true, + } + decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account) + require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) + require.Equal(t, "ws_v2_enabled", decision.Reason) + }) + + t.Run("账号级强制HTTP", func(t *testing.T) { + account := *openAIOAuthEnabled + account.Extra = map[string]any{ + "openai_oauth_responses_websockets_v2_enabled": true, + "openai_ws_force_http": true, + } + decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "account_force_http", decision.Reason) + }) + + t.Run("全局关闭保持HTTP", func(t *testing.T) { + cfg := *baseCfg + cfg.Gateway.OpenAIWS.Enabled = false + decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(openAIOAuthEnabled) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "global_disabled", decision.Reason) + }) + + t.Run("账号开关关闭保持HTTP", func(t *testing.T) { + account := *openAIOAuthEnabled + account.Extra = map[string]any{ + "openai_oauth_responses_websockets_v2_enabled": false, + } + decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "account_disabled", decision.Reason) + }) + + t.Run("OAuth账号不会读取API Key专用开关", func(t *testing.T) { + account := *openAIOAuthEnabled + account.Extra = map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + } + decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "account_disabled", decision.Reason) + }) + + t.Run("兼容旧键openai_ws_enabled", func(t *testing.T) { + account := *openAIOAuthEnabled + account.Extra = map[string]any{ + "openai_ws_enabled": true, + } + decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(&account) + require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) + require.Equal(t, "ws_v2_enabled", decision.Reason) + }) + + t.Run("按账号类型开关控制", func(t *testing.T) { + cfg := *baseCfg + cfg.Gateway.OpenAIWS.OAuthEnabled = false + decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(openAIOAuthEnabled) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "oauth_disabled", decision.Reason) + }) + + t.Run("API Key 账号关闭开关时回退HTTP", func(t *testing.T) { + cfg := *baseCfg + cfg.Gateway.OpenAIWS.APIKeyEnabled = false + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + decision := NewOpenAIWSProtocolResolver(&cfg).Resolve(account) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "apikey_disabled", decision.Reason) + }) + + t.Run("未知认证类型回退HTTP", func(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: "unknown_type", + Extra: map[string]any{ + "responses_websockets_v2_enabled": true, + }, + } + decision := NewOpenAIWSProtocolResolver(baseCfg).Resolve(account) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "unknown_auth_type", decision.Reason) + }) +} + +func TestOpenAIWSProtocolResolver_Resolve_ModeRouterV2(t *testing.T) { + cfg := &config.Config{} + cfg.Gateway.OpenAIWS.Enabled = true + cfg.Gateway.OpenAIWS.OAuthEnabled = true + cfg.Gateway.OpenAIWS.APIKeyEnabled = true + cfg.Gateway.OpenAIWS.ResponsesWebsocketsV2 = true + cfg.Gateway.OpenAIWS.ModeRouterV2Enabled = true + cfg.Gateway.OpenAIWS.IngressModeDefault = OpenAIWSIngressModeShared + + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated, + }, + } + + t.Run("dedicated mode routes to ws v2", func(t *testing.T) { + decision := NewOpenAIWSProtocolResolver(cfg).Resolve(account) + require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) + require.Equal(t, "ws_v2_mode_dedicated", decision.Reason) + }) + + t.Run("off mode routes to http", func(t *testing.T) { + offAccount := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Concurrency: 1, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeOff, + }, + } + decision := NewOpenAIWSProtocolResolver(cfg).Resolve(offAccount) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "account_mode_off", decision.Reason) + }) + + t.Run("legacy boolean maps to shared in v2 router", func(t *testing.T) { + legacyAccount := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeAPIKey, + Concurrency: 1, + Extra: map[string]any{ + "openai_apikey_responses_websockets_v2_enabled": true, + }, + } + decision := NewOpenAIWSProtocolResolver(cfg).Resolve(legacyAccount) + require.Equal(t, OpenAIUpstreamTransportResponsesWebsocketV2, decision.Transport) + require.Equal(t, "ws_v2_mode_shared", decision.Reason) + }) + + t.Run("non-positive concurrency is rejected in v2 router", func(t *testing.T) { + invalidConcurrency := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Extra: map[string]any{ + "openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared, + }, + } + decision := NewOpenAIWSProtocolResolver(cfg).Resolve(invalidConcurrency) + require.Equal(t, OpenAIUpstreamTransportHTTPSSE, decision.Transport) + require.Equal(t, "account_concurrency_invalid", decision.Reason) + }) +} diff --git a/backend/internal/service/openai_ws_state_store.go b/backend/internal/service/openai_ws_state_store.go new file mode 100644 index 00000000..b606baa1 --- /dev/null +++ b/backend/internal/service/openai_ws_state_store.go @@ -0,0 +1,440 @@ +package service + +import ( + "context" + "crypto/sha256" + "encoding/hex" + "fmt" + "strings" + "sync" + "sync/atomic" + "time" +) + +const ( + openAIWSResponseAccountCachePrefix = "openai:response:" + openAIWSStateStoreCleanupInterval = time.Minute + openAIWSStateStoreCleanupMaxPerMap = 512 + openAIWSStateStoreMaxEntriesPerMap = 65536 + openAIWSStateStoreRedisTimeout = 3 * time.Second +) + +type openAIWSAccountBinding struct { + accountID int64 + expiresAt time.Time +} + +type openAIWSConnBinding struct { + connID string + expiresAt time.Time +} + +type openAIWSTurnStateBinding struct { + turnState string + expiresAt time.Time +} + +type openAIWSSessionConnBinding struct { + connID string + expiresAt time.Time +} + +// OpenAIWSStateStore 管理 WSv2 的粘连状态。 +// - response_id -> account_id 用于续链路由 +// - response_id -> conn_id 用于连接内上下文复用 +// +// response_id -> account_id 优先走 GatewayCache(Redis),同时维护本地热缓存。 +// response_id -> conn_id 仅在本进程内有效。 +type OpenAIWSStateStore interface { + BindResponseAccount(ctx context.Context, groupID int64, responseID string, accountID int64, ttl time.Duration) error + GetResponseAccount(ctx context.Context, groupID int64, responseID string) (int64, error) + DeleteResponseAccount(ctx context.Context, groupID int64, responseID string) error + + BindResponseConn(responseID, connID string, ttl time.Duration) + GetResponseConn(responseID string) (string, bool) + DeleteResponseConn(responseID string) + + BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration) + GetSessionTurnState(groupID int64, sessionHash string) (string, bool) + DeleteSessionTurnState(groupID int64, sessionHash string) + + BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration) + GetSessionConn(groupID int64, sessionHash string) (string, bool) + DeleteSessionConn(groupID int64, sessionHash string) +} + +type defaultOpenAIWSStateStore struct { + cache GatewayCache + + responseToAccountMu sync.RWMutex + responseToAccount map[string]openAIWSAccountBinding + responseToConnMu sync.RWMutex + responseToConn map[string]openAIWSConnBinding + sessionToTurnStateMu sync.RWMutex + sessionToTurnState map[string]openAIWSTurnStateBinding + sessionToConnMu sync.RWMutex + sessionToConn map[string]openAIWSSessionConnBinding + + lastCleanupUnixNano atomic.Int64 +} + +// NewOpenAIWSStateStore 创建默认 WS 状态存储。 +func NewOpenAIWSStateStore(cache GatewayCache) OpenAIWSStateStore { + store := &defaultOpenAIWSStateStore{ + cache: cache, + responseToAccount: make(map[string]openAIWSAccountBinding, 256), + responseToConn: make(map[string]openAIWSConnBinding, 256), + sessionToTurnState: make(map[string]openAIWSTurnStateBinding, 256), + sessionToConn: make(map[string]openAIWSSessionConnBinding, 256), + } + store.lastCleanupUnixNano.Store(time.Now().UnixNano()) + return store +} + +func (s *defaultOpenAIWSStateStore) BindResponseAccount(ctx context.Context, groupID int64, responseID string, accountID int64, ttl time.Duration) error { + id := normalizeOpenAIWSResponseID(responseID) + if id == "" || accountID <= 0 { + return nil + } + ttl = normalizeOpenAIWSTTL(ttl) + s.maybeCleanup() + + expiresAt := time.Now().Add(ttl) + s.responseToAccountMu.Lock() + ensureBindingCapacity(s.responseToAccount, id, openAIWSStateStoreMaxEntriesPerMap) + s.responseToAccount[id] = openAIWSAccountBinding{accountID: accountID, expiresAt: expiresAt} + s.responseToAccountMu.Unlock() + + if s.cache == nil { + return nil + } + cacheKey := openAIWSResponseAccountCacheKey(id) + cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx) + defer cancel() + return s.cache.SetSessionAccountID(cacheCtx, groupID, cacheKey, accountID, ttl) +} + +func (s *defaultOpenAIWSStateStore) GetResponseAccount(ctx context.Context, groupID int64, responseID string) (int64, error) { + id := normalizeOpenAIWSResponseID(responseID) + if id == "" { + return 0, nil + } + s.maybeCleanup() + + now := time.Now() + s.responseToAccountMu.RLock() + if binding, ok := s.responseToAccount[id]; ok { + if now.Before(binding.expiresAt) { + accountID := binding.accountID + s.responseToAccountMu.RUnlock() + return accountID, nil + } + } + s.responseToAccountMu.RUnlock() + + if s.cache == nil { + return 0, nil + } + + cacheKey := openAIWSResponseAccountCacheKey(id) + cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx) + defer cancel() + accountID, err := s.cache.GetSessionAccountID(cacheCtx, groupID, cacheKey) + if err != nil || accountID <= 0 { + // 缓存读取失败不阻断主流程,按未命中降级。 + return 0, nil + } + return accountID, nil +} + +func (s *defaultOpenAIWSStateStore) DeleteResponseAccount(ctx context.Context, groupID int64, responseID string) error { + id := normalizeOpenAIWSResponseID(responseID) + if id == "" { + return nil + } + s.responseToAccountMu.Lock() + delete(s.responseToAccount, id) + s.responseToAccountMu.Unlock() + + if s.cache == nil { + return nil + } + cacheCtx, cancel := withOpenAIWSStateStoreRedisTimeout(ctx) + defer cancel() + return s.cache.DeleteSessionAccountID(cacheCtx, groupID, openAIWSResponseAccountCacheKey(id)) +} + +func (s *defaultOpenAIWSStateStore) BindResponseConn(responseID, connID string, ttl time.Duration) { + id := normalizeOpenAIWSResponseID(responseID) + conn := strings.TrimSpace(connID) + if id == "" || conn == "" { + return + } + ttl = normalizeOpenAIWSTTL(ttl) + s.maybeCleanup() + + s.responseToConnMu.Lock() + ensureBindingCapacity(s.responseToConn, id, openAIWSStateStoreMaxEntriesPerMap) + s.responseToConn[id] = openAIWSConnBinding{ + connID: conn, + expiresAt: time.Now().Add(ttl), + } + s.responseToConnMu.Unlock() +} + +func (s *defaultOpenAIWSStateStore) GetResponseConn(responseID string) (string, bool) { + id := normalizeOpenAIWSResponseID(responseID) + if id == "" { + return "", false + } + s.maybeCleanup() + + now := time.Now() + s.responseToConnMu.RLock() + binding, ok := s.responseToConn[id] + s.responseToConnMu.RUnlock() + if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.connID) == "" { + return "", false + } + return binding.connID, true +} + +func (s *defaultOpenAIWSStateStore) DeleteResponseConn(responseID string) { + id := normalizeOpenAIWSResponseID(responseID) + if id == "" { + return + } + s.responseToConnMu.Lock() + delete(s.responseToConn, id) + s.responseToConnMu.Unlock() +} + +func (s *defaultOpenAIWSStateStore) BindSessionTurnState(groupID int64, sessionHash, turnState string, ttl time.Duration) { + key := openAIWSSessionTurnStateKey(groupID, sessionHash) + state := strings.TrimSpace(turnState) + if key == "" || state == "" { + return + } + ttl = normalizeOpenAIWSTTL(ttl) + s.maybeCleanup() + + s.sessionToTurnStateMu.Lock() + ensureBindingCapacity(s.sessionToTurnState, key, openAIWSStateStoreMaxEntriesPerMap) + s.sessionToTurnState[key] = openAIWSTurnStateBinding{ + turnState: state, + expiresAt: time.Now().Add(ttl), + } + s.sessionToTurnStateMu.Unlock() +} + +func (s *defaultOpenAIWSStateStore) GetSessionTurnState(groupID int64, sessionHash string) (string, bool) { + key := openAIWSSessionTurnStateKey(groupID, sessionHash) + if key == "" { + return "", false + } + s.maybeCleanup() + + now := time.Now() + s.sessionToTurnStateMu.RLock() + binding, ok := s.sessionToTurnState[key] + s.sessionToTurnStateMu.RUnlock() + if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.turnState) == "" { + return "", false + } + return binding.turnState, true +} + +func (s *defaultOpenAIWSStateStore) DeleteSessionTurnState(groupID int64, sessionHash string) { + key := openAIWSSessionTurnStateKey(groupID, sessionHash) + if key == "" { + return + } + s.sessionToTurnStateMu.Lock() + delete(s.sessionToTurnState, key) + s.sessionToTurnStateMu.Unlock() +} + +func (s *defaultOpenAIWSStateStore) BindSessionConn(groupID int64, sessionHash, connID string, ttl time.Duration) { + key := openAIWSSessionTurnStateKey(groupID, sessionHash) + conn := strings.TrimSpace(connID) + if key == "" || conn == "" { + return + } + ttl = normalizeOpenAIWSTTL(ttl) + s.maybeCleanup() + + s.sessionToConnMu.Lock() + ensureBindingCapacity(s.sessionToConn, key, openAIWSStateStoreMaxEntriesPerMap) + s.sessionToConn[key] = openAIWSSessionConnBinding{ + connID: conn, + expiresAt: time.Now().Add(ttl), + } + s.sessionToConnMu.Unlock() +} + +func (s *defaultOpenAIWSStateStore) GetSessionConn(groupID int64, sessionHash string) (string, bool) { + key := openAIWSSessionTurnStateKey(groupID, sessionHash) + if key == "" { + return "", false + } + s.maybeCleanup() + + now := time.Now() + s.sessionToConnMu.RLock() + binding, ok := s.sessionToConn[key] + s.sessionToConnMu.RUnlock() + if !ok || now.After(binding.expiresAt) || strings.TrimSpace(binding.connID) == "" { + return "", false + } + return binding.connID, true +} + +func (s *defaultOpenAIWSStateStore) DeleteSessionConn(groupID int64, sessionHash string) { + key := openAIWSSessionTurnStateKey(groupID, sessionHash) + if key == "" { + return + } + s.sessionToConnMu.Lock() + delete(s.sessionToConn, key) + s.sessionToConnMu.Unlock() +} + +func (s *defaultOpenAIWSStateStore) maybeCleanup() { + if s == nil { + return + } + now := time.Now() + last := time.Unix(0, s.lastCleanupUnixNano.Load()) + if now.Sub(last) < openAIWSStateStoreCleanupInterval { + return + } + if !s.lastCleanupUnixNano.CompareAndSwap(last.UnixNano(), now.UnixNano()) { + return + } + + // 增量限额清理,避免高规模下一次性全量扫描导致长时间阻塞。 + s.responseToAccountMu.Lock() + cleanupExpiredAccountBindings(s.responseToAccount, now, openAIWSStateStoreCleanupMaxPerMap) + s.responseToAccountMu.Unlock() + + s.responseToConnMu.Lock() + cleanupExpiredConnBindings(s.responseToConn, now, openAIWSStateStoreCleanupMaxPerMap) + s.responseToConnMu.Unlock() + + s.sessionToTurnStateMu.Lock() + cleanupExpiredTurnStateBindings(s.sessionToTurnState, now, openAIWSStateStoreCleanupMaxPerMap) + s.sessionToTurnStateMu.Unlock() + + s.sessionToConnMu.Lock() + cleanupExpiredSessionConnBindings(s.sessionToConn, now, openAIWSStateStoreCleanupMaxPerMap) + s.sessionToConnMu.Unlock() +} + +func cleanupExpiredAccountBindings(bindings map[string]openAIWSAccountBinding, now time.Time, maxScan int) { + if len(bindings) == 0 || maxScan <= 0 { + return + } + scanned := 0 + for key, binding := range bindings { + if now.After(binding.expiresAt) { + delete(bindings, key) + } + scanned++ + if scanned >= maxScan { + break + } + } +} + +func cleanupExpiredConnBindings(bindings map[string]openAIWSConnBinding, now time.Time, maxScan int) { + if len(bindings) == 0 || maxScan <= 0 { + return + } + scanned := 0 + for key, binding := range bindings { + if now.After(binding.expiresAt) { + delete(bindings, key) + } + scanned++ + if scanned >= maxScan { + break + } + } +} + +func cleanupExpiredTurnStateBindings(bindings map[string]openAIWSTurnStateBinding, now time.Time, maxScan int) { + if len(bindings) == 0 || maxScan <= 0 { + return + } + scanned := 0 + for key, binding := range bindings { + if now.After(binding.expiresAt) { + delete(bindings, key) + } + scanned++ + if scanned >= maxScan { + break + } + } +} + +func cleanupExpiredSessionConnBindings(bindings map[string]openAIWSSessionConnBinding, now time.Time, maxScan int) { + if len(bindings) == 0 || maxScan <= 0 { + return + } + scanned := 0 + for key, binding := range bindings { + if now.After(binding.expiresAt) { + delete(bindings, key) + } + scanned++ + if scanned >= maxScan { + break + } + } +} + +func ensureBindingCapacity[T any](bindings map[string]T, incomingKey string, maxEntries int) { + if len(bindings) < maxEntries || maxEntries <= 0 { + return + } + if _, exists := bindings[incomingKey]; exists { + return + } + // 固定上限保护:淘汰任意一项,优先保证内存有界。 + for key := range bindings { + delete(bindings, key) + return + } +} + +func normalizeOpenAIWSResponseID(responseID string) string { + return strings.TrimSpace(responseID) +} + +func openAIWSResponseAccountCacheKey(responseID string) string { + sum := sha256.Sum256([]byte(responseID)) + return openAIWSResponseAccountCachePrefix + hex.EncodeToString(sum[:]) +} + +func normalizeOpenAIWSTTL(ttl time.Duration) time.Duration { + if ttl <= 0 { + return time.Hour + } + return ttl +} + +func openAIWSSessionTurnStateKey(groupID int64, sessionHash string) string { + hash := strings.TrimSpace(sessionHash) + if hash == "" { + return "" + } + return fmt.Sprintf("%d:%s", groupID, hash) +} + +func withOpenAIWSStateStoreRedisTimeout(ctx context.Context) (context.Context, context.CancelFunc) { + if ctx == nil { + ctx = context.Background() + } + return context.WithTimeout(ctx, openAIWSStateStoreRedisTimeout) +} diff --git a/backend/internal/service/openai_ws_state_store_test.go b/backend/internal/service/openai_ws_state_store_test.go new file mode 100644 index 00000000..235d4233 --- /dev/null +++ b/backend/internal/service/openai_ws_state_store_test.go @@ -0,0 +1,235 @@ +package service + +import ( + "context" + "errors" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestOpenAIWSStateStore_BindGetDeleteResponseAccount(t *testing.T) { + cache := &stubGatewayCache{} + store := NewOpenAIWSStateStore(cache) + ctx := context.Background() + groupID := int64(7) + + require.NoError(t, store.BindResponseAccount(ctx, groupID, "resp_abc", 101, time.Minute)) + + accountID, err := store.GetResponseAccount(ctx, groupID, "resp_abc") + require.NoError(t, err) + require.Equal(t, int64(101), accountID) + + require.NoError(t, store.DeleteResponseAccount(ctx, groupID, "resp_abc")) + accountID, err = store.GetResponseAccount(ctx, groupID, "resp_abc") + require.NoError(t, err) + require.Zero(t, accountID) +} + +func TestOpenAIWSStateStore_ResponseConnTTL(t *testing.T) { + store := NewOpenAIWSStateStore(nil) + store.BindResponseConn("resp_conn", "conn_1", 30*time.Millisecond) + + connID, ok := store.GetResponseConn("resp_conn") + require.True(t, ok) + require.Equal(t, "conn_1", connID) + + time.Sleep(60 * time.Millisecond) + _, ok = store.GetResponseConn("resp_conn") + require.False(t, ok) +} + +func TestOpenAIWSStateStore_SessionTurnStateTTL(t *testing.T) { + store := NewOpenAIWSStateStore(nil) + store.BindSessionTurnState(9, "session_hash_1", "turn_state_1", 30*time.Millisecond) + + state, ok := store.GetSessionTurnState(9, "session_hash_1") + require.True(t, ok) + require.Equal(t, "turn_state_1", state) + + // group 隔离 + _, ok = store.GetSessionTurnState(10, "session_hash_1") + require.False(t, ok) + + time.Sleep(60 * time.Millisecond) + _, ok = store.GetSessionTurnState(9, "session_hash_1") + require.False(t, ok) +} + +func TestOpenAIWSStateStore_SessionConnTTL(t *testing.T) { + store := NewOpenAIWSStateStore(nil) + store.BindSessionConn(9, "session_hash_conn_1", "conn_1", 30*time.Millisecond) + + connID, ok := store.GetSessionConn(9, "session_hash_conn_1") + require.True(t, ok) + require.Equal(t, "conn_1", connID) + + // group 隔离 + _, ok = store.GetSessionConn(10, "session_hash_conn_1") + require.False(t, ok) + + time.Sleep(60 * time.Millisecond) + _, ok = store.GetSessionConn(9, "session_hash_conn_1") + require.False(t, ok) +} + +func TestOpenAIWSStateStore_GetResponseAccount_NoStaleAfterCacheMiss(t *testing.T) { + cache := &stubGatewayCache{sessionBindings: map[string]int64{}} + store := NewOpenAIWSStateStore(cache) + ctx := context.Background() + groupID := int64(17) + responseID := "resp_cache_stale" + cacheKey := openAIWSResponseAccountCacheKey(responseID) + + cache.sessionBindings[cacheKey] = 501 + accountID, err := store.GetResponseAccount(ctx, groupID, responseID) + require.NoError(t, err) + require.Equal(t, int64(501), accountID) + + delete(cache.sessionBindings, cacheKey) + accountID, err = store.GetResponseAccount(ctx, groupID, responseID) + require.NoError(t, err) + require.Zero(t, accountID, "上游缓存失效后不应继续命中本地陈旧映射") +} + +func TestOpenAIWSStateStore_MaybeCleanupRemovesExpiredIncrementally(t *testing.T) { + raw := NewOpenAIWSStateStore(nil) + store, ok := raw.(*defaultOpenAIWSStateStore) + require.True(t, ok) + + expiredAt := time.Now().Add(-time.Minute) + total := 2048 + store.responseToConnMu.Lock() + for i := 0; i < total; i++ { + store.responseToConn[fmt.Sprintf("resp_%d", i)] = openAIWSConnBinding{ + connID: "conn_incremental", + expiresAt: expiredAt, + } + } + store.responseToConnMu.Unlock() + + store.lastCleanupUnixNano.Store(time.Now().Add(-2 * openAIWSStateStoreCleanupInterval).UnixNano()) + store.maybeCleanup() + + store.responseToConnMu.RLock() + remainingAfterFirst := len(store.responseToConn) + store.responseToConnMu.RUnlock() + require.Less(t, remainingAfterFirst, total, "单轮 cleanup 应至少有进展") + require.Greater(t, remainingAfterFirst, 0, "增量清理不要求单轮清空全部键") + + for i := 0; i < 8; i++ { + store.lastCleanupUnixNano.Store(time.Now().Add(-2 * openAIWSStateStoreCleanupInterval).UnixNano()) + store.maybeCleanup() + } + + store.responseToConnMu.RLock() + remaining := len(store.responseToConn) + store.responseToConnMu.RUnlock() + require.Zero(t, remaining, "多轮 cleanup 后应逐步清空全部过期键") +} + +func TestEnsureBindingCapacity_EvictsOneWhenMapIsFull(t *testing.T) { + bindings := map[string]int{ + "a": 1, + "b": 2, + } + + ensureBindingCapacity(bindings, "c", 2) + bindings["c"] = 3 + + require.Len(t, bindings, 2) + require.Equal(t, 3, bindings["c"]) +} + +func TestEnsureBindingCapacity_DoesNotEvictWhenUpdatingExistingKey(t *testing.T) { + bindings := map[string]int{ + "a": 1, + "b": 2, + } + + ensureBindingCapacity(bindings, "a", 2) + bindings["a"] = 9 + + require.Len(t, bindings, 2) + require.Equal(t, 9, bindings["a"]) +} + +type openAIWSStateStoreTimeoutProbeCache struct { + setHasDeadline bool + getHasDeadline bool + deleteHasDeadline bool + setDeadlineDelta time.Duration + getDeadlineDelta time.Duration + delDeadlineDelta time.Duration +} + +func (c *openAIWSStateStoreTimeoutProbeCache) GetSessionAccountID(ctx context.Context, _ int64, _ string) (int64, error) { + if deadline, ok := ctx.Deadline(); ok { + c.getHasDeadline = true + c.getDeadlineDelta = time.Until(deadline) + } + return 123, nil +} + +func (c *openAIWSStateStoreTimeoutProbeCache) SetSessionAccountID(ctx context.Context, _ int64, _ string, _ int64, _ time.Duration) error { + if deadline, ok := ctx.Deadline(); ok { + c.setHasDeadline = true + c.setDeadlineDelta = time.Until(deadline) + } + return errors.New("set failed") +} + +func (c *openAIWSStateStoreTimeoutProbeCache) RefreshSessionTTL(context.Context, int64, string, time.Duration) error { + return nil +} + +func (c *openAIWSStateStoreTimeoutProbeCache) DeleteSessionAccountID(ctx context.Context, _ int64, _ string) error { + if deadline, ok := ctx.Deadline(); ok { + c.deleteHasDeadline = true + c.delDeadlineDelta = time.Until(deadline) + } + return nil +} + +func TestOpenAIWSStateStore_RedisOpsUseShortTimeout(t *testing.T) { + probe := &openAIWSStateStoreTimeoutProbeCache{} + store := NewOpenAIWSStateStore(probe) + ctx := context.Background() + groupID := int64(5) + + err := store.BindResponseAccount(ctx, groupID, "resp_timeout_probe", 11, time.Minute) + require.Error(t, err) + + accountID, getErr := store.GetResponseAccount(ctx, groupID, "resp_timeout_probe") + require.NoError(t, getErr) + require.Equal(t, int64(11), accountID, "本地缓存命中应优先返回已绑定账号") + + require.NoError(t, store.DeleteResponseAccount(ctx, groupID, "resp_timeout_probe")) + + require.True(t, probe.setHasDeadline, "SetSessionAccountID 应携带独立超时上下文") + require.True(t, probe.deleteHasDeadline, "DeleteSessionAccountID 应携带独立超时上下文") + require.False(t, probe.getHasDeadline, "GetSessionAccountID 本用例应由本地缓存命中,不触发 Redis 读取") + require.Greater(t, probe.setDeadlineDelta, 2*time.Second) + require.LessOrEqual(t, probe.setDeadlineDelta, 3*time.Second) + require.Greater(t, probe.delDeadlineDelta, 2*time.Second) + require.LessOrEqual(t, probe.delDeadlineDelta, 3*time.Second) + + probe2 := &openAIWSStateStoreTimeoutProbeCache{} + store2 := NewOpenAIWSStateStore(probe2) + accountID2, err2 := store2.GetResponseAccount(ctx, groupID, "resp_cache_only") + require.NoError(t, err2) + require.Equal(t, int64(123), accountID2) + require.True(t, probe2.getHasDeadline, "GetSessionAccountID 在缓存未命中时应携带独立超时上下文") + require.Greater(t, probe2.getDeadlineDelta, 2*time.Second) + require.LessOrEqual(t, probe2.getDeadlineDelta, 3*time.Second) +} + +func TestWithOpenAIWSStateStoreRedisTimeout_WithParentContext(t *testing.T) { + ctx, cancel := withOpenAIWSStateStoreRedisTimeout(context.Background()) + defer cancel() + require.NotNil(t, ctx) + _, ok := ctx.Deadline() + require.True(t, ok, "应附加短超时") +} diff --git a/backend/internal/service/ops_retry.go b/backend/internal/service/ops_retry.go index 23a524ad..f0daa3e2 100644 --- a/backend/internal/service/ops_retry.go +++ b/backend/internal/service/ops_retry.go @@ -13,7 +13,6 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/domain" - "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/gin-gonic/gin" "github.com/lib/pq" @@ -480,7 +479,7 @@ func (s *OpsService) executeClientRetry(ctx context.Context, reqType opsRetryReq attemptCtx := ctx if switches > 0 { - attemptCtx = context.WithValue(attemptCtx, ctxkey.AccountSwitchCount, switches) + attemptCtx = WithAccountSwitchCount(attemptCtx, switches, false) } exec := func() *opsRetryExecution { defer selection.ReleaseFunc() @@ -675,6 +674,7 @@ func newOpsRetryContext(ctx context.Context, errorLog *OpsErrorLogDetail) (*gin. } c.Request = req + SetOpenAIClientTransport(c, OpenAIClientTransportHTTP) return c, w } diff --git a/backend/internal/service/ops_retry_context_test.go b/backend/internal/service/ops_retry_context_test.go new file mode 100644 index 00000000..a8c26ee4 --- /dev/null +++ b/backend/internal/service/ops_retry_context_test.go @@ -0,0 +1,47 @@ +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewOpsRetryContext_SetsHTTPTransportAndRequestHeaders(t *testing.T) { + errorLog := &OpsErrorLogDetail{ + OpsErrorLog: OpsErrorLog{ + RequestPath: "/openai/v1/responses", + }, + UserAgent: "ops-retry-agent/1.0", + RequestHeaders: `{ + "anthropic-beta":"beta-v1", + "ANTHROPIC-VERSION":"2023-06-01", + "authorization":"Bearer should-not-forward" + }`, + } + + c, w := newOpsRetryContext(context.Background(), errorLog) + require.NotNil(t, c) + require.NotNil(t, w) + require.NotNil(t, c.Request) + + require.Equal(t, "/openai/v1/responses", c.Request.URL.Path) + require.Equal(t, "application/json", c.Request.Header.Get("Content-Type")) + require.Equal(t, "ops-retry-agent/1.0", c.Request.Header.Get("User-Agent")) + require.Equal(t, "beta-v1", c.Request.Header.Get("anthropic-beta")) + require.Equal(t, "2023-06-01", c.Request.Header.Get("anthropic-version")) + require.Empty(t, c.Request.Header.Get("authorization"), "未在白名单内的敏感头不应被重放") + require.Equal(t, OpenAIClientTransportHTTP, GetOpenAIClientTransport(c)) +} + +func TestNewOpsRetryContext_InvalidHeadersJSONStillSetsHTTPTransport(t *testing.T) { + errorLog := &OpsErrorLogDetail{ + RequestHeaders: "{invalid-json", + } + + c, _ := newOpsRetryContext(context.Background(), errorLog) + require.NotNil(t, c) + require.NotNil(t, c.Request) + require.Equal(t, "/", c.Request.URL.Path) + require.Equal(t, OpenAIClientTransportHTTP, GetOpenAIClientTransport(c)) +} diff --git a/backend/internal/service/ops_upstream_context.go b/backend/internal/service/ops_upstream_context.go index 23c154ce..21e09c43 100644 --- a/backend/internal/service/ops_upstream_context.go +++ b/backend/internal/service/ops_upstream_context.go @@ -27,6 +27,11 @@ const ( OpsUpstreamLatencyMsKey = "ops_upstream_latency_ms" OpsResponseLatencyMsKey = "ops_response_latency_ms" OpsTimeToFirstTokenMsKey = "ops_time_to_first_token_ms" + // OpenAI WS 关键观测字段 + OpsOpenAIWSQueueWaitMsKey = "ops_openai_ws_queue_wait_ms" + OpsOpenAIWSConnPickMsKey = "ops_openai_ws_conn_pick_ms" + OpsOpenAIWSConnReusedKey = "ops_openai_ws_conn_reused" + OpsOpenAIWSConnIDKey = "ops_openai_ws_conn_id" // OpsSkipPassthroughKey 由 applyErrorPassthroughRule 在命中 skip_monitoring=true 的规则时设置。 // ops_error_logger 中间件检查此 key,为 true 时跳过错误记录。 diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index fcc7c4a0..d4d70536 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -11,6 +11,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" ) // RateLimitService 处理限流和过载状态管理 @@ -33,6 +34,10 @@ type geminiUsageCacheEntry struct { totals GeminiUsageTotals } +type geminiUsageTotalsBatchProvider interface { + GetGeminiUsageTotalsBatch(ctx context.Context, accountIDs []int64, startTime, endTime time.Time) (map[int64]GeminiUsageTotals, error) +} + const geminiPrecheckCacheTTL = time.Minute // NewRateLimitService 创建RateLimitService实例 @@ -162,6 +167,17 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc if upstreamMsg != "" { msg = "Access forbidden (403): " + upstreamMsg } + logger.LegacyPrintf( + "service.ratelimit", + "[HandleUpstreamErrorRaw] account_id=%d platform=%s type=%s status=403 request_id=%s cf_ray=%s upstream_msg=%s raw_body=%s", + account.ID, + account.Platform, + account.Type, + strings.TrimSpace(headers.Get("x-request-id")), + strings.TrimSpace(headers.Get("cf-ray")), + upstreamMsg, + truncateForLog(responseBody, 1024), + ) s.handleAuthError(ctx, account, msg) shouldDisable = true case 429: @@ -225,7 +241,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, start := geminiDailyWindowStart(now) totals, ok := s.getGeminiUsageTotals(account.ID, start, now) if !ok { - stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil) + stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil, nil) if err != nil { return true, err } @@ -272,7 +288,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, if limit > 0 { start := now.Truncate(time.Minute) - stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil) + stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil, nil) if err != nil { return true, err } @@ -302,6 +318,218 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, return true, nil } +// PreCheckUsageBatch performs quota precheck for multiple accounts in one request. +// Returned map value=false means the account should be skipped. +func (s *RateLimitService) PreCheckUsageBatch(ctx context.Context, accounts []*Account, requestedModel string) (map[int64]bool, error) { + result := make(map[int64]bool, len(accounts)) + for _, account := range accounts { + if account == nil { + continue + } + result[account.ID] = true + } + + if len(accounts) == 0 || requestedModel == "" { + return result, nil + } + if s.usageRepo == nil || s.geminiQuotaService == nil { + return result, nil + } + + modelClass := geminiModelClassFromName(requestedModel) + now := time.Now() + dailyStart := geminiDailyWindowStart(now) + minuteStart := now.Truncate(time.Minute) + + type quotaAccount struct { + account *Account + quota GeminiQuota + } + quotaAccounts := make([]quotaAccount, 0, len(accounts)) + for _, account := range accounts { + if account == nil || account.Platform != PlatformGemini { + continue + } + quota, ok := s.geminiQuotaService.QuotaForAccount(ctx, account) + if !ok { + continue + } + quotaAccounts = append(quotaAccounts, quotaAccount{ + account: account, + quota: quota, + }) + } + if len(quotaAccounts) == 0 { + return result, nil + } + + // 1) Daily precheck (cached + batch DB fallback) + dailyTotalsByID := make(map[int64]GeminiUsageTotals, len(quotaAccounts)) + dailyMissIDs := make([]int64, 0, len(quotaAccounts)) + for _, item := range quotaAccounts { + limit := geminiDailyLimit(item.quota, modelClass) + if limit <= 0 { + continue + } + accountID := item.account.ID + if totals, ok := s.getGeminiUsageTotals(accountID, dailyStart, now); ok { + dailyTotalsByID[accountID] = totals + continue + } + dailyMissIDs = append(dailyMissIDs, accountID) + } + if len(dailyMissIDs) > 0 { + totalsBatch, err := s.getGeminiUsageTotalsBatch(ctx, dailyMissIDs, dailyStart, now) + if err != nil { + return result, err + } + for _, accountID := range dailyMissIDs { + totals := totalsBatch[accountID] + dailyTotalsByID[accountID] = totals + s.setGeminiUsageTotals(accountID, dailyStart, now, totals) + } + } + for _, item := range quotaAccounts { + limit := geminiDailyLimit(item.quota, modelClass) + if limit <= 0 { + continue + } + accountID := item.account.ID + used := geminiUsedRequests(item.quota, modelClass, dailyTotalsByID[accountID], true) + if used >= limit { + resetAt := geminiDailyResetTime(now) + slog.Info("gemini_precheck_daily_quota_reached_batch", "account_id", accountID, "used", used, "limit", limit, "reset_at", resetAt) + result[accountID] = false + } + } + + // 2) Minute precheck (batch DB) + minuteIDs := make([]int64, 0, len(quotaAccounts)) + for _, item := range quotaAccounts { + accountID := item.account.ID + if !result[accountID] { + continue + } + if geminiMinuteLimit(item.quota, modelClass) <= 0 { + continue + } + minuteIDs = append(minuteIDs, accountID) + } + if len(minuteIDs) == 0 { + return result, nil + } + + minuteTotalsByID, err := s.getGeminiUsageTotalsBatch(ctx, minuteIDs, minuteStart, now) + if err != nil { + return result, err + } + for _, item := range quotaAccounts { + accountID := item.account.ID + if !result[accountID] { + continue + } + + limit := geminiMinuteLimit(item.quota, modelClass) + if limit <= 0 { + continue + } + + used := geminiUsedRequests(item.quota, modelClass, minuteTotalsByID[accountID], false) + if used >= limit { + resetAt := minuteStart.Add(time.Minute) + slog.Info("gemini_precheck_minute_quota_reached_batch", "account_id", accountID, "used", used, "limit", limit, "reset_at", resetAt) + result[accountID] = false + } + } + + return result, nil +} + +func (s *RateLimitService) getGeminiUsageTotalsBatch(ctx context.Context, accountIDs []int64, start, end time.Time) (map[int64]GeminiUsageTotals, error) { + result := make(map[int64]GeminiUsageTotals, len(accountIDs)) + if len(accountIDs) == 0 { + return result, nil + } + + ids := make([]int64, 0, len(accountIDs)) + seen := make(map[int64]struct{}, len(accountIDs)) + for _, accountID := range accountIDs { + if accountID <= 0 { + continue + } + if _, ok := seen[accountID]; ok { + continue + } + seen[accountID] = struct{}{} + ids = append(ids, accountID) + } + if len(ids) == 0 { + return result, nil + } + + if batchReader, ok := s.usageRepo.(geminiUsageTotalsBatchProvider); ok { + stats, err := batchReader.GetGeminiUsageTotalsBatch(ctx, ids, start, end) + if err != nil { + return nil, err + } + for _, accountID := range ids { + result[accountID] = stats[accountID] + } + return result, nil + } + + for _, accountID := range ids { + stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, end, 0, 0, accountID, 0, nil, nil, nil) + if err != nil { + return nil, err + } + result[accountID] = geminiAggregateUsage(stats) + } + return result, nil +} + +func geminiDailyLimit(quota GeminiQuota, modelClass geminiModelClass) int64 { + if quota.SharedRPD > 0 { + return quota.SharedRPD + } + switch modelClass { + case geminiModelFlash: + return quota.FlashRPD + default: + return quota.ProRPD + } +} + +func geminiMinuteLimit(quota GeminiQuota, modelClass geminiModelClass) int64 { + if quota.SharedRPM > 0 { + return quota.SharedRPM + } + switch modelClass { + case geminiModelFlash: + return quota.FlashRPM + default: + return quota.ProRPM + } +} + +func geminiUsedRequests(quota GeminiQuota, modelClass geminiModelClass, totals GeminiUsageTotals, daily bool) int64 { + if daily { + if quota.SharedRPD > 0 { + return totals.ProRequests + totals.FlashRequests + } + } else { + if quota.SharedRPM > 0 { + return totals.ProRequests + totals.FlashRequests + } + } + switch modelClass { + case geminiModelFlash: + return totals.FlashRequests + default: + return totals.ProRequests + } +} + func (s *RateLimitService) getGeminiUsageTotals(accountID int64, windowStart, now time.Time) (GeminiUsageTotals, bool) { s.usageCacheMu.RLock() defer s.usageCacheMu.RUnlock() diff --git a/backend/internal/service/request_metadata.go b/backend/internal/service/request_metadata.go new file mode 100644 index 00000000..5c81bbf1 --- /dev/null +++ b/backend/internal/service/request_metadata.go @@ -0,0 +1,216 @@ +package service + +import ( + "context" + "sync/atomic" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" +) + +type requestMetadataContextKey struct{} + +var requestMetadataKey = requestMetadataContextKey{} + +type RequestMetadata struct { + IsMaxTokensOneHaikuRequest *bool + ThinkingEnabled *bool + PrefetchedStickyAccountID *int64 + PrefetchedStickyGroupID *int64 + SingleAccountRetry *bool + AccountSwitchCount *int +} + +var ( + requestMetadataFallbackIsMaxTokensOneHaikuTotal atomic.Int64 + requestMetadataFallbackThinkingEnabledTotal atomic.Int64 + requestMetadataFallbackPrefetchedStickyAccount atomic.Int64 + requestMetadataFallbackPrefetchedStickyGroup atomic.Int64 + requestMetadataFallbackSingleAccountRetryTotal atomic.Int64 + requestMetadataFallbackAccountSwitchCountTotal atomic.Int64 +) + +func RequestMetadataFallbackStats() (isMaxTokensOneHaiku, thinkingEnabled, prefetchedStickyAccount, prefetchedStickyGroup, singleAccountRetry, accountSwitchCount int64) { + return requestMetadataFallbackIsMaxTokensOneHaikuTotal.Load(), + requestMetadataFallbackThinkingEnabledTotal.Load(), + requestMetadataFallbackPrefetchedStickyAccount.Load(), + requestMetadataFallbackPrefetchedStickyGroup.Load(), + requestMetadataFallbackSingleAccountRetryTotal.Load(), + requestMetadataFallbackAccountSwitchCountTotal.Load() +} + +func metadataFromContext(ctx context.Context) *RequestMetadata { + if ctx == nil { + return nil + } + md, _ := ctx.Value(requestMetadataKey).(*RequestMetadata) + return md +} + +func updateRequestMetadata( + ctx context.Context, + bridgeOldKeys bool, + update func(md *RequestMetadata), + legacyBridge func(ctx context.Context) context.Context, +) context.Context { + if ctx == nil { + return nil + } + current := metadataFromContext(ctx) + next := &RequestMetadata{} + if current != nil { + *next = *current + } + update(next) + ctx = context.WithValue(ctx, requestMetadataKey, next) + if bridgeOldKeys && legacyBridge != nil { + ctx = legacyBridge(ctx) + } + return ctx +} + +func WithIsMaxTokensOneHaikuRequest(ctx context.Context, value bool, bridgeOldKeys bool) context.Context { + return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) { + v := value + md.IsMaxTokensOneHaikuRequest = &v + }, func(base context.Context) context.Context { + return context.WithValue(base, ctxkey.IsMaxTokensOneHaikuRequest, value) + }) +} + +func WithThinkingEnabled(ctx context.Context, value bool, bridgeOldKeys bool) context.Context { + return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) { + v := value + md.ThinkingEnabled = &v + }, func(base context.Context) context.Context { + return context.WithValue(base, ctxkey.ThinkingEnabled, value) + }) +} + +func WithPrefetchedStickySession(ctx context.Context, accountID, groupID int64, bridgeOldKeys bool) context.Context { + return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) { + account := accountID + group := groupID + md.PrefetchedStickyAccountID = &account + md.PrefetchedStickyGroupID = &group + }, func(base context.Context) context.Context { + bridged := context.WithValue(base, ctxkey.PrefetchedStickyAccountID, accountID) + return context.WithValue(bridged, ctxkey.PrefetchedStickyGroupID, groupID) + }) +} + +func WithSingleAccountRetry(ctx context.Context, value bool, bridgeOldKeys bool) context.Context { + return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) { + v := value + md.SingleAccountRetry = &v + }, func(base context.Context) context.Context { + return context.WithValue(base, ctxkey.SingleAccountRetry, value) + }) +} + +func WithAccountSwitchCount(ctx context.Context, value int, bridgeOldKeys bool) context.Context { + return updateRequestMetadata(ctx, bridgeOldKeys, func(md *RequestMetadata) { + v := value + md.AccountSwitchCount = &v + }, func(base context.Context) context.Context { + return context.WithValue(base, ctxkey.AccountSwitchCount, value) + }) +} + +func IsMaxTokensOneHaikuRequestFromContext(ctx context.Context) (bool, bool) { + if md := metadataFromContext(ctx); md != nil && md.IsMaxTokensOneHaikuRequest != nil { + return *md.IsMaxTokensOneHaikuRequest, true + } + if ctx == nil { + return false, false + } + if value, ok := ctx.Value(ctxkey.IsMaxTokensOneHaikuRequest).(bool); ok { + requestMetadataFallbackIsMaxTokensOneHaikuTotal.Add(1) + return value, true + } + return false, false +} + +func ThinkingEnabledFromContext(ctx context.Context) (bool, bool) { + if md := metadataFromContext(ctx); md != nil && md.ThinkingEnabled != nil { + return *md.ThinkingEnabled, true + } + if ctx == nil { + return false, false + } + if value, ok := ctx.Value(ctxkey.ThinkingEnabled).(bool); ok { + requestMetadataFallbackThinkingEnabledTotal.Add(1) + return value, true + } + return false, false +} + +func PrefetchedStickyGroupIDFromContext(ctx context.Context) (int64, bool) { + if md := metadataFromContext(ctx); md != nil && md.PrefetchedStickyGroupID != nil { + return *md.PrefetchedStickyGroupID, true + } + if ctx == nil { + return 0, false + } + v := ctx.Value(ctxkey.PrefetchedStickyGroupID) + switch t := v.(type) { + case int64: + requestMetadataFallbackPrefetchedStickyGroup.Add(1) + return t, true + case int: + requestMetadataFallbackPrefetchedStickyGroup.Add(1) + return int64(t), true + } + return 0, false +} + +func PrefetchedStickyAccountIDFromContext(ctx context.Context) (int64, bool) { + if md := metadataFromContext(ctx); md != nil && md.PrefetchedStickyAccountID != nil { + return *md.PrefetchedStickyAccountID, true + } + if ctx == nil { + return 0, false + } + v := ctx.Value(ctxkey.PrefetchedStickyAccountID) + switch t := v.(type) { + case int64: + requestMetadataFallbackPrefetchedStickyAccount.Add(1) + return t, true + case int: + requestMetadataFallbackPrefetchedStickyAccount.Add(1) + return int64(t), true + } + return 0, false +} + +func SingleAccountRetryFromContext(ctx context.Context) (bool, bool) { + if md := metadataFromContext(ctx); md != nil && md.SingleAccountRetry != nil { + return *md.SingleAccountRetry, true + } + if ctx == nil { + return false, false + } + if value, ok := ctx.Value(ctxkey.SingleAccountRetry).(bool); ok { + requestMetadataFallbackSingleAccountRetryTotal.Add(1) + return value, true + } + return false, false +} + +func AccountSwitchCountFromContext(ctx context.Context) (int, bool) { + if md := metadataFromContext(ctx); md != nil && md.AccountSwitchCount != nil { + return *md.AccountSwitchCount, true + } + if ctx == nil { + return 0, false + } + v := ctx.Value(ctxkey.AccountSwitchCount) + switch t := v.(type) { + case int: + requestMetadataFallbackAccountSwitchCountTotal.Add(1) + return t, true + case int64: + requestMetadataFallbackAccountSwitchCountTotal.Add(1) + return int(t), true + } + return 0, false +} diff --git a/backend/internal/service/request_metadata_test.go b/backend/internal/service/request_metadata_test.go new file mode 100644 index 00000000..7d192699 --- /dev/null +++ b/backend/internal/service/request_metadata_test.go @@ -0,0 +1,119 @@ +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey" + "github.com/stretchr/testify/require" +) + +func TestRequestMetadataWriteAndRead_NoBridge(t *testing.T) { + ctx := context.Background() + ctx = WithIsMaxTokensOneHaikuRequest(ctx, true, false) + ctx = WithThinkingEnabled(ctx, true, false) + ctx = WithPrefetchedStickySession(ctx, 123, 456, false) + ctx = WithSingleAccountRetry(ctx, true, false) + ctx = WithAccountSwitchCount(ctx, 2, false) + + isHaiku, ok := IsMaxTokensOneHaikuRequestFromContext(ctx) + require.True(t, ok) + require.True(t, isHaiku) + + thinking, ok := ThinkingEnabledFromContext(ctx) + require.True(t, ok) + require.True(t, thinking) + + accountID, ok := PrefetchedStickyAccountIDFromContext(ctx) + require.True(t, ok) + require.Equal(t, int64(123), accountID) + + groupID, ok := PrefetchedStickyGroupIDFromContext(ctx) + require.True(t, ok) + require.Equal(t, int64(456), groupID) + + singleRetry, ok := SingleAccountRetryFromContext(ctx) + require.True(t, ok) + require.True(t, singleRetry) + + switchCount, ok := AccountSwitchCountFromContext(ctx) + require.True(t, ok) + require.Equal(t, 2, switchCount) + + require.Nil(t, ctx.Value(ctxkey.IsMaxTokensOneHaikuRequest)) + require.Nil(t, ctx.Value(ctxkey.ThinkingEnabled)) + require.Nil(t, ctx.Value(ctxkey.PrefetchedStickyAccountID)) + require.Nil(t, ctx.Value(ctxkey.PrefetchedStickyGroupID)) + require.Nil(t, ctx.Value(ctxkey.SingleAccountRetry)) + require.Nil(t, ctx.Value(ctxkey.AccountSwitchCount)) +} + +func TestRequestMetadataWrite_BridgeLegacyKeys(t *testing.T) { + ctx := context.Background() + ctx = WithIsMaxTokensOneHaikuRequest(ctx, true, true) + ctx = WithThinkingEnabled(ctx, true, true) + ctx = WithPrefetchedStickySession(ctx, 123, 456, true) + ctx = WithSingleAccountRetry(ctx, true, true) + ctx = WithAccountSwitchCount(ctx, 2, true) + + require.Equal(t, true, ctx.Value(ctxkey.IsMaxTokensOneHaikuRequest)) + require.Equal(t, true, ctx.Value(ctxkey.ThinkingEnabled)) + require.Equal(t, int64(123), ctx.Value(ctxkey.PrefetchedStickyAccountID)) + require.Equal(t, int64(456), ctx.Value(ctxkey.PrefetchedStickyGroupID)) + require.Equal(t, true, ctx.Value(ctxkey.SingleAccountRetry)) + require.Equal(t, 2, ctx.Value(ctxkey.AccountSwitchCount)) +} + +func TestRequestMetadataRead_LegacyFallbackAndStats(t *testing.T) { + beforeHaiku, beforeThinking, beforeAccount, beforeGroup, beforeSingleRetry, beforeSwitchCount := RequestMetadataFallbackStats() + + ctx := context.Background() + ctx = context.WithValue(ctx, ctxkey.IsMaxTokensOneHaikuRequest, true) + ctx = context.WithValue(ctx, ctxkey.ThinkingEnabled, true) + ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyAccountID, int64(321)) + ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, int64(654)) + ctx = context.WithValue(ctx, ctxkey.SingleAccountRetry, true) + ctx = context.WithValue(ctx, ctxkey.AccountSwitchCount, int64(3)) + + isHaiku, ok := IsMaxTokensOneHaikuRequestFromContext(ctx) + require.True(t, ok) + require.True(t, isHaiku) + + thinking, ok := ThinkingEnabledFromContext(ctx) + require.True(t, ok) + require.True(t, thinking) + + accountID, ok := PrefetchedStickyAccountIDFromContext(ctx) + require.True(t, ok) + require.Equal(t, int64(321), accountID) + + groupID, ok := PrefetchedStickyGroupIDFromContext(ctx) + require.True(t, ok) + require.Equal(t, int64(654), groupID) + + singleRetry, ok := SingleAccountRetryFromContext(ctx) + require.True(t, ok) + require.True(t, singleRetry) + + switchCount, ok := AccountSwitchCountFromContext(ctx) + require.True(t, ok) + require.Equal(t, 3, switchCount) + + afterHaiku, afterThinking, afterAccount, afterGroup, afterSingleRetry, afterSwitchCount := RequestMetadataFallbackStats() + require.Equal(t, beforeHaiku+1, afterHaiku) + require.Equal(t, beforeThinking+1, afterThinking) + require.Equal(t, beforeAccount+1, afterAccount) + require.Equal(t, beforeGroup+1, afterGroup) + require.Equal(t, beforeSingleRetry+1, afterSingleRetry) + require.Equal(t, beforeSwitchCount+1, afterSwitchCount) +} + +func TestRequestMetadataRead_PreferMetadataOverLegacy(t *testing.T) { + ctx := context.WithValue(context.Background(), ctxkey.ThinkingEnabled, false) + ctx = WithThinkingEnabled(ctx, true, false) + + thinking, ok := ThinkingEnabledFromContext(ctx) + require.True(t, ok) + require.True(t, thinking) + require.Equal(t, false, ctx.Value(ctxkey.ThinkingEnabled)) +} diff --git a/backend/internal/service/response_header_filter.go b/backend/internal/service/response_header_filter.go new file mode 100644 index 00000000..81012b01 --- /dev/null +++ b/backend/internal/service/response_header_filter.go @@ -0,0 +1,13 @@ +package service + +import ( + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" +) + +func compileResponseHeaderFilter(cfg *config.Config) *responseheaders.CompiledHeaderFilter { + if cfg == nil { + return nil + } + return responseheaders.CompileHeaderFilter(cfg.Security.ResponseHeaders) +} diff --git a/backend/internal/service/scheduler_snapshot_service.go b/backend/internal/service/scheduler_snapshot_service.go index 4d95743c..9f8fa14a 100644 --- a/backend/internal/service/scheduler_snapshot_service.go +++ b/backend/internal/service/scheduler_snapshot_service.go @@ -305,13 +305,78 @@ func (s *SchedulerSnapshotService) handleBulkAccountEvent(ctx context.Context, p if payload == nil { return nil } - ids := parseInt64Slice(payload["account_ids"]) - for _, id := range ids { - if err := s.handleAccountEvent(ctx, &id, payload); err != nil { - return err + if s.accountRepo == nil { + return nil + } + + rawIDs := parseInt64Slice(payload["account_ids"]) + if len(rawIDs) == 0 { + return nil + } + + ids := make([]int64, 0, len(rawIDs)) + seen := make(map[int64]struct{}, len(rawIDs)) + for _, id := range rawIDs { + if id <= 0 { + continue + } + if _, exists := seen[id]; exists { + continue + } + seen[id] = struct{}{} + ids = append(ids, id) + } + if len(ids) == 0 { + return nil + } + + preloadGroupIDs := parseInt64Slice(payload["group_ids"]) + accounts, err := s.accountRepo.GetByIDs(ctx, ids) + if err != nil { + return err + } + + found := make(map[int64]struct{}, len(accounts)) + rebuildGroupSet := make(map[int64]struct{}, len(preloadGroupIDs)) + for _, gid := range preloadGroupIDs { + if gid > 0 { + rebuildGroupSet[gid] = struct{}{} } } - return nil + + for _, account := range accounts { + if account == nil || account.ID <= 0 { + continue + } + found[account.ID] = struct{}{} + if s.cache != nil { + if err := s.cache.SetAccount(ctx, account); err != nil { + return err + } + } + for _, gid := range account.GroupIDs { + if gid > 0 { + rebuildGroupSet[gid] = struct{}{} + } + } + } + + if s.cache != nil { + for _, id := range ids { + if _, ok := found[id]; ok { + continue + } + if err := s.cache.DeleteAccount(ctx, id); err != nil { + return err + } + } + } + + rebuildGroupIDs := make([]int64, 0, len(rebuildGroupSet)) + for gid := range rebuildGroupSet { + rebuildGroupIDs = append(rebuildGroupIDs, gid) + } + return s.rebuildByGroupIDs(ctx, rebuildGroupIDs, "account_bulk_change") } func (s *SchedulerSnapshotService) handleAccountEvent(ctx context.Context, accountID *int64, payload map[string]any) error { diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index f5ba9d71..445167b7 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -9,14 +9,17 @@ import ( "fmt" "strconv" "strings" + "time" "github.com/Wei-Shaw/sub2api/internal/config" infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" ) var ( - ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") - ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found") + ErrRegistrationDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") + ErrSettingNotFound = infraerrors.NotFound("SETTING_NOT_FOUND", "setting not found") + ErrSoraS3ProfileNotFound = infraerrors.NotFound("SORA_S3_PROFILE_NOT_FOUND", "sora s3 profile not found") + ErrSoraS3ProfileExists = infraerrors.Conflict("SORA_S3_PROFILE_EXISTS", "sora s3 profile already exists") ) type SettingRepository interface { @@ -34,6 +37,7 @@ type SettingService struct { settingRepo SettingRepository cfg *config.Config onUpdate func() // Callback when settings are updated (for cache invalidation) + onS3Update func() // Callback when Sora S3 settings are updated version string // Application version } @@ -76,6 +80,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingKeyHideCcsImportButton, SettingKeyPurchaseSubscriptionEnabled, SettingKeyPurchaseSubscriptionURL, + SettingKeySoraClientEnabled, SettingKeyLinuxDoConnectEnabled, } @@ -114,6 +119,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), + SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", LinuxDoOAuthEnabled: linuxDoEnabled, }, nil } @@ -124,6 +130,11 @@ func (s *SettingService) SetOnUpdateCallback(callback func()) { s.onUpdate = callback } +// SetOnS3UpdateCallback 设置 Sora S3 配置变更时的回调函数(用于刷新 S3 客户端缓存)。 +func (s *SettingService) SetOnS3UpdateCallback(callback func()) { + s.onS3Update = callback +} + // SetVersion sets the application version for injection into public settings func (s *SettingService) SetVersion(version string) { s.version = version @@ -157,6 +168,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any HideCcsImportButton bool `json:"hide_ccs_import_button"` PurchaseSubscriptionEnabled bool `json:"purchase_subscription_enabled"` PurchaseSubscriptionURL string `json:"purchase_subscription_url,omitempty"` + SoraClientEnabled bool `json:"sora_client_enabled"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` Version string `json:"version,omitempty"` }{ @@ -178,6 +190,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any HideCcsImportButton: settings.HideCcsImportButton, PurchaseSubscriptionEnabled: settings.PurchaseSubscriptionEnabled, PurchaseSubscriptionURL: settings.PurchaseSubscriptionURL, + SoraClientEnabled: settings.SoraClientEnabled, LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, Version: s.version, }, nil @@ -232,6 +245,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyHideCcsImportButton] = strconv.FormatBool(settings.HideCcsImportButton) updates[SettingKeyPurchaseSubscriptionEnabled] = strconv.FormatBool(settings.PurchaseSubscriptionEnabled) updates[SettingKeyPurchaseSubscriptionURL] = strings.TrimSpace(settings.PurchaseSubscriptionURL) + updates[SettingKeySoraClientEnabled] = strconv.FormatBool(settings.SoraClientEnabled) // 默认配置 updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency) @@ -383,6 +397,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeySiteLogo: "", SettingKeyPurchaseSubscriptionEnabled: "false", SettingKeyPurchaseSubscriptionURL: "", + SettingKeySoraClientEnabled: "false", SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), SettingKeySMTPPort: "587", @@ -436,6 +451,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", PurchaseSubscriptionEnabled: settings[SettingKeyPurchaseSubscriptionEnabled] == "true", PurchaseSubscriptionURL: strings.TrimSpace(settings[SettingKeyPurchaseSubscriptionURL]), + SoraClientEnabled: settings[SettingKeySoraClientEnabled] == "true", } // 解析整数类型 @@ -854,3 +870,607 @@ func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings return s.settingRepo.Set(ctx, SettingKeyStreamTimeoutSettings, string(data)) } + +type soraS3ProfilesStore struct { + ActiveProfileID string `json:"active_profile_id"` + Items []soraS3ProfileStoreItem `json:"items"` +} + +type soraS3ProfileStoreItem struct { + ProfileID string `json:"profile_id"` + Name string `json:"name"` + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + CDNURL string `json:"cdn_url"` + DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` + UpdatedAt string `json:"updated_at"` +} + +// GetSoraS3Settings 获取 Sora S3 存储配置(兼容旧单配置语义:返回当前激活配置) +func (s *SettingService) GetSoraS3Settings(ctx context.Context) (*SoraS3Settings, error) { + profiles, err := s.ListSoraS3Profiles(ctx) + if err != nil { + return nil, err + } + + activeProfile := pickActiveSoraS3Profile(profiles.Items, profiles.ActiveProfileID) + if activeProfile == nil { + return &SoraS3Settings{}, nil + } + + return &SoraS3Settings{ + Enabled: activeProfile.Enabled, + Endpoint: activeProfile.Endpoint, + Region: activeProfile.Region, + Bucket: activeProfile.Bucket, + AccessKeyID: activeProfile.AccessKeyID, + SecretAccessKey: activeProfile.SecretAccessKey, + SecretAccessKeyConfigured: activeProfile.SecretAccessKeyConfigured, + Prefix: activeProfile.Prefix, + ForcePathStyle: activeProfile.ForcePathStyle, + CDNURL: activeProfile.CDNURL, + DefaultStorageQuotaBytes: activeProfile.DefaultStorageQuotaBytes, + }, nil +} + +// SetSoraS3Settings 更新 Sora S3 存储配置(兼容旧单配置语义:写入当前激活配置) +func (s *SettingService) SetSoraS3Settings(ctx context.Context, settings *SoraS3Settings) error { + if settings == nil { + return fmt.Errorf("settings cannot be nil") + } + + store, err := s.loadSoraS3ProfilesStore(ctx) + if err != nil { + return err + } + + now := time.Now().UTC().Format(time.RFC3339) + activeIndex := findSoraS3ProfileIndex(store.Items, store.ActiveProfileID) + if activeIndex < 0 { + activeID := "default" + if hasSoraS3ProfileID(store.Items, activeID) { + activeID = fmt.Sprintf("default-%d", time.Now().Unix()) + } + store.Items = append(store.Items, soraS3ProfileStoreItem{ + ProfileID: activeID, + Name: "Default", + UpdatedAt: now, + }) + store.ActiveProfileID = activeID + activeIndex = len(store.Items) - 1 + } + + active := store.Items[activeIndex] + active.Enabled = settings.Enabled + active.Endpoint = strings.TrimSpace(settings.Endpoint) + active.Region = strings.TrimSpace(settings.Region) + active.Bucket = strings.TrimSpace(settings.Bucket) + active.AccessKeyID = strings.TrimSpace(settings.AccessKeyID) + active.Prefix = strings.TrimSpace(settings.Prefix) + active.ForcePathStyle = settings.ForcePathStyle + active.CDNURL = strings.TrimSpace(settings.CDNURL) + active.DefaultStorageQuotaBytes = maxInt64(settings.DefaultStorageQuotaBytes, 0) + if settings.SecretAccessKey != "" { + active.SecretAccessKey = settings.SecretAccessKey + } + active.UpdatedAt = now + store.Items[activeIndex] = active + + return s.persistSoraS3ProfilesStore(ctx, store) +} + +// ListSoraS3Profiles 获取 Sora S3 多配置列表 +func (s *SettingService) ListSoraS3Profiles(ctx context.Context) (*SoraS3ProfileList, error) { + store, err := s.loadSoraS3ProfilesStore(ctx) + if err != nil { + return nil, err + } + return convertSoraS3ProfilesStore(store), nil +} + +// CreateSoraS3Profile 创建 Sora S3 配置 +func (s *SettingService) CreateSoraS3Profile(ctx context.Context, profile *SoraS3Profile, setActive bool) (*SoraS3Profile, error) { + if profile == nil { + return nil, fmt.Errorf("profile cannot be nil") + } + + profileID := strings.TrimSpace(profile.ProfileID) + if profileID == "" { + return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required") + } + name := strings.TrimSpace(profile.Name) + if name == "" { + return nil, infraerrors.BadRequest("SORA_S3_PROFILE_NAME_REQUIRED", "name is required") + } + + store, err := s.loadSoraS3ProfilesStore(ctx) + if err != nil { + return nil, err + } + if hasSoraS3ProfileID(store.Items, profileID) { + return nil, ErrSoraS3ProfileExists + } + + now := time.Now().UTC().Format(time.RFC3339) + store.Items = append(store.Items, soraS3ProfileStoreItem{ + ProfileID: profileID, + Name: name, + Enabled: profile.Enabled, + Endpoint: strings.TrimSpace(profile.Endpoint), + Region: strings.TrimSpace(profile.Region), + Bucket: strings.TrimSpace(profile.Bucket), + AccessKeyID: strings.TrimSpace(profile.AccessKeyID), + SecretAccessKey: profile.SecretAccessKey, + Prefix: strings.TrimSpace(profile.Prefix), + ForcePathStyle: profile.ForcePathStyle, + CDNURL: strings.TrimSpace(profile.CDNURL), + DefaultStorageQuotaBytes: maxInt64(profile.DefaultStorageQuotaBytes, 0), + UpdatedAt: now, + }) + + if setActive || store.ActiveProfileID == "" { + store.ActiveProfileID = profileID + } + + if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil { + return nil, err + } + + profiles := convertSoraS3ProfilesStore(store) + created := findSoraS3ProfileByID(profiles.Items, profileID) + if created == nil { + return nil, ErrSoraS3ProfileNotFound + } + return created, nil +} + +// UpdateSoraS3Profile 更新 Sora S3 配置 +func (s *SettingService) UpdateSoraS3Profile(ctx context.Context, profileID string, profile *SoraS3Profile) (*SoraS3Profile, error) { + if profile == nil { + return nil, fmt.Errorf("profile cannot be nil") + } + + targetID := strings.TrimSpace(profileID) + if targetID == "" { + return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required") + } + + store, err := s.loadSoraS3ProfilesStore(ctx) + if err != nil { + return nil, err + } + + targetIndex := findSoraS3ProfileIndex(store.Items, targetID) + if targetIndex < 0 { + return nil, ErrSoraS3ProfileNotFound + } + + target := store.Items[targetIndex] + name := strings.TrimSpace(profile.Name) + if name == "" { + return nil, infraerrors.BadRequest("SORA_S3_PROFILE_NAME_REQUIRED", "name is required") + } + target.Name = name + target.Enabled = profile.Enabled + target.Endpoint = strings.TrimSpace(profile.Endpoint) + target.Region = strings.TrimSpace(profile.Region) + target.Bucket = strings.TrimSpace(profile.Bucket) + target.AccessKeyID = strings.TrimSpace(profile.AccessKeyID) + target.Prefix = strings.TrimSpace(profile.Prefix) + target.ForcePathStyle = profile.ForcePathStyle + target.CDNURL = strings.TrimSpace(profile.CDNURL) + target.DefaultStorageQuotaBytes = maxInt64(profile.DefaultStorageQuotaBytes, 0) + if profile.SecretAccessKey != "" { + target.SecretAccessKey = profile.SecretAccessKey + } + target.UpdatedAt = time.Now().UTC().Format(time.RFC3339) + store.Items[targetIndex] = target + + if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil { + return nil, err + } + + profiles := convertSoraS3ProfilesStore(store) + updated := findSoraS3ProfileByID(profiles.Items, targetID) + if updated == nil { + return nil, ErrSoraS3ProfileNotFound + } + return updated, nil +} + +// DeleteSoraS3Profile 删除 Sora S3 配置 +func (s *SettingService) DeleteSoraS3Profile(ctx context.Context, profileID string) error { + targetID := strings.TrimSpace(profileID) + if targetID == "" { + return infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required") + } + + store, err := s.loadSoraS3ProfilesStore(ctx) + if err != nil { + return err + } + + targetIndex := findSoraS3ProfileIndex(store.Items, targetID) + if targetIndex < 0 { + return ErrSoraS3ProfileNotFound + } + + store.Items = append(store.Items[:targetIndex], store.Items[targetIndex+1:]...) + if store.ActiveProfileID == targetID { + store.ActiveProfileID = "" + if len(store.Items) > 0 { + store.ActiveProfileID = store.Items[0].ProfileID + } + } + + return s.persistSoraS3ProfilesStore(ctx, store) +} + +// SetActiveSoraS3Profile 设置激活的 Sora S3 配置 +func (s *SettingService) SetActiveSoraS3Profile(ctx context.Context, profileID string) (*SoraS3Profile, error) { + targetID := strings.TrimSpace(profileID) + if targetID == "" { + return nil, infraerrors.BadRequest("SORA_S3_PROFILE_ID_REQUIRED", "profile_id is required") + } + + store, err := s.loadSoraS3ProfilesStore(ctx) + if err != nil { + return nil, err + } + + targetIndex := findSoraS3ProfileIndex(store.Items, targetID) + if targetIndex < 0 { + return nil, ErrSoraS3ProfileNotFound + } + + store.ActiveProfileID = targetID + store.Items[targetIndex].UpdatedAt = time.Now().UTC().Format(time.RFC3339) + if err := s.persistSoraS3ProfilesStore(ctx, store); err != nil { + return nil, err + } + + profiles := convertSoraS3ProfilesStore(store) + active := pickActiveSoraS3Profile(profiles.Items, profiles.ActiveProfileID) + if active == nil { + return nil, ErrSoraS3ProfileNotFound + } + return active, nil +} + +func (s *SettingService) loadSoraS3ProfilesStore(ctx context.Context) (*soraS3ProfilesStore, error) { + raw, err := s.settingRepo.GetValue(ctx, SettingKeySoraS3Profiles) + if err == nil { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return &soraS3ProfilesStore{}, nil + } + var store soraS3ProfilesStore + if unmarshalErr := json.Unmarshal([]byte(trimmed), &store); unmarshalErr != nil { + legacy, legacyErr := s.getLegacySoraS3Settings(ctx) + if legacyErr != nil { + return nil, fmt.Errorf("unmarshal sora s3 profiles: %w", unmarshalErr) + } + if isEmptyLegacySoraS3Settings(legacy) { + return &soraS3ProfilesStore{}, nil + } + now := time.Now().UTC().Format(time.RFC3339) + return &soraS3ProfilesStore{ + ActiveProfileID: "default", + Items: []soraS3ProfileStoreItem{ + { + ProfileID: "default", + Name: "Default", + Enabled: legacy.Enabled, + Endpoint: strings.TrimSpace(legacy.Endpoint), + Region: strings.TrimSpace(legacy.Region), + Bucket: strings.TrimSpace(legacy.Bucket), + AccessKeyID: strings.TrimSpace(legacy.AccessKeyID), + SecretAccessKey: legacy.SecretAccessKey, + Prefix: strings.TrimSpace(legacy.Prefix), + ForcePathStyle: legacy.ForcePathStyle, + CDNURL: strings.TrimSpace(legacy.CDNURL), + DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0), + UpdatedAt: now, + }, + }, + }, nil + } + normalized := normalizeSoraS3ProfilesStore(store) + return &normalized, nil + } + + if !errors.Is(err, ErrSettingNotFound) { + return nil, fmt.Errorf("get sora s3 profiles: %w", err) + } + + legacy, legacyErr := s.getLegacySoraS3Settings(ctx) + if legacyErr != nil { + return nil, legacyErr + } + if isEmptyLegacySoraS3Settings(legacy) { + return &soraS3ProfilesStore{}, nil + } + + now := time.Now().UTC().Format(time.RFC3339) + return &soraS3ProfilesStore{ + ActiveProfileID: "default", + Items: []soraS3ProfileStoreItem{ + { + ProfileID: "default", + Name: "Default", + Enabled: legacy.Enabled, + Endpoint: strings.TrimSpace(legacy.Endpoint), + Region: strings.TrimSpace(legacy.Region), + Bucket: strings.TrimSpace(legacy.Bucket), + AccessKeyID: strings.TrimSpace(legacy.AccessKeyID), + SecretAccessKey: legacy.SecretAccessKey, + Prefix: strings.TrimSpace(legacy.Prefix), + ForcePathStyle: legacy.ForcePathStyle, + CDNURL: strings.TrimSpace(legacy.CDNURL), + DefaultStorageQuotaBytes: maxInt64(legacy.DefaultStorageQuotaBytes, 0), + UpdatedAt: now, + }, + }, + }, nil +} + +func (s *SettingService) persistSoraS3ProfilesStore(ctx context.Context, store *soraS3ProfilesStore) error { + if store == nil { + return fmt.Errorf("sora s3 profiles store cannot be nil") + } + + normalized := normalizeSoraS3ProfilesStore(*store) + data, err := json.Marshal(normalized) + if err != nil { + return fmt.Errorf("marshal sora s3 profiles: %w", err) + } + + updates := map[string]string{ + SettingKeySoraS3Profiles: string(data), + } + + active := pickActiveSoraS3ProfileFromStore(normalized.Items, normalized.ActiveProfileID) + if active == nil { + updates[SettingKeySoraS3Enabled] = "false" + updates[SettingKeySoraS3Endpoint] = "" + updates[SettingKeySoraS3Region] = "" + updates[SettingKeySoraS3Bucket] = "" + updates[SettingKeySoraS3AccessKeyID] = "" + updates[SettingKeySoraS3Prefix] = "" + updates[SettingKeySoraS3ForcePathStyle] = "false" + updates[SettingKeySoraS3CDNURL] = "" + updates[SettingKeySoraDefaultStorageQuotaBytes] = "0" + updates[SettingKeySoraS3SecretAccessKey] = "" + } else { + updates[SettingKeySoraS3Enabled] = strconv.FormatBool(active.Enabled) + updates[SettingKeySoraS3Endpoint] = strings.TrimSpace(active.Endpoint) + updates[SettingKeySoraS3Region] = strings.TrimSpace(active.Region) + updates[SettingKeySoraS3Bucket] = strings.TrimSpace(active.Bucket) + updates[SettingKeySoraS3AccessKeyID] = strings.TrimSpace(active.AccessKeyID) + updates[SettingKeySoraS3Prefix] = strings.TrimSpace(active.Prefix) + updates[SettingKeySoraS3ForcePathStyle] = strconv.FormatBool(active.ForcePathStyle) + updates[SettingKeySoraS3CDNURL] = strings.TrimSpace(active.CDNURL) + updates[SettingKeySoraDefaultStorageQuotaBytes] = strconv.FormatInt(maxInt64(active.DefaultStorageQuotaBytes, 0), 10) + updates[SettingKeySoraS3SecretAccessKey] = active.SecretAccessKey + } + + if err := s.settingRepo.SetMultiple(ctx, updates); err != nil { + return err + } + + if s.onUpdate != nil { + s.onUpdate() + } + if s.onS3Update != nil { + s.onS3Update() + } + return nil +} + +func (s *SettingService) getLegacySoraS3Settings(ctx context.Context) (*SoraS3Settings, error) { + keys := []string{ + SettingKeySoraS3Enabled, + SettingKeySoraS3Endpoint, + SettingKeySoraS3Region, + SettingKeySoraS3Bucket, + SettingKeySoraS3AccessKeyID, + SettingKeySoraS3SecretAccessKey, + SettingKeySoraS3Prefix, + SettingKeySoraS3ForcePathStyle, + SettingKeySoraS3CDNURL, + SettingKeySoraDefaultStorageQuotaBytes, + } + + values, err := s.settingRepo.GetMultiple(ctx, keys) + if err != nil { + return nil, fmt.Errorf("get legacy sora s3 settings: %w", err) + } + + result := &SoraS3Settings{ + Enabled: values[SettingKeySoraS3Enabled] == "true", + Endpoint: values[SettingKeySoraS3Endpoint], + Region: values[SettingKeySoraS3Region], + Bucket: values[SettingKeySoraS3Bucket], + AccessKeyID: values[SettingKeySoraS3AccessKeyID], + SecretAccessKey: values[SettingKeySoraS3SecretAccessKey], + SecretAccessKeyConfigured: values[SettingKeySoraS3SecretAccessKey] != "", + Prefix: values[SettingKeySoraS3Prefix], + ForcePathStyle: values[SettingKeySoraS3ForcePathStyle] == "true", + CDNURL: values[SettingKeySoraS3CDNURL], + } + if v, parseErr := strconv.ParseInt(values[SettingKeySoraDefaultStorageQuotaBytes], 10, 64); parseErr == nil { + result.DefaultStorageQuotaBytes = v + } + return result, nil +} + +func normalizeSoraS3ProfilesStore(store soraS3ProfilesStore) soraS3ProfilesStore { + seen := make(map[string]struct{}, len(store.Items)) + normalized := soraS3ProfilesStore{ + ActiveProfileID: strings.TrimSpace(store.ActiveProfileID), + Items: make([]soraS3ProfileStoreItem, 0, len(store.Items)), + } + now := time.Now().UTC().Format(time.RFC3339) + + for idx := range store.Items { + item := store.Items[idx] + item.ProfileID = strings.TrimSpace(item.ProfileID) + if item.ProfileID == "" { + item.ProfileID = fmt.Sprintf("profile-%d", idx+1) + } + if _, exists := seen[item.ProfileID]; exists { + continue + } + seen[item.ProfileID] = struct{}{} + + item.Name = strings.TrimSpace(item.Name) + if item.Name == "" { + item.Name = item.ProfileID + } + item.Endpoint = strings.TrimSpace(item.Endpoint) + item.Region = strings.TrimSpace(item.Region) + item.Bucket = strings.TrimSpace(item.Bucket) + item.AccessKeyID = strings.TrimSpace(item.AccessKeyID) + item.Prefix = strings.TrimSpace(item.Prefix) + item.CDNURL = strings.TrimSpace(item.CDNURL) + item.DefaultStorageQuotaBytes = maxInt64(item.DefaultStorageQuotaBytes, 0) + item.UpdatedAt = strings.TrimSpace(item.UpdatedAt) + if item.UpdatedAt == "" { + item.UpdatedAt = now + } + normalized.Items = append(normalized.Items, item) + } + + if len(normalized.Items) == 0 { + normalized.ActiveProfileID = "" + return normalized + } + + if findSoraS3ProfileIndex(normalized.Items, normalized.ActiveProfileID) >= 0 { + return normalized + } + + normalized.ActiveProfileID = normalized.Items[0].ProfileID + return normalized +} + +func convertSoraS3ProfilesStore(store *soraS3ProfilesStore) *SoraS3ProfileList { + if store == nil { + return &SoraS3ProfileList{} + } + items := make([]SoraS3Profile, 0, len(store.Items)) + for idx := range store.Items { + item := store.Items[idx] + items = append(items, SoraS3Profile{ + ProfileID: item.ProfileID, + Name: item.Name, + IsActive: item.ProfileID == store.ActiveProfileID, + Enabled: item.Enabled, + Endpoint: item.Endpoint, + Region: item.Region, + Bucket: item.Bucket, + AccessKeyID: item.AccessKeyID, + SecretAccessKey: item.SecretAccessKey, + SecretAccessKeyConfigured: item.SecretAccessKey != "", + Prefix: item.Prefix, + ForcePathStyle: item.ForcePathStyle, + CDNURL: item.CDNURL, + DefaultStorageQuotaBytes: item.DefaultStorageQuotaBytes, + UpdatedAt: item.UpdatedAt, + }) + } + return &SoraS3ProfileList{ + ActiveProfileID: store.ActiveProfileID, + Items: items, + } +} + +func pickActiveSoraS3Profile(items []SoraS3Profile, activeProfileID string) *SoraS3Profile { + for idx := range items { + if items[idx].ProfileID == activeProfileID { + return &items[idx] + } + } + if len(items) == 0 { + return nil + } + return &items[0] +} + +func findSoraS3ProfileByID(items []SoraS3Profile, profileID string) *SoraS3Profile { + for idx := range items { + if items[idx].ProfileID == profileID { + return &items[idx] + } + } + return nil +} + +func pickActiveSoraS3ProfileFromStore(items []soraS3ProfileStoreItem, activeProfileID string) *soraS3ProfileStoreItem { + for idx := range items { + if items[idx].ProfileID == activeProfileID { + return &items[idx] + } + } + if len(items) == 0 { + return nil + } + return &items[0] +} + +func findSoraS3ProfileIndex(items []soraS3ProfileStoreItem, profileID string) int { + for idx := range items { + if items[idx].ProfileID == profileID { + return idx + } + } + return -1 +} + +func hasSoraS3ProfileID(items []soraS3ProfileStoreItem, profileID string) bool { + return findSoraS3ProfileIndex(items, profileID) >= 0 +} + +func isEmptyLegacySoraS3Settings(settings *SoraS3Settings) bool { + if settings == nil { + return true + } + if settings.Enabled { + return false + } + if strings.TrimSpace(settings.Endpoint) != "" { + return false + } + if strings.TrimSpace(settings.Region) != "" { + return false + } + if strings.TrimSpace(settings.Bucket) != "" { + return false + } + if strings.TrimSpace(settings.AccessKeyID) != "" { + return false + } + if settings.SecretAccessKey != "" { + return false + } + if strings.TrimSpace(settings.Prefix) != "" { + return false + } + if strings.TrimSpace(settings.CDNURL) != "" { + return false + } + return settings.DefaultStorageQuotaBytes == 0 +} + +func maxInt64(value int64, min int64) int64 { + if value < min { + return min + } + return value +} diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index 0c7bab67..74166926 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -39,6 +39,7 @@ type SystemSettings struct { HideCcsImportButton bool PurchaseSubscriptionEnabled bool PurchaseSubscriptionURL string + SoraClientEnabled bool DefaultConcurrency int DefaultBalance float64 @@ -81,11 +82,52 @@ type PublicSettings struct { PurchaseSubscriptionEnabled bool PurchaseSubscriptionURL string + SoraClientEnabled bool LinuxDoOAuthEnabled bool Version string } +// SoraS3Settings Sora S3 存储配置 +type SoraS3Settings struct { + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"secret_access_key"` // 仅内部使用,不直接返回前端 + SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用 + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + CDNURL string `json:"cdn_url"` + DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` +} + +// SoraS3Profile Sora S3 多配置项(服务内部模型) +type SoraS3Profile struct { + ProfileID string `json:"profile_id"` + Name string `json:"name"` + IsActive bool `json:"is_active"` + Enabled bool `json:"enabled"` + Endpoint string `json:"endpoint"` + Region string `json:"region"` + Bucket string `json:"bucket"` + AccessKeyID string `json:"access_key_id"` + SecretAccessKey string `json:"-"` // 仅内部使用,不直接返回前端 + SecretAccessKeyConfigured bool `json:"secret_access_key_configured"` // 前端展示用 + Prefix string `json:"prefix"` + ForcePathStyle bool `json:"force_path_style"` + CDNURL string `json:"cdn_url"` + DefaultStorageQuotaBytes int64 `json:"default_storage_quota_bytes"` + UpdatedAt string `json:"updated_at"` +} + +// SoraS3ProfileList Sora S3 多配置列表 +type SoraS3ProfileList struct { + ActiveProfileID string `json:"active_profile_id"` + Items []SoraS3Profile `json:"items"` +} + // StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制) type StreamTimeoutSettings struct { // Enabled 是否启用流超时处理 diff --git a/backend/internal/service/sora_client.go b/backend/internal/service/sora_client.go index 4680538c..0a914d2d 100644 --- a/backend/internal/service/sora_client.go +++ b/backend/internal/service/sora_client.go @@ -43,6 +43,7 @@ type SoraVideoRequest struct { Frames int Model string Size string + VideoCount int MediaID string RemixTargetID string CameoIDs []string diff --git a/backend/internal/service/sora_gateway_service.go b/backend/internal/service/sora_gateway_service.go index b8241eef..ab6871bb 100644 --- a/backend/internal/service/sora_gateway_service.go +++ b/backend/internal/service/sora_gateway_service.go @@ -21,6 +21,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" "github.com/gin-gonic/gin" ) @@ -63,8 +64,8 @@ var soraBlockedCIDRs = mustParseCIDRs([]string{ // SoraGatewayService handles forwarding requests to Sora upstream. type SoraGatewayService struct { soraClient SoraClient - mediaStorage *SoraMediaStorage rateLimitService *RateLimitService + httpUpstream HTTPUpstream // 用于 apikey 类型账号的 HTTP 透传 cfg *config.Config } @@ -100,14 +101,14 @@ type soraPreflightChecker interface { func NewSoraGatewayService( soraClient SoraClient, - mediaStorage *SoraMediaStorage, rateLimitService *RateLimitService, + httpUpstream HTTPUpstream, cfg *config.Config, ) *SoraGatewayService { return &SoraGatewayService{ soraClient: soraClient, - mediaStorage: mediaStorage, rateLimitService: rateLimitService, + httpUpstream: httpUpstream, cfg: cfg, } } @@ -115,6 +116,15 @@ func NewSoraGatewayService( func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte, clientStream bool) (*ForwardResult, error) { startTime := time.Now() + // apikey 类型账号:HTTP 透传到上游,不走 SoraSDKClient + if account.Type == AccountTypeAPIKey && account.GetBaseURL() != "" { + if s.httpUpstream == nil { + s.writeSoraError(c, http.StatusInternalServerError, "api_error", "HTTP upstream client not configured", clientStream) + return nil, errors.New("httpUpstream not configured for sora apikey forwarding") + } + return s.forwardToUpstream(ctx, c, account, body, clientStream, startTime) + } + if s.soraClient == nil || !s.soraClient.Enabled() { if c != nil { c.JSON(http.StatusServiceUnavailable, gin.H{ @@ -296,6 +306,7 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun taskID := "" var err error + videoCount := parseSoraVideoCount(reqBody) switch modelCfg.Type { case "image": taskID, err = s.soraClient.CreateImageTask(reqCtx, account, SoraImageRequest{ @@ -321,6 +332,7 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun Frames: modelCfg.Frames, Model: modelCfg.Model, Size: modelCfg.Size, + VideoCount: videoCount, MediaID: mediaID, RemixTargetID: remixTargetID, CameoIDs: extractSoraCameoIDs(reqBody), @@ -378,16 +390,9 @@ func (s *SoraGatewayService) Forward(ctx context.Context, c *gin.Context, accoun } } + // 直调路径(/sora/v1/chat/completions)保持纯透传,不执行本地/S3 媒体落盘。 + // 媒体存储由客户端 API 路径(/api/v1/sora/generate)的异步流程负责。 finalURLs := s.normalizeSoraMediaURLs(mediaURLs) - if len(mediaURLs) > 0 && s.mediaStorage != nil && s.mediaStorage.Enabled() { - stored, storeErr := s.mediaStorage.StoreFromURLs(reqCtx, mediaType, mediaURLs) - if storeErr != nil { - // 存储失败时降级使用原始 URL,不中断用户请求 - log.Printf("[Sora] StoreFromURLs failed, falling back to original URLs: %v", storeErr) - } else { - finalURLs = s.normalizeSoraMediaURLs(stored) - } - } if watermarkPostID != "" && watermarkOpts.DeletePost { if deleteErr := s.soraClient.DeletePost(reqCtx, account, watermarkPostID); deleteErr != nil { log.Printf("[Sora] delete post failed, post_id=%s err=%v", watermarkPostID, deleteErr) @@ -463,6 +468,20 @@ func parseSoraCharacterOptions(body map[string]any) soraCharacterOptions { } } +func parseSoraVideoCount(body map[string]any) int { + if body == nil { + return 1 + } + keys := []string{"video_count", "videos", "n_variants"} + for _, key := range keys { + count := parseIntWithDefault(body, key, 0) + if count > 0 { + return clampInt(count, 1, 3) + } + } + return 1 +} + func parseBoolWithDefault(body map[string]any, key string, def bool) bool { if body == nil { return def @@ -508,6 +527,42 @@ func parseStringWithDefault(body map[string]any, key, def string) string { return def } +func parseIntWithDefault(body map[string]any, key string, def int) int { + if body == nil { + return def + } + val, ok := body[key] + if !ok { + return def + } + switch typed := val.(type) { + case int: + return typed + case int32: + return int(typed) + case int64: + return int(typed) + case float64: + return int(typed) + case string: + parsed, err := strconv.Atoi(strings.TrimSpace(typed)) + if err == nil { + return parsed + } + } + return def +} + +func clampInt(v, minVal, maxVal int) int { + if v < minVal { + return minVal + } + if v > maxVal { + return maxVal + } + return v +} + func extractSoraCameoIDs(body map[string]any) []string { if body == nil { return nil @@ -904,6 +959,21 @@ func (s *SoraGatewayService) handleSoraRequestError(ctx context.Context, account } var upstreamErr *SoraUpstreamError if errors.As(err, &upstreamErr) { + accountID := int64(0) + if account != nil { + accountID = account.ID + } + logger.LegacyPrintf( + "service.sora", + "[SoraRawError] account_id=%d model=%s status=%d request_id=%s cf_ray=%s message=%s raw_body=%s", + accountID, + model, + upstreamErr.StatusCode, + strings.TrimSpace(upstreamErr.Headers.Get("x-request-id")), + strings.TrimSpace(upstreamErr.Headers.Get("cf-ray")), + strings.TrimSpace(upstreamErr.Message), + truncateForLog(upstreamErr.Body, 1024), + ) if s.rateLimitService != nil && account != nil { s.rateLimitService.HandleUpstreamError(ctx, account, upstreamErr.StatusCode, upstreamErr.Headers, upstreamErr.Body) } diff --git a/backend/internal/service/sora_gateway_service_test.go b/backend/internal/service/sora_gateway_service_test.go index 5888fe92..206636ff 100644 --- a/backend/internal/service/sora_gateway_service_test.go +++ b/backend/internal/service/sora_gateway_service_test.go @@ -179,6 +179,31 @@ func TestSoraGatewayService_ForwardStoryboardPrompt(t *testing.T) { require.True(t, client.storyboard) } +func TestSoraGatewayService_ForwardVideoCount(t *testing.T) { + client := &stubSoraClientForPoll{ + videoStatus: &SoraVideoTaskStatus{ + Status: "completed", + URLs: []string{"https://example.com/v.mp4"}, + }, + } + cfg := &config.Config{ + Sora: config.SoraConfig{ + Client: config.SoraClientConfig{ + PollIntervalSeconds: 1, + MaxPollAttempts: 1, + }, + }, + } + svc := NewSoraGatewayService(client, nil, nil, cfg) + account := &Account{ID: 1, Platform: PlatformSora, Status: StatusActive} + body := []byte(`{"model":"sora2-landscape-10s","messages":[{"role":"user","content":"cat running"}],"video_count":3,"stream":false}`) + + result, err := svc.Forward(context.Background(), nil, account, body, false) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 3, client.videoReq.VideoCount) +} + func TestSoraGatewayService_ForwardCharacterOnly(t *testing.T) { client := &stubSoraClientForPoll{} cfg := &config.Config{ @@ -524,3 +549,10 @@ func TestParseSoraWatermarkOptions_NumericBool(t *testing.T) { require.True(t, opts.Enabled) require.False(t, opts.FallbackOnFailure) } + +func TestParseSoraVideoCount(t *testing.T) { + require.Equal(t, 1, parseSoraVideoCount(nil)) + require.Equal(t, 2, parseSoraVideoCount(map[string]any{"video_count": float64(2)})) + require.Equal(t, 3, parseSoraVideoCount(map[string]any{"videos": "5"})) + require.Equal(t, 1, parseSoraVideoCount(map[string]any{"n_variants": 0})) +} diff --git a/backend/internal/service/sora_generation.go b/backend/internal/service/sora_generation.go new file mode 100644 index 00000000..a704454b --- /dev/null +++ b/backend/internal/service/sora_generation.go @@ -0,0 +1,63 @@ +package service + +import ( + "context" + "time" +) + +// SoraGeneration 代表一条 Sora 客户端生成记录。 +type SoraGeneration struct { + ID int64 `json:"id"` + UserID int64 `json:"user_id"` + APIKeyID *int64 `json:"api_key_id,omitempty"` + Model string `json:"model"` + Prompt string `json:"prompt"` + MediaType string `json:"media_type"` // video / image + Status string `json:"status"` // pending / generating / completed / failed / cancelled + MediaURL string `json:"media_url"` // 主媒体 URL(预签名或 CDN) + MediaURLs []string `json:"media_urls"` // 多图时的 URL 数组 + FileSizeBytes int64 `json:"file_size_bytes"` + StorageType string `json:"storage_type"` // s3 / local / upstream / none + S3ObjectKeys []string `json:"s3_object_keys"` // S3 object key 数组 + UpstreamTaskID string `json:"upstream_task_id"` + ErrorMessage string `json:"error_message"` + CreatedAt time.Time `json:"created_at"` + CompletedAt *time.Time `json:"completed_at,omitempty"` +} + +// Sora 生成记录状态常量 +const ( + SoraGenStatusPending = "pending" + SoraGenStatusGenerating = "generating" + SoraGenStatusCompleted = "completed" + SoraGenStatusFailed = "failed" + SoraGenStatusCancelled = "cancelled" +) + +// Sora 存储类型常量 +const ( + SoraStorageTypeS3 = "s3" + SoraStorageTypeLocal = "local" + SoraStorageTypeUpstream = "upstream" + SoraStorageTypeNone = "none" +) + +// SoraGenerationListParams 查询生成记录的参数。 +type SoraGenerationListParams struct { + UserID int64 + Status string // 可选筛选 + StorageType string // 可选筛选 + MediaType string // 可选筛选 + Page int + PageSize int +} + +// SoraGenerationRepository 生成记录持久化接口。 +type SoraGenerationRepository interface { + Create(ctx context.Context, gen *SoraGeneration) error + GetByID(ctx context.Context, id int64) (*SoraGeneration, error) + Update(ctx context.Context, gen *SoraGeneration) error + Delete(ctx context.Context, id int64) error + List(ctx context.Context, params SoraGenerationListParams) ([]*SoraGeneration, int64, error) + CountByUserAndStatus(ctx context.Context, userID int64, statuses []string) (int64, error) +} diff --git a/backend/internal/service/sora_generation_service.go b/backend/internal/service/sora_generation_service.go new file mode 100644 index 00000000..22d5b519 --- /dev/null +++ b/backend/internal/service/sora_generation_service.go @@ -0,0 +1,332 @@ +package service + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +var ( + // ErrSoraGenerationConcurrencyLimit 表示用户进行中的任务数超限。 + ErrSoraGenerationConcurrencyLimit = errors.New("sora generation concurrent limit exceeded") + // ErrSoraGenerationStateConflict 表示状态已发生变化(例如任务已取消)。 + ErrSoraGenerationStateConflict = errors.New("sora generation state conflict") + // ErrSoraGenerationNotActive 表示任务不在可取消状态。 + ErrSoraGenerationNotActive = errors.New("sora generation is not active") +) + +const soraGenerationActiveLimit = 3 + +type soraGenerationRepoAtomicCreator interface { + CreatePendingWithLimit(ctx context.Context, gen *SoraGeneration, activeStatuses []string, maxActive int64) error +} + +type soraGenerationRepoConditionalUpdater interface { + UpdateGeneratingIfPending(ctx context.Context, id int64, upstreamTaskID string) (bool, error) + UpdateCompletedIfActive(ctx context.Context, id int64, mediaURL string, mediaURLs []string, storageType string, s3Keys []string, fileSizeBytes int64, completedAt time.Time) (bool, error) + UpdateFailedIfActive(ctx context.Context, id int64, errMsg string, completedAt time.Time) (bool, error) + UpdateCancelledIfActive(ctx context.Context, id int64, completedAt time.Time) (bool, error) + UpdateStorageIfCompleted(ctx context.Context, id int64, mediaURL string, mediaURLs []string, storageType string, s3Keys []string, fileSizeBytes int64) (bool, error) +} + +// SoraGenerationService 管理 Sora 客户端的生成记录 CRUD。 +type SoraGenerationService struct { + genRepo SoraGenerationRepository + s3Storage *SoraS3Storage + quotaService *SoraQuotaService +} + +// NewSoraGenerationService 创建生成记录服务。 +func NewSoraGenerationService( + genRepo SoraGenerationRepository, + s3Storage *SoraS3Storage, + quotaService *SoraQuotaService, +) *SoraGenerationService { + return &SoraGenerationService{ + genRepo: genRepo, + s3Storage: s3Storage, + quotaService: quotaService, + } +} + +// CreatePending 创建一条 pending 状态的生成记录。 +func (s *SoraGenerationService) CreatePending(ctx context.Context, userID int64, apiKeyID *int64, model, prompt, mediaType string) (*SoraGeneration, error) { + gen := &SoraGeneration{ + UserID: userID, + APIKeyID: apiKeyID, + Model: model, + Prompt: prompt, + MediaType: mediaType, + Status: SoraGenStatusPending, + StorageType: SoraStorageTypeNone, + } + if atomicCreator, ok := s.genRepo.(soraGenerationRepoAtomicCreator); ok { + if err := atomicCreator.CreatePendingWithLimit( + ctx, + gen, + []string{SoraGenStatusPending, SoraGenStatusGenerating}, + soraGenerationActiveLimit, + ); err != nil { + if errors.Is(err, ErrSoraGenerationConcurrencyLimit) { + return nil, err + } + return nil, fmt.Errorf("create generation: %w", err) + } + logger.LegacyPrintf("service.sora_gen", "[SoraGen] 创建记录 id=%d user=%d model=%s", gen.ID, userID, model) + return gen, nil + } + + if err := s.genRepo.Create(ctx, gen); err != nil { + return nil, fmt.Errorf("create generation: %w", err) + } + logger.LegacyPrintf("service.sora_gen", "[SoraGen] 创建记录 id=%d user=%d model=%s", gen.ID, userID, model) + return gen, nil +} + +// MarkGenerating 标记为生成中。 +func (s *SoraGenerationService) MarkGenerating(ctx context.Context, id int64, upstreamTaskID string) error { + if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok { + updated, err := updater.UpdateGeneratingIfPending(ctx, id, upstreamTaskID) + if err != nil { + return err + } + if !updated { + return ErrSoraGenerationStateConflict + } + return nil + } + + gen, err := s.genRepo.GetByID(ctx, id) + if err != nil { + return err + } + if gen.Status != SoraGenStatusPending { + return ErrSoraGenerationStateConflict + } + gen.Status = SoraGenStatusGenerating + gen.UpstreamTaskID = upstreamTaskID + return s.genRepo.Update(ctx, gen) +} + +// MarkCompleted 标记为已完成。 +func (s *SoraGenerationService) MarkCompleted(ctx context.Context, id int64, mediaURL string, mediaURLs []string, storageType string, s3Keys []string, fileSizeBytes int64) error { + now := time.Now() + if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok { + updated, err := updater.UpdateCompletedIfActive(ctx, id, mediaURL, mediaURLs, storageType, s3Keys, fileSizeBytes, now) + if err != nil { + return err + } + if !updated { + return ErrSoraGenerationStateConflict + } + return nil + } + + gen, err := s.genRepo.GetByID(ctx, id) + if err != nil { + return err + } + if gen.Status != SoraGenStatusPending && gen.Status != SoraGenStatusGenerating { + return ErrSoraGenerationStateConflict + } + gen.Status = SoraGenStatusCompleted + gen.MediaURL = mediaURL + gen.MediaURLs = mediaURLs + gen.StorageType = storageType + gen.S3ObjectKeys = s3Keys + gen.FileSizeBytes = fileSizeBytes + gen.CompletedAt = &now + return s.genRepo.Update(ctx, gen) +} + +// MarkFailed 标记为失败。 +func (s *SoraGenerationService) MarkFailed(ctx context.Context, id int64, errMsg string) error { + now := time.Now() + if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok { + updated, err := updater.UpdateFailedIfActive(ctx, id, errMsg, now) + if err != nil { + return err + } + if !updated { + return ErrSoraGenerationStateConflict + } + return nil + } + + gen, err := s.genRepo.GetByID(ctx, id) + if err != nil { + return err + } + if gen.Status != SoraGenStatusPending && gen.Status != SoraGenStatusGenerating { + return ErrSoraGenerationStateConflict + } + gen.Status = SoraGenStatusFailed + gen.ErrorMessage = errMsg + gen.CompletedAt = &now + return s.genRepo.Update(ctx, gen) +} + +// MarkCancelled 标记为已取消。 +func (s *SoraGenerationService) MarkCancelled(ctx context.Context, id int64) error { + now := time.Now() + if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok { + updated, err := updater.UpdateCancelledIfActive(ctx, id, now) + if err != nil { + return err + } + if !updated { + return ErrSoraGenerationNotActive + } + return nil + } + + gen, err := s.genRepo.GetByID(ctx, id) + if err != nil { + return err + } + if gen.Status != SoraGenStatusPending && gen.Status != SoraGenStatusGenerating { + return ErrSoraGenerationNotActive + } + gen.Status = SoraGenStatusCancelled + gen.CompletedAt = &now + return s.genRepo.Update(ctx, gen) +} + +// UpdateStorageForCompleted 更新已完成记录的存储信息(不重置 completed_at)。 +func (s *SoraGenerationService) UpdateStorageForCompleted( + ctx context.Context, + id int64, + mediaURL string, + mediaURLs []string, + storageType string, + s3Keys []string, + fileSizeBytes int64, +) error { + if updater, ok := s.genRepo.(soraGenerationRepoConditionalUpdater); ok { + updated, err := updater.UpdateStorageIfCompleted(ctx, id, mediaURL, mediaURLs, storageType, s3Keys, fileSizeBytes) + if err != nil { + return err + } + if !updated { + return ErrSoraGenerationStateConflict + } + return nil + } + + gen, err := s.genRepo.GetByID(ctx, id) + if err != nil { + return err + } + if gen.Status != SoraGenStatusCompleted { + return ErrSoraGenerationStateConflict + } + gen.MediaURL = mediaURL + gen.MediaURLs = mediaURLs + gen.StorageType = storageType + gen.S3ObjectKeys = s3Keys + gen.FileSizeBytes = fileSizeBytes + return s.genRepo.Update(ctx, gen) +} + +// GetByID 获取记录详情(含权限校验)。 +func (s *SoraGenerationService) GetByID(ctx context.Context, id, userID int64) (*SoraGeneration, error) { + gen, err := s.genRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + if gen.UserID != userID { + return nil, fmt.Errorf("无权访问此生成记录") + } + return gen, nil +} + +// List 查询生成记录列表(分页 + 筛选)。 +func (s *SoraGenerationService) List(ctx context.Context, params SoraGenerationListParams) ([]*SoraGeneration, int64, error) { + if params.Page <= 0 { + params.Page = 1 + } + if params.PageSize <= 0 { + params.PageSize = 20 + } + if params.PageSize > 100 { + params.PageSize = 100 + } + return s.genRepo.List(ctx, params) +} + +// Delete 删除记录(联动 S3/本地文件清理 + 配额释放)。 +func (s *SoraGenerationService) Delete(ctx context.Context, id, userID int64) error { + gen, err := s.genRepo.GetByID(ctx, id) + if err != nil { + return err + } + if gen.UserID != userID { + return fmt.Errorf("无权删除此生成记录") + } + + // 清理 S3 文件 + if gen.StorageType == SoraStorageTypeS3 && len(gen.S3ObjectKeys) > 0 && s.s3Storage != nil { + if err := s.s3Storage.DeleteObjects(ctx, gen.S3ObjectKeys); err != nil { + logger.LegacyPrintf("service.sora_gen", "[SoraGen] S3 清理失败 id=%d err=%v", id, err) + } + } + + // 释放配额(S3/本地均释放) + if gen.FileSizeBytes > 0 && (gen.StorageType == SoraStorageTypeS3 || gen.StorageType == SoraStorageTypeLocal) && s.quotaService != nil { + if err := s.quotaService.ReleaseUsage(ctx, userID, gen.FileSizeBytes); err != nil { + logger.LegacyPrintf("service.sora_gen", "[SoraGen] 配额释放失败 id=%d err=%v", id, err) + } + } + + return s.genRepo.Delete(ctx, id) +} + +// CountActiveByUser 统计用户进行中的任务数(用于并发限制)。 +func (s *SoraGenerationService) CountActiveByUser(ctx context.Context, userID int64) (int64, error) { + return s.genRepo.CountByUserAndStatus(ctx, userID, []string{SoraGenStatusPending, SoraGenStatusGenerating}) +} + +// ResolveMediaURLs 为 S3 记录动态生成预签名 URL。 +func (s *SoraGenerationService) ResolveMediaURLs(ctx context.Context, gen *SoraGeneration) error { + if gen == nil || gen.StorageType != SoraStorageTypeS3 || s.s3Storage == nil { + return nil + } + if len(gen.S3ObjectKeys) == 0 { + return nil + } + + urls := make([]string, len(gen.S3ObjectKeys)) + var wg sync.WaitGroup + var firstErr error + var errMu sync.Mutex + + for idx, key := range gen.S3ObjectKeys { + wg.Add(1) + go func(i int, objectKey string) { + defer wg.Done() + url, err := s.s3Storage.GetAccessURL(ctx, objectKey) + if err != nil { + errMu.Lock() + if firstErr == nil { + firstErr = err + } + errMu.Unlock() + return + } + urls[i] = url + }(idx, key) + } + wg.Wait() + if firstErr != nil { + return firstErr + } + + gen.MediaURL = urls[0] + gen.MediaURLs = urls + + return nil +} diff --git a/backend/internal/service/sora_generation_service_test.go b/backend/internal/service/sora_generation_service_test.go new file mode 100644 index 00000000..820945f0 --- /dev/null +++ b/backend/internal/service/sora_generation_service_test.go @@ -0,0 +1,875 @@ +//go:build unit + +package service + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/stretchr/testify/require" +) + +// ==================== Stub: SoraGenerationRepository ==================== + +var _ SoraGenerationRepository = (*stubGenRepo)(nil) + +type stubGenRepo struct { + gens map[int64]*SoraGeneration + nextID int64 + createErr error + getErr error + updateErr error + deleteErr error + listErr error + countErr error + countValue int64 +} + +func newStubGenRepo() *stubGenRepo { + return &stubGenRepo{gens: make(map[int64]*SoraGeneration), nextID: 1} +} + +func (r *stubGenRepo) Create(_ context.Context, gen *SoraGeneration) error { + if r.createErr != nil { + return r.createErr + } + gen.ID = r.nextID + gen.CreatedAt = time.Now() + r.nextID++ + r.gens[gen.ID] = gen + return nil +} + +func (r *stubGenRepo) GetByID(_ context.Context, id int64) (*SoraGeneration, error) { + if r.getErr != nil { + return nil, r.getErr + } + if gen, ok := r.gens[id]; ok { + return gen, nil + } + return nil, fmt.Errorf("not found") +} + +func (r *stubGenRepo) Update(_ context.Context, gen *SoraGeneration) error { + if r.updateErr != nil { + return r.updateErr + } + r.gens[gen.ID] = gen + return nil +} + +func (r *stubGenRepo) Delete(_ context.Context, id int64) error { + if r.deleteErr != nil { + return r.deleteErr + } + delete(r.gens, id) + return nil +} + +func (r *stubGenRepo) List(_ context.Context, params SoraGenerationListParams) ([]*SoraGeneration, int64, error) { + if r.listErr != nil { + return nil, 0, r.listErr + } + var result []*SoraGeneration + for _, gen := range r.gens { + if gen.UserID != params.UserID { + continue + } + if params.Status != "" && gen.Status != params.Status { + continue + } + if params.StorageType != "" && gen.StorageType != params.StorageType { + continue + } + if params.MediaType != "" && gen.MediaType != params.MediaType { + continue + } + result = append(result, gen) + } + return result, int64(len(result)), nil +} + +func (r *stubGenRepo) CountByUserAndStatus(_ context.Context, userID int64, statuses []string) (int64, error) { + if r.countErr != nil { + return 0, r.countErr + } + if r.countValue > 0 { + return r.countValue, nil + } + var count int64 + statusSet := make(map[string]struct{}) + for _, s := range statuses { + statusSet[s] = struct{}{} + } + for _, gen := range r.gens { + if gen.UserID == userID { + if _, ok := statusSet[gen.Status]; ok { + count++ + } + } + } + return count, nil +} + +// ==================== Stub: UserRepository (用于 SoraQuotaService) ==================== + +var _ UserRepository = (*stubUserRepoForQuota)(nil) + +type stubUserRepoForQuota struct { + users map[int64]*User + updateErr error +} + +func newStubUserRepoForQuota() *stubUserRepoForQuota { + return &stubUserRepoForQuota{users: make(map[int64]*User)} +} + +func (r *stubUserRepoForQuota) GetByID(_ context.Context, id int64) (*User, error) { + if u, ok := r.users[id]; ok { + return u, nil + } + return nil, fmt.Errorf("user not found") +} +func (r *stubUserRepoForQuota) Update(_ context.Context, user *User) error { + if r.updateErr != nil { + return r.updateErr + } + r.users[user.ID] = user + return nil +} +func (r *stubUserRepoForQuota) Create(context.Context, *User) error { return nil } +func (r *stubUserRepoForQuota) GetByEmail(context.Context, string) (*User, error) { + return nil, nil +} +func (r *stubUserRepoForQuota) GetFirstAdmin(context.Context) (*User, error) { return nil, nil } +func (r *stubUserRepoForQuota) Delete(context.Context, int64) error { return nil } +func (r *stubUserRepoForQuota) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubUserRepoForQuota) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubUserRepoForQuota) UpdateBalance(context.Context, int64, float64) error { return nil } +func (r *stubUserRepoForQuota) DeductBalance(context.Context, int64, float64) error { return nil } +func (r *stubUserRepoForQuota) UpdateConcurrency(context.Context, int64, int) error { return nil } +func (r *stubUserRepoForQuota) ExistsByEmail(context.Context, string) (bool, error) { + return false, nil +} +func (r *stubUserRepoForQuota) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) { + return 0, nil +} +func (r *stubUserRepoForQuota) UpdateTotpSecret(context.Context, int64, *string) error { return nil } +func (r *stubUserRepoForQuota) EnableTotp(context.Context, int64) error { return nil } +func (r *stubUserRepoForQuota) DisableTotp(context.Context, int64) error { return nil } + +// ==================== 辅助函数:构造带 CDN 缓存的 SoraS3Storage ==================== + +// newS3StorageWithCDN 创建一个预缓存了 CDN 配置的 SoraS3Storage, +// 避免实际初始化 AWS 客户端。用于测试 GetAccessURL 的 CDN 路径。 +func newS3StorageWithCDN(cdnURL string) *SoraS3Storage { + storage := &SoraS3Storage{} + storage.cfg = &SoraS3Settings{ + Enabled: true, + Bucket: "test-bucket", + CDNURL: cdnURL, + } + // 需要 non-nil client 使 getClient 命中缓存 + storage.client = s3.New(s3.Options{}) + return storage +} + +// newS3StorageFailingDelete 创建一个 settingService=nil 的 SoraS3Storage, +// 使 DeleteObjects 返回错误(无法获取配置)。用于测试 Delete 方法 S3 清理失败但仍继续的场景。 +func newS3StorageFailingDelete() *SoraS3Storage { + return &SoraS3Storage{} // settingService 为 nil → getConfig 返回 error +} + +// ==================== CreatePending ==================== + +func TestCreatePending_Success(t *testing.T) { + repo := newStubGenRepo() + svc := NewSoraGenerationService(repo, nil, nil) + + gen, err := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "一只猫跳舞", "video") + require.NoError(t, err) + require.Equal(t, int64(1), gen.ID) + require.Equal(t, int64(1), gen.UserID) + require.Equal(t, "sora2-landscape-10s", gen.Model) + require.Equal(t, "一只猫跳舞", gen.Prompt) + require.Equal(t, "video", gen.MediaType) + require.Equal(t, SoraGenStatusPending, gen.Status) + require.Equal(t, SoraStorageTypeNone, gen.StorageType) + require.Nil(t, gen.APIKeyID) +} + +func TestCreatePending_WithAPIKeyID(t *testing.T) { + repo := newStubGenRepo() + svc := NewSoraGenerationService(repo, nil, nil) + + apiKeyID := int64(42) + gen, err := svc.CreatePending(context.Background(), 1, &apiKeyID, "gpt-image", "画一朵花", "image") + require.NoError(t, err) + require.NotNil(t, gen.APIKeyID) + require.Equal(t, int64(42), *gen.APIKeyID) +} + +func TestCreatePending_RepoError(t *testing.T) { + repo := newStubGenRepo() + repo.createErr = fmt.Errorf("db write error") + svc := NewSoraGenerationService(repo, nil, nil) + + gen, err := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video") + require.Error(t, err) + require.Nil(t, gen) + require.Contains(t, err.Error(), "create generation") +} + +// ==================== MarkGenerating ==================== + +func TestMarkGenerating_Success(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending} + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.MarkGenerating(context.Background(), 1, "upstream-task-123") + require.NoError(t, err) + require.Equal(t, SoraGenStatusGenerating, repo.gens[1].Status) + require.Equal(t, "upstream-task-123", repo.gens[1].UpstreamTaskID) +} + +func TestMarkGenerating_NotFound(t *testing.T) { + repo := newStubGenRepo() + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.MarkGenerating(context.Background(), 999, "") + require.Error(t, err) +} + +func TestMarkGenerating_UpdateError(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending} + repo.updateErr = fmt.Errorf("update failed") + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.MarkGenerating(context.Background(), 1, "") + require.Error(t, err) +} + +// ==================== MarkCompleted ==================== + +func TestMarkCompleted_Success(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating} + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.MarkCompleted(context.Background(), 1, + "https://cdn.example.com/video.mp4", + []string{"https://cdn.example.com/video.mp4"}, + SoraStorageTypeS3, + []string{"sora/1/2024/01/01/uuid.mp4"}, + 1048576, + ) + require.NoError(t, err) + gen := repo.gens[1] + require.Equal(t, SoraGenStatusCompleted, gen.Status) + require.Equal(t, "https://cdn.example.com/video.mp4", gen.MediaURL) + require.Equal(t, []string{"https://cdn.example.com/video.mp4"}, gen.MediaURLs) + require.Equal(t, SoraStorageTypeS3, gen.StorageType) + require.Equal(t, []string{"sora/1/2024/01/01/uuid.mp4"}, gen.S3ObjectKeys) + require.Equal(t, int64(1048576), gen.FileSizeBytes) + require.NotNil(t, gen.CompletedAt) +} + +func TestMarkCompleted_NotFound(t *testing.T) { + repo := newStubGenRepo() + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.MarkCompleted(context.Background(), 999, "", nil, "", nil, 0) + require.Error(t, err) +} + +func TestMarkCompleted_UpdateError(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating} + repo.updateErr = fmt.Errorf("update failed") + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.MarkCompleted(context.Background(), 1, "url", nil, SoraStorageTypeUpstream, nil, 0) + require.Error(t, err) +} + +// ==================== MarkFailed ==================== + +func TestMarkFailed_Success(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating} + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.MarkFailed(context.Background(), 1, "上游返回 500 错误") + require.NoError(t, err) + gen := repo.gens[1] + require.Equal(t, SoraGenStatusFailed, gen.Status) + require.Equal(t, "上游返回 500 错误", gen.ErrorMessage) + require.NotNil(t, gen.CompletedAt) +} + +func TestMarkFailed_NotFound(t *testing.T) { + repo := newStubGenRepo() + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.MarkFailed(context.Background(), 999, "error") + require.Error(t, err) +} + +func TestMarkFailed_UpdateError(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating} + repo.updateErr = fmt.Errorf("update failed") + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.MarkFailed(context.Background(), 1, "err") + require.Error(t, err) +} + +// ==================== MarkCancelled ==================== + +func TestMarkCancelled_Pending(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending} + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.MarkCancelled(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, SoraGenStatusCancelled, repo.gens[1].Status) + require.NotNil(t, repo.gens[1].CompletedAt) +} + +func TestMarkCancelled_Generating(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusGenerating} + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.MarkCancelled(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, SoraGenStatusCancelled, repo.gens[1].Status) +} + +func TestMarkCancelled_Completed(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted} + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.MarkCancelled(context.Background(), 1) + require.Error(t, err) + require.ErrorIs(t, err, ErrSoraGenerationNotActive) +} + +func TestMarkCancelled_Failed(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusFailed} + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.MarkCancelled(context.Background(), 1) + require.Error(t, err) +} + +func TestMarkCancelled_AlreadyCancelled(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCancelled} + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.MarkCancelled(context.Background(), 1) + require.Error(t, err) +} + +func TestMarkCancelled_NotFound(t *testing.T) { + repo := newStubGenRepo() + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.MarkCancelled(context.Background(), 999) + require.Error(t, err) +} + +func TestMarkCancelled_UpdateError(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending} + repo.updateErr = fmt.Errorf("update failed") + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.MarkCancelled(context.Background(), 1) + require.Error(t, err) +} + +// ==================== GetByID ==================== + +func TestGetByID_Success(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted, Model: "sora2-landscape-10s"} + svc := NewSoraGenerationService(repo, nil, nil) + + gen, err := svc.GetByID(context.Background(), 1, 1) + require.NoError(t, err) + require.Equal(t, int64(1), gen.ID) + require.Equal(t, "sora2-landscape-10s", gen.Model) +} + +func TestGetByID_WrongUser(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 2, Status: SoraGenStatusCompleted} + svc := NewSoraGenerationService(repo, nil, nil) + + gen, err := svc.GetByID(context.Background(), 1, 1) + require.Error(t, err) + require.Nil(t, gen) + require.Contains(t, err.Error(), "无权访问") +} + +func TestGetByID_NotFound(t *testing.T) { + repo := newStubGenRepo() + svc := NewSoraGenerationService(repo, nil, nil) + + gen, err := svc.GetByID(context.Background(), 999, 1) + require.Error(t, err) + require.Nil(t, gen) +} + +// ==================== List ==================== + +func TestList_Success(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted, MediaType: "video"} + repo.gens[2] = &SoraGeneration{ID: 2, UserID: 1, Status: SoraGenStatusPending, MediaType: "image"} + repo.gens[3] = &SoraGeneration{ID: 3, UserID: 2, Status: SoraGenStatusCompleted, MediaType: "video"} + svc := NewSoraGenerationService(repo, nil, nil) + + gens, total, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1, Page: 1, PageSize: 20}) + require.NoError(t, err) + require.Len(t, gens, 2) // 只有 userID=1 的 + require.Equal(t, int64(2), total) +} + +func TestList_DefaultPagination(t *testing.T) { + repo := newStubGenRepo() + svc := NewSoraGenerationService(repo, nil, nil) + + // page=0, pageSize=0 → 应修正为 page=1, pageSize=20 + _, _, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1}) + require.NoError(t, err) +} + +func TestList_MaxPageSize(t *testing.T) { + repo := newStubGenRepo() + svc := NewSoraGenerationService(repo, nil, nil) + + // pageSize > 100 → 应限制为 100 + _, _, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1, Page: 1, PageSize: 200}) + require.NoError(t, err) +} + +func TestList_Error(t *testing.T) { + repo := newStubGenRepo() + repo.listErr = fmt.Errorf("db error") + svc := NewSoraGenerationService(repo, nil, nil) + + _, _, err := svc.List(context.Background(), SoraGenerationListParams{UserID: 1}) + require.Error(t, err) +} + +// ==================== Delete ==================== + +func TestDelete_Success(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted, StorageType: SoraStorageTypeUpstream} + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.Delete(context.Background(), 1, 1) + require.NoError(t, err) + _, exists := repo.gens[1] + require.False(t, exists) +} + +func TestDelete_WrongUser(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 2, Status: SoraGenStatusCompleted} + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.Delete(context.Background(), 1, 1) + require.Error(t, err) + require.Contains(t, err.Error(), "无权删除") +} + +func TestDelete_NotFound(t *testing.T) { + repo := newStubGenRepo() + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.Delete(context.Background(), 999, 1) + require.Error(t, err) +} + +func TestDelete_S3Cleanup_NilS3(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeS3, S3ObjectKeys: []string{"key1"}} + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.Delete(context.Background(), 1, 1) + require.NoError(t, err) // s3Storage 为 nil,跳过清理 +} + +func TestDelete_QuotaRelease_NilQuota(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeS3, FileSizeBytes: 1024} + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.Delete(context.Background(), 1, 1) + require.NoError(t, err) // quotaService 为 nil,跳过释放 +} + +func TestDelete_NonS3NoCleanup(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeLocal, FileSizeBytes: 1024} + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.Delete(context.Background(), 1, 1) + require.NoError(t, err) +} + +func TestDelete_DeleteRepoError(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, StorageType: SoraStorageTypeUpstream} + repo.deleteErr = fmt.Errorf("delete failed") + svc := NewSoraGenerationService(repo, nil, nil) + + err := svc.Delete(context.Background(), 1, 1) + require.Error(t, err) +} + +// ==================== CountActiveByUser ==================== + +func TestCountActiveByUser_Success(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusPending} + repo.gens[2] = &SoraGeneration{ID: 2, UserID: 1, Status: SoraGenStatusGenerating} + repo.gens[3] = &SoraGeneration{ID: 3, UserID: 1, Status: SoraGenStatusCompleted} // 不算 + svc := NewSoraGenerationService(repo, nil, nil) + + count, err := svc.CountActiveByUser(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, int64(2), count) +} + +func TestCountActiveByUser_NoActive(t *testing.T) { + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ID: 1, UserID: 1, Status: SoraGenStatusCompleted} + svc := NewSoraGenerationService(repo, nil, nil) + + count, err := svc.CountActiveByUser(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, int64(0), count) +} + +func TestCountActiveByUser_Error(t *testing.T) { + repo := newStubGenRepo() + repo.countErr = fmt.Errorf("db error") + svc := NewSoraGenerationService(repo, nil, nil) + + _, err := svc.CountActiveByUser(context.Background(), 1) + require.Error(t, err) +} + +// ==================== ResolveMediaURLs ==================== + +func TestResolveMediaURLs_NilGen(t *testing.T) { + svc := NewSoraGenerationService(newStubGenRepo(), nil, nil) + require.NoError(t, svc.ResolveMediaURLs(context.Background(), nil)) +} + +func TestResolveMediaURLs_NonS3(t *testing.T) { + svc := NewSoraGenerationService(newStubGenRepo(), nil, nil) + gen := &SoraGeneration{StorageType: SoraStorageTypeUpstream, MediaURL: "https://original.com/v.mp4"} + require.NoError(t, svc.ResolveMediaURLs(context.Background(), gen)) + require.Equal(t, "https://original.com/v.mp4", gen.MediaURL) // 不变 +} + +func TestResolveMediaURLs_S3NilStorage(t *testing.T) { + svc := NewSoraGenerationService(newStubGenRepo(), nil, nil) + gen := &SoraGeneration{StorageType: SoraStorageTypeS3, S3ObjectKeys: []string{"key1"}} + require.NoError(t, svc.ResolveMediaURLs(context.Background(), gen)) +} + +func TestResolveMediaURLs_Local(t *testing.T) { + svc := NewSoraGenerationService(newStubGenRepo(), nil, nil) + gen := &SoraGeneration{StorageType: SoraStorageTypeLocal, MediaURL: "/video/2024/01/01/file.mp4"} + require.NoError(t, svc.ResolveMediaURLs(context.Background(), gen)) + require.Equal(t, "/video/2024/01/01/file.mp4", gen.MediaURL) // 不变 +} + +// ==================== 状态流转完整测试 ==================== + +func TestStatusTransition_PendingToCompletedFlow(t *testing.T) { + repo := newStubGenRepo() + svc := NewSoraGenerationService(repo, nil, nil) + + // 1. 创建 pending + gen, err := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video") + require.NoError(t, err) + require.Equal(t, SoraGenStatusPending, gen.Status) + + // 2. 标记 generating + err = svc.MarkGenerating(context.Background(), gen.ID, "task-123") + require.NoError(t, err) + require.Equal(t, SoraGenStatusGenerating, repo.gens[gen.ID].Status) + + // 3. 标记 completed + err = svc.MarkCompleted(context.Background(), gen.ID, "https://s3.com/video.mp4", nil, SoraStorageTypeS3, []string{"key"}, 1024) + require.NoError(t, err) + require.Equal(t, SoraGenStatusCompleted, repo.gens[gen.ID].Status) +} + +func TestStatusTransition_PendingToFailedFlow(t *testing.T) { + repo := newStubGenRepo() + svc := NewSoraGenerationService(repo, nil, nil) + + gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video") + _ = svc.MarkGenerating(context.Background(), gen.ID, "") + + err := svc.MarkFailed(context.Background(), gen.ID, "上游超时") + require.NoError(t, err) + require.Equal(t, SoraGenStatusFailed, repo.gens[gen.ID].Status) + require.Equal(t, "上游超时", repo.gens[gen.ID].ErrorMessage) +} + +func TestStatusTransition_PendingToCancelledFlow(t *testing.T) { + repo := newStubGenRepo() + svc := NewSoraGenerationService(repo, nil, nil) + + gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video") + err := svc.MarkCancelled(context.Background(), gen.ID) + require.NoError(t, err) + require.Equal(t, SoraGenStatusCancelled, repo.gens[gen.ID].Status) +} + +func TestStatusTransition_GeneratingToCancelledFlow(t *testing.T) { + repo := newStubGenRepo() + svc := NewSoraGenerationService(repo, nil, nil) + + gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video") + _ = svc.MarkGenerating(context.Background(), gen.ID, "") + err := svc.MarkCancelled(context.Background(), gen.ID) + require.NoError(t, err) + require.Equal(t, SoraGenStatusCancelled, repo.gens[gen.ID].Status) +} + +// ==================== 权限隔离测试 ==================== + +func TestUserIsolation_CannotAccessOthersRecord(t *testing.T) { + repo := newStubGenRepo() + svc := NewSoraGenerationService(repo, nil, nil) + + gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video") + + // 用户 2 尝试访问用户 1 的记录 + _, err := svc.GetByID(context.Background(), gen.ID, 2) + require.Error(t, err) + require.Contains(t, err.Error(), "无权访问") +} + +func TestUserIsolation_CannotDeleteOthersRecord(t *testing.T) { + repo := newStubGenRepo() + svc := NewSoraGenerationService(repo, nil, nil) + + gen, _ := svc.CreatePending(context.Background(), 1, nil, "sora2-landscape-10s", "test", "video") + + err := svc.Delete(context.Background(), gen.ID, 2) + require.Error(t, err) + require.Contains(t, err.Error(), "无权删除") +} + +// ==================== Delete: S3 清理 + 配额释放路径 ==================== + +func TestDelete_S3Cleanup_WithS3Storage(t *testing.T) { + // S3 存储存在但 deleteObjects 会失败(settingService=nil), + // 验证 Delete 仍然成功(S3 错误只是记录日志) + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ + ID: 1, UserID: 1, + StorageType: SoraStorageTypeS3, + S3ObjectKeys: []string{"sora/1/2024/01/01/abc.mp4"}, + } + s3Storage := newS3StorageFailingDelete() + svc := NewSoraGenerationService(repo, s3Storage, nil) + + err := svc.Delete(context.Background(), 1, 1) + require.NoError(t, err) // S3 清理失败不影响删除 + _, exists := repo.gens[1] + require.False(t, exists) +} + +func TestDelete_QuotaRelease_WithQuotaService(t *testing.T) { + // 有配额服务时,删除 S3 类型记录会释放配额 + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ + ID: 1, UserID: 1, + StorageType: SoraStorageTypeS3, + FileSizeBytes: 1048576, // 1MB + } + + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 2097152} // 2MB + quotaService := NewSoraQuotaService(userRepo, nil, nil) + + svc := NewSoraGenerationService(repo, nil, quotaService) + err := svc.Delete(context.Background(), 1, 1) + require.NoError(t, err) + // 配额应被释放: 2MB - 1MB = 1MB + require.Equal(t, int64(1048576), userRepo.users[1].SoraStorageUsedBytes) +} + +func TestDelete_S3Cleanup_And_QuotaRelease(t *testing.T) { + // S3 清理 + 配额释放同时触发 + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ + ID: 1, UserID: 1, + StorageType: SoraStorageTypeS3, + S3ObjectKeys: []string{"key1"}, + FileSizeBytes: 512, + } + + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024} + quotaService := NewSoraQuotaService(userRepo, nil, nil) + s3Storage := newS3StorageFailingDelete() + + svc := NewSoraGenerationService(repo, s3Storage, quotaService) + err := svc.Delete(context.Background(), 1, 1) + require.NoError(t, err) + _, exists := repo.gens[1] + require.False(t, exists) + require.Equal(t, int64(512), userRepo.users[1].SoraStorageUsedBytes) +} + +func TestDelete_QuotaRelease_LocalStorage(t *testing.T) { + // 本地存储同样需要释放配额 + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ + ID: 1, UserID: 1, + StorageType: SoraStorageTypeLocal, + FileSizeBytes: 1024, + } + + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 2048} + quotaService := NewSoraQuotaService(userRepo, nil, nil) + + svc := NewSoraGenerationService(repo, nil, quotaService) + err := svc.Delete(context.Background(), 1, 1) + require.NoError(t, err) + require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes) +} + +func TestDelete_QuotaRelease_ZeroFileSize(t *testing.T) { + // FileSizeBytes=0 跳过配额释放 + repo := newStubGenRepo() + repo.gens[1] = &SoraGeneration{ + ID: 1, UserID: 1, + StorageType: SoraStorageTypeS3, + FileSizeBytes: 0, + } + + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024} + quotaService := NewSoraQuotaService(userRepo, nil, nil) + + svc := NewSoraGenerationService(repo, nil, quotaService) + err := svc.Delete(context.Background(), 1, 1) + require.NoError(t, err) + require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes) +} + +// ==================== ResolveMediaURLs: S3 + CDN 路径 ==================== + +func TestResolveMediaURLs_S3_CDN_SingleKey(t *testing.T) { + s3Storage := newS3StorageWithCDN("https://cdn.example.com") + svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil) + + gen := &SoraGeneration{ + StorageType: SoraStorageTypeS3, + S3ObjectKeys: []string{"sora/1/2024/01/01/video.mp4"}, + MediaURL: "original", + } + err := svc.ResolveMediaURLs(context.Background(), gen) + require.NoError(t, err) + require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/video.mp4", gen.MediaURL) +} + +func TestResolveMediaURLs_S3_CDN_MultipleKeys(t *testing.T) { + s3Storage := newS3StorageWithCDN("https://cdn.example.com/") + svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil) + + gen := &SoraGeneration{ + StorageType: SoraStorageTypeS3, + S3ObjectKeys: []string{ + "sora/1/2024/01/01/img1.png", + "sora/1/2024/01/01/img2.png", + "sora/1/2024/01/01/img3.png", + }, + MediaURL: "original", + } + err := svc.ResolveMediaURLs(context.Background(), gen) + require.NoError(t, err) + // 主 URL 更新为第一个 key 的 CDN URL + require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img1.png", gen.MediaURL) + // 多图 URLs 全部更新 + require.Len(t, gen.MediaURLs, 3) + require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img1.png", gen.MediaURLs[0]) + require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img2.png", gen.MediaURLs[1]) + require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/img3.png", gen.MediaURLs[2]) +} + +func TestResolveMediaURLs_S3_EmptyKeys(t *testing.T) { + s3Storage := newS3StorageWithCDN("https://cdn.example.com") + svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil) + + gen := &SoraGeneration{ + StorageType: SoraStorageTypeS3, + S3ObjectKeys: []string{}, + MediaURL: "original", + } + err := svc.ResolveMediaURLs(context.Background(), gen) + require.NoError(t, err) + require.Equal(t, "original", gen.MediaURL) // 不变 +} + +func TestResolveMediaURLs_S3_GetAccessURL_Error(t *testing.T) { + // 使用无 settingService 的 S3 Storage,getClient 会失败 + s3Storage := newS3StorageFailingDelete() // 同样 GetAccessURL 也会失败 + svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil) + + gen := &SoraGeneration{ + StorageType: SoraStorageTypeS3, + S3ObjectKeys: []string{"sora/1/2024/01/01/video.mp4"}, + MediaURL: "original", + } + err := svc.ResolveMediaURLs(context.Background(), gen) + require.Error(t, err) // GetAccessURL 失败应传播错误 +} + +func TestResolveMediaURLs_S3_MultiKey_ErrorOnSecond(t *testing.T) { + // 只有一个 key 时走主 URL 路径成功,但多 key 路径的错误也需覆盖 + s3Storage := newS3StorageFailingDelete() + svc := NewSoraGenerationService(newStubGenRepo(), s3Storage, nil) + + gen := &SoraGeneration{ + StorageType: SoraStorageTypeS3, + S3ObjectKeys: []string{ + "sora/1/2024/01/01/img1.png", + "sora/1/2024/01/01/img2.png", + }, + MediaURL: "original", + } + err := svc.ResolveMediaURLs(context.Background(), gen) + require.Error(t, err) // 第一个 key 的 GetAccessURL 就会失败 +} diff --git a/backend/internal/service/sora_media_storage.go b/backend/internal/service/sora_media_storage.go index eb363c4f..18783865 100644 --- a/backend/internal/service/sora_media_storage.go +++ b/backend/internal/service/sora_media_storage.go @@ -157,6 +157,64 @@ func (s *SoraMediaStorage) StoreFromURLs(ctx context.Context, mediaType string, return results, nil } +// TotalSizeByRelativePaths 统计本地存储路径总大小(仅统计 /image 和 /video 路径)。 +func (s *SoraMediaStorage) TotalSizeByRelativePaths(paths []string) (int64, error) { + if s == nil || len(paths) == 0 { + return 0, nil + } + var total int64 + for _, p := range paths { + localPath, err := s.resolveLocalPath(p) + if err != nil { + continue + } + info, err := os.Stat(localPath) + if err != nil { + if os.IsNotExist(err) { + continue + } + return 0, err + } + if info.Mode().IsRegular() { + total += info.Size() + } + } + return total, nil +} + +// DeleteByRelativePaths 删除本地媒体路径(仅删除 /image 和 /video 路径)。 +func (s *SoraMediaStorage) DeleteByRelativePaths(paths []string) error { + if s == nil || len(paths) == 0 { + return nil + } + var lastErr error + for _, p := range paths { + localPath, err := s.resolveLocalPath(p) + if err != nil { + continue + } + if err := os.Remove(localPath); err != nil && !os.IsNotExist(err) { + lastErr = err + } + } + return lastErr +} + +func (s *SoraMediaStorage) resolveLocalPath(relativePath string) (string, error) { + if s == nil || strings.TrimSpace(relativePath) == "" { + return "", errors.New("empty path") + } + cleaned := path.Clean(relativePath) + if !strings.HasPrefix(cleaned, "/image/") && !strings.HasPrefix(cleaned, "/video/") { + return "", errors.New("not a local media path") + } + if strings.TrimSpace(s.root) == "" { + return "", errors.New("storage root not configured") + } + relative := strings.TrimPrefix(cleaned, "/") + return filepath.Join(s.root, filepath.FromSlash(relative)), nil +} + func (s *SoraMediaStorage) downloadAndStore(ctx context.Context, mediaType, rawURL string) (string, error) { if strings.TrimSpace(rawURL) == "" { return "", errors.New("empty url") diff --git a/backend/internal/service/sora_models.go b/backend/internal/service/sora_models.go index 80b20a4b..53d4c788 100644 --- a/backend/internal/service/sora_models.go +++ b/backend/internal/service/sora_models.go @@ -1,6 +1,9 @@ package service import ( + "regexp" + "sort" + "strconv" "strings" "github.com/Wei-Shaw/sub2api/internal/config" @@ -247,6 +250,218 @@ func GetSoraModelConfig(model string) (SoraModelConfig, bool) { return cfg, ok } +// SoraModelFamily 模型家族(前端 Sora 客户端使用) +type SoraModelFamily struct { + ID string `json:"id"` + Name string `json:"name"` + Type string `json:"type"` + Orientations []string `json:"orientations"` + Durations []int `json:"durations,omitempty"` +} + +var ( + videoSuffixRe = regexp.MustCompile(`-(landscape|portrait)-(\d+)s$`) + imageSuffixRe = regexp.MustCompile(`-(landscape|portrait)$`) + + soraFamilyNames = map[string]string{ + "sora2": "Sora 2", + "sora2pro": "Sora 2 Pro", + "sora2pro-hd": "Sora 2 Pro HD", + "gpt-image": "GPT Image", + } +) + +// BuildSoraModelFamilies 从 soraModelConfigs 自动聚合模型家族及其支持的方向和时长 +func BuildSoraModelFamilies() []SoraModelFamily { + type familyData struct { + modelType string + orientations map[string]bool + durations map[int]bool + } + families := make(map[string]*familyData) + + for id, cfg := range soraModelConfigs { + if cfg.Type == "prompt_enhance" { + continue + } + var famID, orientation string + var duration int + + switch cfg.Type { + case "video": + if m := videoSuffixRe.FindStringSubmatch(id); m != nil { + famID = id[:len(id)-len(m[0])] + orientation = m[1] + duration, _ = strconv.Atoi(m[2]) + } + case "image": + if m := imageSuffixRe.FindStringSubmatch(id); m != nil { + famID = id[:len(id)-len(m[0])] + orientation = m[1] + } else { + famID = id + orientation = "square" + } + } + if famID == "" { + continue + } + + fd, ok := families[famID] + if !ok { + fd = &familyData{ + modelType: cfg.Type, + orientations: make(map[string]bool), + durations: make(map[int]bool), + } + families[famID] = fd + } + if orientation != "" { + fd.orientations[orientation] = true + } + if duration > 0 { + fd.durations[duration] = true + } + } + + // 排序:视频在前、图像在后,同类按名称排序 + famIDs := make([]string, 0, len(families)) + for id := range families { + famIDs = append(famIDs, id) + } + sort.Slice(famIDs, func(i, j int) bool { + fi, fj := families[famIDs[i]], families[famIDs[j]] + if fi.modelType != fj.modelType { + return fi.modelType == "video" + } + return famIDs[i] < famIDs[j] + }) + + result := make([]SoraModelFamily, 0, len(famIDs)) + for _, famID := range famIDs { + fd := families[famID] + fam := SoraModelFamily{ + ID: famID, + Name: soraFamilyNames[famID], + Type: fd.modelType, + } + if fam.Name == "" { + fam.Name = famID + } + for o := range fd.orientations { + fam.Orientations = append(fam.Orientations, o) + } + sort.Strings(fam.Orientations) + for d := range fd.durations { + fam.Durations = append(fam.Durations, d) + } + sort.Ints(fam.Durations) + result = append(result, fam) + } + return result +} + +// BuildSoraModelFamiliesFromIDs 从任意模型 ID 列表聚合模型家族(用于解析上游返回的模型列表)。 +// 通过命名约定自动识别视频/图像模型并分组。 +func BuildSoraModelFamiliesFromIDs(modelIDs []string) []SoraModelFamily { + type familyData struct { + modelType string + orientations map[string]bool + durations map[int]bool + } + families := make(map[string]*familyData) + + for _, id := range modelIDs { + id = strings.ToLower(strings.TrimSpace(id)) + if id == "" || strings.HasPrefix(id, "prompt-enhance") { + continue + } + + var famID, orientation, modelType string + var duration int + + if m := videoSuffixRe.FindStringSubmatch(id); m != nil { + // 视频模型: {family}-{orientation}-{duration}s + famID = id[:len(id)-len(m[0])] + orientation = m[1] + duration, _ = strconv.Atoi(m[2]) + modelType = "video" + } else if m := imageSuffixRe.FindStringSubmatch(id); m != nil { + // 图像模型(带方向): {family}-{orientation} + famID = id[:len(id)-len(m[0])] + orientation = m[1] + modelType = "image" + } else if cfg, ok := soraModelConfigs[id]; ok && cfg.Type == "image" { + // 已知的无后缀图像模型(如 gpt-image) + famID = id + orientation = "square" + modelType = "image" + } else if strings.Contains(id, "image") { + // 未知但名称包含 image 的模型,推断为图像模型 + famID = id + orientation = "square" + modelType = "image" + } else { + continue + } + + if famID == "" { + continue + } + + fd, ok := families[famID] + if !ok { + fd = &familyData{ + modelType: modelType, + orientations: make(map[string]bool), + durations: make(map[int]bool), + } + families[famID] = fd + } + if orientation != "" { + fd.orientations[orientation] = true + } + if duration > 0 { + fd.durations[duration] = true + } + } + + famIDs := make([]string, 0, len(families)) + for id := range families { + famIDs = append(famIDs, id) + } + sort.Slice(famIDs, func(i, j int) bool { + fi, fj := families[famIDs[i]], families[famIDs[j]] + if fi.modelType != fj.modelType { + return fi.modelType == "video" + } + return famIDs[i] < famIDs[j] + }) + + result := make([]SoraModelFamily, 0, len(famIDs)) + for _, famID := range famIDs { + fd := families[famID] + fam := SoraModelFamily{ + ID: famID, + Name: soraFamilyNames[famID], + Type: fd.modelType, + } + if fam.Name == "" { + fam.Name = famID + } + for o := range fd.orientations { + fam.Orientations = append(fam.Orientations, o) + } + sort.Strings(fam.Orientations) + for d := range fd.durations { + fam.Durations = append(fam.Durations, d) + } + sort.Ints(fam.Durations) + result = append(result, fam) + } + return result +} + // DefaultSoraModels returns the default Sora model list. func DefaultSoraModels(cfg *config.Config) []openai.Model { models := make([]openai.Model, 0, len(soraModelIDs)) diff --git a/backend/internal/service/sora_quota_service.go b/backend/internal/service/sora_quota_service.go new file mode 100644 index 00000000..f0843374 --- /dev/null +++ b/backend/internal/service/sora_quota_service.go @@ -0,0 +1,257 @@ +package service + +import ( + "context" + "errors" + "fmt" + "strconv" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +// SoraQuotaService 管理 Sora 用户存储配额。 +// 配额优先级:用户级 → 分组级 → 系统默认值。 +type SoraQuotaService struct { + userRepo UserRepository + groupRepo GroupRepository + settingService *SettingService +} + +// NewSoraQuotaService 创建配额服务实例。 +func NewSoraQuotaService( + userRepo UserRepository, + groupRepo GroupRepository, + settingService *SettingService, +) *SoraQuotaService { + return &SoraQuotaService{ + userRepo: userRepo, + groupRepo: groupRepo, + settingService: settingService, + } +} + +// QuotaInfo 返回给客户端的配额信息。 +type QuotaInfo struct { + QuotaBytes int64 `json:"quota_bytes"` // 总配额(0 表示无限制) + UsedBytes int64 `json:"used_bytes"` // 已使用 + AvailableBytes int64 `json:"available_bytes"` // 剩余可用(无限制时为 0) + QuotaSource string `json:"quota_source"` // 配额来源:user / group / system / unlimited + Source string `json:"source,omitempty"` // 兼容旧字段 +} + +// ErrSoraStorageQuotaExceeded 表示配额不足。 +var ErrSoraStorageQuotaExceeded = errors.New("sora storage quota exceeded") + +// QuotaExceededError 包含配额不足的上下文信息。 +type QuotaExceededError struct { + QuotaBytes int64 + UsedBytes int64 +} + +func (e *QuotaExceededError) Error() string { + if e == nil { + return "存储配额不足" + } + return fmt.Sprintf("存储配额不足(已用 %d / 配额 %d 字节)", e.UsedBytes, e.QuotaBytes) +} + +type soraQuotaAtomicUserRepository interface { + AddSoraStorageUsageWithQuota(ctx context.Context, userID int64, deltaBytes int64, effectiveQuota int64) (int64, error) + ReleaseSoraStorageUsageAtomic(ctx context.Context, userID int64, deltaBytes int64) (int64, error) +} + +// GetQuota 获取用户的存储配额信息。 +// 优先级:用户级 > 用户所属分组级 > 系统默认值。 +func (s *SoraQuotaService) GetQuota(ctx context.Context, userID int64) (*QuotaInfo, error) { + user, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + return nil, fmt.Errorf("get user: %w", err) + } + + info := &QuotaInfo{ + UsedBytes: user.SoraStorageUsedBytes, + } + + // 1. 用户级配额 + if user.SoraStorageQuotaBytes > 0 { + info.QuotaBytes = user.SoraStorageQuotaBytes + info.QuotaSource = "user" + info.Source = info.QuotaSource + info.AvailableBytes = calcAvailableBytes(info.QuotaBytes, info.UsedBytes) + return info, nil + } + + // 2. 分组级配额(取用户可用分组中最大的配额) + if len(user.AllowedGroups) > 0 { + var maxGroupQuota int64 + for _, gid := range user.AllowedGroups { + group, err := s.groupRepo.GetByID(ctx, gid) + if err != nil { + continue + } + if group.SoraStorageQuotaBytes > maxGroupQuota { + maxGroupQuota = group.SoraStorageQuotaBytes + } + } + if maxGroupQuota > 0 { + info.QuotaBytes = maxGroupQuota + info.QuotaSource = "group" + info.Source = info.QuotaSource + info.AvailableBytes = calcAvailableBytes(info.QuotaBytes, info.UsedBytes) + return info, nil + } + } + + // 3. 系统默认值 + defaultQuota := s.getSystemDefaultQuota(ctx) + if defaultQuota > 0 { + info.QuotaBytes = defaultQuota + info.QuotaSource = "system" + info.Source = info.QuotaSource + info.AvailableBytes = calcAvailableBytes(info.QuotaBytes, info.UsedBytes) + return info, nil + } + + // 无配额限制 + info.QuotaSource = "unlimited" + info.Source = info.QuotaSource + info.AvailableBytes = 0 + return info, nil +} + +// CheckQuota 检查用户是否有足够的存储配额。 +// 返回 nil 表示配额充足或无限制。 +func (s *SoraQuotaService) CheckQuota(ctx context.Context, userID int64, additionalBytes int64) error { + quota, err := s.GetQuota(ctx, userID) + if err != nil { + return err + } + // 0 表示无限制 + if quota.QuotaBytes == 0 { + return nil + } + if quota.UsedBytes+additionalBytes > quota.QuotaBytes { + return &QuotaExceededError{ + QuotaBytes: quota.QuotaBytes, + UsedBytes: quota.UsedBytes, + } + } + return nil +} + +// AddUsage 原子累加用量(上传成功后调用)。 +func (s *SoraQuotaService) AddUsage(ctx context.Context, userID int64, bytes int64) error { + if bytes <= 0 { + return nil + } + + quota, err := s.GetQuota(ctx, userID) + if err != nil { + return err + } + + if quota.QuotaBytes > 0 && quota.UsedBytes+bytes > quota.QuotaBytes { + return &QuotaExceededError{ + QuotaBytes: quota.QuotaBytes, + UsedBytes: quota.UsedBytes, + } + } + + if repo, ok := s.userRepo.(soraQuotaAtomicUserRepository); ok { + newUsed, err := repo.AddSoraStorageUsageWithQuota(ctx, userID, bytes, quota.QuotaBytes) + if err != nil { + if errors.Is(err, ErrSoraStorageQuotaExceeded) { + return &QuotaExceededError{ + QuotaBytes: quota.QuotaBytes, + UsedBytes: quota.UsedBytes, + } + } + return fmt.Errorf("update user quota usage (atomic): %w", err) + } + logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 累加用量 user=%d +%d total=%d", userID, bytes, newUsed) + return nil + } + + user, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + return fmt.Errorf("get user for quota update: %w", err) + } + user.SoraStorageUsedBytes += bytes + if err := s.userRepo.Update(ctx, user); err != nil { + return fmt.Errorf("update user quota usage: %w", err) + } + logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 累加用量 user=%d +%d total=%d", userID, bytes, user.SoraStorageUsedBytes) + return nil +} + +// ReleaseUsage 释放用量(删除文件后调用)。 +func (s *SoraQuotaService) ReleaseUsage(ctx context.Context, userID int64, bytes int64) error { + if bytes <= 0 { + return nil + } + + if repo, ok := s.userRepo.(soraQuotaAtomicUserRepository); ok { + newUsed, err := repo.ReleaseSoraStorageUsageAtomic(ctx, userID, bytes) + if err != nil { + return fmt.Errorf("update user quota release (atomic): %w", err) + } + logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 释放用量 user=%d -%d total=%d", userID, bytes, newUsed) + return nil + } + + user, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + return fmt.Errorf("get user for quota release: %w", err) + } + user.SoraStorageUsedBytes -= bytes + if user.SoraStorageUsedBytes < 0 { + user.SoraStorageUsedBytes = 0 + } + if err := s.userRepo.Update(ctx, user); err != nil { + return fmt.Errorf("update user quota release: %w", err) + } + logger.LegacyPrintf("service.sora_quota", "[SoraQuota] 释放用量 user=%d -%d total=%d", userID, bytes, user.SoraStorageUsedBytes) + return nil +} + +func calcAvailableBytes(quotaBytes, usedBytes int64) int64 { + if quotaBytes <= 0 { + return 0 + } + if usedBytes >= quotaBytes { + return 0 + } + return quotaBytes - usedBytes +} + +func (s *SoraQuotaService) getSystemDefaultQuota(ctx context.Context) int64 { + if s.settingService == nil { + return 0 + } + settings, err := s.settingService.GetSoraS3Settings(ctx) + if err != nil { + return 0 + } + return settings.DefaultStorageQuotaBytes +} + +// GetQuotaFromSettings 从系统设置获取默认配额(供外部使用)。 +func (s *SoraQuotaService) GetQuotaFromSettings(ctx context.Context) int64 { + return s.getSystemDefaultQuota(ctx) +} + +// SetUserQuota 设置用户级配额(管理员操作)。 +func SetUserSoraQuota(ctx context.Context, userRepo UserRepository, userID int64, quotaBytes int64) error { + user, err := userRepo.GetByID(ctx, userID) + if err != nil { + return err + } + user.SoraStorageQuotaBytes = quotaBytes + return userRepo.Update(ctx, user) +} + +// ParseQuotaBytes 解析配额字符串为字节数。 +func ParseQuotaBytes(s string) int64 { + v, _ := strconv.ParseInt(s, 10, 64) + return v +} diff --git a/backend/internal/service/sora_quota_service_test.go b/backend/internal/service/sora_quota_service_test.go new file mode 100644 index 00000000..040e427d --- /dev/null +++ b/backend/internal/service/sora_quota_service_test.go @@ -0,0 +1,492 @@ +//go:build unit + +package service + +import ( + "context" + "fmt" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +// ==================== Stub: GroupRepository (用于 SoraQuotaService) ==================== + +var _ GroupRepository = (*stubGroupRepoForQuota)(nil) + +type stubGroupRepoForQuota struct { + groups map[int64]*Group +} + +func newStubGroupRepoForQuota() *stubGroupRepoForQuota { + return &stubGroupRepoForQuota{groups: make(map[int64]*Group)} +} + +func (r *stubGroupRepoForQuota) GetByID(_ context.Context, id int64) (*Group, error) { + if g, ok := r.groups[id]; ok { + return g, nil + } + return nil, fmt.Errorf("group not found") +} +func (r *stubGroupRepoForQuota) Create(context.Context, *Group) error { return nil } +func (r *stubGroupRepoForQuota) GetByIDLite(_ context.Context, id int64) (*Group, error) { + return r.GetByID(context.Background(), id) +} +func (r *stubGroupRepoForQuota) Update(context.Context, *Group) error { return nil } +func (r *stubGroupRepoForQuota) Delete(context.Context, int64) error { return nil } +func (r *stubGroupRepoForQuota) DeleteCascade(context.Context, int64) ([]int64, error) { + return nil, nil +} +func (r *stubGroupRepoForQuota) List(context.Context, pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubGroupRepoForQuota) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]Group, *pagination.PaginationResult, error) { + return nil, nil, nil +} +func (r *stubGroupRepoForQuota) ListActive(context.Context) ([]Group, error) { return nil, nil } +func (r *stubGroupRepoForQuota) ListActiveByPlatform(context.Context, string) ([]Group, error) { + return nil, nil +} +func (r *stubGroupRepoForQuota) ExistsByName(context.Context, string) (bool, error) { + return false, nil +} +func (r *stubGroupRepoForQuota) GetAccountCount(context.Context, int64) (int64, error) { + return 0, nil +} +func (r *stubGroupRepoForQuota) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) { + return 0, nil +} +func (r *stubGroupRepoForQuota) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) { + return nil, nil +} +func (r *stubGroupRepoForQuota) BindAccountsToGroup(context.Context, int64, []int64) error { + return nil +} +func (r *stubGroupRepoForQuota) UpdateSortOrders(context.Context, []GroupSortOrderUpdate) error { + return nil +} + +// ==================== Stub: SettingRepository (用于 SettingService) ==================== + +var _ SettingRepository = (*stubSettingRepoForQuota)(nil) + +type stubSettingRepoForQuota struct { + values map[string]string +} + +func newStubSettingRepoForQuota(values map[string]string) *stubSettingRepoForQuota { + if values == nil { + values = make(map[string]string) + } + return &stubSettingRepoForQuota{values: values} +} + +func (r *stubSettingRepoForQuota) Get(_ context.Context, key string) (*Setting, error) { + if v, ok := r.values[key]; ok { + return &Setting{Key: key, Value: v}, nil + } + return nil, ErrSettingNotFound +} +func (r *stubSettingRepoForQuota) GetValue(_ context.Context, key string) (string, error) { + if v, ok := r.values[key]; ok { + return v, nil + } + return "", ErrSettingNotFound +} +func (r *stubSettingRepoForQuota) Set(_ context.Context, key, value string) error { + r.values[key] = value + return nil +} +func (r *stubSettingRepoForQuota) GetMultiple(_ context.Context, keys []string) (map[string]string, error) { + result := make(map[string]string) + for _, k := range keys { + if v, ok := r.values[k]; ok { + result[k] = v + } + } + return result, nil +} +func (r *stubSettingRepoForQuota) SetMultiple(_ context.Context, settings map[string]string) error { + for k, v := range settings { + r.values[k] = v + } + return nil +} +func (r *stubSettingRepoForQuota) GetAll(_ context.Context) (map[string]string, error) { + return r.values, nil +} +func (r *stubSettingRepoForQuota) Delete(_ context.Context, key string) error { + delete(r.values, key) + return nil +} + +// ==================== GetQuota ==================== + +func TestGetQuota_UserLevel(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ + ID: 1, + SoraStorageQuotaBytes: 10 * 1024 * 1024, // 10MB + SoraStorageUsedBytes: 3 * 1024 * 1024, // 3MB + } + svc := NewSoraQuotaService(userRepo, nil, nil) + + quota, err := svc.GetQuota(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, int64(10*1024*1024), quota.QuotaBytes) + require.Equal(t, int64(3*1024*1024), quota.UsedBytes) + require.Equal(t, "user", quota.Source) +} + +func TestGetQuota_GroupLevel(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ + ID: 1, + SoraStorageQuotaBytes: 0, // 用户级无配额 + SoraStorageUsedBytes: 1024, + AllowedGroups: []int64{10, 20}, + } + + groupRepo := newStubGroupRepoForQuota() + groupRepo.groups[10] = &Group{ID: 10, SoraStorageQuotaBytes: 5 * 1024 * 1024} + groupRepo.groups[20] = &Group{ID: 20, SoraStorageQuotaBytes: 20 * 1024 * 1024} + + svc := NewSoraQuotaService(userRepo, groupRepo, nil) + quota, err := svc.GetQuota(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, int64(20*1024*1024), quota.QuotaBytes) // 取最大值 + require.Equal(t, "group", quota.Source) +} + +func TestGetQuota_SystemLevel(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ID: 1, SoraStorageQuotaBytes: 0, SoraStorageUsedBytes: 512} + + settingRepo := newStubSettingRepoForQuota(map[string]string{ + SettingKeySoraDefaultStorageQuotaBytes: "104857600", // 100MB + }) + settingService := NewSettingService(settingRepo, &config.Config{}) + svc := NewSoraQuotaService(userRepo, nil, settingService) + + quota, err := svc.GetQuota(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, int64(104857600), quota.QuotaBytes) + require.Equal(t, "system", quota.Source) +} + +func TestGetQuota_NoLimit(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ID: 1, SoraStorageQuotaBytes: 0, SoraStorageUsedBytes: 0} + svc := NewSoraQuotaService(userRepo, nil, nil) + + quota, err := svc.GetQuota(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, int64(0), quota.QuotaBytes) + require.Equal(t, "unlimited", quota.Source) +} + +func TestGetQuota_UserNotFound(t *testing.T) { + userRepo := newStubUserRepoForQuota() + svc := NewSoraQuotaService(userRepo, nil, nil) + + _, err := svc.GetQuota(context.Background(), 999) + require.Error(t, err) + require.Contains(t, err.Error(), "get user") +} + +func TestGetQuota_GroupRepoError(t *testing.T) { + // 分组获取失败时跳过该分组(不影响整体) + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ + ID: 1, SoraStorageQuotaBytes: 0, + AllowedGroups: []int64{999}, // 不存在的分组 + } + + groupRepo := newStubGroupRepoForQuota() + svc := NewSoraQuotaService(userRepo, groupRepo, nil) + + quota, err := svc.GetQuota(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, "unlimited", quota.Source) // 分组获取失败,回退到无限制 +} + +// ==================== CheckQuota ==================== + +func TestCheckQuota_Sufficient(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ + ID: 1, + SoraStorageQuotaBytes: 10 * 1024 * 1024, + SoraStorageUsedBytes: 3 * 1024 * 1024, + } + svc := NewSoraQuotaService(userRepo, nil, nil) + + err := svc.CheckQuota(context.Background(), 1, 1024) + require.NoError(t, err) +} + +func TestCheckQuota_Exceeded(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ + ID: 1, + SoraStorageQuotaBytes: 10 * 1024 * 1024, + SoraStorageUsedBytes: 10 * 1024 * 1024, // 已满 + } + svc := NewSoraQuotaService(userRepo, nil, nil) + + err := svc.CheckQuota(context.Background(), 1, 1) + require.Error(t, err) + require.Contains(t, err.Error(), "配额不足") +} + +func TestCheckQuota_NoLimit(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ + ID: 1, + SoraStorageQuotaBytes: 0, // 无限制 + SoraStorageUsedBytes: 1000000000, + } + svc := NewSoraQuotaService(userRepo, nil, nil) + + err := svc.CheckQuota(context.Background(), 1, 999999999) + require.NoError(t, err) // 无限制时始终通过 +} + +func TestCheckQuota_ExactBoundary(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ + ID: 1, + SoraStorageQuotaBytes: 1024, + SoraStorageUsedBytes: 1024, // 恰好满 + } + svc := NewSoraQuotaService(userRepo, nil, nil) + + // 额外 0 字节不超 + require.NoError(t, svc.CheckQuota(context.Background(), 1, 0)) + // 额外 1 字节超出 + require.Error(t, svc.CheckQuota(context.Background(), 1, 1)) +} + +// ==================== AddUsage ==================== + +func TestAddUsage_Success(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024} + svc := NewSoraQuotaService(userRepo, nil, nil) + + err := svc.AddUsage(context.Background(), 1, 2048) + require.NoError(t, err) + require.Equal(t, int64(3072), userRepo.users[1].SoraStorageUsedBytes) +} + +func TestAddUsage_ZeroBytes(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024} + svc := NewSoraQuotaService(userRepo, nil, nil) + + err := svc.AddUsage(context.Background(), 1, 0) + require.NoError(t, err) + require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes) // 不变 +} + +func TestAddUsage_NegativeBytes(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024} + svc := NewSoraQuotaService(userRepo, nil, nil) + + err := svc.AddUsage(context.Background(), 1, -100) + require.NoError(t, err) + require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes) // 不变 +} + +func TestAddUsage_UserNotFound(t *testing.T) { + userRepo := newStubUserRepoForQuota() + svc := NewSoraQuotaService(userRepo, nil, nil) + + err := svc.AddUsage(context.Background(), 999, 1024) + require.Error(t, err) +} + +func TestAddUsage_UpdateError(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 0} + userRepo.updateErr = fmt.Errorf("db error") + svc := NewSoraQuotaService(userRepo, nil, nil) + + err := svc.AddUsage(context.Background(), 1, 1024) + require.Error(t, err) + require.Contains(t, err.Error(), "update user quota usage") +} + +// ==================== ReleaseUsage ==================== + +func TestReleaseUsage_Success(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 3072} + svc := NewSoraQuotaService(userRepo, nil, nil) + + err := svc.ReleaseUsage(context.Background(), 1, 1024) + require.NoError(t, err) + require.Equal(t, int64(2048), userRepo.users[1].SoraStorageUsedBytes) +} + +func TestReleaseUsage_ClampToZero(t *testing.T) { + // 释放量大于已用量时,应 clamp 到 0 + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 500} + svc := NewSoraQuotaService(userRepo, nil, nil) + + err := svc.ReleaseUsage(context.Background(), 1, 1000) + require.NoError(t, err) + require.Equal(t, int64(0), userRepo.users[1].SoraStorageUsedBytes) +} + +func TestReleaseUsage_ZeroBytes(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024} + svc := NewSoraQuotaService(userRepo, nil, nil) + + err := svc.ReleaseUsage(context.Background(), 1, 0) + require.NoError(t, err) + require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes) // 不变 +} + +func TestReleaseUsage_NegativeBytes(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024} + svc := NewSoraQuotaService(userRepo, nil, nil) + + err := svc.ReleaseUsage(context.Background(), 1, -50) + require.NoError(t, err) + require.Equal(t, int64(1024), userRepo.users[1].SoraStorageUsedBytes) // 不变 +} + +func TestReleaseUsage_UserNotFound(t *testing.T) { + userRepo := newStubUserRepoForQuota() + svc := NewSoraQuotaService(userRepo, nil, nil) + + err := svc.ReleaseUsage(context.Background(), 999, 1024) + require.Error(t, err) +} + +func TestReleaseUsage_UpdateError(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ID: 1, SoraStorageUsedBytes: 1024} + userRepo.updateErr = fmt.Errorf("db error") + svc := NewSoraQuotaService(userRepo, nil, nil) + + err := svc.ReleaseUsage(context.Background(), 1, 512) + require.Error(t, err) + require.Contains(t, err.Error(), "update user quota release") +} + +// ==================== GetQuotaFromSettings ==================== + +func TestGetQuotaFromSettings_NilSettingService(t *testing.T) { + svc := NewSoraQuotaService(nil, nil, nil) + require.Equal(t, int64(0), svc.GetQuotaFromSettings(context.Background())) +} + +func TestGetQuotaFromSettings_WithSettings(t *testing.T) { + settingRepo := newStubSettingRepoForQuota(map[string]string{ + SettingKeySoraDefaultStorageQuotaBytes: "52428800", // 50MB + }) + settingService := NewSettingService(settingRepo, &config.Config{}) + svc := NewSoraQuotaService(nil, nil, settingService) + + require.Equal(t, int64(52428800), svc.GetQuotaFromSettings(context.Background())) +} + +// ==================== SetUserSoraQuota ==================== + +func TestSetUserSoraQuota_Success(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ID: 1, SoraStorageQuotaBytes: 0} + + err := SetUserSoraQuota(context.Background(), userRepo, 1, 10*1024*1024) + require.NoError(t, err) + require.Equal(t, int64(10*1024*1024), userRepo.users[1].SoraStorageQuotaBytes) +} + +func TestSetUserSoraQuota_UserNotFound(t *testing.T) { + userRepo := newStubUserRepoForQuota() + err := SetUserSoraQuota(context.Background(), userRepo, 999, 1024) + require.Error(t, err) +} + +// ==================== ParseQuotaBytes ==================== + +func TestParseQuotaBytes(t *testing.T) { + require.Equal(t, int64(1048576), ParseQuotaBytes("1048576")) + require.Equal(t, int64(0), ParseQuotaBytes("")) + require.Equal(t, int64(0), ParseQuotaBytes("abc")) + require.Equal(t, int64(-1), ParseQuotaBytes("-1")) +} + +// ==================== 优先级完整测试 ==================== + +func TestQuotaPriority_UserOverridesGroup(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ + ID: 1, + SoraStorageQuotaBytes: 5 * 1024 * 1024, + AllowedGroups: []int64{10}, + } + + groupRepo := newStubGroupRepoForQuota() + groupRepo.groups[10] = &Group{ID: 10, SoraStorageQuotaBytes: 20 * 1024 * 1024} + + svc := NewSoraQuotaService(userRepo, groupRepo, nil) + quota, err := svc.GetQuota(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, "user", quota.Source) // 用户级优先 + require.Equal(t, int64(5*1024*1024), quota.QuotaBytes) +} + +func TestQuotaPriority_GroupOverridesSystem(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ + ID: 1, + SoraStorageQuotaBytes: 0, + AllowedGroups: []int64{10}, + } + + groupRepo := newStubGroupRepoForQuota() + groupRepo.groups[10] = &Group{ID: 10, SoraStorageQuotaBytes: 20 * 1024 * 1024} + + settingRepo := newStubSettingRepoForQuota(map[string]string{ + SettingKeySoraDefaultStorageQuotaBytes: "104857600", // 100MB + }) + settingService := NewSettingService(settingRepo, &config.Config{}) + + svc := NewSoraQuotaService(userRepo, groupRepo, settingService) + quota, err := svc.GetQuota(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, "group", quota.Source) // 分组级优先于系统 + require.Equal(t, int64(20*1024*1024), quota.QuotaBytes) +} + +func TestQuotaPriority_FallbackToSystem(t *testing.T) { + userRepo := newStubUserRepoForQuota() + userRepo.users[1] = &User{ + ID: 1, + SoraStorageQuotaBytes: 0, + AllowedGroups: []int64{10}, + } + + groupRepo := newStubGroupRepoForQuota() + groupRepo.groups[10] = &Group{ID: 10, SoraStorageQuotaBytes: 0} // 分组无配额 + + settingRepo := newStubSettingRepoForQuota(map[string]string{ + SettingKeySoraDefaultStorageQuotaBytes: "52428800", // 50MB + }) + settingService := NewSettingService(settingRepo, &config.Config{}) + + svc := NewSoraQuotaService(userRepo, groupRepo, settingService) + quota, err := svc.GetQuota(context.Background(), 1) + require.NoError(t, err) + require.Equal(t, "system", quota.Source) + require.Equal(t, int64(52428800), quota.QuotaBytes) +} diff --git a/backend/internal/service/sora_s3_storage.go b/backend/internal/service/sora_s3_storage.go new file mode 100644 index 00000000..4c573905 --- /dev/null +++ b/backend/internal/service/sora_s3_storage.go @@ -0,0 +1,392 @@ +package service + +import ( + "context" + "fmt" + "io" + "net/http" + "path" + "strings" + "sync" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/google/uuid" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" +) + +// SoraS3Storage 负责 Sora 媒体文件的 S3 存储操作。 +// 从 Settings 表读取 S3 配置,初始化并缓存 S3 客户端。 +type SoraS3Storage struct { + settingService *SettingService + + mu sync.RWMutex + client *s3.Client + cfg *SoraS3Settings // 上次加载的配置快照 + + healthCheckedAt time.Time + healthErr error + healthTTL time.Duration +} + +const defaultSoraS3HealthTTL = 30 * time.Second + +// UpstreamDownloadError 表示从上游下载媒体失败(包含 HTTP 状态码)。 +type UpstreamDownloadError struct { + StatusCode int +} + +func (e *UpstreamDownloadError) Error() string { + if e == nil { + return "upstream download failed" + } + return fmt.Sprintf("upstream returned %d", e.StatusCode) +} + +// NewSoraS3Storage 创建 S3 存储服务实例。 +func NewSoraS3Storage(settingService *SettingService) *SoraS3Storage { + return &SoraS3Storage{ + settingService: settingService, + healthTTL: defaultSoraS3HealthTTL, + } +} + +// Enabled 返回 S3 存储是否已启用且配置有效。 +func (s *SoraS3Storage) Enabled(ctx context.Context) bool { + cfg, err := s.getConfig(ctx) + if err != nil || cfg == nil { + return false + } + return cfg.Enabled && cfg.Bucket != "" +} + +// getConfig 获取当前 S3 配置(从 settings 表读取)。 +func (s *SoraS3Storage) getConfig(ctx context.Context) (*SoraS3Settings, error) { + if s.settingService == nil { + return nil, fmt.Errorf("setting service not available") + } + return s.settingService.GetSoraS3Settings(ctx) +} + +// getClient 获取或初始化 S3 客户端(带缓存)。 +// 配置变更时调用 RefreshClient 清除缓存。 +func (s *SoraS3Storage) getClient(ctx context.Context) (*s3.Client, *SoraS3Settings, error) { + s.mu.RLock() + if s.client != nil && s.cfg != nil { + client, cfg := s.client, s.cfg + s.mu.RUnlock() + return client, cfg, nil + } + s.mu.RUnlock() + + return s.initClient(ctx) +} + +func (s *SoraS3Storage) initClient(ctx context.Context) (*s3.Client, *SoraS3Settings, error) { + s.mu.Lock() + defer s.mu.Unlock() + + // 双重检查 + if s.client != nil && s.cfg != nil { + return s.client, s.cfg, nil + } + + cfg, err := s.getConfig(ctx) + if err != nil { + return nil, nil, fmt.Errorf("load s3 config: %w", err) + } + if !cfg.Enabled { + return nil, nil, fmt.Errorf("sora s3 storage is disabled") + } + if cfg.Bucket == "" || cfg.AccessKeyID == "" || cfg.SecretAccessKey == "" { + return nil, nil, fmt.Errorf("sora s3 config incomplete: bucket, access_key_id, secret_access_key are required") + } + + client, region, err := buildSoraS3Client(ctx, cfg) + if err != nil { + return nil, nil, err + } + + s.client = client + s.cfg = cfg + logger.LegacyPrintf("service.sora_s3", "[SoraS3] 客户端已初始化 bucket=%s endpoint=%s region=%s", cfg.Bucket, cfg.Endpoint, region) + return client, cfg, nil +} + +// RefreshClient 清除缓存的 S3 客户端,下次使用时重新初始化。 +// 应在系统设置中 S3 配置变更时调用。 +func (s *SoraS3Storage) RefreshClient() { + s.mu.Lock() + defer s.mu.Unlock() + s.client = nil + s.cfg = nil + s.healthCheckedAt = time.Time{} + s.healthErr = nil + logger.LegacyPrintf("service.sora_s3", "[SoraS3] 客户端缓存已清除,下次使用将重新初始化") +} + +// TestConnection 测试 S3 连接(HeadBucket)。 +func (s *SoraS3Storage) TestConnection(ctx context.Context) error { + client, cfg, err := s.getClient(ctx) + if err != nil { + return err + } + _, err = client.HeadBucket(ctx, &s3.HeadBucketInput{ + Bucket: &cfg.Bucket, + }) + if err != nil { + return fmt.Errorf("s3 HeadBucket failed: %w", err) + } + return nil +} + +// IsHealthy 返回 S3 健康状态(带短缓存,避免每次请求都触发 HeadBucket)。 +func (s *SoraS3Storage) IsHealthy(ctx context.Context) bool { + if s == nil { + return false + } + now := time.Now() + s.mu.RLock() + lastCheck := s.healthCheckedAt + lastErr := s.healthErr + ttl := s.healthTTL + s.mu.RUnlock() + + if ttl <= 0 { + ttl = defaultSoraS3HealthTTL + } + if !lastCheck.IsZero() && now.Sub(lastCheck) < ttl { + return lastErr == nil + } + + err := s.TestConnection(ctx) + s.mu.Lock() + s.healthCheckedAt = time.Now() + s.healthErr = err + s.mu.Unlock() + return err == nil +} + +// TestConnectionWithSettings 使用临时配置测试连接,不污染缓存的客户端。 +func (s *SoraS3Storage) TestConnectionWithSettings(ctx context.Context, cfg *SoraS3Settings) error { + if cfg == nil { + return fmt.Errorf("s3 config is required") + } + if !cfg.Enabled { + return fmt.Errorf("sora s3 storage is disabled") + } + if cfg.Endpoint == "" || cfg.Bucket == "" || cfg.AccessKeyID == "" || cfg.SecretAccessKey == "" { + return fmt.Errorf("sora s3 config incomplete: endpoint, bucket, access_key_id, secret_access_key are required") + } + client, _, err := buildSoraS3Client(ctx, cfg) + if err != nil { + return err + } + _, err = client.HeadBucket(ctx, &s3.HeadBucketInput{ + Bucket: &cfg.Bucket, + }) + if err != nil { + return fmt.Errorf("s3 HeadBucket failed: %w", err) + } + return nil +} + +// GenerateObjectKey 生成 S3 object key。 +// 格式: {prefix}sora/{userID}/{YYYY/MM/DD}/{uuid}.{ext} +func (s *SoraS3Storage) GenerateObjectKey(prefix string, userID int64, ext string) string { + if !strings.HasPrefix(ext, ".") { + ext = "." + ext + } + datePath := time.Now().Format("2006/01/02") + key := fmt.Sprintf("sora/%d/%s/%s%s", userID, datePath, uuid.NewString(), ext) + if prefix != "" { + prefix = strings.TrimRight(prefix, "/") + "/" + key = prefix + key + } + return key +} + +// UploadFromURL 从上游 URL 下载并流式上传到 S3。 +// 返回 S3 object key。 +func (s *SoraS3Storage) UploadFromURL(ctx context.Context, userID int64, sourceURL string) (string, int64, error) { + client, cfg, err := s.getClient(ctx) + if err != nil { + return "", 0, err + } + + // 下载源文件 + req, err := http.NewRequestWithContext(ctx, http.MethodGet, sourceURL, nil) + if err != nil { + return "", 0, fmt.Errorf("create download request: %w", err) + } + httpClient := &http.Client{Timeout: 5 * time.Minute} + resp, err := httpClient.Do(req) + if err != nil { + return "", 0, fmt.Errorf("download from upstream: %w", err) + } + defer func() { + _ = resp.Body.Close() + }() + + if resp.StatusCode != http.StatusOK { + return "", 0, &UpstreamDownloadError{StatusCode: resp.StatusCode} + } + + // 推断文件扩展名 + ext := fileExtFromURL(sourceURL) + if ext == "" { + ext = fileExtFromContentType(resp.Header.Get("Content-Type")) + } + if ext == "" { + ext = ".bin" + } + + objectKey := s.GenerateObjectKey(cfg.Prefix, userID, ext) + + // 检测 Content-Type + contentType := resp.Header.Get("Content-Type") + if contentType == "" { + contentType = "application/octet-stream" + } + + reader, writer := io.Pipe() + uploadErrCh := make(chan error, 1) + go func() { + defer close(uploadErrCh) + input := &s3.PutObjectInput{ + Bucket: &cfg.Bucket, + Key: &objectKey, + Body: reader, + ContentType: &contentType, + } + if resp.ContentLength >= 0 { + input.ContentLength = &resp.ContentLength + } + _, uploadErr := client.PutObject(ctx, input) + uploadErrCh <- uploadErr + }() + + written, copyErr := io.CopyBuffer(writer, resp.Body, make([]byte, 1024*1024)) + _ = writer.CloseWithError(copyErr) + uploadErr := <-uploadErrCh + if copyErr != nil { + return "", 0, fmt.Errorf("stream upload copy failed: %w", copyErr) + } + if uploadErr != nil { + return "", 0, fmt.Errorf("s3 upload: %w", uploadErr) + } + + logger.LegacyPrintf("service.sora_s3", "[SoraS3] 上传完成 key=%s size=%d", objectKey, written) + return objectKey, written, nil +} + +func buildSoraS3Client(ctx context.Context, cfg *SoraS3Settings) (*s3.Client, string, error) { + if cfg == nil { + return nil, "", fmt.Errorf("s3 config is required") + } + region := cfg.Region + if region == "" { + region = "us-east-1" + } + + awsCfg, err := awsconfig.LoadDefaultConfig(ctx, + awsconfig.WithRegion(region), + awsconfig.WithCredentialsProvider( + credentials.NewStaticCredentialsProvider(cfg.AccessKeyID, cfg.SecretAccessKey, ""), + ), + ) + if err != nil { + return nil, "", fmt.Errorf("load aws config: %w", err) + } + + client := s3.NewFromConfig(awsCfg, func(o *s3.Options) { + if cfg.Endpoint != "" { + o.BaseEndpoint = &cfg.Endpoint + } + if cfg.ForcePathStyle { + o.UsePathStyle = true + } + o.APIOptions = append(o.APIOptions, v4.SwapComputePayloadSHA256ForUnsignedPayloadMiddleware) + // 兼容非 TLS 连接(如 MinIO)的流式上传,避免 io.Pipe checksum 校验失败 + o.RequestChecksumCalculation = aws.RequestChecksumCalculationWhenRequired + }) + return client, region, nil +} + +// DeleteObjects 删除一组 S3 object(遍历逐一删除)。 +func (s *SoraS3Storage) DeleteObjects(ctx context.Context, objectKeys []string) error { + if len(objectKeys) == 0 { + return nil + } + + client, cfg, err := s.getClient(ctx) + if err != nil { + return err + } + + var lastErr error + for _, key := range objectKeys { + k := key + _, err := client.DeleteObject(ctx, &s3.DeleteObjectInput{ + Bucket: &cfg.Bucket, + Key: &k, + }) + if err != nil { + logger.LegacyPrintf("service.sora_s3", "[SoraS3] 删除失败 key=%s err=%v", key, err) + lastErr = err + } + } + return lastErr +} + +// GetAccessURL 获取 S3 文件的访问 URL。 +// CDN URL 优先,否则生成 24h 预签名 URL。 +func (s *SoraS3Storage) GetAccessURL(ctx context.Context, objectKey string) (string, error) { + _, cfg, err := s.getClient(ctx) + if err != nil { + return "", err + } + + // CDN URL 优先 + if cfg.CDNURL != "" { + cdnBase := strings.TrimRight(cfg.CDNURL, "/") + return cdnBase + "/" + objectKey, nil + } + + // 生成 24h 预签名 URL + return s.GeneratePresignedURL(ctx, objectKey, 24*time.Hour) +} + +// GeneratePresignedURL 生成预签名 URL。 +func (s *SoraS3Storage) GeneratePresignedURL(ctx context.Context, objectKey string, ttl time.Duration) (string, error) { + client, cfg, err := s.getClient(ctx) + if err != nil { + return "", err + } + + presignClient := s3.NewPresignClient(client) + result, err := presignClient.PresignGetObject(ctx, &s3.GetObjectInput{ + Bucket: &cfg.Bucket, + Key: &objectKey, + }, s3.WithPresignExpires(ttl)) + if err != nil { + return "", fmt.Errorf("presign url: %w", err) + } + return result.URL, nil +} + +// GetMediaType 从 object key 推断媒体类型(image/video)。 +func GetMediaTypeFromKey(objectKey string) string { + ext := strings.ToLower(path.Ext(objectKey)) + switch ext { + case ".mp4", ".mov", ".webm", ".m4v", ".avi", ".mkv", ".3gp", ".flv": + return "video" + default: + return "image" + } +} diff --git a/backend/internal/service/sora_s3_storage_test.go b/backend/internal/service/sora_s3_storage_test.go new file mode 100644 index 00000000..32ff9a6f --- /dev/null +++ b/backend/internal/service/sora_s3_storage_test.go @@ -0,0 +1,263 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +// ==================== RefreshClient ==================== + +func TestRefreshClient(t *testing.T) { + s := newS3StorageWithCDN("https://cdn.example.com") + require.NotNil(t, s.client) + require.NotNil(t, s.cfg) + + s.RefreshClient() + require.Nil(t, s.client) + require.Nil(t, s.cfg) +} + +func TestRefreshClient_AlreadyNil(t *testing.T) { + s := NewSoraS3Storage(nil) + s.RefreshClient() // 不应 panic + require.Nil(t, s.client) + require.Nil(t, s.cfg) +} + +// ==================== GetMediaTypeFromKey ==================== + +func TestGetMediaTypeFromKey_VideoExtensions(t *testing.T) { + for _, ext := range []string{".mp4", ".mov", ".webm", ".m4v", ".avi", ".mkv", ".3gp", ".flv"} { + require.Equal(t, "video", GetMediaTypeFromKey("path/to/file"+ext), "ext=%s", ext) + } +} + +func TestGetMediaTypeFromKey_VideoUpperCase(t *testing.T) { + require.Equal(t, "video", GetMediaTypeFromKey("file.MP4")) + require.Equal(t, "video", GetMediaTypeFromKey("file.MOV")) +} + +func TestGetMediaTypeFromKey_ImageExtensions(t *testing.T) { + require.Equal(t, "image", GetMediaTypeFromKey("file.png")) + require.Equal(t, "image", GetMediaTypeFromKey("file.jpg")) + require.Equal(t, "image", GetMediaTypeFromKey("file.jpeg")) + require.Equal(t, "image", GetMediaTypeFromKey("file.gif")) + require.Equal(t, "image", GetMediaTypeFromKey("file.webp")) +} + +func TestGetMediaTypeFromKey_NoExtension(t *testing.T) { + require.Equal(t, "image", GetMediaTypeFromKey("file")) + require.Equal(t, "image", GetMediaTypeFromKey("path/to/file")) +} + +func TestGetMediaTypeFromKey_UnknownExtension(t *testing.T) { + require.Equal(t, "image", GetMediaTypeFromKey("file.bin")) + require.Equal(t, "image", GetMediaTypeFromKey("file.xyz")) +} + +// ==================== Enabled ==================== + +func TestEnabled_NilSettingService(t *testing.T) { + s := NewSoraS3Storage(nil) + require.False(t, s.Enabled(context.Background())) +} + +func TestEnabled_ConfigDisabled(t *testing.T) { + settingRepo := newStubSettingRepoForQuota(map[string]string{ + SettingKeySoraS3Enabled: "false", + SettingKeySoraS3Bucket: "test-bucket", + }) + settingService := NewSettingService(settingRepo, &config.Config{}) + s := NewSoraS3Storage(settingService) + require.False(t, s.Enabled(context.Background())) +} + +func TestEnabled_ConfigEnabledWithBucket(t *testing.T) { + settingRepo := newStubSettingRepoForQuota(map[string]string{ + SettingKeySoraS3Enabled: "true", + SettingKeySoraS3Bucket: "my-bucket", + }) + settingService := NewSettingService(settingRepo, &config.Config{}) + s := NewSoraS3Storage(settingService) + require.True(t, s.Enabled(context.Background())) +} + +func TestEnabled_ConfigEnabledEmptyBucket(t *testing.T) { + settingRepo := newStubSettingRepoForQuota(map[string]string{ + SettingKeySoraS3Enabled: "true", + }) + settingService := NewSettingService(settingRepo, &config.Config{}) + s := NewSoraS3Storage(settingService) + require.False(t, s.Enabled(context.Background())) +} + +// ==================== initClient ==================== + +func TestInitClient_Disabled(t *testing.T) { + settingRepo := newStubSettingRepoForQuota(map[string]string{ + SettingKeySoraS3Enabled: "false", + }) + settingService := NewSettingService(settingRepo, &config.Config{}) + s := NewSoraS3Storage(settingService) + + _, _, err := s.getClient(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "disabled") +} + +func TestInitClient_IncompleteConfig(t *testing.T) { + settingRepo := newStubSettingRepoForQuota(map[string]string{ + SettingKeySoraS3Enabled: "true", + SettingKeySoraS3Bucket: "test-bucket", + // 缺少 access_key_id 和 secret_access_key + }) + settingService := NewSettingService(settingRepo, &config.Config{}) + s := NewSoraS3Storage(settingService) + + _, _, err := s.getClient(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "incomplete") +} + +func TestInitClient_DefaultRegion(t *testing.T) { + settingRepo := newStubSettingRepoForQuota(map[string]string{ + SettingKeySoraS3Enabled: "true", + SettingKeySoraS3Bucket: "test-bucket", + SettingKeySoraS3AccessKeyID: "AKID", + SettingKeySoraS3SecretAccessKey: "SECRET", + // Region 为空 → 默认 us-east-1 + }) + settingService := NewSettingService(settingRepo, &config.Config{}) + s := NewSoraS3Storage(settingService) + + client, cfg, err := s.getClient(context.Background()) + require.NoError(t, err) + require.NotNil(t, client) + require.Equal(t, "test-bucket", cfg.Bucket) +} + +func TestInitClient_DoubleCheck(t *testing.T) { + // 验证双重检查锁定:第二次 getClient 命中缓存 + settingRepo := newStubSettingRepoForQuota(map[string]string{ + SettingKeySoraS3Enabled: "true", + SettingKeySoraS3Bucket: "test-bucket", + SettingKeySoraS3AccessKeyID: "AKID", + SettingKeySoraS3SecretAccessKey: "SECRET", + }) + settingService := NewSettingService(settingRepo, &config.Config{}) + s := NewSoraS3Storage(settingService) + + client1, _, err1 := s.getClient(context.Background()) + require.NoError(t, err1) + client2, _, err2 := s.getClient(context.Background()) + require.NoError(t, err2) + require.Equal(t, client1, client2) // 同一客户端实例 +} + +func TestInitClient_NilSettingService(t *testing.T) { + s := NewSoraS3Storage(nil) + _, _, err := s.getClient(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "setting service not available") +} + +// ==================== GenerateObjectKey ==================== + +func TestGenerateObjectKey_ExtWithoutDot(t *testing.T) { + s := NewSoraS3Storage(nil) + key := s.GenerateObjectKey("", 1, "mp4") + require.Contains(t, key, ".mp4") + require.True(t, len(key) > 0) +} + +func TestGenerateObjectKey_ExtWithDot(t *testing.T) { + s := NewSoraS3Storage(nil) + key := s.GenerateObjectKey("", 1, ".mp4") + require.Contains(t, key, ".mp4") + // 不应出现 ..mp4 + require.NotContains(t, key, "..mp4") +} + +func TestGenerateObjectKey_WithPrefix(t *testing.T) { + s := NewSoraS3Storage(nil) + key := s.GenerateObjectKey("uploads/", 42, ".png") + require.True(t, len(key) > 0) + require.Contains(t, key, "uploads/sora/42/") +} + +func TestGenerateObjectKey_PrefixWithoutTrailingSlash(t *testing.T) { + s := NewSoraS3Storage(nil) + key := s.GenerateObjectKey("uploads", 42, ".png") + require.Contains(t, key, "uploads/sora/42/") +} + +// ==================== GeneratePresignedURL ==================== + +func TestGeneratePresignedURL_GetClientError(t *testing.T) { + s := NewSoraS3Storage(nil) // settingService=nil → getClient 失败 + _, err := s.GeneratePresignedURL(context.Background(), "key", 3600) + require.Error(t, err) +} + +// ==================== GetAccessURL ==================== + +func TestGetAccessURL_CDN(t *testing.T) { + s := newS3StorageWithCDN("https://cdn.example.com") + url, err := s.GetAccessURL(context.Background(), "sora/1/2024/01/01/video.mp4") + require.NoError(t, err) + require.Equal(t, "https://cdn.example.com/sora/1/2024/01/01/video.mp4", url) +} + +func TestGetAccessURL_CDNTrailingSlash(t *testing.T) { + s := newS3StorageWithCDN("https://cdn.example.com/") + url, err := s.GetAccessURL(context.Background(), "key.mp4") + require.NoError(t, err) + require.Equal(t, "https://cdn.example.com/key.mp4", url) +} + +func TestGetAccessURL_GetClientError(t *testing.T) { + s := NewSoraS3Storage(nil) + _, err := s.GetAccessURL(context.Background(), "key") + require.Error(t, err) +} + +// ==================== TestConnection ==================== + +func TestTestConnection_GetClientError(t *testing.T) { + s := NewSoraS3Storage(nil) + err := s.TestConnection(context.Background()) + require.Error(t, err) +} + +// ==================== UploadFromURL ==================== + +func TestUploadFromURL_GetClientError(t *testing.T) { + s := NewSoraS3Storage(nil) + _, _, err := s.UploadFromURL(context.Background(), 1, "https://example.com/file.mp4") + require.Error(t, err) +} + +// ==================== DeleteObjects ==================== + +func TestDeleteObjects_EmptyKeys(t *testing.T) { + s := NewSoraS3Storage(nil) + err := s.DeleteObjects(context.Background(), []string{}) + require.NoError(t, err) // 空列表直接返回 +} + +func TestDeleteObjects_NilKeys(t *testing.T) { + s := NewSoraS3Storage(nil) + err := s.DeleteObjects(context.Background(), nil) + require.NoError(t, err) // nil 列表直接返回 +} + +func TestDeleteObjects_GetClientError(t *testing.T) { + s := NewSoraS3Storage(nil) + err := s.DeleteObjects(context.Background(), []string{"key1", "key2"}) + require.Error(t, err) +} diff --git a/backend/internal/service/sora_sdk_client.go b/backend/internal/service/sora_sdk_client.go index 604c2749..f9221c5b 100644 --- a/backend/internal/service/sora_sdk_client.go +++ b/backend/internal/service/sora_sdk_client.go @@ -15,6 +15,7 @@ import ( "github.com/DouDOU-start/go-sora2api/sora" "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" openaioauth "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/Wei-Shaw/sub2api/internal/util/logredact" "github.com/tidwall/gjson" @@ -75,6 +76,17 @@ func (c *SoraSDKClient) PreflightCheck(ctx context.Context, account *Account, re } balance, err := sdkClient.GetCreditBalance(ctx, token) if err != nil { + accountID := int64(0) + if account != nil { + accountID = account.ID + } + logger.LegacyPrintf( + "service.sora_sdk", + "[PreflightCheckRawError] account_id=%d model=%s op=get_credit_balance raw_err=%s", + accountID, + requestedModel, + logredact.RedactText(err.Error()), + ) return &SoraUpstreamError{ StatusCode: http.StatusForbidden, Message: "当前账号未开通 Sora2 能力或无可用配额", @@ -170,9 +182,23 @@ func (c *SoraSDKClient) CreateVideoTask(ctx context.Context, account *Account, r if size == "" { size = "small" } + videoCount := req.VideoCount + if videoCount <= 0 { + videoCount = 1 + } + if videoCount > 3 { + videoCount = 3 + } // Remix 模式 if strings.TrimSpace(req.RemixTargetID) != "" { + if videoCount > 1 { + accountID := int64(0) + if account != nil { + accountID = account.ID + } + c.debugLogf("video_count_ignored_for_remix account_id=%d count=%d", accountID, videoCount) + } styleID := "" // SDK ExtractStyle 可从 prompt 中提取 taskID, err := sdkClient.RemixVideo(ctx, token, sentinel, req.RemixTargetID, req.Prompt, orientation, nFrames, styleID) if err != nil { @@ -182,13 +208,60 @@ func (c *SoraSDKClient) CreateVideoTask(ctx context.Context, account *Account, r } // 普通视频(文生视频或图生视频) - taskID, err := sdkClient.CreateVideoTaskWithOptions(ctx, token, sentinel, req.Prompt, orientation, nFrames, model, size, req.MediaID, "") + var taskID string + if videoCount <= 1 { + taskID, err = sdkClient.CreateVideoTaskWithOptions(ctx, token, sentinel, req.Prompt, orientation, nFrames, model, size, req.MediaID, "") + } else { + taskID, err = c.createVideoTaskWithVariants(ctx, account, token, sentinel, req.Prompt, orientation, nFrames, model, size, req.MediaID, videoCount) + } if err != nil { return "", c.wrapSDKError(err, account) } return taskID, nil } +func (c *SoraSDKClient) createVideoTaskWithVariants( + ctx context.Context, + account *Account, + accessToken string, + sentinelToken string, + prompt string, + orientation string, + nFrames int, + model string, + size string, + mediaID string, + videoCount int, +) (string, error) { + inpaintItems := make([]any, 0, 1) + if strings.TrimSpace(mediaID) != "" { + inpaintItems = append(inpaintItems, map[string]any{ + "kind": "upload", + "upload_id": mediaID, + }) + } + payload := map[string]any{ + "kind": "video", + "prompt": prompt, + "orientation": orientation, + "size": size, + "n_frames": nFrames, + "n_variants": videoCount, + "model": model, + "inpaint_items": inpaintItems, + "style_id": nil, + } + raw, err := c.doSoraBackendJSON(ctx, account, http.MethodPost, "/nf/create", accessToken, sentinelToken, payload) + if err != nil { + return "", err + } + taskID := strings.TrimSpace(gjson.GetBytes(raw, "id").String()) + if taskID == "" { + return "", errors.New("create video task response missing id") + } + return taskID, nil +} + func (c *SoraSDKClient) CreateStoryboardTask(ctx context.Context, account *Account, req SoraStoryboardRequest) (string, error) { token, err := c.getAccessToken(ctx, account) if err != nil { @@ -512,7 +585,7 @@ func (c *SoraSDKClient) GetVideoTask(ctx context.Context, account *Account, task } // 任务不在 pending 中,查询 drafts 获取下载链接 - downloadURL, err := sdkClient.GetDownloadURL(ctx, token, taskID) + downloadURLs, err := c.getVideoTaskDownloadURLs(ctx, account, token, taskID) if err != nil { errMsg := err.Error() if strings.Contains(errMsg, "内容违规") || strings.Contains(errMsg, "Content violates") { @@ -528,13 +601,147 @@ func (c *SoraSDKClient) GetVideoTask(ctx context.Context, account *Account, task Status: "processing", }, nil } + if len(downloadURLs) == 0 { + return &SoraVideoTaskStatus{ + ID: taskID, + Status: "processing", + }, nil + } return &SoraVideoTaskStatus{ ID: taskID, Status: "completed", - URLs: []string{downloadURL}, + URLs: downloadURLs, }, nil } +func (c *SoraSDKClient) getVideoTaskDownloadURLs(ctx context.Context, account *Account, accessToken, taskID string) ([]string, error) { + raw, err := c.doSoraBackendJSON(ctx, account, http.MethodGet, "/project_y/profile/drafts?limit=30", accessToken, "", nil) + if err != nil { + return nil, err + } + items := gjson.GetBytes(raw, "items") + if !items.Exists() || !items.IsArray() { + return nil, fmt.Errorf("drafts response missing items for task %s", taskID) + } + urlSet := make(map[string]struct{}, 4) + urls := make([]string, 0, 4) + items.ForEach(func(_, item gjson.Result) bool { + if strings.TrimSpace(item.Get("task_id").String()) != taskID { + return true + } + kind := strings.TrimSpace(item.Get("kind").String()) + reason := strings.TrimSpace(item.Get("reason_str").String()) + markdownReason := strings.TrimSpace(item.Get("markdown_reason_str").String()) + if kind == "sora_content_violation" || reason != "" || markdownReason != "" { + if reason == "" { + reason = markdownReason + } + if reason == "" { + reason = "内容违规" + } + err = fmt.Errorf("内容违规: %s", reason) + return false + } + url := strings.TrimSpace(item.Get("downloadable_url").String()) + if url == "" { + url = strings.TrimSpace(item.Get("url").String()) + } + if url == "" { + return true + } + if _, exists := urlSet[url]; exists { + return true + } + urlSet[url] = struct{}{} + urls = append(urls, url) + return true + }) + if err != nil { + return nil, err + } + if len(urls) > 0 { + return urls, nil + } + + // 兼容旧 SDK 的兜底逻辑 + sdkClient, sdkErr := c.getSDKClient(account) + if sdkErr != nil { + return nil, sdkErr + } + downloadURL, sdkErr := sdkClient.GetDownloadURL(ctx, accessToken, taskID) + if sdkErr != nil { + return nil, sdkErr + } + if strings.TrimSpace(downloadURL) == "" { + return nil, nil + } + return []string{downloadURL}, nil +} + +func (c *SoraSDKClient) doSoraBackendJSON( + ctx context.Context, + account *Account, + method string, + path string, + accessToken string, + sentinelToken string, + payload map[string]any, +) ([]byte, error) { + endpoint := "https://sora.chatgpt.com/backend" + path + var body io.Reader + if payload != nil { + raw, err := json.Marshal(payload) + if err != nil { + return nil, err + } + body = bytes.NewReader(raw) + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint, body) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+accessToken) + req.Header.Set("Accept", "application/json, text/plain, */*") + req.Header.Set("Origin", "https://sora.chatgpt.com") + req.Header.Set("Referer", "https://sora.chatgpt.com/") + req.Header.Set("User-Agent", "Sora/1.2026.007 (Android 15; 24122RKC7C; build 2600700)") + if payload != nil { + req.Header.Set("Content-Type", "application/json") + } + if strings.TrimSpace(sentinelToken) != "" { + req.Header.Set("openai-sentinel-token", sentinelToken) + } + + proxyURL := c.resolveProxyURL(account) + accountID := int64(0) + accountConcurrency := 0 + if account != nil { + accountID = account.ID + accountConcurrency = account.Concurrency + } + + var resp *http.Response + if c.httpUpstream != nil { + resp, err = c.httpUpstream.Do(req, proxyURL, accountID, accountConcurrency) + } else { + resp, err = http.DefaultClient.Do(req) + } + if err != nil { + return nil, err + } + defer func() { _ = resp.Body.Close() }() + + raw, err := io.ReadAll(io.LimitReader(resp.Body, 4<<20)) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { + return nil, fmt.Errorf("HTTP %d: %s", resp.StatusCode, truncateForLog(raw, 256)) + } + return raw, nil +} + // --- 内部方法 --- // getSDKClient 获取或创建指定代理的 SDK 客户端实例 @@ -791,6 +998,17 @@ func (c *SoraSDKClient) wrapSDKError(err error, account *Account) error { } else if strings.Contains(msg, "HTTP 404") { statusCode = http.StatusNotFound } + accountID := int64(0) + if account != nil { + accountID = account.ID + } + logger.LegacyPrintf( + "service.sora_sdk", + "[WrapSDKError] account_id=%d mapped_status=%d raw_err=%s", + accountID, + statusCode, + logredact.RedactText(msg), + ) return &SoraUpstreamError{ StatusCode: statusCode, Message: msg, diff --git a/backend/internal/service/sora_upstream_forwarder.go b/backend/internal/service/sora_upstream_forwarder.go new file mode 100644 index 00000000..cdf9570b --- /dev/null +++ b/backend/internal/service/sora_upstream_forwarder.go @@ -0,0 +1,149 @@ +package service + +import ( + "bytes" + "context" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/gin-gonic/gin" +) + +// forwardToUpstream 将请求 HTTP 透传到上游 Sora 服务(用于 apikey 类型账号)。 +// 上游地址为 account.GetBaseURL() + "/sora/v1/chat/completions", +// 使用 account.GetCredential("api_key") 作为 Bearer Token。 +// 支持流式和非流式响应的直接透传。 +func (s *SoraGatewayService) forwardToUpstream( + ctx context.Context, + c *gin.Context, + account *Account, + body []byte, + clientStream bool, + startTime time.Time, +) (*ForwardResult, error) { + apiKey := account.GetCredential("api_key") + if apiKey == "" { + s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey account missing api_key credential", clientStream) + return nil, fmt.Errorf("sora apikey account %d missing api_key", account.ID) + } + + baseURL := account.GetBaseURL() + if baseURL == "" { + s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey account missing base_url", clientStream) + return nil, fmt.Errorf("sora apikey account %d missing base_url", account.ID) + } + // 校验 scheme 合法性(仅允许 http/https) + if !strings.HasPrefix(baseURL, "http://") && !strings.HasPrefix(baseURL, "https://") { + s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Sora apikey base_url must start with http:// or https://", clientStream) + return nil, fmt.Errorf("sora apikey account %d invalid base_url scheme: %s", account.ID, baseURL) + } + upstreamURL := strings.TrimRight(baseURL, "/") + "/sora/v1/chat/completions" + + // 构建上游请求 + upstreamReq, err := http.NewRequestWithContext(ctx, http.MethodPost, upstreamURL, bytes.NewReader(body)) + if err != nil { + s.writeSoraError(c, http.StatusInternalServerError, "api_error", "Failed to create upstream request", clientStream) + return nil, fmt.Errorf("create upstream request: %w", err) + } + + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Authorization", "Bearer "+apiKey) + + // 透传客户端的部分请求头 + for _, header := range []string{"Accept", "Accept-Encoding"} { + if v := c.GetHeader(header); v != "" { + upstreamReq.Header.Set(header, v) + } + } + + logger.LegacyPrintf("service.sora", "[ForwardUpstream] account=%d url=%s", account.ID, upstreamURL) + + // 获取代理 URL + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + + // 发送请求 + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + if err != nil { + s.writeSoraError(c, http.StatusBadGateway, "upstream_error", "Failed to connect to upstream Sora service", clientStream) + return nil, &UpstreamFailoverError{ + StatusCode: http.StatusBadGateway, + } + } + defer func() { + _ = resp.Body.Close() + }() + + // 错误响应处理 + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 64*1024)) + + if s.shouldFailoverUpstreamError(resp.StatusCode) { + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + ResponseHeaders: resp.Header.Clone(), + } + } + + // 非转移错误,直接透传给客户端 + c.Status(resp.StatusCode) + for key, values := range resp.Header { + for _, v := range values { + c.Writer.Header().Add(key, v) + } + } + if _, err := c.Writer.Write(respBody); err != nil { + return nil, fmt.Errorf("write upstream error response: %w", err) + } + return nil, fmt.Errorf("upstream error: %d", resp.StatusCode) + } + + // 成功响应 — 直接透传 + c.Status(resp.StatusCode) + for key, values := range resp.Header { + lower := strings.ToLower(key) + // 透传内容相关头部 + if lower == "content-type" || lower == "transfer-encoding" || + lower == "cache-control" || lower == "x-request-id" { + for _, v := range values { + c.Writer.Header().Add(key, v) + } + } + } + + // 流式复制响应体 + if flusher, ok := c.Writer.(http.Flusher); ok && clientStream { + buf := make([]byte, 4096) + for { + n, readErr := resp.Body.Read(buf) + if n > 0 { + if _, err := c.Writer.Write(buf[:n]); err != nil { + return nil, fmt.Errorf("stream upstream response write: %w", err) + } + flusher.Flush() + } + if readErr != nil { + break + } + } + } else { + if _, err := io.Copy(c.Writer, resp.Body); err != nil { + return nil, fmt.Errorf("copy upstream response: %w", err) + } + } + + duration := time.Since(startTime) + return &ForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Model: "", // 由调用方填充 + Stream: clientStream, + Duration: duration, + }, nil +} diff --git a/backend/internal/service/usage_cleanup.go b/backend/internal/service/usage_cleanup.go index 7e3ffbb9..6e32f3c0 100644 --- a/backend/internal/service/usage_cleanup.go +++ b/backend/internal/service/usage_cleanup.go @@ -33,6 +33,7 @@ type UsageCleanupFilters struct { AccountID *int64 `json:"account_id,omitempty"` GroupID *int64 `json:"group_id,omitempty"` Model *string `json:"model,omitempty"` + RequestType *int16 `json:"request_type,omitempty"` Stream *bool `json:"stream,omitempty"` BillingType *int8 `json:"billing_type,omitempty"` } diff --git a/backend/internal/service/usage_cleanup_service.go b/backend/internal/service/usage_cleanup_service.go index ee795aa4..5600542e 100644 --- a/backend/internal/service/usage_cleanup_service.go +++ b/backend/internal/service/usage_cleanup_service.go @@ -68,6 +68,9 @@ func describeUsageCleanupFilters(filters UsageCleanupFilters) string { if filters.Model != nil { parts = append(parts, "model="+strings.TrimSpace(*filters.Model)) } + if filters.RequestType != nil { + parts = append(parts, "request_type="+RequestTypeFromInt16(*filters.RequestType).String()) + } if filters.Stream != nil { parts = append(parts, fmt.Sprintf("stream=%t", *filters.Stream)) } @@ -368,6 +371,16 @@ func sanitizeUsageCleanupFilters(filters *UsageCleanupFilters) { filters.Model = &model } } + if filters.RequestType != nil { + requestType := RequestType(*filters.RequestType) + if !requestType.IsValid() { + filters.RequestType = nil + } else { + value := int16(requestType.Normalize()) + filters.RequestType = &value + filters.Stream = nil + } + } if filters.BillingType != nil && *filters.BillingType < 0 { filters.BillingType = nil } diff --git a/backend/internal/service/usage_cleanup_service_test.go b/backend/internal/service/usage_cleanup_service_test.go index 1f9f4776..0fdbfd47 100644 --- a/backend/internal/service/usage_cleanup_service_test.go +++ b/backend/internal/service/usage_cleanup_service_test.go @@ -257,6 +257,53 @@ func TestUsageCleanupServiceCreateTaskSanitizeFilters(t *testing.T) { require.Equal(t, int64(9), task.CreatedBy) } +func TestSanitizeUsageCleanupFiltersRequestTypePriority(t *testing.T) { + requestType := int16(RequestTypeWSV2) + stream := false + model := " gpt-5 " + filters := UsageCleanupFilters{ + Model: &model, + RequestType: &requestType, + Stream: &stream, + } + + sanitizeUsageCleanupFilters(&filters) + + require.NotNil(t, filters.RequestType) + require.Equal(t, int16(RequestTypeWSV2), *filters.RequestType) + require.Nil(t, filters.Stream) + require.NotNil(t, filters.Model) + require.Equal(t, "gpt-5", *filters.Model) +} + +func TestSanitizeUsageCleanupFiltersInvalidRequestType(t *testing.T) { + requestType := int16(99) + stream := true + filters := UsageCleanupFilters{ + RequestType: &requestType, + Stream: &stream, + } + + sanitizeUsageCleanupFilters(&filters) + + require.Nil(t, filters.RequestType) + require.NotNil(t, filters.Stream) + require.True(t, *filters.Stream) +} + +func TestDescribeUsageCleanupFiltersIncludesRequestType(t *testing.T) { + start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + requestType := int16(RequestTypeWSV2) + desc := describeUsageCleanupFilters(UsageCleanupFilters{ + StartTime: start, + EndTime: end, + RequestType: &requestType, + }) + + require.Contains(t, desc, "request_type=ws_v2") +} + func TestUsageCleanupServiceCreateTaskInvalidCreator(t *testing.T) { repo := &cleanupRepoStub{} cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}} diff --git a/backend/internal/service/usage_log.go b/backend/internal/service/usage_log.go index f9824183..c1a95541 100644 --- a/backend/internal/service/usage_log.go +++ b/backend/internal/service/usage_log.go @@ -1,12 +1,96 @@ package service -import "time" +import ( + "fmt" + "strings" + "time" +) const ( BillingTypeBalance int8 = 0 // 钱包余额 BillingTypeSubscription int8 = 1 // 订阅套餐 ) +type RequestType int16 + +const ( + RequestTypeUnknown RequestType = 0 + RequestTypeSync RequestType = 1 + RequestTypeStream RequestType = 2 + RequestTypeWSV2 RequestType = 3 +) + +func (t RequestType) IsValid() bool { + switch t { + case RequestTypeUnknown, RequestTypeSync, RequestTypeStream, RequestTypeWSV2: + return true + default: + return false + } +} + +func (t RequestType) Normalize() RequestType { + if t.IsValid() { + return t + } + return RequestTypeUnknown +} + +func (t RequestType) String() string { + switch t.Normalize() { + case RequestTypeSync: + return "sync" + case RequestTypeStream: + return "stream" + case RequestTypeWSV2: + return "ws_v2" + default: + return "unknown" + } +} + +func RequestTypeFromInt16(v int16) RequestType { + return RequestType(v).Normalize() +} + +func ParseUsageRequestType(value string) (RequestType, error) { + switch strings.ToLower(strings.TrimSpace(value)) { + case "unknown": + return RequestTypeUnknown, nil + case "sync": + return RequestTypeSync, nil + case "stream": + return RequestTypeStream, nil + case "ws_v2": + return RequestTypeWSV2, nil + default: + return RequestTypeUnknown, fmt.Errorf("invalid request_type, allowed values: unknown, sync, stream, ws_v2") + } +} + +func RequestTypeFromLegacy(stream bool, openAIWSMode bool) RequestType { + if openAIWSMode { + return RequestTypeWSV2 + } + if stream { + return RequestTypeStream + } + return RequestTypeSync +} + +func ApplyLegacyRequestFields(requestType RequestType, fallbackStream bool, fallbackOpenAIWSMode bool) (stream bool, openAIWSMode bool) { + switch requestType.Normalize() { + case RequestTypeSync: + return false, false + case RequestTypeStream: + return true, false + case RequestTypeWSV2: + return true, true + default: + return fallbackStream, fallbackOpenAIWSMode + } +} + type UsageLog struct { ID int64 UserID int64 @@ -40,7 +124,9 @@ type UsageLog struct { AccountRateMultiplier *float64 BillingType int8 + RequestType RequestType Stream bool + OpenAIWSMode bool DurationMs *int FirstTokenMs *int UserAgent *string @@ -66,3 +152,22 @@ type UsageLog struct { func (u *UsageLog) TotalTokens() int { return u.InputTokens + u.OutputTokens + u.CacheCreationTokens + u.CacheReadTokens } + +func (u *UsageLog) EffectiveRequestType() RequestType { + if u == nil { + return RequestTypeUnknown + } + if normalized := u.RequestType.Normalize(); normalized != RequestTypeUnknown { + return normalized + } + return RequestTypeFromLegacy(u.Stream, u.OpenAIWSMode) +} + +func (u *UsageLog) SyncRequestTypeAndLegacyFields() { + if u == nil { + return + } + requestType := u.EffectiveRequestType() + u.RequestType = requestType + u.Stream, u.OpenAIWSMode = ApplyLegacyRequestFields(requestType, u.Stream, u.OpenAIWSMode) +} diff --git a/backend/internal/service/usage_log_test.go b/backend/internal/service/usage_log_test.go new file mode 100644 index 00000000..280237c2 --- /dev/null +++ b/backend/internal/service/usage_log_test.go @@ -0,0 +1,112 @@ +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseUsageRequestType(t *testing.T) { + t.Parallel() + + type testCase struct { + name string + input string + want RequestType + wantErr bool + } + + cases := []testCase{ + {name: "unknown", input: "unknown", want: RequestTypeUnknown}, + {name: "sync", input: "sync", want: RequestTypeSync}, + {name: "stream", input: "stream", want: RequestTypeStream}, + {name: "ws_v2", input: "ws_v2", want: RequestTypeWSV2}, + {name: "case_insensitive", input: "WS_V2", want: RequestTypeWSV2}, + {name: "trim_spaces", input: " stream ", want: RequestTypeStream}, + {name: "invalid", input: "xxx", wantErr: true}, + } + + for _, tc := range cases { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got, err := ParseUsageRequestType(tc.input) + if tc.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tc.want, got) + }) + } +} + +func TestRequestTypeNormalizeAndString(t *testing.T) { + t.Parallel() + + require.Equal(t, RequestTypeUnknown, RequestType(99).Normalize()) + require.Equal(t, "unknown", RequestType(99).String()) + require.Equal(t, "sync", RequestTypeSync.String()) + require.Equal(t, "stream", RequestTypeStream.String()) + require.Equal(t, "ws_v2", RequestTypeWSV2.String()) +} + +func TestRequestTypeFromLegacy(t *testing.T) { + t.Parallel() + + require.Equal(t, RequestTypeWSV2, RequestTypeFromLegacy(false, true)) + require.Equal(t, RequestTypeStream, RequestTypeFromLegacy(true, false)) + require.Equal(t, RequestTypeSync, RequestTypeFromLegacy(false, false)) +} + +func TestApplyLegacyRequestFields(t *testing.T) { + t.Parallel() + + stream, ws := ApplyLegacyRequestFields(RequestTypeSync, true, true) + require.False(t, stream) + require.False(t, ws) + + stream, ws = ApplyLegacyRequestFields(RequestTypeStream, false, true) + require.True(t, stream) + require.False(t, ws) + + stream, ws = ApplyLegacyRequestFields(RequestTypeWSV2, false, false) + require.True(t, stream) + require.True(t, ws) + + stream, ws = ApplyLegacyRequestFields(RequestTypeUnknown, true, false) + require.True(t, stream) + require.False(t, ws) +} + +func TestUsageLogSyncRequestTypeAndLegacyFields(t *testing.T) { + t.Parallel() + + log := &UsageLog{RequestType: RequestTypeWSV2, Stream: false, OpenAIWSMode: false} + log.SyncRequestTypeAndLegacyFields() + + require.Equal(t, RequestTypeWSV2, log.RequestType) + require.True(t, log.Stream) + require.True(t, log.OpenAIWSMode) +} + +func TestUsageLogEffectiveRequestTypeFallback(t *testing.T) { + t.Parallel() + + log := &UsageLog{RequestType: RequestTypeUnknown, Stream: true, OpenAIWSMode: true} + require.Equal(t, RequestTypeWSV2, log.EffectiveRequestType()) +} + +func TestUsageLogEffectiveRequestTypeNilReceiver(t *testing.T) { + t.Parallel() + + var log *UsageLog + require.Equal(t, RequestTypeUnknown, log.EffectiveRequestType()) +} + +func TestUsageLogSyncRequestTypeAndLegacyFieldsNilReceiver(t *testing.T) { + t.Parallel() + + var log *UsageLog + log.SyncRequestTypeAndLegacyFields() +} diff --git a/backend/internal/service/user.go b/backend/internal/service/user.go index e56d83bf..487f12da 100644 --- a/backend/internal/service/user.go +++ b/backend/internal/service/user.go @@ -25,6 +25,10 @@ type User struct { // map[groupID]rateMultiplier GroupRates map[int64]float64 + // Sora 存储配额 + SoraStorageQuotaBytes int64 // 用户级 Sora 存储配额(0 表示使用分组或系统默认值) + SoraStorageUsedBytes int64 // Sora 存储已用量 + // TOTP 双因素认证字段 TotpSecretEncrypted *string // AES-256-GCM 加密的 TOTP 密钥 TotpEnabled bool // 是否启用 TOTP diff --git a/backend/internal/service/user_service_test.go b/backend/internal/service/user_service_test.go index 0f355d70..7c3c984f 100644 --- a/backend/internal/service/user_service_test.go +++ b/backend/internal/service/user_service_test.go @@ -21,12 +21,12 @@ type mockUserRepo struct { updateBalanceFn func(ctx context.Context, id int64, amount float64) error } -func (m *mockUserRepo) Create(context.Context, *User) error { return nil } -func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) { return &User{}, nil } -func (m *mockUserRepo) GetByEmail(context.Context, string) (*User, error) { return &User{}, nil } -func (m *mockUserRepo) GetFirstAdmin(context.Context) (*User, error) { return &User{}, nil } -func (m *mockUserRepo) Update(context.Context, *User) error { return nil } -func (m *mockUserRepo) Delete(context.Context, int64) error { return nil } +func (m *mockUserRepo) Create(context.Context, *User) error { return nil } +func (m *mockUserRepo) GetByID(context.Context, int64) (*User, error) { return &User{}, nil } +func (m *mockUserRepo) GetByEmail(context.Context, string) (*User, error) { return &User{}, nil } +func (m *mockUserRepo) GetFirstAdmin(context.Context) (*User, error) { return &User{}, nil } +func (m *mockUserRepo) Update(context.Context, *User) error { return nil } +func (m *mockUserRepo) Delete(context.Context, int64) error { return nil } func (m *mockUserRepo) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) { return nil, nil, nil } @@ -56,8 +56,8 @@ type mockAuthCacheInvalidator struct { mu sync.Mutex } -func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByKey(context.Context, string) {} -func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByGroupID(context.Context, int64) {} +func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByKey(context.Context, string) {} +func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByGroupID(context.Context, int64) {} func (m *mockAuthCacheInvalidator) InvalidateAuthCacheByUserID(_ context.Context, userID int64) { m.mu.Lock() defer m.mu.Unlock() @@ -73,9 +73,9 @@ type mockBillingCache struct { mu sync.Mutex } -func (m *mockBillingCache) GetUserBalance(context.Context, int64) (float64, error) { return 0, nil } -func (m *mockBillingCache) SetUserBalance(context.Context, int64, float64) error { return nil } -func (m *mockBillingCache) DeductUserBalance(context.Context, int64, float64) error { return nil } +func (m *mockBillingCache) GetUserBalance(context.Context, int64) (float64, error) { return 0, nil } +func (m *mockBillingCache) SetUserBalance(context.Context, int64, float64) error { return nil } +func (m *mockBillingCache) DeductUserBalance(context.Context, int64, float64) error { return nil } func (m *mockBillingCache) InvalidateUserBalance(_ context.Context, userID int64) error { m.invalidateCallCount.Add(1) m.mu.Lock() diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index f04acc00..68deace9 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -327,6 +327,7 @@ var ProviderSet = wire.NewSet( NewAccountUsageService, NewAccountTestService, NewSettingService, + NewDataManagementService, ProvideOpsSystemLogSink, NewOpsService, ProvideOpsMetricsCollector, diff --git a/backend/internal/testutil/stubs.go b/backend/internal/testutil/stubs.go index 3569db17..217a5f56 100644 --- a/backend/internal/testutil/stubs.go +++ b/backend/internal/testutil/stubs.go @@ -66,6 +66,13 @@ func (c StubConcurrencyCache) GetUsersLoadBatch(_ context.Context, users []servi } return result, nil } +func (c StubConcurrencyCache) GetAccountConcurrencyBatch(_ context.Context, accountIDs []int64) (map[int64]int, error) { + result := make(map[int64]int, len(accountIDs)) + for _, id := range accountIDs { + result[id] = 0 + } + return result, nil +} func (c StubConcurrencyCache) CleanupExpiredAccountSlots(_ context.Context, _ int64) error { return nil } diff --git a/backend/internal/util/logredact/redact.go b/backend/internal/util/logredact/redact.go index 492d875c..9249b761 100644 --- a/backend/internal/util/logredact/redact.go +++ b/backend/internal/util/logredact/redact.go @@ -3,7 +3,9 @@ package logredact import ( "encoding/json" "regexp" + "sort" "strings" + "sync" ) // maxRedactDepth 限制递归深度以防止栈溢出 @@ -31,9 +33,18 @@ var defaultSensitiveKeyList = []string{ "password", } +type textRedactPatterns struct { + reJSONLike *regexp.Regexp + reQueryLike *regexp.Regexp + rePlain *regexp.Regexp +} + var ( reGOCSPX = regexp.MustCompile(`GOCSPX-[0-9A-Za-z_-]{24,}`) reAIza = regexp.MustCompile(`AIza[0-9A-Za-z_-]{35}`) + + defaultTextRedactPatterns = compileTextRedactPatterns(nil) + extraTextPatternCache sync.Map // map[string]*textRedactPatterns ) func RedactMap(input map[string]any, extraKeys ...string) map[string]any { @@ -83,23 +94,71 @@ func RedactText(input string, extraKeys ...string) string { return RedactJSON(raw, extraKeys...) } - keyAlt := buildKeyAlternation(extraKeys) - // JSON-like: "access_token":"..." - reJSONLike := regexp.MustCompile(`(?i)("(?:` + keyAlt + `)"\s*:\s*")([^"]*)(")`) - // Query-like: access_token=... - reQueryLike := regexp.MustCompile(`(?i)\b((?:` + keyAlt + `))=([^&\s]+)`) - // Plain: access_token: ... / access_token = ... - rePlain := regexp.MustCompile(`(?i)\b((?:` + keyAlt + `))\b(\s*[:=]\s*)([^,\s]+)`) + patterns := getTextRedactPatterns(extraKeys) out := input out = reGOCSPX.ReplaceAllString(out, "GOCSPX-***") out = reAIza.ReplaceAllString(out, "AIza***") - out = reJSONLike.ReplaceAllString(out, `$1***$3`) - out = reQueryLike.ReplaceAllString(out, `$1=***`) - out = rePlain.ReplaceAllString(out, `$1$2***`) + out = patterns.reJSONLike.ReplaceAllString(out, `$1***$3`) + out = patterns.reQueryLike.ReplaceAllString(out, `$1=***`) + out = patterns.rePlain.ReplaceAllString(out, `$1$2***`) return out } +func compileTextRedactPatterns(extraKeys []string) *textRedactPatterns { + keyAlt := buildKeyAlternation(extraKeys) + return &textRedactPatterns{ + // JSON-like: "access_token":"..." + reJSONLike: regexp.MustCompile(`(?i)("(?:` + keyAlt + `)"\s*:\s*")([^"]*)(")`), + // Query-like: access_token=... + reQueryLike: regexp.MustCompile(`(?i)\b((?:` + keyAlt + `))=([^&\s]+)`), + // Plain: access_token: ... / access_token = ... + rePlain: regexp.MustCompile(`(?i)\b((?:` + keyAlt + `))\b(\s*[:=]\s*)([^,\s]+)`), + } +} + +func getTextRedactPatterns(extraKeys []string) *textRedactPatterns { + normalizedExtraKeys := normalizeAndSortExtraKeys(extraKeys) + if len(normalizedExtraKeys) == 0 { + return defaultTextRedactPatterns + } + + cacheKey := strings.Join(normalizedExtraKeys, ",") + if cached, ok := extraTextPatternCache.Load(cacheKey); ok { + if patterns, ok := cached.(*textRedactPatterns); ok { + return patterns + } + } + + compiled := compileTextRedactPatterns(normalizedExtraKeys) + actual, _ := extraTextPatternCache.LoadOrStore(cacheKey, compiled) + if patterns, ok := actual.(*textRedactPatterns); ok { + return patterns + } + return compiled +} + +func normalizeAndSortExtraKeys(extraKeys []string) []string { + if len(extraKeys) == 0 { + return nil + } + seen := make(map[string]struct{}, len(extraKeys)) + keys := make([]string, 0, len(extraKeys)) + for _, key := range extraKeys { + normalized := normalizeKey(key) + if normalized == "" { + continue + } + if _, ok := seen[normalized]; ok { + continue + } + seen[normalized] = struct{}{} + keys = append(keys, normalized) + } + sort.Strings(keys) + return keys +} + func buildKeyAlternation(extraKeys []string) string { seen := make(map[string]struct{}, len(defaultSensitiveKeyList)+len(extraKeys)) keys := make([]string, 0, len(defaultSensitiveKeyList)+len(extraKeys)) diff --git a/backend/internal/util/logredact/redact_test.go b/backend/internal/util/logredact/redact_test.go index 64a7b3cf..266db69d 100644 --- a/backend/internal/util/logredact/redact_test.go +++ b/backend/internal/util/logredact/redact_test.go @@ -37,3 +37,48 @@ func TestRedactText_GOCSPX(t *testing.T) { t.Fatalf("expected key redacted, got %q", out) } } + +func TestRedactText_ExtraKeyCacheUsesNormalizedSortedKey(t *testing.T) { + clearExtraTextPatternCache() + + out1 := RedactText("custom_secret=abc", "Custom_Secret", " custom_secret ") + out2 := RedactText("custom_secret=xyz", "custom_secret") + if !strings.Contains(out1, "custom_secret=***") { + t.Fatalf("expected custom key redacted in first call, got %q", out1) + } + if !strings.Contains(out2, "custom_secret=***") { + t.Fatalf("expected custom key redacted in second call, got %q", out2) + } + + if got := countExtraTextPatternCacheEntries(); got != 1 { + t.Fatalf("expected 1 cached pattern set, got %d", got) + } +} + +func TestRedactText_DefaultPathDoesNotUseExtraCache(t *testing.T) { + clearExtraTextPatternCache() + + out := RedactText("access_token=abc") + if !strings.Contains(out, "access_token=***") { + t.Fatalf("expected default key redacted, got %q", out) + } + if got := countExtraTextPatternCacheEntries(); got != 0 { + t.Fatalf("expected extra cache to remain empty, got %d", got) + } +} + +func clearExtraTextPatternCache() { + extraTextPatternCache.Range(func(key, value any) bool { + extraTextPatternCache.Delete(key) + return true + }) +} + +func countExtraTextPatternCacheEntries() int { + count := 0 + extraTextPatternCache.Range(func(key, value any) bool { + count++ + return true + }) + return count +} diff --git a/backend/internal/util/responseheaders/responseheaders.go b/backend/internal/util/responseheaders/responseheaders.go index 86c3f624..7f7baca6 100644 --- a/backend/internal/util/responseheaders/responseheaders.go +++ b/backend/internal/util/responseheaders/responseheaders.go @@ -41,7 +41,14 @@ var hopByHopHeaders = map[string]struct{}{ "connection": {}, } -func FilterHeaders(src http.Header, cfg config.ResponseHeaderConfig) http.Header { +type CompiledHeaderFilter struct { + allowed map[string]struct{} + forceRemove map[string]struct{} +} + +var defaultCompiledHeaderFilter = CompileHeaderFilter(config.ResponseHeaderConfig{}) + +func CompileHeaderFilter(cfg config.ResponseHeaderConfig) *CompiledHeaderFilter { allowed := make(map[string]struct{}, len(defaultAllowed)+len(cfg.AdditionalAllowed)) for key := range defaultAllowed { allowed[key] = struct{}{} @@ -69,13 +76,24 @@ func FilterHeaders(src http.Header, cfg config.ResponseHeaderConfig) http.Header } } + return &CompiledHeaderFilter{ + allowed: allowed, + forceRemove: forceRemove, + } +} + +func FilterHeaders(src http.Header, filter *CompiledHeaderFilter) http.Header { + if filter == nil { + filter = defaultCompiledHeaderFilter + } + filtered := make(http.Header, len(src)) for key, values := range src { lower := strings.ToLower(key) - if _, blocked := forceRemove[lower]; blocked { + if _, blocked := filter.forceRemove[lower]; blocked { continue } - if _, ok := allowed[lower]; !ok { + if _, ok := filter.allowed[lower]; !ok { continue } // 跳过 hop-by-hop 头部,这些由 HTTP 库自动处理 @@ -89,8 +107,8 @@ func FilterHeaders(src http.Header, cfg config.ResponseHeaderConfig) http.Header return filtered } -func WriteFilteredHeaders(dst http.Header, src http.Header, cfg config.ResponseHeaderConfig) { - filtered := FilterHeaders(src, cfg) +func WriteFilteredHeaders(dst http.Header, src http.Header, filter *CompiledHeaderFilter) { + filtered := FilterHeaders(src, filter) for key, values := range filtered { for _, value := range values { dst.Add(key, value) diff --git a/backend/internal/util/responseheaders/responseheaders_test.go b/backend/internal/util/responseheaders/responseheaders_test.go index f7343267..d817559e 100644 --- a/backend/internal/util/responseheaders/responseheaders_test.go +++ b/backend/internal/util/responseheaders/responseheaders_test.go @@ -20,7 +20,7 @@ func TestFilterHeadersDisabledUsesDefaultAllowlist(t *testing.T) { ForceRemove: []string{"x-request-id"}, } - filtered := FilterHeaders(src, cfg) + filtered := FilterHeaders(src, CompileHeaderFilter(cfg)) if filtered.Get("Content-Type") != "application/json" { t.Fatalf("expected Content-Type passthrough, got %q", filtered.Get("Content-Type")) } @@ -51,7 +51,7 @@ func TestFilterHeadersEnabledUsesAllowlist(t *testing.T) { ForceRemove: []string{"x-remove"}, } - filtered := FilterHeaders(src, cfg) + filtered := FilterHeaders(src, CompileHeaderFilter(cfg)) if filtered.Get("Content-Type") != "application/json" { t.Fatalf("expected Content-Type allowed, got %q", filtered.Get("Content-Type")) } diff --git a/backend/migrations/060_add_gemini31_flash_image_to_model_mapping.sql b/backend/migrations/060_add_gemini31_flash_image_to_model_mapping.sql index de9d5776..93af0da7 100644 --- a/backend/migrations/060_add_gemini31_flash_image_to_model_mapping.sql +++ b/backend/migrations/060_add_gemini31_flash_image_to_model_mapping.sql @@ -1,45 +1,36 @@ --- Add gemini-3.1-flash-image and gemini-3.1-flash-image-preview to model_mapping +-- Add gemini-3.1-flash-image mapping keys without wiping existing custom mappings. -- -- Background: --- Antigravity now supports gemini-3.1-flash-image as the latest image generation model, --- replacing the previous gemini-3-pro-image. +-- Antigravity now supports gemini-3.1-flash-image as the latest image generation model. +-- Existing accounts may still contain gemini-3-pro-image aliases. -- -- Strategy: --- Directly overwrite the entire model_mapping with updated mappings --- This ensures consistency with DefaultAntigravityModelMapping in constants.go +-- Incrementally upsert only image-related keys in credentials.model_mapping: +-- 1) add canonical 3.1 image keys +-- 2) keep legacy 3-pro-image keys but remap them to 3.1 image for compatibility +-- This preserves user custom mappings and avoids full mapping overwrite. UPDATE accounts SET credentials = jsonb_set( - credentials, - '{model_mapping}', - '{ - "claude-opus-4-6-thinking": "claude-opus-4-6-thinking", - "claude-opus-4-6": "claude-opus-4-6-thinking", - "claude-opus-4-5-thinking": "claude-opus-4-6-thinking", - "claude-opus-4-5-20251101": "claude-opus-4-6-thinking", - "claude-sonnet-4-6": "claude-sonnet-4-6", - "claude-sonnet-4-5": "claude-sonnet-4-5", - "claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking", - "claude-sonnet-4-5-20250929": "claude-sonnet-4-5", - "claude-haiku-4-5": "claude-sonnet-4-5", - "claude-haiku-4-5-20251001": "claude-sonnet-4-5", - "gemini-2.5-flash": "gemini-2.5-flash", - "gemini-2.5-flash-lite": "gemini-2.5-flash-lite", - "gemini-2.5-flash-thinking": "gemini-2.5-flash-thinking", - "gemini-2.5-pro": "gemini-2.5-pro", - "gemini-3-flash": "gemini-3-flash", - "gemini-3-pro-high": "gemini-3-pro-high", - "gemini-3-pro-low": "gemini-3-pro-low", - "gemini-3-flash-preview": "gemini-3-flash", - "gemini-3-pro-preview": "gemini-3-pro-high", - "gemini-3.1-pro-high": "gemini-3.1-pro-high", - "gemini-3.1-pro-low": "gemini-3.1-pro-low", - "gemini-3.1-pro-preview": "gemini-3.1-pro-high", - "gemini-3.1-flash-image": "gemini-3.1-flash-image", - "gemini-3.1-flash-image-preview": "gemini-3.1-flash-image", - "gpt-oss-120b-medium": "gpt-oss-120b-medium", - "tab_flash_lite_preview": "tab_flash_lite_preview" - }'::jsonb + jsonb_set( + jsonb_set( + jsonb_set( + credentials, + '{model_mapping,gemini-3.1-flash-image}', + '"gemini-3.1-flash-image"'::jsonb, + true + ), + '{model_mapping,gemini-3.1-flash-image-preview}', + '"gemini-3.1-flash-image"'::jsonb, + true + ), + '{model_mapping,gemini-3-pro-image}', + '"gemini-3.1-flash-image"'::jsonb, + true + ), + '{model_mapping,gemini-3-pro-image-preview}', + '"gemini-3.1-flash-image"'::jsonb, + true ) WHERE platform = 'antigravity' AND deleted_at IS NULL diff --git a/backend/migrations/060_add_usage_log_openai_ws_mode.sql b/backend/migrations/060_add_usage_log_openai_ws_mode.sql new file mode 100644 index 00000000..b7d22414 --- /dev/null +++ b/backend/migrations/060_add_usage_log_openai_ws_mode.sql @@ -0,0 +1,2 @@ +-- Add openai_ws_mode flag to usage_logs to persist exact OpenAI WS transport type. +ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS openai_ws_mode BOOLEAN NOT NULL DEFAULT FALSE; diff --git a/backend/migrations/061_add_usage_log_request_type.sql b/backend/migrations/061_add_usage_log_request_type.sql new file mode 100644 index 00000000..68a33d51 --- /dev/null +++ b/backend/migrations/061_add_usage_log_request_type.sql @@ -0,0 +1,29 @@ +-- Add request_type enum for usage_logs while keeping legacy stream/openai_ws_mode compatibility. +ALTER TABLE usage_logs + ADD COLUMN IF NOT EXISTS request_type SMALLINT NOT NULL DEFAULT 0; + +DO $$ +BEGIN + IF NOT EXISTS ( + SELECT 1 + FROM pg_constraint + WHERE conname = 'usage_logs_request_type_check' + ) THEN + ALTER TABLE usage_logs + ADD CONSTRAINT usage_logs_request_type_check + CHECK (request_type IN (0, 1, 2, 3)); + END IF; +END +$$; + +CREATE INDEX IF NOT EXISTS idx_usage_logs_request_type_created_at + ON usage_logs (request_type, created_at); + +-- Backfill from legacy fields. openai_ws_mode has higher priority than stream. +UPDATE usage_logs +SET request_type = CASE + WHEN openai_ws_mode = TRUE THEN 3 + WHEN stream = TRUE THEN 2 + ELSE 1 +END +WHERE request_type = 0; diff --git a/backend/migrations/062_add_scheduler_and_usage_composite_indexes_notx.sql b/backend/migrations/062_add_scheduler_and_usage_composite_indexes_notx.sql new file mode 100644 index 00000000..c6139338 --- /dev/null +++ b/backend/migrations/062_add_scheduler_and_usage_composite_indexes_notx.sql @@ -0,0 +1,15 @@ +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_accounts_schedulable_hot + ON accounts (platform, priority) + WHERE deleted_at IS NULL AND status = 'active' AND schedulable = true; + +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_accounts_active_schedulable + ON accounts (priority, status) + WHERE deleted_at IS NULL AND schedulable = true; + +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_user_subscriptions_user_status_expires_active + ON user_subscriptions (user_id, status, expires_at) + WHERE deleted_at IS NULL; + +CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_usage_logs_group_created_at_not_null + ON usage_logs (group_id, created_at) + WHERE group_id IS NOT NULL; diff --git a/backend/migrations/063_add_sora_client_tables.sql b/backend/migrations/063_add_sora_client_tables.sql new file mode 100644 index 00000000..69197f10 --- /dev/null +++ b/backend/migrations/063_add_sora_client_tables.sql @@ -0,0 +1,56 @@ +-- Migration: 063_add_sora_client_tables +-- Sora 客户端功能所需的数据库变更: +-- 1. 新增 sora_generations 表:记录 Sora 客户端 UI 的生成历史 +-- 2. users 表新增存储配额字段 +-- 3. groups 表新增存储配额字段 + +-- ============================================================ +-- 1. sora_generations 表(生成记录) +-- ============================================================ +CREATE TABLE IF NOT EXISTS sora_generations ( + id BIGSERIAL PRIMARY KEY, + user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + api_key_id BIGINT, + + -- 生成参数 + model VARCHAR(64) NOT NULL, + prompt TEXT NOT NULL DEFAULT '', + media_type VARCHAR(16) NOT NULL DEFAULT 'video', -- video / image + + -- 结果 + status VARCHAR(16) NOT NULL DEFAULT 'pending', -- pending / generating / completed / failed / cancelled + media_url TEXT NOT NULL DEFAULT '', + media_urls JSONB, -- 多图时的 URL 数组 + file_size_bytes BIGINT NOT NULL DEFAULT 0, + storage_type VARCHAR(16) NOT NULL DEFAULT 'none', -- s3 / local / upstream / none + s3_object_keys JSONB, -- S3 object key 数组 + + -- 上游信息 + upstream_task_id VARCHAR(128) NOT NULL DEFAULT '', + error_message TEXT NOT NULL DEFAULT '', + + -- 时间 + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + completed_at TIMESTAMPTZ +); + +-- 按用户+时间查询(作品库列表、历史记录) +CREATE INDEX IF NOT EXISTS idx_sora_gen_user_created + ON sora_generations(user_id, created_at DESC); + +-- 按用户+状态查询(恢复进行中任务) +CREATE INDEX IF NOT EXISTS idx_sora_gen_user_status + ON sora_generations(user_id, status); + +-- ============================================================ +-- 2. users 表新增 Sora 存储配额字段 +-- ============================================================ +ALTER TABLE users + ADD COLUMN IF NOT EXISTS sora_storage_quota_bytes BIGINT NOT NULL DEFAULT 0, + ADD COLUMN IF NOT EXISTS sora_storage_used_bytes BIGINT NOT NULL DEFAULT 0; + +-- ============================================================ +-- 3. groups 表新增 Sora 存储配额字段 +-- ============================================================ +ALTER TABLE groups + ADD COLUMN IF NOT EXISTS sora_storage_quota_bytes BIGINT NOT NULL DEFAULT 0; diff --git a/backend/migrations/README.md b/backend/migrations/README.md index 3fe328e6..47f6fa35 100644 --- a/backend/migrations/README.md +++ b/backend/migrations/README.md @@ -12,6 +12,26 @@ Format: `NNN_description.sql` Example: `017_add_gemini_tier_id.sql` +### `_notx.sql` 命名与执行语义(并发索引专用) + +当迁移包含 `CREATE INDEX CONCURRENTLY` 或 `DROP INDEX CONCURRENTLY` 时,必须使用 `_notx.sql` 后缀,例如: + +- `062_add_accounts_priority_indexes_notx.sql` +- `063_drop_legacy_indexes_notx.sql` + +运行规则: + +1. `*.sql`(不带 `_notx`)按事务执行。 +2. `*_notx.sql` 按非事务执行,不会包裹在 `BEGIN/COMMIT` 中。 +3. `*_notx.sql` 仅允许并发索引语句,不允许混入事务控制语句或其他 DDL/DML。 + +幂等要求(必须): + +- 创建索引:`CREATE INDEX CONCURRENTLY IF NOT EXISTS ...` +- 删除索引:`DROP INDEX CONCURRENTLY IF EXISTS ...` + +这样可以保证灾备重放、重复执行时不会因对象已存在/不存在而失败。 + ## Migration File Structure ```sql diff --git a/deploy/.env.example b/deploy/.env.example index 290f918a..9f2ff13e 100644 --- a/deploy/.env.example +++ b/deploy/.env.example @@ -66,11 +66,15 @@ LOG_SAMPLING_INITIAL=100 # 之后每 N 条保留 1 条 LOG_SAMPLING_THEREAFTER=100 -# Global max request body size in bytes (default: 100MB) -# 全局最大请求体大小(字节,默认 100MB) +# Global max request body size in bytes (default: 256MB) +# 全局最大请求体大小(字节,默认 256MB) # Applies to all requests, especially important for h2c first request memory protection # 适用于所有请求,对 h2c 第一请求的内存保护尤为重要 -SERVER_MAX_REQUEST_BODY_SIZE=104857600 +SERVER_MAX_REQUEST_BODY_SIZE=268435456 + +# Gateway max request body size in bytes (default: 256MB) +# 网关请求体最大字节数(默认 256MB) +GATEWAY_MAX_BODY_SIZE=268435456 # Enable HTTP/2 Cleartext (h2c) for client connections # 启用 HTTP/2 Cleartext (h2c) 客户端连接 diff --git a/deploy/DATAMANAGEMENTD_CN.md b/deploy/DATAMANAGEMENTD_CN.md new file mode 100644 index 00000000..774f03ae --- /dev/null +++ b/deploy/DATAMANAGEMENTD_CN.md @@ -0,0 +1,78 @@ +# datamanagementd 部署说明(数据管理) + +本文说明如何在宿主机部署 `datamanagementd`,并与主进程联动开启“数据管理”功能。 + +## 1. 关键约束 + +- 主进程固定探测路径:`/tmp/sub2api-datamanagement.sock` +- 仅当该 Unix Socket 可连通且 `Health` 成功时,后台“数据管理”才会启用 +- `datamanagementd` 使用 SQLite 持久化元数据,不依赖主库 + +## 2. 宿主机构建与运行 + +```bash +cd /opt/sub2api-src/datamanagement +go build -o /opt/sub2api/datamanagementd ./cmd/datamanagementd + +mkdir -p /var/lib/sub2api/datamanagement +chown -R sub2api:sub2api /var/lib/sub2api/datamanagement +``` + +手动启动示例: + +```bash +/opt/sub2api/datamanagementd \ + -socket-path /tmp/sub2api-datamanagement.sock \ + -sqlite-path /var/lib/sub2api/datamanagement/datamanagementd.db \ + -version 1.0.0 +``` + +## 3. systemd 托管(推荐) + +仓库已提供示例服务文件:`deploy/sub2api-datamanagementd.service` + +```bash +sudo cp deploy/sub2api-datamanagementd.service /etc/systemd/system/ +sudo systemctl daemon-reload +sudo systemctl enable --now sub2api-datamanagementd +sudo systemctl status sub2api-datamanagementd +``` + +查看日志: + +```bash +sudo journalctl -u sub2api-datamanagementd -f +``` + +也可以使用一键安装脚本(自动安装二进制 + 注册 systemd): + +```bash +# 方式一:使用现成二进制 +sudo ./deploy/install-datamanagementd.sh --binary /path/to/datamanagementd + +# 方式二:从源码构建后安装 +sudo ./deploy/install-datamanagementd.sh --source /path/to/sub2api +``` + +## 4. Docker 部署联动 + +若 `sub2api` 运行在 Docker 容器中,需要将宿主机 Socket 挂载到容器同路径: + +```yaml +services: + sub2api: + volumes: + - /tmp/sub2api-datamanagement.sock:/tmp/sub2api-datamanagement.sock +``` + +建议在 `docker-compose.override.yml` 中维护该挂载,避免覆盖主 compose 文件。 + +## 5. 依赖检查 + +`datamanagementd` 执行备份时依赖以下工具: + +- `pg_dump` +- `redis-cli` +- `docker`(仅 `source_mode=docker_exec` 时) + +缺失依赖会导致对应任务失败,并在任务详情中体现错误信息。 diff --git a/deploy/README.md b/deploy/README.md index 3292e81a..807bf510 100644 --- a/deploy/README.md +++ b/deploy/README.md @@ -19,7 +19,10 @@ This directory contains files for deploying Sub2API on Linux servers. | `.env.example` | Docker environment variables template | | `DOCKER.md` | Docker Hub documentation | | `install.sh` | One-click binary installation script | +| `install-datamanagementd.sh` | datamanagementd 一键安装脚本 | | `sub2api.service` | Systemd service unit file | +| `sub2api-datamanagementd.service` | datamanagementd systemd service unit file | +| `DATAMANAGEMENTD_CN.md` | datamanagementd 部署与联动说明(中文) | | `config.example.yaml` | Example configuration file | --- @@ -145,6 +148,14 @@ SELECT (SELECT COUNT(*) FROM user_allowed_groups) AS new_pair_count; ``` +### datamanagementd(数据管理)联动 + +如需启用管理后台“数据管理”功能,请额外部署宿主机 `datamanagementd`: + +- 主进程固定探测 `/tmp/sub2api-datamanagement.sock` +- Docker 场景下需把宿主机 Socket 挂载到容器内同路径 +- 详细步骤见:`deploy/DATAMANAGEMENTD_CN.md` + ### Commands For **local directory version** (docker-compose.local.yml): @@ -575,7 +586,7 @@ gateway: name: "Profile 2" cipher_suites: [4866, 4867, 4865, 49199, 49195, 49200, 49196] curves: [29, 23, 24] - point_formats: [0] + point_formats: 0 # Another custom profile profile_3: diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 46a91ad6..faa85854 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -27,11 +27,11 @@ server: # Trusted proxies for X-Forwarded-For parsing (CIDR/IP). Empty disables trusted proxies. # 信任的代理地址(CIDR/IP 格式),用于解析 X-Forwarded-For 头。留空则禁用代理信任。 trusted_proxies: [] - # Global max request body size in bytes (default: 100MB) - # 全局最大请求体大小(字节,默认 100MB) + # Global max request body size in bytes (default: 256MB) + # 全局最大请求体大小(字节,默认 256MB) # Applies to all requests, especially important for h2c first request memory protection # 适用于所有请求,对 h2c 第一请求的内存保护尤为重要 - max_request_body_size: 104857600 + max_request_body_size: 268435456 # HTTP/2 Cleartext (h2c) configuration # HTTP/2 Cleartext (h2c) 配置 h2c: @@ -143,9 +143,9 @@ gateway: # Timeout for waiting upstream response headers (seconds) # 等待上游响应头超时时间(秒) response_header_timeout: 600 - # Max request body size in bytes (default: 100MB) - # 请求体最大字节数(默认 100MB) - max_body_size: 104857600 + # Max request body size in bytes (default: 256MB) + # 请求体最大字节数(默认 256MB) + max_body_size: 268435456 # Max bytes to read for non-stream upstream responses (default: 8MB) # 非流式上游响应体读取上限(默认 8MB) upstream_response_read_max_bytes: 8388608 @@ -199,6 +199,83 @@ gateway: # OpenAI 透传模式是否放行客户端超时头(如 x-stainless-timeout) # 默认 false:过滤超时头,降低上游提前断流风险。 openai_passthrough_allow_timeout_headers: false + # OpenAI Responses WebSocket 配置(默认开启,可按需回滚到 HTTP) + openai_ws: + # 新版 WS mode 路由(默认关闭)。关闭时保持当前 legacy 实现行为。 + mode_router_v2_enabled: false + # ingress 默认模式:off|shared|dedicated(仅 mode_router_v2_enabled=true 生效) + ingress_mode_default: shared + # 全局总开关,默认 true;关闭时所有请求保持原有 HTTP/SSE 路由 + enabled: true + # 按账号类型细分开关 + oauth_enabled: true + apikey_enabled: true + # 全局强制 HTTP(紧急回滚开关) + force_http: false + # 允许在 WSv2 下按策略恢复 store=true(默认 false) + allow_store_recovery: false + # ingress 模式收到 previous_response_not_found 时,自动去掉 previous_response_id 重试一次(默认 true) + ingress_previous_response_recovery_enabled: true + # store=false 且无可复用会话连接时的策略: + # strict=强制新建连接(隔离优先),adaptive=仅在高风险失败后强制新建,off=尽量复用(性能优先) + store_disabled_conn_mode: strict + # store=false 且无可复用会话连接时,是否强制新建连接(默认 true,优先会话隔离) + # 兼容旧配置:仅在 store_disabled_conn_mode 未配置时生效 + store_disabled_force_new_conn: true + # 是否启用 WSv2 generate=false 预热(默认 false) + prewarm_generate_enabled: false + # 协议 feature 开关,v2 优先于 v1 + responses_websockets: false + responses_websockets_v2: true + # 连接池参数(按账号池化复用) + max_conns_per_account: 128 + min_idle_per_account: 4 + max_idle_per_account: 12 + # 是否按账号并发动态计算连接池上限: + # effective_max_conns = min(max_conns_per_account, ceil(account.concurrency * factor)) + dynamic_max_conns_by_account_concurrency_enabled: true + # 按账号类型分别设置系数(OAuth / API Key) + oauth_max_conns_factor: 1.0 + apikey_max_conns_factor: 1.0 + dial_timeout_seconds: 10 + read_timeout_seconds: 900 + write_timeout_seconds: 120 + pool_target_utilization: 0.7 + queue_limit_per_conn: 64 + # 流式写出批量 flush 参数 + event_flush_batch_size: 1 + event_flush_interval_ms: 10 + # 预热触发冷却(毫秒) + prewarm_cooldown_ms: 300 + # WS 回退到 HTTP 后的冷却时间(秒),用于避免 WS/HTTP 来回抖动;0 表示关闭冷却 + fallback_cooldown_seconds: 30 + # WS 重试退避参数(毫秒) + retry_backoff_initial_ms: 120 + retry_backoff_max_ms: 2000 + # 抖动比例(0-1) + retry_jitter_ratio: 0.2 + # 单次请求 WS 重试总预算(毫秒);建议设置为有限值,避免重试拉高 TTFT 长尾 + retry_total_budget_ms: 5000 + # payload_schema 日志采样率(0-1);降低热路径日志放大 + payload_log_sample_rate: 0.2 + # 调度与粘连参数 + lb_top_k: 7 + sticky_session_ttl_seconds: 3600 + # 会话哈希迁移兼容开关:新 key 未命中时回退读取旧 SHA-256 key + session_hash_read_old_fallback: true + # 会话哈希迁移兼容开关:写入时双写旧 SHA-256 key(短 TTL) + session_hash_dual_write_old: true + # context 元数据迁移兼容开关:保留旧 ctxkey.* 读取/注入桥接 + metadata_bridge_enabled: true + sticky_response_id_ttl_seconds: 3600 + # 兼容旧键:当 sticky_response_id_ttl_seconds 缺失时回退该值 + sticky_previous_response_ttl_seconds: 3600 + scheduler_score_weights: + priority: 1.0 + load: 1.0 + queue: 0.7 + error_rate: 0.8 + ttft: 0.5 # HTTP upstream connection pool settings (HTTP/2 + multi-proxy scenario defaults) # HTTP 上游连接池配置(HTTP/2 + 多代理场景默认值) # Max idle connections across all hosts @@ -779,12 +856,12 @@ rate_limit: # 定价数据源(可选) # ============================================================================= pricing: - # URL to fetch model pricing data (default: LiteLLM) - # 获取模型定价数据的 URL(默认:LiteLLM) - remote_url: "https://github.com/Wei-Shaw/model-price-repo/raw/refs/heads/main/model_prices_and_context_window.json" + # URL to fetch model pricing data (default: pinned model-price-repo commit) + # 获取模型定价数据的 URL(默认:固定 commit 的 model-price-repo) + remote_url: "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.json" # Hash verification URL (optional) # 哈希校验 URL(可选) - hash_url: "https://github.com/Wei-Shaw/model-price-repo/raw/refs/heads/main/model_prices_and_context_window.sha256" + hash_url: "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/c7947e9871687e664180bc971d4837f1fc2784a9/model_prices_and_context_window.sha256" # Local data directory for caching # 本地数据缓存目录 data_dir: "./data" diff --git a/deploy/docker-compose.override.yml.example b/deploy/docker-compose.override.yml.example index 297724f5..7157f212 100644 --- a/deploy/docker-compose.override.yml.example +++ b/deploy/docker-compose.override.yml.example @@ -127,6 +127,19 @@ services: # - ./logs:/app/logs # - ./backups:/app/backups +# ============================================================================= +# Scenario 6: 启用宿主机 datamanagementd(数据管理) +# ============================================================================= +# 说明: +# - datamanagementd 运行在宿主机(systemd 或手动) +# - 主进程固定探测 /tmp/sub2api-datamanagement.sock +# - 需要把宿主机 socket 挂载到容器内同路径 +# +# services: +# sub2api: +# volumes: +# - /tmp/sub2api-datamanagement.sock:/tmp/sub2api-datamanagement.sock + # ============================================================================= # Additional Notes # ============================================================================= diff --git a/deploy/install-datamanagementd.sh b/deploy/install-datamanagementd.sh new file mode 100755 index 00000000..8d53134b --- /dev/null +++ b/deploy/install-datamanagementd.sh @@ -0,0 +1,123 @@ +#!/usr/bin/env bash + +set -euo pipefail + +# 用法: +# sudo ./install-datamanagementd.sh --binary /path/to/datamanagementd +# 或: +# sudo ./install-datamanagementd.sh --source /path/to/sub2api/repo + +BIN_PATH="" +SOURCE_PATH="" +INSTALL_DIR="/opt/sub2api" +DATA_DIR="/var/lib/sub2api/datamanagement" +SERVICE_FILE_NAME="sub2api-datamanagementd.service" + +function print_help() { + cat <<'EOF' +用法: + install-datamanagementd.sh [--binary ] [--source <仓库路径>] + +参数: + --binary 指定已构建的 datamanagementd 二进制路径 + --source 指定 sub2api 仓库路径(脚本会执行 go build) + -h, --help 显示帮助 + +示例: + sudo ./install-datamanagementd.sh --binary ./datamanagement/datamanagementd + sudo ./install-datamanagementd.sh --source /opt/sub2api-src +EOF +} + +while [[ $# -gt 0 ]]; do + case "$1" in + --binary) + BIN_PATH="${2:-}" + shift 2 + ;; + --source) + SOURCE_PATH="${2:-}" + shift 2 + ;; + -h|--help) + print_help + exit 0 + ;; + *) + echo "未知参数: $1" + print_help + exit 1 + ;; + esac +done + +if [[ -n "$BIN_PATH" && -n "$SOURCE_PATH" ]]; then + echo "错误: --binary 与 --source 只能二选一" + exit 1 +fi + +if [[ -z "$BIN_PATH" && -z "$SOURCE_PATH" ]]; then + echo "错误: 必须提供 --binary 或 --source" + exit 1 +fi + +if [[ "$(id -u)" -ne 0 ]]; then + echo "错误: 请使用 root 权限执行(例如 sudo)" + exit 1 +fi + +if [[ -n "$SOURCE_PATH" ]]; then + if [[ ! -d "$SOURCE_PATH/datamanagement" ]]; then + echo "错误: 无效仓库路径,未找到 $SOURCE_PATH/datamanagement" + exit 1 + fi + echo "[1/6] 从源码构建 datamanagementd..." + (cd "$SOURCE_PATH/datamanagement" && go build -o datamanagementd ./cmd/datamanagementd) + BIN_PATH="$SOURCE_PATH/datamanagement/datamanagementd" +fi + +if [[ ! -f "$BIN_PATH" ]]; then + echo "错误: 二进制文件不存在: $BIN_PATH" + exit 1 +fi + +if ! id sub2api >/dev/null 2>&1; then + echo "[2/6] 创建系统用户 sub2api..." + useradd --system --no-create-home --shell /usr/sbin/nologin sub2api +else + echo "[2/6] 系统用户 sub2api 已存在,跳过创建" +fi + +echo "[3/6] 安装 datamanagementd 二进制..." +mkdir -p "$INSTALL_DIR" +install -m 0755 "$BIN_PATH" "$INSTALL_DIR/datamanagementd" + +echo "[4/6] 准备数据目录..." +mkdir -p "$DATA_DIR" +chown -R sub2api:sub2api /var/lib/sub2api +chmod 0750 "$DATA_DIR" + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +SERVICE_TEMPLATE="$SCRIPT_DIR/$SERVICE_FILE_NAME" +if [[ ! -f "$SERVICE_TEMPLATE" ]]; then + echo "错误: 未找到服务模板 $SERVICE_TEMPLATE" + exit 1 +fi + +echo "[5/6] 安装 systemd 服务..." +cp "$SERVICE_TEMPLATE" "/etc/systemd/system/$SERVICE_FILE_NAME" +systemctl daemon-reload +systemctl enable --now sub2api-datamanagementd + +echo "[6/6] 完成,当前状态:" +systemctl --no-pager --full status sub2api-datamanagementd || true + +cat <<'EOF' + +下一步建议: +1. 查看日志:sudo journalctl -u sub2api-datamanagementd -f +2. 在 sub2api(容器部署时)挂载 socket: + /tmp/sub2api-datamanagement.sock:/tmp/sub2api-datamanagement.sock +3. 进入管理后台“数据管理”页面确认 agent=enabled + +EOF diff --git a/deploy/sub2api-datamanagementd.service b/deploy/sub2api-datamanagementd.service new file mode 100644 index 00000000..b32733b7 --- /dev/null +++ b/deploy/sub2api-datamanagementd.service @@ -0,0 +1,22 @@ +[Unit] +Description=Sub2API Data Management Daemon +After=network.target +Wants=network.target + +[Service] +Type=simple +User=sub2api +Group=sub2api +WorkingDirectory=/opt/sub2api +ExecStart=/opt/sub2api/datamanagementd \ + -socket-path /tmp/sub2api-datamanagement.sock \ + -sqlite-path /var/lib/sub2api/datamanagement/datamanagementd.db \ + -version 1.0.0 +Restart=always +RestartSec=5s +LimitNOFILE=100000 +NoNewPrivileges=true +PrivateTmp=false + +[Install] +WantedBy=multi-user.target diff --git a/frontend/src/api/__tests__/sora.spec.ts b/frontend/src/api/__tests__/sora.spec.ts new file mode 100644 index 00000000..88c0c416 --- /dev/null +++ b/frontend/src/api/__tests__/sora.spec.ts @@ -0,0 +1,80 @@ +import { describe, expect, it } from 'vitest' +import { + normalizeGenerationListResponse, + normalizeModelFamiliesResponse +} from '../sora' + +describe('sora api normalizers', () => { + it('normalizes generation list from data shape', () => { + const result = normalizeGenerationListResponse({ + data: [{ id: 1, status: 'pending' }], + total: 9, + page: 2 + }) + + expect(result.data).toHaveLength(1) + expect(result.total).toBe(9) + expect(result.page).toBe(2) + }) + + it('normalizes generation list from items shape', () => { + const result = normalizeGenerationListResponse({ + items: [{ id: 1, status: 'completed' }], + total: 1 + }) + + expect(result.data).toHaveLength(1) + expect(result.total).toBe(1) + expect(result.page).toBe(1) + }) + + it('falls back to empty generation list on invalid payload', () => { + const result = normalizeGenerationListResponse(null) + expect(result).toEqual({ data: [], total: 0, page: 1 }) + }) + + it('normalizes family model payload', () => { + const result = normalizeModelFamiliesResponse({ + data: [ + { + id: 'sora2', + name: 'Sora 2', + type: 'video', + orientations: ['landscape', 'portrait'], + durations: [10, 15] + } + ] + }) + + expect(result).toHaveLength(1) + expect(result[0].id).toBe('sora2') + expect(result[0].orientations).toEqual(['landscape', 'portrait']) + expect(result[0].durations).toEqual([10, 15]) + }) + + it('normalizes legacy flat model list into families', () => { + const result = normalizeModelFamiliesResponse({ + items: [ + { id: 'sora2-landscape-10s', type: 'video' }, + { id: 'sora2-portrait-15s', type: 'video' }, + { id: 'gpt-image-square', type: 'image' } + ] + }) + + const sora2 = result.find((m) => m.id === 'sora2') + expect(sora2).toBeTruthy() + expect(sora2?.orientations).toEqual(['landscape', 'portrait']) + expect(sora2?.durations).toEqual([10, 15]) + + const image = result.find((m) => m.id === 'gpt-image') + expect(image).toBeTruthy() + expect(image?.type).toBe('image') + expect(image?.orientations).toEqual(['square']) + }) + + it('falls back to empty families on invalid payload', () => { + expect(normalizeModelFamiliesResponse(undefined)).toEqual([]) + expect(normalizeModelFamiliesResponse({})).toEqual([]) + }) +}) + diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts index 1b8ae9ad..56571699 100644 --- a/frontend/src/api/admin/accounts.ts +++ b/frontend/src/api/admin/accounts.ts @@ -369,6 +369,22 @@ export async function getTodayStats(id: number): Promise { return data } +export interface BatchTodayStatsResponse { + stats: Record +} + +/** + * 批量获取多个账号的今日统计 + * @param accountIds - 账号 ID 列表 + * @returns 以账号 ID(字符串)为键的统计映射 + */ +export async function getBatchTodayStats(accountIds: number[]): Promise { + const { data } = await apiClient.post('/admin/accounts/today-stats/batch', { + account_ids: accountIds + }) + return data +} + /** * Set account schedulable status * @param id - Account ID @@ -556,6 +572,7 @@ export const accountsAPI = { clearError, getUsage, getTodayStats, + getBatchTodayStats, clearRateLimit, getTempUnschedulableStatus, resetTempUnschedulable, diff --git a/frontend/src/api/admin/dashboard.ts b/frontend/src/api/admin/dashboard.ts index ae48bec2..a5113dd1 100644 --- a/frontend/src/api/admin/dashboard.ts +++ b/frontend/src/api/admin/dashboard.ts @@ -9,7 +9,8 @@ import type { TrendDataPoint, ModelStat, ApiKeyUsageTrendPoint, - UserUsageTrendPoint + UserUsageTrendPoint, + UsageRequestType } from '@/types' /** @@ -49,6 +50,7 @@ export interface TrendParams { model?: string account_id?: number group_id?: number + request_type?: UsageRequestType stream?: boolean billing_type?: number | null } @@ -78,6 +80,7 @@ export interface ModelStatsParams { model?: string account_id?: number group_id?: number + request_type?: UsageRequestType stream?: boolean billing_type?: number | null } diff --git a/frontend/src/api/admin/dataManagement.ts b/frontend/src/api/admin/dataManagement.ts new file mode 100644 index 00000000..cec71446 --- /dev/null +++ b/frontend/src/api/admin/dataManagement.ts @@ -0,0 +1,332 @@ +import { apiClient } from '../client' + +export type BackupType = 'postgres' | 'redis' | 'full' +export type BackupJobStatus = 'queued' | 'running' | 'succeeded' | 'failed' | 'partial_succeeded' + +export interface BackupAgentInfo { + status: string + version: string + uptime_seconds: number +} + +export interface BackupAgentHealth { + enabled: boolean + reason: string + socket_path: string + agent?: BackupAgentInfo +} + +export interface DataManagementPostgresConfig { + host: string + port: number + user: string + password?: string + password_configured?: boolean + database: string + ssl_mode: string + container_name: string +} + +export interface DataManagementRedisConfig { + addr: string + username: string + password?: string + password_configured?: boolean + db: number + container_name: string +} + +export interface DataManagementS3Config { + enabled: boolean + endpoint: string + region: string + bucket: string + access_key_id: string + secret_access_key?: string + secret_access_key_configured?: boolean + prefix: string + force_path_style: boolean + use_ssl: boolean +} + +export interface DataManagementConfig { + source_mode: 'direct' | 'docker_exec' + backup_root: string + sqlite_path?: string + retention_days: number + keep_last: number + active_postgres_profile_id?: string + active_redis_profile_id?: string + active_s3_profile_id?: string + postgres: DataManagementPostgresConfig + redis: DataManagementRedisConfig + s3: DataManagementS3Config +} + +export type SourceType = 'postgres' | 'redis' + +export interface DataManagementSourceConfig { + host: string + port: number + user: string + password?: string + database: string + ssl_mode: string + addr: string + username: string + db: number + container_name: string +} + +export interface DataManagementSourceProfile { + source_type: SourceType + profile_id: string + name: string + is_active: boolean + password_configured?: boolean + config: DataManagementSourceConfig + created_at?: string + updated_at?: string +} + +export interface TestS3Request { + endpoint: string + region: string + bucket: string + access_key_id: string + secret_access_key: string + prefix?: string + force_path_style?: boolean + use_ssl?: boolean +} + +export interface TestS3Response { + ok: boolean + message: string +} + +export interface CreateBackupJobRequest { + backup_type: BackupType + upload_to_s3?: boolean + s3_profile_id?: string + postgres_profile_id?: string + redis_profile_id?: string + idempotency_key?: string +} + +export interface CreateBackupJobResponse { + job_id: string + status: BackupJobStatus +} + +export interface BackupArtifactInfo { + local_path: string + size_bytes: number + sha256: string +} + +export interface BackupS3Info { + bucket: string + key: string + etag: string +} + +export interface BackupJob { + job_id: string + backup_type: BackupType + status: BackupJobStatus + triggered_by: string + s3_profile_id?: string + postgres_profile_id?: string + redis_profile_id?: string + started_at?: string + finished_at?: string + error_message?: string + artifact?: BackupArtifactInfo + s3?: BackupS3Info +} + +export interface ListSourceProfilesResponse { + items: DataManagementSourceProfile[] +} + +export interface CreateSourceProfileRequest { + profile_id: string + name: string + config: DataManagementSourceConfig + set_active?: boolean +} + +export interface UpdateSourceProfileRequest { + name: string + config: DataManagementSourceConfig +} + +export interface DataManagementS3Profile { + profile_id: string + name: string + is_active: boolean + s3: DataManagementS3Config + secret_access_key_configured?: boolean + created_at?: string + updated_at?: string +} + +export interface ListS3ProfilesResponse { + items: DataManagementS3Profile[] +} + +export interface CreateS3ProfileRequest { + profile_id: string + name: string + enabled: boolean + endpoint: string + region: string + bucket: string + access_key_id: string + secret_access_key?: string + prefix?: string + force_path_style?: boolean + use_ssl?: boolean + set_active?: boolean +} + +export interface UpdateS3ProfileRequest { + name: string + enabled: boolean + endpoint: string + region: string + bucket: string + access_key_id: string + secret_access_key?: string + prefix?: string + force_path_style?: boolean + use_ssl?: boolean +} + +export interface ListBackupJobsRequest { + page_size?: number + page_token?: string + status?: BackupJobStatus + backup_type?: BackupType +} + +export interface ListBackupJobsResponse { + items: BackupJob[] + next_page_token?: string +} + +export async function getAgentHealth(): Promise { + const { data } = await apiClient.get('/admin/data-management/agent/health') + return data +} + +export async function getConfig(): Promise { + const { data } = await apiClient.get('/admin/data-management/config') + return data +} + +export async function updateConfig(request: DataManagementConfig): Promise { + const { data } = await apiClient.put('/admin/data-management/config', request) + return data +} + +export async function testS3(request: TestS3Request): Promise { + const { data } = await apiClient.post('/admin/data-management/s3/test', request) + return data +} + +export async function listSourceProfiles(sourceType: SourceType): Promise { + const { data } = await apiClient.get(`/admin/data-management/sources/${sourceType}/profiles`) + return data +} + +export async function createSourceProfile(sourceType: SourceType, request: CreateSourceProfileRequest): Promise { + const { data } = await apiClient.post(`/admin/data-management/sources/${sourceType}/profiles`, request) + return data +} + +export async function updateSourceProfile(sourceType: SourceType, profileID: string, request: UpdateSourceProfileRequest): Promise { + const { data } = await apiClient.put(`/admin/data-management/sources/${sourceType}/profiles/${profileID}`, request) + return data +} + +export async function deleteSourceProfile(sourceType: SourceType, profileID: string): Promise { + await apiClient.delete(`/admin/data-management/sources/${sourceType}/profiles/${profileID}`) +} + +export async function setActiveSourceProfile(sourceType: SourceType, profileID: string): Promise { + const { data } = await apiClient.post(`/admin/data-management/sources/${sourceType}/profiles/${profileID}/activate`) + return data +} + +export async function listS3Profiles(): Promise { + const { data } = await apiClient.get('/admin/data-management/s3/profiles') + return data +} + +export async function createS3Profile(request: CreateS3ProfileRequest): Promise { + const { data } = await apiClient.post('/admin/data-management/s3/profiles', request) + return data +} + +export async function updateS3Profile(profileID: string, request: UpdateS3ProfileRequest): Promise { + const { data } = await apiClient.put(`/admin/data-management/s3/profiles/${profileID}`, request) + return data +} + +export async function deleteS3Profile(profileID: string): Promise { + await apiClient.delete(`/admin/data-management/s3/profiles/${profileID}`) +} + +export async function setActiveS3Profile(profileID: string): Promise { + const { data } = await apiClient.post(`/admin/data-management/s3/profiles/${profileID}/activate`) + return data +} + +export async function createBackupJob(request: CreateBackupJobRequest): Promise { + const headers = request.idempotency_key + ? { 'X-Idempotency-Key': request.idempotency_key } + : undefined + + const { data } = await apiClient.post( + '/admin/data-management/backups', + request, + { headers } + ) + return data +} + +export async function listBackupJobs(request?: ListBackupJobsRequest): Promise { + const { data } = await apiClient.get('/admin/data-management/backups', { + params: request + }) + return data +} + +export async function getBackupJob(jobID: string): Promise { + const { data } = await apiClient.get(`/admin/data-management/backups/${jobID}`) + return data +} + +export const dataManagementAPI = { + getAgentHealth, + getConfig, + updateConfig, + listSourceProfiles, + createSourceProfile, + updateSourceProfile, + deleteSourceProfile, + setActiveSourceProfile, + testS3, + listS3Profiles, + createS3Profile, + updateS3Profile, + deleteS3Profile, + setActiveS3Profile, + createBackupJob, + listBackupJobs, + getBackupJob +} + +export default dataManagementAPI diff --git a/frontend/src/api/admin/index.ts b/frontend/src/api/admin/index.ts index ffb9b179..1a19fa00 100644 --- a/frontend/src/api/admin/index.ts +++ b/frontend/src/api/admin/index.ts @@ -20,6 +20,7 @@ import antigravityAPI from './antigravity' import userAttributesAPI from './userAttributes' import opsAPI from './ops' import errorPassthroughAPI from './errorPassthrough' +import dataManagementAPI from './dataManagement' /** * Unified admin API object for convenient access @@ -41,7 +42,8 @@ export const adminAPI = { antigravity: antigravityAPI, userAttributes: userAttributesAPI, ops: opsAPI, - errorPassthrough: errorPassthroughAPI + errorPassthrough: errorPassthroughAPI, + dataManagement: dataManagementAPI } export { @@ -61,7 +63,8 @@ export { antigravityAPI, userAttributesAPI, opsAPI, - errorPassthroughAPI + errorPassthroughAPI, + dataManagementAPI } export default adminAPI @@ -69,3 +72,4 @@ export default adminAPI // Re-export types used by components export type { BalanceHistoryItem } from './users' export type { ErrorPassthroughRule, CreateRuleRequest, UpdateRuleRequest } from './errorPassthrough' +export type { BackupAgentHealth, DataManagementConfig } from './dataManagement' diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 3dc76fe7..858dd147 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -31,6 +31,7 @@ export interface SystemSettings { hide_ccs_import_button: boolean purchase_subscription_enabled: boolean purchase_subscription_url: string + sora_client_enabled: boolean // SMTP settings smtp_host: string smtp_port: number @@ -87,6 +88,7 @@ export interface UpdateSettingsRequest { hide_ccs_import_button?: boolean purchase_subscription_enabled?: boolean purchase_subscription_url?: string + sora_client_enabled?: boolean smtp_host?: string smtp_port?: number smtp_username?: string @@ -251,6 +253,142 @@ export async function updateStreamTimeoutSettings( return data } +// ==================== Sora S3 Settings ==================== + +export interface SoraS3Settings { + enabled: boolean + endpoint: string + region: string + bucket: string + access_key_id: string + secret_access_key_configured: boolean + prefix: string + force_path_style: boolean + cdn_url: string + default_storage_quota_bytes: number +} + +export interface SoraS3Profile { + profile_id: string + name: string + is_active: boolean + enabled: boolean + endpoint: string + region: string + bucket: string + access_key_id: string + secret_access_key_configured: boolean + prefix: string + force_path_style: boolean + cdn_url: string + default_storage_quota_bytes: number + updated_at: string +} + +export interface ListSoraS3ProfilesResponse { + active_profile_id: string + items: SoraS3Profile[] +} + +export interface UpdateSoraS3SettingsRequest { + profile_id?: string + enabled: boolean + endpoint: string + region: string + bucket: string + access_key_id: string + secret_access_key?: string + prefix: string + force_path_style: boolean + cdn_url: string + default_storage_quota_bytes: number +} + +export interface CreateSoraS3ProfileRequest { + profile_id: string + name: string + set_active?: boolean + enabled: boolean + endpoint: string + region: string + bucket: string + access_key_id: string + secret_access_key?: string + prefix: string + force_path_style: boolean + cdn_url: string + default_storage_quota_bytes: number +} + +export interface UpdateSoraS3ProfileRequest { + name: string + enabled: boolean + endpoint: string + region: string + bucket: string + access_key_id: string + secret_access_key?: string + prefix: string + force_path_style: boolean + cdn_url: string + default_storage_quota_bytes: number +} + +export interface TestSoraS3ConnectionRequest { + profile_id?: string + enabled: boolean + endpoint: string + region: string + bucket: string + access_key_id: string + secret_access_key?: string + prefix: string + force_path_style: boolean + cdn_url: string + default_storage_quota_bytes?: number +} + +export async function getSoraS3Settings(): Promise { + const { data } = await apiClient.get('/admin/settings/sora-s3') + return data +} + +export async function updateSoraS3Settings(settings: UpdateSoraS3SettingsRequest): Promise { + const { data } = await apiClient.put('/admin/settings/sora-s3', settings) + return data +} + +export async function testSoraS3Connection( + settings: TestSoraS3ConnectionRequest +): Promise<{ message: string }> { + const { data } = await apiClient.post<{ message: string }>('/admin/settings/sora-s3/test', settings) + return data +} + +export async function listSoraS3Profiles(): Promise { + const { data } = await apiClient.get('/admin/settings/sora-s3/profiles') + return data +} + +export async function createSoraS3Profile(request: CreateSoraS3ProfileRequest): Promise { + const { data } = await apiClient.post('/admin/settings/sora-s3/profiles', request) + return data +} + +export async function updateSoraS3Profile(profileID: string, request: UpdateSoraS3ProfileRequest): Promise { + const { data } = await apiClient.put(`/admin/settings/sora-s3/profiles/${profileID}`, request) + return data +} + +export async function deleteSoraS3Profile(profileID: string): Promise { + await apiClient.delete(`/admin/settings/sora-s3/profiles/${profileID}`) +} + +export async function setActiveSoraS3Profile(profileID: string): Promise { + const { data } = await apiClient.post(`/admin/settings/sora-s3/profiles/${profileID}/activate`) + return data +} + export const settingsAPI = { getSettings, updateSettings, @@ -260,7 +398,15 @@ export const settingsAPI = { regenerateAdminApiKey, deleteAdminApiKey, getStreamTimeoutSettings, - updateStreamTimeoutSettings + updateStreamTimeoutSettings, + getSoraS3Settings, + updateSoraS3Settings, + testSoraS3Connection, + listSoraS3Profiles, + createSoraS3Profile, + updateSoraS3Profile, + deleteSoraS3Profile, + setActiveSoraS3Profile } export default settingsAPI diff --git a/frontend/src/api/admin/usage.ts b/frontend/src/api/admin/usage.ts index 94f7b57b..66c84410 100644 --- a/frontend/src/api/admin/usage.ts +++ b/frontend/src/api/admin/usage.ts @@ -4,7 +4,7 @@ */ import { apiClient } from '../client' -import type { AdminUsageLog, UsageQueryParams, PaginatedResponse } from '@/types' +import type { AdminUsageLog, UsageQueryParams, PaginatedResponse, UsageRequestType } from '@/types' // ==================== Types ==================== @@ -39,6 +39,7 @@ export interface UsageCleanupFilters { account_id?: number group_id?: number model?: string | null + request_type?: UsageRequestType | null stream?: boolean | null billing_type?: number | null } @@ -66,6 +67,7 @@ export interface CreateUsageCleanupTaskRequest { account_id?: number group_id?: number model?: string | null + request_type?: UsageRequestType | null stream?: boolean | null billing_type?: number | null timezone?: string @@ -104,6 +106,7 @@ export async function getStats(params: { account_id?: number group_id?: number model?: string + request_type?: UsageRequestType stream?: boolean period?: string start_date?: string diff --git a/frontend/src/api/sora.ts b/frontend/src/api/sora.ts new file mode 100644 index 00000000..45108454 --- /dev/null +++ b/frontend/src/api/sora.ts @@ -0,0 +1,307 @@ +/** + * Sora 客户端 API + * 封装所有 Sora 生成、作品库、配额等接口调用 + */ + +import { apiClient } from './client' + +// ==================== 类型定义 ==================== + +export interface SoraGeneration { + id: number + user_id: number + model: string + prompt: string + media_type: string + status: string // pending | generating | completed | failed | cancelled + storage_type: string // upstream | s3 | local + media_url: string + media_urls: string[] + s3_object_keys: string[] + file_size_bytes: number + error_message: string + created_at: string + completed_at?: string +} + +export interface GenerateRequest { + model: string + prompt: string + video_count?: number + media_type?: string + image_input?: string + api_key_id?: number +} + +export interface GenerateResponse { + generation_id: number + status: string +} + +export interface GenerationListResponse { + data: SoraGeneration[] + total: number + page: number +} + +export interface QuotaInfo { + quota_bytes: number + used_bytes: number + available_bytes: number + quota_source: string // user | group | system | unlimited + source?: string // 兼容旧字段 +} + +export interface StorageStatus { + s3_enabled: boolean + s3_healthy: boolean + local_enabled: boolean +} + +/** 单个扁平模型(旧接口,保留兼容) */ +export interface SoraModel { + id: string + name: string + type: string // video | image + orientation?: string + duration?: number +} + +/** 模型家族(新接口 — 后端从 soraModelConfigs 自动聚合) */ +export interface SoraModelFamily { + id: string // 家族 ID,如 "sora2" + name: string // 显示名,如 "Sora 2" + type: string // "video" | "image" + orientations: string[] // ["landscape", "portrait"] 或 ["landscape", "portrait", "square"] + durations?: number[] // [10, 15, 25](仅视频模型) +} + +type LooseRecord = Record + +function asRecord(value: unknown): LooseRecord | null { + return value !== null && typeof value === 'object' ? value as LooseRecord : null +} + +function asArray(value: unknown): T[] { + return Array.isArray(value) ? value as T[] : [] +} + +function asPositiveInt(value: unknown): number | null { + const n = Number(value) + if (!Number.isFinite(n) || n <= 0) return null + return Math.round(n) +} + +function dedupeStrings(values: string[]): string[] { + return Array.from(new Set(values)) +} + +function extractOrientationFromModelID(modelID: string): string | null { + const m = modelID.match(/-(landscape|portrait|square)(?:-\d+s)?$/i) + return m ? m[1].toLowerCase() : null +} + +function extractDurationFromModelID(modelID: string): number | null { + const m = modelID.match(/-(\d+)s$/i) + return m ? asPositiveInt(m[1]) : null +} + +function normalizeLegacyFamilies(candidates: unknown[]): SoraModelFamily[] { + const familyMap = new Map() + + for (const item of candidates) { + const model = asRecord(item) + if (!model || typeof model.id !== 'string' || model.id.trim() === '') continue + + const rawID = model.id.trim() + const type = model.type === 'image' ? 'image' : 'video' + const name = typeof model.name === 'string' && model.name.trim() ? model.name.trim() : rawID + const baseID = rawID.replace(/-(landscape|portrait|square)(?:-\d+s)?$/i, '') + const orientation = + typeof model.orientation === 'string' && model.orientation + ? model.orientation.toLowerCase() + : extractOrientationFromModelID(rawID) + const duration = asPositiveInt(model.duration) ?? extractDurationFromModelID(rawID) + const familyKey = baseID || rawID + + const family = familyMap.get(familyKey) ?? { + id: familyKey, + name, + type, + orientations: [], + durations: [] + } + + if (orientation) { + family.orientations.push(orientation) + } + if (type === 'video' && duration) { + family.durations = family.durations || [] + family.durations.push(duration) + } + + familyMap.set(familyKey, family) + } + + return Array.from(familyMap.values()) + .map((family) => ({ + ...family, + orientations: + family.orientations.length > 0 + ? dedupeStrings(family.orientations) + : (family.type === 'image' ? ['square'] : ['landscape']), + durations: + family.type === 'video' + ? Array.from(new Set((family.durations || []).filter((d): d is number => Number.isFinite(d)))).sort((a, b) => a - b) + : [] + })) + .filter((family) => family.id !== '') +} + +function normalizeModelFamilyRecord(item: unknown): SoraModelFamily | null { + const model = asRecord(item) + if (!model || typeof model.id !== 'string' || model.id.trim() === '') return null + // 仅把明确的“家族结构”识别为 family;老结构(单模型)走 legacy 聚合逻辑。 + if (!Array.isArray(model.orientations) && !Array.isArray(model.durations)) return null + + const orientations = asArray(model.orientations).filter((o): o is string => typeof o === 'string' && o.length > 0) + const durations = asArray(model.durations) + .map(asPositiveInt) + .filter((d): d is number => d !== null) + + return { + id: model.id.trim(), + name: typeof model.name === 'string' && model.name.trim() ? model.name.trim() : model.id.trim(), + type: model.type === 'image' ? 'image' : 'video', + orientations: dedupeStrings(orientations), + durations: Array.from(new Set(durations)).sort((a, b) => a - b) + } +} + +function extractCandidateArray(payload: unknown): unknown[] { + if (Array.isArray(payload)) return payload + const record = asRecord(payload) + if (!record) return [] + + const keys: Array = ['data', 'items', 'models', 'families'] + for (const key of keys) { + if (Array.isArray(record[key])) { + return record[key] as unknown[] + } + } + return [] +} + +export function normalizeModelFamiliesResponse(payload: unknown): SoraModelFamily[] { + const candidates = extractCandidateArray(payload) + if (candidates.length === 0) return [] + + const normalized = candidates + .map(normalizeModelFamilyRecord) + .filter((item): item is SoraModelFamily => item !== null) + + if (normalized.length > 0) return normalized + return normalizeLegacyFamilies(candidates) +} + +export function normalizeGenerationListResponse(payload: unknown): GenerationListResponse { + const record = asRecord(payload) + if (!record) { + return { data: [], total: 0, page: 1 } + } + + const data = Array.isArray(record.data) + ? (record.data as SoraGeneration[]) + : Array.isArray(record.items) + ? (record.items as SoraGeneration[]) + : [] + + const total = Number(record.total) + const page = Number(record.page) + + return { + data, + total: Number.isFinite(total) ? total : data.length, + page: Number.isFinite(page) && page > 0 ? page : 1 + } +} + +// ==================== API 方法 ==================== + +/** 异步生成 — 创建 pending 记录后立即返回 */ +export async function generate(req: GenerateRequest): Promise { + const { data } = await apiClient.post('/sora/generate', req) + return data +} + +/** 查询生成记录列表 */ +export async function listGenerations(params?: { + page?: number + page_size?: number + status?: string + storage_type?: string + media_type?: string +}): Promise { + const { data } = await apiClient.get('/sora/generations', { params }) + return normalizeGenerationListResponse(data) +} + +/** 查询生成记录详情 */ +export async function getGeneration(id: number): Promise { + const { data } = await apiClient.get(`/sora/generations/${id}`) + return data +} + +/** 删除生成记录 */ +export async function deleteGeneration(id: number): Promise<{ message: string }> { + const { data } = await apiClient.delete<{ message: string }>(`/sora/generations/${id}`) + return data +} + +/** 取消生成任务 */ +export async function cancelGeneration(id: number): Promise<{ message: string }> { + const { data } = await apiClient.post<{ message: string }>(`/sora/generations/${id}/cancel`) + return data +} + +/** 手动保存到 S3 */ +export async function saveToStorage( + id: number +): Promise<{ message: string; object_key: string; object_keys?: string[] }> { + const { data } = await apiClient.post<{ message: string; object_key: string; object_keys?: string[] }>( + `/sora/generations/${id}/save` + ) + return data +} + +/** 查询配额信息 */ +export async function getQuota(): Promise { + const { data } = await apiClient.get('/sora/quota') + return data +} + +/** 获取可用模型家族列表 */ +export async function getModels(): Promise { + const { data } = await apiClient.get('/sora/models') + return normalizeModelFamiliesResponse(data) +} + +/** 获取存储状态 */ +export async function getStorageStatus(): Promise { + const { data } = await apiClient.get('/sora/storage-status') + return data +} + +const soraAPI = { + generate, + listGenerations, + getGeneration, + deleteGeneration, + cancelGeneration, + saveToStorage, + getQuota, + getModels, + getStorageStatus +} + +export default soraAPI diff --git a/frontend/src/components/account/AccountTodayStatsCell.vue b/frontend/src/components/account/AccountTodayStatsCell.vue index a920f314..a422d1f0 100644 --- a/frontend/src/components/account/AccountTodayStatsCell.vue +++ b/frontend/src/components/account/AccountTodayStatsCell.vue @@ -1,26 +1,26 @@ diff --git a/frontend/src/components/account/AccountUsageCell.vue b/frontend/src/components/account/AccountUsageCell.vue index 12fab57d..859bd7c9 100644 --- a/frontend/src/components/account/AccountUsageCell.vue +++ b/frontend/src/components/account/AccountUsageCell.vue @@ -398,7 +398,9 @@ const antigravity3ProUsageFromAPI = computed(() => const antigravity3FlashUsageFromAPI = computed(() => getAntigravityUsageFromAPI(['gemini-3-flash'])) // Gemini Image from API -const antigravity3ImageUsageFromAPI = computed(() => getAntigravityUsageFromAPI(['gemini-3.1-flash-image'])) +const antigravity3ImageUsageFromAPI = computed(() => + getAntigravityUsageFromAPI(['gemini-3.1-flash-image', 'gemini-3-pro-image']) +) // Claude from API (all Claude model variants) const antigravityClaudeUsageFromAPI = computed(() => diff --git a/frontend/src/components/account/BulkEditAccountModal.vue b/frontend/src/components/account/BulkEditAccountModal.vue index a9632a92..4a19e35e 100644 --- a/frontend/src/components/account/BulkEditAccountModal.vue +++ b/frontend/src/components/account/BulkEditAccountModal.vue @@ -688,7 +688,7 @@ const isMixedPlatform = computed(() => props.selectedPlatforms.length > 1) const platformModelPrefix: Record = { anthropic: ['claude-'], - antigravity: ['claude-'], + antigravity: ['claude-', 'gemini-', 'gpt-oss-', 'tab_'], openai: ['gpt-'], gemini: ['gemini-'], sora: [] @@ -766,6 +766,8 @@ const allModels = [ { value: 'gemini-2.0-flash', label: 'Gemini 2.0 Flash' }, { value: 'gemini-2.5-flash', label: 'Gemini 2.5 Flash' }, { value: 'gemini-2.5-pro', label: 'Gemini 2.5 Pro' }, + { value: 'gemini-3.1-flash-image', label: 'Gemini 3.1 Flash Image' }, + { value: 'gemini-3-pro-image', label: 'Gemini 3 Pro Image (Legacy)' }, { value: 'gemini-3-flash-preview', label: 'Gemini 3 Flash Preview' }, { value: 'gemini-3-pro-preview', label: 'Gemini 3 Pro Preview' } ] @@ -844,6 +846,18 @@ const presetMappings = [ to: 'claude-sonnet-4-5-20250929', color: 'bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-400' }, + { + label: 'Gemini 3.1 Image', + from: 'gemini-3.1-flash-image', + to: 'gemini-3.1-flash-image', + color: 'bg-sky-100 text-sky-700 hover:bg-sky-200 dark:bg-sky-900/30 dark:text-sky-400' + }, + { + label: 'G3 Image→3.1', + from: 'gemini-3-pro-image', + to: 'gemini-3.1-flash-image', + color: 'bg-sky-100 text-sky-700 hover:bg-sky-200 dark:bg-sky-900/30 dark:text-sky-400' + }, { label: 'GPT-5.3 Codex', from: 'gpt-5.3-codex', diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index 72d74318..a0ad4eb4 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -175,13 +175,13 @@
-
+
+
@@ -879,14 +904,14 @@ type="text" class="input" :placeholder=" - form.platform === 'openai' + form.platform === 'openai' || form.platform === 'sora' ? 'https://api.openai.com' : form.platform === 'gemini' ? 'https://generativelanguage.googleapis.com' : 'https://api.anthropic.com' " /> -

{{ baseUrlHint }}

+

{{ form.platform === 'sora' ? t('admin.accounts.soraUpstreamBaseUrlHint') : baseUrlHint }}

@@ -1669,6 +1694,27 @@
+ +
+
+
+ +

+ {{ t('admin.accounts.openai.wsModeDesc') }} +

+

+ {{ t('admin.accounts.openai.wsModeConcurrencyHint') }} +

+
+
+ +
+
+
+
('5m') // OpenAI 自动透传开关(OAuth/API Key) const openaiPassthroughEnabled = ref(false) +const openaiOAuthResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF) +const openaiAPIKeyResponsesWebSocketV2Mode = ref(OPENAI_WS_MODE_OFF) const codexCLIOnlyEnabled = ref(false) const anthropicPassthroughEnabled = ref(false) +const openAIWSModeOptions = computed(() => [ + { value: OPENAI_WS_MODE_OFF, label: t('admin.accounts.openai.wsModeOff') }, + { value: OPENAI_WS_MODE_SHARED, label: t('admin.accounts.openai.wsModeShared') }, + { value: OPENAI_WS_MODE_DEDICATED, label: t('admin.accounts.openai.wsModeDedicated') } +]) +const openaiResponsesWebSocketV2Mode = computed({ + get: () => { + if (props.account?.type === 'apikey') { + return openaiAPIKeyResponsesWebSocketV2Mode.value + } + return openaiOAuthResponsesWebSocketV2Mode.value + }, + set: (mode: OpenAIWSMode) => { + if (props.account?.type === 'apikey') { + openaiAPIKeyResponsesWebSocketV2Mode.value = mode + return + } + openaiOAuthResponsesWebSocketV2Mode.value = mode + } +}) const isOpenAIModelRestrictionDisabled = computed(() => props.account?.platform === 'openai' && openaiPassthroughEnabled.value ) @@ -1269,7 +1320,7 @@ const tempUnschedPresets = computed(() => [ // Computed: default base URL based on platform const defaultBaseUrl = computed(() => { - if (props.account?.platform === 'openai') return 'https://api.openai.com' + if (props.account?.platform === 'openai' || props.account?.platform === 'sora') return 'https://api.openai.com' if (props.account?.platform === 'gemini') return 'https://generativelanguage.googleapis.com' return 'https://api.anthropic.com' }) @@ -1336,10 +1387,24 @@ watch( // Load OpenAI passthrough toggle (OpenAI OAuth/API Key) openaiPassthroughEnabled.value = false + openaiOAuthResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF + openaiAPIKeyResponsesWebSocketV2Mode.value = OPENAI_WS_MODE_OFF codexCLIOnlyEnabled.value = false anthropicPassthroughEnabled.value = false if (newAccount.platform === 'openai' && (newAccount.type === 'oauth' || newAccount.type === 'apikey')) { openaiPassthroughEnabled.value = extra?.openai_passthrough === true || extra?.openai_oauth_passthrough === true + openaiOAuthResponsesWebSocketV2Mode.value = resolveOpenAIWSModeFromExtra(extra, { + modeKey: 'openai_oauth_responses_websockets_v2_mode', + enabledKey: 'openai_oauth_responses_websockets_v2_enabled', + fallbackEnabledKeys: ['responses_websockets_v2_enabled', 'openai_ws_enabled'], + defaultMode: OPENAI_WS_MODE_OFF + }) + openaiAPIKeyResponsesWebSocketV2Mode.value = resolveOpenAIWSModeFromExtra(extra, { + modeKey: 'openai_apikey_responses_websockets_v2_mode', + enabledKey: 'openai_apikey_responses_websockets_v2_enabled', + fallbackEnabledKeys: ['responses_websockets_v2_enabled', 'openai_ws_enabled'], + defaultMode: OPENAI_WS_MODE_OFF + }) if (newAccount.type === 'oauth') { codexCLIOnlyEnabled.value = extra?.codex_cli_only === true } @@ -1389,7 +1454,7 @@ watch( if (newAccount.type === 'apikey' && newAccount.credentials) { const credentials = newAccount.credentials as Record const platformDefaultUrl = - newAccount.platform === 'openai' + newAccount.platform === 'openai' || newAccount.platform === 'sora' ? 'https://api.openai.com' : newAccount.platform === 'gemini' ? 'https://generativelanguage.googleapis.com' @@ -1435,7 +1500,7 @@ watch( editBaseUrl.value = (credentials.base_url as string) || '' } else { const platformDefaultUrl = - newAccount.platform === 'openai' + newAccount.platform === 'openai' || newAccount.platform === 'sora' ? 'https://api.openai.com' : newAccount.platform === 'gemini' ? 'https://generativelanguage.googleapis.com' @@ -2021,6 +2086,12 @@ const handleSubmit = async () => { const currentExtra = (props.account.extra as Record) || {} const newExtra: Record = { ...currentExtra } const hadCodexCLIOnlyEnabled = currentExtra.codex_cli_only === true + newExtra.openai_oauth_responses_websockets_v2_mode = openaiOAuthResponsesWebSocketV2Mode.value + newExtra.openai_apikey_responses_websockets_v2_mode = openaiAPIKeyResponsesWebSocketV2Mode.value + newExtra.openai_oauth_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiOAuthResponsesWebSocketV2Mode.value) + newExtra.openai_apikey_responses_websockets_v2_enabled = isOpenAIWSModeEnabled(openaiAPIKeyResponsesWebSocketV2Mode.value) + delete newExtra.responses_websockets_v2_enabled + delete newExtra.openai_ws_enabled if (openaiPassthroughEnabled.value) { newExtra.openai_passthrough = true } else { diff --git a/frontend/src/components/account/OAuthAuthorizationFlow.vue b/frontend/src/components/account/OAuthAuthorizationFlow.vue index 94d417dc..cc74f8ce 100644 --- a/frontend/src/components/account/OAuthAuthorizationFlow.vue +++ b/frontend/src/components/account/OAuthAuthorizationFlow.vue @@ -171,7 +171,7 @@ class="mb-2 flex items-center gap-2 text-sm font-semibold text-gray-700 dark:text-gray-300" > - Session Token + {{ t(getOAuthKey('sessionTokenRawLabel')) }} +

+ {{ t(getOAuthKey('sessionTokenRawHint')) }} +

+
+ + +
+

+ {{ soraSessionUrl }} +

+

+ {{ t(getOAuthKey('sessionUrlHint')) }} +

+
+
+ + +

+ {{ t(getOAuthKey('parsedSessionTokensEmpty')) }} +

+
+ +
+ + +
+
+
{ .filter((rt) => rt).length }) +const parsedSoraRawTokens = computed(() => parseSoraRawTokens(sessionTokenInput.value)) + const parsedSessionTokenCount = computed(() => { - return sessionTokenInput.value - .split('\n') - .map((st) => st.trim()) - .filter((st) => st).length + return parsedSoraRawTokens.value.sessionTokens.length }) +const parsedSessionTokensText = computed(() => { + return parsedSoraRawTokens.value.sessionTokens.join('\n') +}) + +const parsedAccessTokenFromSessionInputCount = computed(() => { + return parsedSoraRawTokens.value.accessTokens.length +}) + +const parsedAccessTokensText = computed(() => { + return parsedSoraRawTokens.value.accessTokens.join('\n') +}) + +const soraSessionUrl = 'https://sora.chatgpt.com/api/auth/session' + const parsedAccessTokenCount = computed(() => { return accessTokenInput.value .split('\n') @@ -863,11 +950,19 @@ const handleValidateRefreshToken = () => { } const handleValidateSessionToken = () => { - if (sessionTokenInput.value.trim()) { - emit('validate-session-token', sessionTokenInput.value.trim()) + if (parsedSessionTokenCount.value > 0) { + emit('validate-session-token', parsedSessionTokensText.value) } } +const handleOpenSoraSessionUrl = () => { + window.open(soraSessionUrl, '_blank', 'noopener,noreferrer') +} + +const handleCopySoraSessionUrl = () => { + copyToClipboard(soraSessionUrl, 'URL copied to clipboard') +} + const handleImportAccessToken = () => { if (accessTokenInput.value.trim()) { emit('import-access-token', accessTokenInput.value.trim()) diff --git a/frontend/src/components/account/__tests__/AccountUsageCell.spec.ts b/frontend/src/components/account/__tests__/AccountUsageCell.spec.ts new file mode 100644 index 00000000..0b61b3bd --- /dev/null +++ b/frontend/src/components/account/__tests__/AccountUsageCell.spec.ts @@ -0,0 +1,70 @@ +import { describe, expect, it, vi, beforeEach } from 'vitest' +import { flushPromises, mount } from '@vue/test-utils' +import AccountUsageCell from '../AccountUsageCell.vue' + +const { getUsage } = vi.hoisted(() => ({ + getUsage: vi.fn() +})) + +vi.mock('@/api/admin', () => ({ + adminAPI: { + accounts: { + getUsage + } + } +})) + +vi.mock('vue-i18n', async () => { + const actual = await vi.importActual('vue-i18n') + return { + ...actual, + useI18n: () => ({ + t: (key: string) => key + }) + } +}) + +describe('AccountUsageCell', () => { + beforeEach(() => { + getUsage.mockReset() + }) + + it('Antigravity 图片用量会聚合新旧 image 模型', async () => { + getUsage.mockResolvedValue({ + antigravity_quota: { + 'gemini-3.1-flash-image': { + utilization: 20, + reset_time: '2026-03-01T10:00:00Z' + }, + 'gemini-3-pro-image': { + utilization: 70, + reset_time: '2026-03-01T09:00:00Z' + } + } + }) + + const wrapper = mount(AccountUsageCell, { + props: { + account: { + id: 1001, + platform: 'antigravity', + type: 'oauth', + extra: {} + } as any + }, + global: { + stubs: { + UsageProgressBar: { + props: ['label', 'utilization', 'resetsAt', 'color'], + template: '
{{ label }}|{{ utilization }}|{{ resetsAt }}
' + }, + AccountQuotaInfo: true + } + } + }) + + await flushPromises() + + expect(wrapper.text()).toContain('admin.accounts.usageWindow.gemini3Image|70|2026-03-01T09:00:00Z') + }) +}) diff --git a/frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts b/frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts new file mode 100644 index 00000000..28ac61ec --- /dev/null +++ b/frontend/src/components/account/__tests__/BulkEditAccountModal.spec.ts @@ -0,0 +1,72 @@ +import { describe, expect, it, vi } from 'vitest' +import { mount } from '@vue/test-utils' +import BulkEditAccountModal from '../BulkEditAccountModal.vue' + +vi.mock('@/stores/app', () => ({ + useAppStore: () => ({ + showError: vi.fn(), + showSuccess: vi.fn(), + showInfo: vi.fn() + }) +})) + +vi.mock('@/api/admin', () => ({ + adminAPI: { + accounts: { + bulkEdit: vi.fn() + } + } +})) + +vi.mock('vue-i18n', async () => { + const actual = await vi.importActual('vue-i18n') + return { + ...actual, + useI18n: () => ({ + t: (key: string) => key + }) + } +}) + +function mountModal() { + return mount(BulkEditAccountModal, { + props: { + show: true, + accountIds: [1, 2], + selectedPlatforms: ['antigravity'], + proxies: [], + groups: [] + } as any, + global: { + stubs: { + BaseDialog: { template: '
' }, + Select: true, + ProxySelector: true, + GroupSelector: true, + Icon: true + } + } + }) +} + +describe('BulkEditAccountModal', () => { + it('antigravity 白名单包含 Gemini 图片模型且过滤掉普通 GPT 模型', () => { + const wrapper = mountModal() + + expect(wrapper.text()).toContain('Gemini 3.1 Flash Image') + expect(wrapper.text()).toContain('Gemini 3 Pro Image (Legacy)') + expect(wrapper.text()).not.toContain('GPT-5.3 Codex') + }) + + it('antigravity 映射预设包含图片映射并过滤 OpenAI 预设', async () => { + const wrapper = mountModal() + + const mappingTab = wrapper.findAll('button').find((btn) => btn.text().includes('admin.accounts.modelMapping')) + expect(mappingTab).toBeTruthy() + await mappingTab!.trigger('click') + + expect(wrapper.text()).toContain('Gemini 3.1 Image') + expect(wrapper.text()).toContain('G3 Image→3.1') + expect(wrapper.text()).not.toContain('GPT-5.3 Codex') + }) +}) diff --git a/frontend/src/components/admin/usage/UsageCleanupDialog.vue b/frontend/src/components/admin/usage/UsageCleanupDialog.vue index d5e81e72..3218be30 100644 --- a/frontend/src/components/admin/usage/UsageCleanupDialog.vue +++ b/frontend/src/components/admin/usage/UsageCleanupDialog.vue @@ -125,6 +125,7 @@ import Pagination from '@/components/common/Pagination.vue' import UsageFilters from '@/components/admin/usage/UsageFilters.vue' import { adminUsageAPI } from '@/api/admin/usage' import type { AdminUsageQueryParams, UsageCleanupTask, CreateUsageCleanupTaskRequest } from '@/api/admin/usage' +import { requestTypeToLegacyStream } from '@/utils/usageRequestType' interface Props { show: boolean @@ -310,7 +311,13 @@ const buildPayload = (): CreateUsageCleanupTaskRequest | null => { if (localFilters.value.model) { payload.model = localFilters.value.model } - if (localFilters.value.stream !== null && localFilters.value.stream !== undefined) { + if (localFilters.value.request_type) { + payload.request_type = localFilters.value.request_type + const legacyStream = requestTypeToLegacyStream(localFilters.value.request_type) + if (legacyStream !== null && legacyStream !== undefined) { + payload.stream = legacyStream + } + } else if (localFilters.value.stream !== null && localFilters.value.stream !== undefined) { payload.stream = localFilters.value.stream } if (localFilters.value.billing_type !== null && localFilters.value.billing_type !== undefined) { diff --git a/frontend/src/components/admin/usage/UsageFilters.vue b/frontend/src/components/admin/usage/UsageFilters.vue index d305dc18..a632a76e 100644 --- a/frontend/src/components/admin/usage/UsageFilters.vue +++ b/frontend/src/components/admin/usage/UsageFilters.vue @@ -121,10 +121,10 @@
- +
-
@@ -233,10 +233,11 @@ let accountSearchTimeout: ReturnType | null = null const modelOptions = ref([{ value: null, label: t('admin.usage.allModels') }]) const groupOptions = ref([{ value: null, label: t('admin.usage.allGroups') }]) -const streamTypeOptions = ref([ +const requestTypeOptions = ref([ { value: null, label: t('admin.usage.allTypes') }, - { value: true, label: t('usage.stream') }, - { value: false, label: t('usage.sync') } + { value: 'ws_v2', label: t('usage.ws') }, + { value: 'stream', label: t('usage.stream') }, + { value: 'sync', label: t('usage.sync') } ]) const billingTypeOptions = ref([ diff --git a/frontend/src/components/admin/usage/UsageTable.vue b/frontend/src/components/admin/usage/UsageTable.vue index c8acf6b8..14c434d6 100644 --- a/frontend/src/components/admin/usage/UsageTable.vue +++ b/frontend/src/components/admin/usage/UsageTable.vue @@ -35,8 +35,8 @@ @@ -271,6 +271,7 @@ import { ref } from 'vue' import { useI18n } from 'vue-i18n' import { formatDateTime, formatReasoningEffort } from '@/utils/format' +import { resolveUsageRequestType } from '@/utils/usageRequestType' import DataTable from '@/components/common/DataTable.vue' import EmptyState from '@/components/common/EmptyState.vue' import Icon from '@/components/icons/Icon.vue' @@ -289,6 +290,21 @@ const tokenTooltipVisible = ref(false) const tokenTooltipPosition = ref({ x: 0, y: 0 }) const tokenTooltipData = ref(null) +const getRequestTypeLabel = (row: AdminUsageLog): string => { + const requestType = resolveUsageRequestType(row) + if (requestType === 'ws_v2') return t('usage.ws') + if (requestType === 'stream') return t('usage.stream') + if (requestType === 'sync') return t('usage.sync') + return t('usage.unknown') +} + +const getRequestTypeBadgeClass = (row: AdminUsageLog): string => { + const requestType = resolveUsageRequestType(row) + if (requestType === 'ws_v2') return 'bg-violet-100 text-violet-800 dark:bg-violet-900 dark:text-violet-200' + if (requestType === 'stream') return 'bg-blue-100 text-blue-800 dark:bg-blue-900 dark:text-blue-200' + if (requestType === 'sync') return 'bg-gray-100 text-gray-800 dark:bg-gray-700 dark:text-gray-200' + return 'bg-amber-100 text-amber-800 dark:bg-amber-900 dark:text-amber-200' +} const formatCacheTokens = (tokens: number): string => { if (tokens >= 1000000) return `${(tokens / 1000000).toFixed(1)}M` if (tokens >= 1000) return `${(tokens / 1000).toFixed(1)}K` diff --git a/frontend/src/components/admin/user/UserEditModal.vue b/frontend/src/components/admin/user/UserEditModal.vue index 70ebd2d3..e537dbf6 100644 --- a/frontend/src/components/admin/user/UserEditModal.vue +++ b/frontend/src/components/admin/user/UserEditModal.vue @@ -37,6 +37,14 @@ +
+ +
+ + GB +
+

{{ t('admin.users.soraStorageQuotaHint') }}

+