diff --git a/backend/ent/group.go b/backend/ent/group.go index 2dce468e..f10b50c3 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -53,21 +53,21 @@ type Group struct { ImagePrice2k *float64 `json:"image_price_2k,omitempty"` // ImagePrice4k holds the value of the "image_price_4k" field. ImagePrice4k *float64 `json:"image_price_4k,omitempty"` - // allow Claude Code client only + // 是否仅允许 Claude Code 客户端 ClaudeCodeOnly bool `json:"claude_code_only,omitempty"` - // fallback group for non-Claude-Code requests + // 非 Claude Code 请求降级使用的分组 ID FallbackGroupID *int64 `json:"fallback_group_id,omitempty"` - // fallback group for invalid request + // 无效请求兜底使用的分组 ID FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"` - // model routing config: pattern -> account ids + // 模型路由配置:模型模式 -> 优先账号ID列表 ModelRouting map[string][]int64 `json:"model_routing,omitempty"` - // whether model routing is enabled + // 是否启用模型路由配置 ModelRoutingEnabled bool `json:"model_routing_enabled,omitempty"` - // whether MCP XML prompt injection is enabled + // 是否注入 MCP XML 调用协议提示词(仅 antigravity 平台) McpXMLInject bool `json:"mcp_xml_inject,omitempty"` - // supported model scopes: claude, gemini_text, gemini_image + // 支持的模型系列:claude, gemini_text, gemini_image SupportedModelScopes []string `json:"supported_model_scopes,omitempty"` - // group display order, lower comes first + // 分组显示排序,数值越小越靠前 SortOrder int `json:"sort_order,omitempty"` // 是否允许 /v1/messages 调度到此 OpenAI 分组 AllowMessagesDispatch bool `json:"allow_messages_dispatch,omitempty"` diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index 00420f9f..d78a6898 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -33,6 +33,8 @@ func (Group) Mixin() []ent.Mixin { func (Group) Fields() []ent.Field { return []ent.Field{ + // 唯一约束通过部分索引实现(WHERE deleted_at IS NULL),支持软删除后重用 + // 见迁移文件 016_soft_delete_partial_unique_indexes.sql field.String("name"). MaxLen(100). NotEmpty(), @@ -49,6 +51,7 @@ func (Group) Fields() []ent.Field { MaxLen(20). Default(domain.StatusActive), + // Subscription-related fields (added by migration 003) field.String("platform"). MaxLen(50). Default(domain.PlatformAnthropic), @@ -70,6 +73,7 @@ func (Group) Fields() []ent.Field { field.Int("default_validity_days"). Default(30), + // 图片生成计费配置(antigravity 和 gemini 平台使用) field.Float("image_price_1k"). Optional(). Nillable(). @@ -86,36 +90,42 @@ func (Group) Fields() []ent.Field { // Claude Code 客户端限制 (added by migration 029) field.Bool("claude_code_only"). Default(false). - Comment("allow Claude Code client only"), + Comment("是否仅允许 Claude Code 客户端"), field.Int64("fallback_group_id"). Optional(). Nillable(). - Comment("fallback group for non-Claude-Code requests"), + Comment("非 Claude Code 请求降级使用的分组 ID"), field.Int64("fallback_group_id_on_invalid_request"). Optional(). Nillable(). - Comment("fallback group for invalid request"), + Comment("无效请求兜底使用的分组 ID"), + // 模型路由配置 (added by migration 040) field.JSON("model_routing", map[string][]int64{}). Optional(). SchemaType(map[string]string{dialect.Postgres: "jsonb"}). - Comment("model routing config: pattern -> account ids"), + Comment("模型路由配置:模型模式 -> 优先账号ID列表"), + + // 模型路由开关 (added by migration 041) field.Bool("model_routing_enabled"). Default(false). - Comment("whether model routing is enabled"), + Comment("是否启用模型路由配置"), + // MCP XML 协议注入开关 (added by migration 042) field.Bool("mcp_xml_inject"). Default(true). - Comment("whether MCP XML prompt injection is enabled"), + Comment("是否注入 MCP XML 调用协议提示词(仅 antigravity 平台)"), + // 支持的模型系列 (added by migration 046) field.JSON("supported_model_scopes", []string{}). Default([]string{"claude", "gemini_text", "gemini_image"}). SchemaType(map[string]string{dialect.Postgres: "jsonb"}). - Comment("supported model scopes: claude, gemini_text, gemini_image"), + Comment("支持的模型系列:claude, gemini_text, gemini_image"), + // 分组排序 (added by migration 052) field.Int("sort_order"). Default(0). - Comment("group display order, lower comes first"), + Comment("分组显示排序,数值越小越靠前"), // OpenAI Messages 调度配置 (added by migration 069) field.Bool("allow_messages_dispatch"). @@ -150,11 +160,14 @@ func (Group) Edges() []ent.Edge { edge.From("allowed_users", User.Type). Ref("allowed_groups"). Through("user_allowed_groups", UserAllowedGroup.Type), + // 注意:fallback_group_id 直接作为字段使用,不定义 edge + // 这样允许多个分组指向同一个降级分组(M2O 关系) } } func (Group) Indexes() []ent.Index { return []ent.Index{ + // name 字段已在 Fields() 中声明 Unique(),无需重复索引 index.Fields("status"), index.Fields("platform"), index.Fields("subscription_type"), diff --git a/backend/go.sum b/backend/go.sum index f1c864f5..0f366ee1 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -162,6 +162,8 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17 github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= +github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= +github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4= @@ -181,6 +183,8 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4= github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y= github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI= github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -216,6 +220,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= +github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= @@ -249,6 +255,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= +github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= @@ -278,6 +286,8 @@ github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEv github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= @@ -310,6 +320,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= +github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= diff --git a/backend/internal/handler/admin/account_handler_mixed_channel_test.go b/backend/internal/handler/admin/account_handler_mixed_channel_test.go index 5b81db2a..24ec5bcf 100644 --- a/backend/internal/handler/admin/account_handler_mixed_channel_test.go +++ b/backend/internal/handler/admin/account_handler_mixed_channel_test.go @@ -111,7 +111,7 @@ func TestAccountHandlerCreateMixedChannelConflictSimplifiedResponse(t *testing.T var resp map[string]any require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) require.Equal(t, "mixed_channel_warning", resp["error"]) - require.Contains(t, resp["message"], "claude-max") + require.Contains(t, resp["message"], "mixed_channel_warning") _, hasDetails := resp["details"] _, hasRequireConfirmation := resp["require_confirmation"] require.False(t, hasDetails) @@ -140,7 +140,7 @@ func TestAccountHandlerUpdateMixedChannelConflictSimplifiedResponse(t *testing.T var resp map[string]any require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp)) require.Equal(t, "mixed_channel_warning", resp["error"]) - require.Contains(t, resp["message"], "claude-max") + require.Contains(t, resp["message"], "mixed_channel_warning") _, hasDetails := resp["details"] _, hasRequireConfirmation := resp["require_confirmation"] require.False(t, hasDetails) diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 88e0092a..09dc8251 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -235,9 +235,11 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { PaymentCancelRateLimitWindow: paymentCfg.CancelRateLimitWindow, PaymentCancelRateLimitUnit: paymentCfg.CancelRateLimitUnit, PaymentCancelRateLimitMode: paymentCfg.CancelRateLimitMode, - ChannelMonitorEnabled: settings.ChannelMonitorEnabled, - ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds, - AvailableChannelsEnabled: settings.AvailableChannelsEnabled, + + ChannelMonitorEnabled: settings.ChannelMonitorEnabled, + ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds, + + AvailableChannelsEnabled: settings.AvailableChannelsEnabled, } response.Success(c, systemSettingsResponseData(payload, authSourceDefaults)) } @@ -1477,9 +1479,11 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { PaymentCancelRateLimitWindow: updatedPaymentCfg.CancelRateLimitWindow, PaymentCancelRateLimitUnit: updatedPaymentCfg.CancelRateLimitUnit, PaymentCancelRateLimitMode: updatedPaymentCfg.CancelRateLimitMode, - ChannelMonitorEnabled: updatedSettings.ChannelMonitorEnabled, - ChannelMonitorDefaultIntervalSeconds: updatedSettings.ChannelMonitorDefaultIntervalSeconds, - AvailableChannelsEnabled: updatedSettings.AvailableChannelsEnabled, + + ChannelMonitorEnabled: updatedSettings.ChannelMonitorEnabled, + ChannelMonitorDefaultIntervalSeconds: updatedSettings.ChannelMonitorDefaultIntervalSeconds, + + AvailableChannelsEnabled: updatedSettings.AvailableChannelsEnabled, } response.Success(c, systemSettingsResponseData(payload, updatedAuthSourceDefaults)) } diff --git a/backend/internal/handler/dto/public_settings_injection_schema_test.go b/backend/internal/handler/dto/public_settings_injection_schema_test.go index 24853c7d..428fed3d 100644 --- a/backend/internal/handler/dto/public_settings_injection_schema_test.go +++ b/backend/internal/handler/dto/public_settings_injection_schema_test.go @@ -31,6 +31,8 @@ func TestPublicSettingsInjectionPayload_SchemaDoesNotDrift(t *testing.T) { dtoOnlyFields := map[string]string{ // sora_client_enabled is an upstream-only field the fork does not surface. "sora_client_enabled": "upstream-only field, not used on this fork", + // force_email_on_third_party_signup lives on the DTO but is not injected via SSR. + "force_email_on_third_party_signup": "auth-source default, not a feature flag", } var missing []string diff --git a/backend/internal/payment/provider/alipay.go b/backend/internal/payment/provider/alipay.go index 4a260295..1234b568 100644 --- a/backend/internal/payment/provider/alipay.go +++ b/backend/internal/payment/provider/alipay.go @@ -15,8 +15,9 @@ import ( // Alipay product codes. const ( - alipayProductCodeWapPay = "QUICK_WAP_WAY" - alipayProductCodePagePay = "FAST_INSTANT_TRADE_PAY" + alipayProductCodePreCreate = "FACE_TO_FACE_PAYMENT" + alipayProductCodeWapPay = "QUICK_WAP_WAY" + alipayProductCodePagePay = "FAST_INSTANT_TRADE_PAY" ) // Alipay response constants. @@ -30,6 +31,9 @@ var ( alipayTradeWapPay = func(client *alipay.Client, param alipay.TradeWapPay) (*url.URL, error) { return client.TradeWapPay(param) } + alipayTradePreCreate = func(ctx context.Context, client *alipay.Client, param alipay.TradePreCreate) (*alipay.TradePreCreateRsp, error) { + return client.TradePreCreate(ctx, param) + } alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) { return client.TradePagePay(param) } @@ -99,13 +103,13 @@ func (a *Alipay) MerchantIdentityMetadata() map[string]string { return map[string]string{"app_id": appID} } -// CreatePayment creates an Alipay payment using redirect-only flow: -// - Mobile (H5): alipay.trade.wap.pay — returns a URL the browser jumps to. -// - PC: alipay.trade.page.pay — returns a gateway URL the browser opens in a -// new window; Alipay's own page then shows login/QR. We intentionally do -// NOT encode the URL into a QR on the client (it isn't a scannable payload -// and would produce an invalid scan result). -func (a *Alipay) CreatePayment(_ context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) { +// CreatePayment creates an Alipay payment using the following routing: +// - Mobile (H5): alipay.trade.wap.pay — browser redirect into Alipay. +// - Desktop: prefer alipay.trade.precreate to get a scan payload directly. +// - Desktop fallback: if precreate is unavailable for the merchant, fall back +// to alipay.trade.page.pay and expose both pay_url and qr_code so the +// frontend can render a QR while still allowing direct page open. +func (a *Alipay) CreatePayment(ctx context.Context, req payment.CreatePaymentRequest) (*payment.CreatePaymentResponse, error) { client, err := a.getClient() if err != nil { return nil, err @@ -123,7 +127,7 @@ func (a *Alipay) CreatePayment(_ context.Context, req payment.CreatePaymentReque if req.IsMobile { return a.createWapTrade(client, req, notifyURL, returnURL) } - return a.createPagePayTrade(client, req, notifyURL, returnURL) + return a.createDesktopTrade(ctx, client, req, notifyURL, returnURL) } func (a *Alipay) createWapTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) { @@ -145,6 +149,48 @@ func (a *Alipay) createWapTrade(client *alipay.Client, req payment.CreatePayment }, nil } +func (a *Alipay) createDesktopTrade(ctx context.Context, client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) { + resp, precreateErr := a.createPrecreateTrade(ctx, client, req, notifyURL) + if precreateErr == nil { + return resp, nil + } + + resp, pagePayErr := a.createPagePayTrade(client, req, notifyURL, returnURL) + if pagePayErr == nil { + return resp, nil + } + + return nil, fmt.Errorf("alipay desktop payment failed: precreate=%v; pagepay=%w", precreateErr, pagePayErr) +} + +func (a *Alipay) createPrecreateTrade(ctx context.Context, client *alipay.Client, req payment.CreatePaymentRequest, notifyURL string) (*payment.CreatePaymentResponse, error) { + param := alipay.TradePreCreate{} + param.OutTradeNo = req.OrderID + param.TotalAmount = req.Amount + param.Subject = req.Subject + param.ProductCode = alipayProductCodePreCreate + param.NotifyURL = notifyURL + + rsp, err := alipayTradePreCreate(ctx, client, param) + if err != nil { + return nil, fmt.Errorf("alipay TradePreCreate: %w", err) + } + if rsp == nil { + return nil, fmt.Errorf("alipay TradePreCreate: empty response") + } + if rsp.IsFailure() { + return nil, fmt.Errorf("alipay TradePreCreate failed: %s", rsp.Error.Error()) + } + if strings.TrimSpace(rsp.QRCode) == "" { + return nil, fmt.Errorf("alipay TradePreCreate: empty qr_code") + } + + return &payment.CreatePaymentResponse{ + TradeNo: req.OrderID, + QRCode: rsp.QRCode, + }, nil +} + func (a *Alipay) createPagePayTrade(client *alipay.Client, req payment.CreatePaymentRequest, notifyURL, returnURL string) (*payment.CreatePaymentResponse, error) { param := alipay.TradePagePay{} param.OutTradeNo = req.OrderID @@ -161,6 +207,7 @@ func (a *Alipay) createPagePayTrade(client *alipay.Client, req payment.CreatePay return &payment.CreatePaymentResponse{ TradeNo: req.OrderID, PayURL: payURL.String(), + QRCode: payURL.String(), }, nil } @@ -192,7 +239,15 @@ func (a *Alipay) QueryOrder(ctx context.Context, tradeNo string) (*payment.Query amount, err := strconv.ParseFloat(result.TotalAmount, 64) if err != nil { - return nil, fmt.Errorf("alipay parse amount %q: %w", result.TotalAmount, err) + amount, err = parseAlipayAmount( + result.TotalAmount, + result.ReceiptAmount, + result.BuyerPayAmount, + result.InvoiceAmount, + ) + if err != nil { + return nil, fmt.Errorf("alipay parse amount: %w", err) + } } return &payment.QueryOrderResponse{ @@ -228,7 +283,14 @@ func (a *Alipay) VerifyNotification(ctx context.Context, rawBody string, _ map[s amount, err := strconv.ParseFloat(notification.TotalAmount, 64) if err != nil { - return nil, fmt.Errorf("alipay parse notification amount %q: %w", notification.TotalAmount, err) + amount, err = parseAlipayAmount( + notification.TotalAmount, + notification.ReceiptAmount, + notification.BuyerPayAmount, + ) + if err != nil { + return nil, fmt.Errorf("alipay parse notification amount: %w", err) + } } metadata := a.MerchantIdentityMetadata() @@ -306,6 +368,20 @@ func isTradeNotExist(err error) bool { return strings.Contains(err.Error(), alipayErrTradeNotExist) } +func parseAlipayAmount(values ...string) (float64, error) { + for _, raw := range values { + raw = strings.TrimSpace(raw) + if raw == "" { + continue + } + amount, err := strconv.ParseFloat(raw, 64) + if err == nil { + return amount, nil + } + } + return 0, fmt.Errorf("no valid amount field") +} + // Ensure interface compliance. var ( _ payment.Provider = (*Alipay)(nil) diff --git a/backend/internal/payment/provider/alipay_test.go b/backend/internal/payment/provider/alipay_test.go index 8b3ff8ce..fdc8eec1 100644 --- a/backend/internal/payment/provider/alipay_test.go +++ b/backend/internal/payment/provider/alipay_test.go @@ -3,6 +3,7 @@ package provider import ( + "context" "errors" "net/url" "strings" @@ -136,15 +137,22 @@ func TestNewAlipay(t *testing.T) { } func TestCreateTradeUsesPagePayForDesktop(t *testing.T) { + origPreCreate := alipayTradePreCreate origPagePay := alipayTradePagePay origWapPay := alipayTradeWapPay t.Cleanup(func() { + alipayTradePreCreate = origPreCreate alipayTradePagePay = origPagePay alipayTradeWapPay = origWapPay }) + preCreateCalls := 0 pagePayCalls := 0 wapPayCalls := 0 + alipayTradePreCreate = func(ctx context.Context, client *alipay.Client, param alipay.TradePreCreate) (*alipay.TradePreCreateRsp, error) { + preCreateCalls++ + return nil, errors.New("merchant does not have FACE_TO_FACE_PAYMENT") + } alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) { pagePayCalls++ if param.OutTradeNo != "sub2_100" { @@ -161,7 +169,7 @@ func TestCreateTradeUsesPagePayForDesktop(t *testing.T) { } provider := &Alipay{} - resp, err := provider.createPagePayTrade(&alipay.Client{}, payment.CreatePaymentRequest{ + resp, err := provider.createDesktopTrade(context.Background(), &alipay.Client{}, payment.CreatePaymentRequest{ OrderID: "sub2_100", Amount: "88.00", Subject: "Balance recharge", @@ -169,6 +177,9 @@ func TestCreateTradeUsesPagePayForDesktop(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } + if preCreateCalls != 1 { + t.Fatalf("precreate calls = %d, want 1", preCreateCalls) + } if pagePayCalls != 1 { t.Fatalf("page pay calls = %d, want 1", pagePayCalls) } @@ -178,6 +189,9 @@ func TestCreateTradeUsesPagePayForDesktop(t *testing.T) { if resp.PayURL == "" { t.Fatal("expected pay_url for desktop page pay") } + if resp.QRCode != resp.PayURL { + t.Fatalf("qr_code = %q, want same as pay_url %q", resp.QRCode, resp.PayURL) + } } func TestCreateTradeUsesWapPayForMobile(t *testing.T) { @@ -213,6 +227,54 @@ func TestCreateTradeUsesWapPayForMobile(t *testing.T) { } } +func TestCreateTradeUsesPrecreateForDesktopWhenAvailable(t *testing.T) { + origPreCreate := alipayTradePreCreate + origPagePay := alipayTradePagePay + t.Cleanup(func() { + alipayTradePreCreate = origPreCreate + alipayTradePagePay = origPagePay + }) + + preCreateCalls := 0 + pagePayCalls := 0 + alipayTradePreCreate = func(ctx context.Context, client *alipay.Client, param alipay.TradePreCreate) (*alipay.TradePreCreateRsp, error) { + preCreateCalls++ + if param.ProductCode != alipayProductCodePreCreate { + t.Fatalf("product_code = %q, want %q", param.ProductCode, alipayProductCodePreCreate) + } + return &alipay.TradePreCreateRsp{ + Error: alipay.Error{Code: alipay.CodeSuccess}, + QRCode: "https://qr.alipay.example.com/precreate-token", + }, nil + } + alipayTradePagePay = func(client *alipay.Client, param alipay.TradePagePay) (*url.URL, error) { + pagePayCalls++ + return url.Parse("https://openapi.alipay.com/gateway.do?page-pay") + } + + provider := &Alipay{} + resp, err := provider.createDesktopTrade(context.Background(), &alipay.Client{}, payment.CreatePaymentRequest{ + OrderID: "sub2_102", + Amount: "66.00", + Subject: "Balance recharge", + }, "https://merchant.example.com/api/v1/payment/webhook/alipay", "https://merchant.example.com/payment/result") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if preCreateCalls != 1 { + t.Fatalf("precreate calls = %d, want 1", preCreateCalls) + } + if pagePayCalls != 0 { + t.Fatalf("page pay calls = %d, want 0", pagePayCalls) + } + if resp.QRCode != "https://qr.alipay.example.com/precreate-token" { + t.Fatalf("qr_code = %q", resp.QRCode) + } + if resp.PayURL != "" { + t.Fatalf("pay_url = %q, want empty for precreate", resp.PayURL) + } +} + func TestAlipayMerchantIdentityMetadata(t *testing.T) { t.Parallel() @@ -227,3 +289,19 @@ func TestAlipayMerchantIdentityMetadata(t *testing.T) { t.Fatalf("app_id = %q, want %q", metadata["app_id"], "2021001234567890") } } + +func TestParseAlipayAmount(t *testing.T) { + t.Parallel() + + amount, err := parseAlipayAmount("", "88.00", "77.00") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if amount != 88 { + t.Fatalf("amount = %v, want 88", amount) + } + + if _, err := parseAlipayAmount("", "not-a-number"); err == nil { + t.Fatal("expected error when no valid amount field exists") + } +} diff --git a/backend/internal/pkg/ctxkey/ctxkey.go b/backend/internal/pkg/ctxkey/ctxkey.go index f6495de2..25782c55 100644 --- a/backend/internal/pkg/ctxkey/ctxkey.go +++ b/backend/internal/pkg/ctxkey/ctxkey.go @@ -55,9 +55,4 @@ const ( // ClaudeCodeVersion stores the extracted Claude Code version from User-Agent (e.g. "2.1.22") ClaudeCodeVersion Key = "ctx_claude_code_version" - - // IsSignatureRectifyRetry marks a retry request that was produced by the signature rectifier - // (strip or pool-replace). The harvester consults this flag to avoid ingesting signatures - // from retries, which would pollute the pool with signatures we ourselves injected. - IsSignatureRectifyRetry Key = "ctx_is_signature_rectify_retry" ) diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index fb9099bd..78f739ac 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -313,31 +313,6 @@ func (r *accountRepository) ListCRSAccountIDs(ctx context.Context) (map[string]i return result, nil } -// CountByTLSFingerprintProfile 按 TLS 指纹模板 ID 聚合绑定账号数。 -// 走 108_add_tls_fingerprint_profile_id_index.sql 的表达式索引。 -func (r *accountRepository) CountByTLSFingerprintProfile(ctx context.Context) (map[int64]int, error) { - rows, err := r.sql.QueryContext(ctx, ` - SELECT (extra->>'tls_fingerprint_profile_id')::bigint AS profile_id, COUNT(*) - FROM accounts - WHERE deleted_at IS NULL AND extra ? 'tls_fingerprint_profile_id' - GROUP BY profile_id`) - if err != nil { - return nil, err - } - defer func() { _ = rows.Close() }() - - counts := make(map[int64]int) - for rows.Next() { - var id int64 - var n int - if err := rows.Scan(&id, &n); err != nil { - return nil, err - } - counts[id] = n - } - return counts, rows.Err() -} - func (r *accountRepository) Update(ctx context.Context, account *service.Account) error { if account == nil { return nil diff --git a/backend/internal/repository/gateway_cache.go b/backend/internal/repository/gateway_cache.go index 280629d1..58291b66 100644 --- a/backend/internal/repository/gateway_cache.go +++ b/backend/internal/repository/gateway_cache.go @@ -9,9 +9,7 @@ import ( "github.com/redis/go-redis/v9" ) -const ( - stickySessionPrefix = "sticky_session:" -) +const stickySessionPrefix = "sticky_session:" type gatewayCache struct { rdb *redis.Client @@ -43,6 +41,12 @@ func (c *gatewayCache) RefreshSessionTTL(ctx context.Context, groupID int64, ses } // DeleteSessionAccountID 删除粘性会话与账号的绑定关系。 +// 当检测到绑定的账号不可用(如状态错误、禁用、不可调度等)时调用, +// 以便下次请求能够重新选择可用账号。 +// +// DeleteSessionAccountID removes the sticky session binding for the given session. +// Called when the bound account becomes unavailable (e.g., error status, disabled, +// or unschedulable), allowing subsequent requests to select a new available account. func (c *gatewayCache) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error { key := buildSessionKey(groupID, sessionHash) return c.rdb.Del(ctx, key).Err() diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 0c4c0041..f2fb87da 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -3080,7 +3080,7 @@ func (r *usageLogRepository) GetGroupStatsWithFilters(ctx context.Context, start query := ` SELECT COALESCE(ul.group_id, 0) as group_id, - COALESCE(g.name, '(无分组)') as group_name, + COALESCE(g.name, '') as group_name, COUNT(*) as requests, COALESCE(SUM(ul.input_tokens + ul.output_tokens + ul.cache_creation_tokens + ul.cache_read_tokens), 0) as total_tokens, COALESCE(SUM(ul.total_cost), 0) as cost, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index e2f5bd1b..d2b108f5 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -54,13 +54,8 @@ func TestAPIContracts(t *testing.T) { "username": "alice", "role": "user", "balance": 12.5, - "balance_notify_enabled": false, - "balance_notify_extra_emails": null, - "balance_notify_threshold": null, - "balance_notify_threshold_type": "", "concurrency": 5, "status": "active", - "total_recharged": 0, "allowed_groups": null, "created_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z", @@ -769,13 +764,10 @@ func TestAPIContracts(t *testing.T) { "payment_cancel_rate_limit_unit": "", "payment_cancel_rate_limit_window_mode": "", "balance_low_notify_enabled": false, + "account_quota_notify_enabled": false, "balance_low_notify_threshold": 0, "balance_low_notify_recharge_url": "", - "account_quota_notify_enabled": false, "account_quota_notify_emails": [], - "channel_monitor_enabled": true, - "channel_monitor_default_interval_seconds": 60, - "available_channels_enabled": false, "wechat_connect_enabled": false, "wechat_connect_app_id": "", "wechat_connect_app_secret_configured": false, @@ -983,10 +975,7 @@ func TestAPIContracts(t *testing.T) { "auth_source_default_wechat_subscriptions": [], "auth_source_default_wechat_grant_on_signup": false, "auth_source_default_wechat_grant_on_first_bind": false, - "force_email_on_third_party_signup": false, - "channel_monitor_enabled": true, - "channel_monitor_default_interval_seconds": 60, - "available_channels_enabled": false + "force_email_on_third_party_signup": false } }`, }, @@ -1457,10 +1446,6 @@ func (s *stubAccountRepo) FindByExtraField(ctx context.Context, key string, valu return nil, errors.New("not implemented") } -func (s *stubAccountRepo) CountByTLSFingerprintProfile(ctx context.Context) (map[int64]int, error) { - return nil, errors.New("not implemented") -} - func (s *stubAccountRepo) Update(ctx context.Context, account *service.Account) error { return errors.New("not implemented") } diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index d41fe7da..3189a729 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -30,10 +30,6 @@ type AccountRepository interface { GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) // FindByExtraField 根据 extra 字段中的键值对查找账号 FindByExtraField(ctx context.Context, key string, value any) ([]Account, error) - // CountByTLSFingerprintProfile 按 TLS 指纹模板 ID 聚合每个模板当前被多少账号绑定。 - // 返回 map[profile_id]count;未绑定任何账号的 profile 不出现在 map 中。 - // 查询走 108_add_tls_fingerprint_profile_id_index.sql 的表达式索引。 - CountByTLSFingerprintProfile(ctx context.Context) (map[int64]int, error) // ListCRSAccountIDs returns a map of crs_account_id -> local account ID // for all accounts that have been synced from CRS. ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) diff --git a/backend/internal/service/account_service_delete_test.go b/backend/internal/service/account_service_delete_test.go index f2e0876e..81169a02 100644 --- a/backend/internal/service/account_service_delete_test.go +++ b/backend/internal/service/account_service_delete_test.go @@ -58,10 +58,6 @@ func (s *accountRepoStub) FindByExtraField(ctx context.Context, key string, valu panic("unexpected FindByExtraField call") } -func (s *accountRepoStub) CountByTLSFingerprintProfile(ctx context.Context) (map[int64]int, error) { - panic("unexpected CountByTLSFingerprintProfile call") -} - func (s *accountRepoStub) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) { panic("unexpected ListCRSAccountIDs call") } diff --git a/backend/internal/service/admin_service_bulk_update_test.go b/backend/internal/service/admin_service_bulk_update_test.go index e90ec93a..4845d87c 100644 --- a/backend/internal/service/admin_service_bulk_update_test.go +++ b/backend/internal/service/admin_service_bulk_update_test.go @@ -43,16 +43,6 @@ func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID i return nil } -func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID int64) ([]Account, error) { - if err, ok := s.listByGroupErr[groupID]; ok { - return nil, err - } - if rows, ok := s.listByGroupData[groupID]; ok { - return rows, nil - } - return nil, nil -} - func (s *accountRepoStubForBulkUpdate) GetByIDs(_ context.Context, ids []int64) ([]*Account, error) { s.getByIDsCalled = true s.getByIDsIDs = append([]int64{}, ids...) @@ -73,6 +63,16 @@ func (s *accountRepoStubForBulkUpdate) GetByID(_ context.Context, id int64) (*Ac return nil, errors.New("account not found") } +func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID int64) ([]Account, error) { + if err, ok := s.listByGroupErr[groupID]; ok { + return nil, err + } + if rows, ok := s.listByGroupData[groupID]; ok { + return rows, nil + } + return nil, nil +} + // TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。 func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) { repo := &accountRepoStubForBulkUpdate{} diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index 1970cc45..dbd18a20 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -170,11 +170,11 @@ func (s *emailCacheStub) SetPasswordResetEmailCooldown(ctx context.Context, emai return nil } -func (s *emailCacheStub) IncrNotifyCodeUserRate(ctx context.Context, userID int64, window time.Duration) (int64, error) { +func (s *emailCacheStub) GetNotifyCodeUserRate(ctx context.Context, userID int64) (int64, error) { return 0, nil } -func (s *emailCacheStub) GetNotifyCodeUserRate(ctx context.Context, userID int64) (int64, error) { +func (s *emailCacheStub) IncrNotifyCodeUserRate(ctx context.Context, userID int64, window time.Duration) (int64, error) { return 0, nil } diff --git a/backend/internal/service/concurrency_service_test.go b/backend/internal/service/concurrency_service_test.go index 65f1c7c5..078ba0dc 100644 --- a/backend/internal/service/concurrency_service_test.go +++ b/backend/internal/service/concurrency_service_test.go @@ -87,7 +87,6 @@ func (c *stubConcurrencyCacheForTest) GetAccountsLoadBatch(_ context.Context, _ func (c *stubConcurrencyCacheForTest) GetUsersLoadBatch(_ context.Context, _ []UserWithConcurrency) (map[int64]*UserLoadInfo, error) { return c.usersLoadBatch, c.usersLoadErr } - func (c *stubConcurrencyCacheForTest) CleanupExpiredAccountSlots(_ context.Context, _ int64) error { return c.cleanupErr } diff --git a/backend/internal/service/error_passthrough_runtime_test.go b/backend/internal/service/error_passthrough_runtime_test.go index 2b7bbf60..7032d15b 100644 --- a/backend/internal/service/error_passthrough_runtime_test.go +++ b/backend/internal/service/error_passthrough_runtime_test.go @@ -220,7 +220,7 @@ func TestApplyErrorPassthroughRule_SkipMonitoringSetsContextKey(t *testing.T) { v, exists := c.Get(OpsSkipPassthroughKey) assert.True(t, exists, "OpsSkipPassthroughKey should be set when skip_monitoring=true") boolVal, ok := v.(bool) - assert.True(t, ok, "value should be a bool") + assert.True(t, ok, "value should be bool") assert.True(t, boolVal) } diff --git a/backend/internal/service/error_policy_test.go b/backend/internal/service/error_policy_test.go index 4099399b..297a954c 100644 --- a/backend/internal/service/error_policy_test.go +++ b/backend/internal/service/error_policy_test.go @@ -110,12 +110,13 @@ func TestCheckErrorPolicy(t *testing.T) { expected: ErrorPolicyTempUnscheduled, }, { - // Gemini OAuth 401 second hit 会升级为 error(返回 None,交由默认错误逻辑处理)。 - name: "temp_unschedulable_401_second_hit_gemini_escalates", + // Antigravity 401 不走升级逻辑(由 applyErrorPolicy 的 temp_unschedulable_rules 自行控制), + // second hit 仍然返回 TempUnscheduled。 + name: "temp_unschedulable_401_second_hit_antigravity_stays_temp", account: &Account{ ID: 15, Type: AccountTypeOAuth, - Platform: PlatformGemini, // 非 Antigravity 平台 401 second hit 升级 + Platform: PlatformAntigravity, TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`, Credentials: map[string]any{ "temp_unschedulable_enabled": true, @@ -130,29 +131,7 @@ func TestCheckErrorPolicy(t *testing.T) { }, statusCode: 401, body: []byte(`unauthorized`), - expected: ErrorPolicyNone, // Gemini 401 second hit 升级为 error - }, - { - name: "temp_unschedulable_401_antigravity_no_escalation", - account: &Account{ - ID: 16, - Type: AccountTypeOAuth, - Platform: PlatformAntigravity, // Antigravity 跳过 401 升级,由 rules 正常处理 - TempUnschedulableReason: `{"status_code":401,"until_unix":1735689600}`, - Credentials: map[string]any{ - "temp_unschedulable_enabled": true, - "temp_unschedulable_rules": []any{ - map[string]any{ - "error_code": float64(401), - "keywords": []any{"unauthorized"}, - "duration_minutes": float64(10), - }, - }, - }, - }, - statusCode: 401, - body: []byte(`unauthorized`), - expected: ErrorPolicyTempUnscheduled, // Antigravity 不升级,继续走规则匹配 + expected: ErrorPolicyTempUnscheduled, }, { name: "temp_unschedulable_body_miss_returns_none", diff --git a/backend/internal/service/gateway_hotpath_optimization_test.go b/backend/internal/service/gateway_hotpath_optimization_test.go index 86c45743..e5bf49b8 100644 --- a/backend/internal/service/gateway_hotpath_optimization_test.go +++ b/backend/internal/service/gateway_hotpath_optimization_test.go @@ -143,6 +143,7 @@ func (s *stickyGatewayCacheHotpathStub) RefreshSessionTTL(ctx context.Context, g func (s *stickyGatewayCacheHotpathStub) DeleteSessionAccountID(ctx context.Context, groupID int64, sessionHash string) error { return nil } + func (s *modelsListAccountRepoStub) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error) { s.listByGroupCalls.Add(1) if s.err != nil { diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index 93a2a583..72832837 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -82,10 +82,6 @@ func (m *mockAccountRepoForPlatform) FindByExtraField(ctx context.Context, key s return nil, nil } -func (m *mockAccountRepoForPlatform) CountByTLSFingerprintProfile(ctx context.Context) (map[int64]int, error) { - return nil, nil -} - func (m *mockAccountRepoForPlatform) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) { return nil, nil } diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index 360b2d40..5e09b95a 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -71,10 +71,6 @@ func (m *mockAccountRepoForGemini) FindByExtraField(ctx context.Context, key str return nil, nil } -func (m *mockAccountRepoForGemini) CountByTLSFingerprintProfile(ctx context.Context) (map[int64]int, error) { - return nil, nil -} - func (m *mockAccountRepoForGemini) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) { return nil, nil } diff --git a/backend/internal/service/openai_account_scheduler.go b/backend/internal/service/openai_account_scheduler.go index bb8ee35e..808f1229 100644 --- a/backend/internal/service/openai_account_scheduler.go +++ b/backend/internal/service/openai_account_scheduler.go @@ -781,7 +781,7 @@ func (s *defaultOpenAIAccountScheduler) isAccountRequestCompatible(account *Acco if account == nil { return false } - if req.RequestedModel != "" && !account.IsOpenAIPassthroughEnabled() && !account.IsModelSupported(req.RequestedModel) { + if req.RequestedModel != "" && !account.IsModelSupported(req.RequestedModel) { return false } return account.SupportsOpenAIImageCapability(req.RequiredImageCapability) diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index 457309d3..a68c9b67 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -187,9 +187,13 @@ func applyCodexOAuthTransform(reqBody map[string]any, isCodexCLI bool, isCompact } func normalizeCodexModel(model string) string { + model = strings.TrimSpace(model) if model == "" { return "gpt-5.4" } + if isOpenAIImageGenerationModel(model) { + return model + } modelID := model if strings.Contains(modelID, "/") { @@ -231,6 +235,78 @@ func normalizeCodexModel(model string) string { return "gpt-5.4" } +func hasOpenAIImageGenerationTool(reqBody map[string]any) bool { + rawTools, ok := reqBody["tools"] + if !ok || rawTools == nil { + return false + } + tools, ok := rawTools.([]any) + if !ok { + return false + } + for _, rawTool := range tools { + toolMap, ok := rawTool.(map[string]any) + if !ok { + continue + } + if strings.TrimSpace(firstNonEmptyString(toolMap["type"])) == "image_generation" { + return true + } + } + return false +} + +func normalizeOpenAIResponsesImageGenerationTools(reqBody map[string]any) bool { + rawTools, ok := reqBody["tools"] + if !ok || rawTools == nil { + return false + } + tools, ok := rawTools.([]any) + if !ok { + return false + } + + modified := false + for _, rawTool := range tools { + toolMap, ok := rawTool.(map[string]any) + if !ok || strings.TrimSpace(firstNonEmptyString(toolMap["type"])) != "image_generation" { + continue + } + if _, ok := toolMap["output_format"]; !ok { + if value := strings.TrimSpace(firstNonEmptyString(toolMap["format"])); value != "" { + toolMap["output_format"] = value + modified = true + } + } + if _, ok := toolMap["output_compression"]; !ok { + if value, exists := toolMap["compression"]; exists && value != nil { + toolMap["output_compression"] = value + modified = true + } + } + if _, ok := toolMap["format"]; ok { + delete(toolMap, "format") + modified = true + } + if _, ok := toolMap["compression"]; ok { + delete(toolMap, "compression") + modified = true + } + } + return modified +} + +func validateOpenAIResponsesImageModel(reqBody map[string]any, model string) error { + if !hasOpenAIImageGenerationTool(reqBody) { + return nil + } + model = strings.TrimSpace(model) + if !isOpenAIImageGenerationModel(model) { + return nil + } + return fmt.Errorf("/v1/responses image_generation requests require a Responses-capable text model; image-only model %q is not allowed", model) +} + func normalizeOpenAIModelForUpstream(account *Account, model string) string { if account == nil || account.Type == AccountTypeOAuth { return normalizeCodexModel(model) diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index 22264f5e..f08e4b15 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -217,6 +217,42 @@ func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunction require.Equal(t, "bash", first["name"]) } +func TestNormalizeOpenAIResponsesImageGenerationTools_RewritesLegacyFields(t *testing.T) { + reqBody := map[string]any{ + "tools": []any{ + map[string]any{ + "type": "image_generation", + "format": "png", + "compression": 60, + }, + }, + } + + modified := normalizeOpenAIResponsesImageGenerationTools(reqBody) + require.True(t, modified) + + tools, ok := reqBody["tools"].([]any) + require.True(t, ok) + first, ok := tools[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "png", first["output_format"]) + require.Equal(t, 60, first["output_compression"]) + _, hasFormat := first["format"] + require.False(t, hasFormat) + _, hasCompression := first["compression"] + require.False(t, hasCompression) +} + +func TestValidateOpenAIResponsesImageModel_RejectsImageOnlyModel(t *testing.T) { + err := validateOpenAIResponsesImageModel(map[string]any{ + "tools": []any{ + map[string]any{"type": "image_generation"}, + }, + }, "gpt-image-2") + + require.ErrorContains(t, err, `/v1/responses image_generation requests require a Responses-capable text model`) +} + func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) { // 空 input 应保持为空且不触发异常。 diff --git a/backend/internal/service/openai_gateway_chat_completions.go b/backend/internal/service/openai_gateway_chat_completions.go index 09763ea2..663066a3 100644 --- a/backend/internal/service/openai_gateway_chat_completions.go +++ b/backend/internal/service/openai_gateway_chat_completions.go @@ -151,38 +151,23 @@ func (s *OpenAIGatewayService) ForwardAsChatCompletions( } logger.L().Debug("openai chat_completions: model mapping applied", logFields...) - { + if account.Type == AccountTypeOAuth { var reqBody map[string]any if err := json.Unmarshal(responsesBody, &reqBody); err != nil { return nil, fmt.Errorf("unmarshal for codex transform: %w", err) } - modified := false - if account.Type == AccountTypeOAuth { - codexResult := applyCodexOAuthTransform(reqBody, false, false) - modified = codexResult.Modified - if codexResult.NormalizedModel != "" { - upstreamModel = codexResult.NormalizedModel - } - if codexResult.PromptCacheKey != "" { - promptCacheKey = codexResult.PromptCacheKey - } else if promptCacheKey != "" { - reqBody["prompt_cache_key"] = promptCacheKey - } - } else { - // 非 OAuth 账号也需要提取 system 消息并注入 instructions, - // 否则上游 GPT-5/Codex 等模型会报 "Instructions are required"。 - if extractSystemMessagesFromInput(reqBody) { - modified = true - } - if applyInstructions(reqBody, false) { - modified = true - } + codexResult := applyCodexOAuthTransform(reqBody, false, false) + if codexResult.NormalizedModel != "" { + upstreamModel = codexResult.NormalizedModel } - if modified { - responsesBody, err = json.Marshal(reqBody) - if err != nil { - return nil, fmt.Errorf("remarshal after codex transform: %w", err) - } + if codexResult.PromptCacheKey != "" { + promptCacheKey = codexResult.PromptCacheKey + } else if promptCacheKey != "" { + reqBody["prompt_cache_key"] = promptCacheKey + } + responsesBody, err = json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("remarshal after codex transform: %w", err) } } diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index f3c12048..534ffeee 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -1503,7 +1503,7 @@ func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Contex if !acc.IsSchedulable() { continue } - if requestedModel != "" && !acc.IsOpenAIPassthroughEnabled() && !acc.IsModelSupported(requestedModel) { + if requestedModel != "" && !acc.IsModelSupported(requestedModel) { continue } if needsUpstreamCheck && s.isUpstreamModelRestrictedByChannel(ctx, *groupID, acc, requestedModel) { @@ -1665,7 +1665,7 @@ func (s *OpenAIGatewayService) resolveFreshSchedulableOpenAIAccount(ctx context. if !fresh.IsSchedulable() || !fresh.IsOpenAI() { return nil } - if requestedModel != "" && !fresh.IsOpenAIPassthroughEnabled() && !fresh.IsModelSupported(requestedModel) { + if requestedModel != "" && !fresh.IsModelSupported(requestedModel) { return nil } return fresh @@ -1935,6 +1935,12 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco markPatchSet("instructions", "You are a helpful coding assistant.") } + if normalizeOpenAIResponsesImageGenerationTools(reqBody) { + bodyModified = true + disablePatch() + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Normalized /responses image_generation tool payload") + } + // 对所有请求执行模型映射(包含 Codex CLI)。 billingModel := account.GetMappedModel(reqModel) if billingModel != reqModel { @@ -1944,6 +1950,26 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco markPatchSet("model", billingModel) } upstreamModel := billingModel + if err := validateOpenAIResponsesImageModel(reqBody, upstreamModel); err != nil { + setOpsUpstreamError(c, http.StatusBadRequest, err.Error(), "") + c.JSON(http.StatusBadRequest, gin.H{ + "error": gin.H{ + "type": "invalid_request_error", + "message": err.Error(), + "param": "model", + }, + }) + return nil, err + } + if hasOpenAIImageGenerationTool(reqBody) { + logger.LegacyPrintf( + "service.openai_gateway", + "[OpenAI] /responses image_generation request inbound_model=%s mapped_model=%s account_type=%s", + reqModel, + upstreamModel, + account.Type, + ) + } // OpenAI OAuth 账号走 ChatGPT internal Codex endpoint,需要将模型名规范化为 // 上游可识别的 Codex/GPT 系列。API Key 账号则应保留原始/映射后的模型名, diff --git a/backend/internal/service/openai_images.go b/backend/internal/service/openai_images.go index fb6bdc7f..7935376b 100644 --- a/backend/internal/service/openai_images.go +++ b/backend/internal/service/openai_images.go @@ -45,8 +45,11 @@ const ( openAIChatGPTConversationPrepareURL = "https://chatgpt.com/backend-api/f/conversation/prepare" openAIChatGPTChatRequirementsURL = "https://chatgpt.com/backend-api/sentinel/chat-requirements" - openAIImageBackendUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36" - openAIImageRequirementsDiff = "0fffff" + openAIImageBackendUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36" + openAIImageRequirementsDiff = "0fffff" + openAIImageLifecycleTimeout = 2 * time.Minute + openAIImageMaxDownloadBytes = 20 << 20 // 20MB per image download + openAIImageMaxUploadPartSize = 20 << 20 // 20MB per multipart upload part ) type OpenAIImagesCapability string @@ -148,6 +151,9 @@ func (s *OpenAIGatewayService) ParseOpenAIImagesRequest(c *gin.Context, body []b } applyOpenAIImagesDefaults(req) + if err := validateOpenAIImagesModel(req.Model); err != nil { + return nil, err + } req.SizeTier = normalizeOpenAIImageSizeTier(req.Size) req.RequiredCapability = classifyOpenAIImagesCapability(req) return req, nil @@ -214,7 +220,7 @@ func parseOpenAIImagesMultipartRequest(body []byte, contentType string, req *Ope continue } - data, err := io.ReadAll(part) + data, err := io.ReadAll(io.LimitReader(part, openAIImageMaxUploadPartSize)) _ = part.Close() if err != nil { return fmt.Errorf("read multipart field %s: %w", name, err) @@ -295,6 +301,21 @@ func applyOpenAIImagesDefaults(req *OpenAIImagesRequest) { req.Model = "gpt-image-2" } +func isOpenAIImageGenerationModel(model string) bool { + return strings.HasPrefix(strings.ToLower(strings.TrimSpace(model)), "gpt-image-") +} + +func validateOpenAIImagesModel(model string) error { + model = strings.TrimSpace(model) + if isOpenAIImageGenerationModel(model) { + return nil + } + if model == "" { + return fmt.Errorf("images endpoint requires an image model") + } + return fmt.Errorf("images endpoint requires an image model, got %q", model) +} + func normalizeOpenAIImagesEndpointPath(path string) string { trimmed := strings.TrimSpace(path) switch { @@ -400,7 +421,21 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesAPIKey( if mapped := strings.TrimSpace(channelMappedModel); mapped != "" { requestModel = mapped } + if err := validateOpenAIImagesModel(requestModel); err != nil { + return nil, err + } upstreamModel := account.GetMappedModel(requestModel) + if err := validateOpenAIImagesModel(upstreamModel); err != nil { + return nil, err + } + logger.LegacyPrintf( + "service.openai_gateway", + "[OpenAI] Images request routing request_model=%s upstream_model=%s endpoint=%s account_type=%s", + strings.TrimSpace(parsed.Model), + upstreamModel, + parsed.Endpoint, + account.Type, + ) forwardBody, forwardContentType, err := rewriteOpenAIImagesModel(body, parsed.ContentType, upstreamModel) if err != nil { return nil, err @@ -759,6 +794,17 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth( if mapped := strings.TrimSpace(channelMappedModel); mapped != "" { requestModel = mapped } + if err := validateOpenAIImagesModel(requestModel); err != nil { + return nil, err + } + logger.LegacyPrintf( + "service.openai_gateway", + "[OpenAI] Images request routing request_model=%s endpoint=%s account_type=%s uploads=%d", + requestModel, + parsed.Endpoint, + account.Type, + len(parsed.Uploads), + ) token, _, err := s.GetAccessToken(ctx, account) if err != nil { @@ -844,8 +890,18 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth( return nil, err } pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, nil) + logger.LegacyPrintf( + "service.openai_gateway", + "[OpenAI] Image extraction stream conversation_id=%s total_assets=%d file_service_assets=%d direct_assets=%d", + conversationID, + len(pointerInfos), + countOpenAIFileServicePointerInfos(pointerInfos), + countOpenAIDirectImageAssets(pointerInfos), + ) + lifecycleCtx, releaseLifecycleCtx := detachOpenAIImageLifecycleContext(ctx, openAIImageLifecycleTimeout) + defer releaseLifecycleCtx() if conversationID != "" && !hasOpenAIFileServicePointerInfos(pointerInfos) { - polledPointers, pollErr := pollOpenAIImageConversation(ctx, client, headers, conversationID) + polledPointers, pollErr := pollOpenAIImageConversation(lifecycleCtx, client, headers, conversationID) if pollErr != nil { return nil, s.wrapOpenAIImageBackendError(ctx, c, account, pollErr) } @@ -853,10 +909,11 @@ func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth( } pointerInfos = preferOpenAIFileServicePointerInfos(pointerInfos) if len(pointerInfos) == 0 { + logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Image extraction yielded no assets conversation_id=%s", conversationID) return nil, fmt.Errorf("openai image conversation returned no downloadable images") } - responseBody, imageCount, err := buildOpenAIImageResponse(ctx, client, headers, conversationID, pointerInfos) + responseBody, imageCount, err := buildOpenAIImageResponse(lifecycleCtx, client, headers, conversationID, pointerInfos) if err != nil { return nil, s.wrapOpenAIImageBackendError(ctx, c, account, err) } @@ -1283,8 +1340,11 @@ func buildOpenAIImageConversationRequest(parsed *OpenAIImagesRequest, parentMess } type openAIImagePointerInfo struct { - Pointer string - Prompt string + Pointer string + DownloadURL string + B64JSON string + MimeType string + Prompt string } type openAIImageToolMessage struct { @@ -1336,10 +1396,6 @@ func collectOpenAIImagePointers(body []byte) []openAIImagePointerInfo { if len(body) == 0 { return nil } - matches := openAIImagePointerMatches(body) - if len(matches) == 0 { - return nil - } prompt := "" for _, path := range []string{ "message.metadata.dalle.prompt", @@ -1351,11 +1407,12 @@ func collectOpenAIImagePointers(body []byte) []openAIImagePointerInfo { break } } + matches := openAIImagePointerMatches(body) out := make([]openAIImagePointerInfo, 0, len(matches)) for _, pointer := range matches { out = append(out, openAIImagePointerInfo{Pointer: pointer, Prompt: prompt}) } - return out + return mergeOpenAIImagePointerInfos(out, collectOpenAIImageInlineAssets(body, prompt)) } func openAIImagePointerMatches(body []byte) []string { @@ -1394,27 +1451,72 @@ func mergeOpenAIImagePointerInfos(existing []openAIImagePointerInfo, next []open seen := make(map[string]openAIImagePointerInfo, len(existing)+len(next)) out := make([]openAIImagePointerInfo, 0, len(existing)+len(next)) for _, item := range existing { - seen[item.Pointer] = item + if key := item.identityKey(); key != "" { + seen[key] = item + } out = append(out, item) } for _, item := range next { - if existingItem, ok := seen[item.Pointer]; ok { - if existingItem.Prompt == "" && item.Prompt != "" { + key := item.identityKey() + if key == "" { + continue + } + if existingItem, ok := seen[key]; ok { + merged := mergeOpenAIImagePointerInfo(existingItem, item) + if merged != existingItem { for i := range out { - if out[i].Pointer == item.Pointer { - out[i].Prompt = item.Prompt + if out[i].identityKey() == key { + out[i] = merged break } } + seen[key] = merged } continue } - seen[item.Pointer] = item + seen[key] = item out = append(out, item) } return out } +func (i openAIImagePointerInfo) identityKey() string { + switch { + case strings.TrimSpace(i.Pointer) != "": + return "pointer:" + strings.TrimSpace(i.Pointer) + case strings.TrimSpace(i.DownloadURL) != "": + return "download:" + strings.TrimSpace(i.DownloadURL) + case strings.TrimSpace(i.B64JSON) != "": + b64 := strings.TrimSpace(i.B64JSON) + if len(b64) > 64 { + b64 = b64[:64] + } + return "b64:" + b64 + default: + return "" + } +} + +func mergeOpenAIImagePointerInfo(existing, next openAIImagePointerInfo) openAIImagePointerInfo { + merged := existing + if strings.TrimSpace(merged.Pointer) == "" { + merged.Pointer = next.Pointer + } + if strings.TrimSpace(merged.DownloadURL) == "" { + merged.DownloadURL = next.DownloadURL + } + if strings.TrimSpace(merged.B64JSON) == "" { + merged.B64JSON = next.B64JSON + } + if strings.TrimSpace(merged.MimeType) == "" { + merged.MimeType = next.MimeType + } + if strings.TrimSpace(merged.Prompt) == "" { + merged.Prompt = next.Prompt + } + return merged +} + func hasOpenAIFileServicePointerInfos(items []openAIImagePointerInfo) bool { for _, item := range items { if strings.HasPrefix(item.Pointer, "file-service://") { @@ -1424,6 +1526,26 @@ func hasOpenAIFileServicePointerInfos(items []openAIImagePointerInfo) bool { return false } +func countOpenAIFileServicePointerInfos(items []openAIImagePointerInfo) int { + count := 0 + for _, item := range items { + if strings.HasPrefix(item.Pointer, "file-service://") { + count++ + } + } + return count +} + +func countOpenAIDirectImageAssets(items []openAIImagePointerInfo) int { + count := 0 + for _, item := range items { + if strings.TrimSpace(item.DownloadURL) != "" || strings.TrimSpace(item.B64JSON) != "" { + count++ + } + } + return count +} + func preferOpenAIFileServicePointerInfos(items []openAIImagePointerInfo) []openAIImagePointerInfo { if !hasOpenAIFileServicePointerInfos(items) { return items @@ -1591,11 +1713,7 @@ func buildOpenAIImageResponse( } items := make([]responseItem, 0, len(pointers)) for _, pointer := range pointers { - downloadURL, err := fetchOpenAIImageDownloadURL(ctx, client, headers, conversationID, pointer.Pointer) - if err != nil { - return nil, 0, err - } - data, err := downloadOpenAIImageBytes(ctx, client, headers, downloadURL) + data, err := resolveOpenAIImageBytes(ctx, client, headers, conversationID, pointer) if err != nil { return nil, 0, err } @@ -1615,6 +1733,136 @@ func buildOpenAIImageResponse( return body, len(items), nil } +func resolveOpenAIImageBytes( + ctx context.Context, + client *req.Client, + headers http.Header, + conversationID string, + pointer openAIImagePointerInfo, +) ([]byte, error) { + if normalized := normalizeOpenAIImageBase64(pointer.B64JSON); normalized != "" { + return base64.StdEncoding.DecodeString(normalized) + } + if downloadURL := strings.TrimSpace(pointer.DownloadURL); downloadURL != "" { + return downloadOpenAIImageBytes(ctx, client, headers, downloadURL) + } + if strings.TrimSpace(pointer.Pointer) == "" { + return nil, fmt.Errorf("image asset is missing pointer, url, and base64 data") + } + downloadURL, err := fetchOpenAIImageDownloadURL(ctx, client, headers, conversationID, pointer.Pointer) + if err != nil { + return nil, err + } + return downloadOpenAIImageBytes(ctx, client, headers, downloadURL) +} + +func normalizeOpenAIImageBase64(raw string) string { + raw = strings.TrimSpace(raw) + if raw == "" { + return "" + } + if strings.HasPrefix(strings.ToLower(raw), "data:") { + if idx := strings.Index(raw, ","); idx >= 0 && idx+1 < len(raw) { + raw = raw[idx+1:] + } + } + raw = strings.TrimSpace(raw) + raw = strings.TrimRight(raw, "=") + strings.Repeat("=", (4-len(raw)%4)%4) + if raw == "" { + return "" + } + if _, err := base64.StdEncoding.DecodeString(raw); err != nil { + return "" + } + return raw +} + +func collectOpenAIImageInlineAssets(body []byte, fallbackPrompt string) []openAIImagePointerInfo { + if len(body) == 0 || !gjson.ValidBytes(body) { + return nil + } + var decoded any + if err := json.Unmarshal(body, &decoded); err != nil { + return nil + } + var out []openAIImagePointerInfo + walkOpenAIImageInlineAssets(decoded, strings.TrimSpace(fallbackPrompt), &out) + return out +} + +func walkOpenAIImageInlineAssets(node any, prompt string, out *[]openAIImagePointerInfo) { + switch value := node.(type) { + case map[string]any: + localPrompt := prompt + for _, key := range []string{"revised_prompt", "image_gen_title", "prompt"} { + if v, ok := value[key].(string); ok && strings.TrimSpace(v) != "" { + localPrompt = strings.TrimSpace(v) + break + } + } + item := openAIImagePointerInfo{ + Prompt: localPrompt, + Pointer: firstNonEmptyString(value["asset_pointer"], value["pointer"]), + DownloadURL: firstNonEmptyString(value["download_url"], value["url"], value["image_url"]), + B64JSON: firstNonEmptyString(value["b64_json"], value["base64"], value["image_base64"]), + MimeType: firstNonEmptyString(value["mime_type"], value["mimeType"], value["content_type"]), + } + switch { + case strings.HasPrefix(strings.TrimSpace(item.Pointer), "file-service://"), + strings.HasPrefix(strings.TrimSpace(item.Pointer), "sediment://"), + isLikelyOpenAIImageDownloadURL(item.DownloadURL), + normalizeOpenAIImageBase64(item.B64JSON) != "": + *out = append(*out, item) + } + for _, child := range value { + walkOpenAIImageInlineAssets(child, localPrompt, out) + } + case []any: + for _, child := range value { + walkOpenAIImageInlineAssets(child, prompt, out) + } + } +} + +func firstNonEmptyString(values ...any) string { + for _, value := range values { + if s, ok := value.(string); ok && strings.TrimSpace(s) != "" { + return strings.TrimSpace(s) + } + } + return "" +} + +func isLikelyOpenAIImageDownloadURL(raw string) bool { + raw = strings.TrimSpace(raw) + if raw == "" { + return false + } + if strings.HasPrefix(strings.ToLower(raw), "data:image/") { + return true + } + if !strings.HasPrefix(strings.ToLower(raw), "http://") && !strings.HasPrefix(strings.ToLower(raw), "https://") { + return false + } + lower := strings.ToLower(raw) + return strings.Contains(lower, "/download") || + strings.Contains(lower, ".png") || + strings.Contains(lower, ".jpg") || + strings.Contains(lower, ".jpeg") || + strings.Contains(lower, ".webp") +} + +func detachOpenAIImageLifecycleContext(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { + base := context.Background() + if ctx != nil { + base = context.WithoutCancel(ctx) + } + if timeout <= 0 { + return base, func() {} + } + return context.WithTimeout(base, timeout) +} + func fetchOpenAIImageDownloadURL( ctx context.Context, client *req.Client, @@ -1706,7 +1954,7 @@ func downloadOpenAIImageBytes(ctx context.Context, client *req.Client, headers h if resp.StatusCode < 200 || resp.StatusCode >= 300 { return nil, newOpenAIImageStatusError(resp, "download image bytes failed") } - return io.ReadAll(resp.Body) + return io.ReadAll(io.LimitReader(resp.Body, openAIImageMaxDownloadBytes)) } func handleOpenAIImageBackendError(resp *req.Response) error { diff --git a/backend/internal/service/openai_images_test.go b/backend/internal/service/openai_images_test.go index 173d69ba..6aa1d5e5 100644 --- a/backend/internal/service/openai_images_test.go +++ b/backend/internal/service/openai_images_test.go @@ -2,6 +2,7 @@ package service import ( "bytes" + "context" "mime/multipart" "net/http" "net/http/httptest" @@ -103,3 +104,56 @@ func TestOpenAIGatewayServiceParseOpenAIImagesRequest_ExplicitSizeRequiresNative require.NotNil(t, parsed) require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability) } + +func TestOpenAIGatewayServiceParseOpenAIImagesRequest_RejectsNonImageModel(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-5.4","prompt":"draw a cat"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.Nil(t, parsed) + require.ErrorContains(t, err, `images endpoint requires an image model, got "gpt-5.4"`) +} + +func TestCollectOpenAIImagePointers_RecognizesDirectAssets(t *testing.T) { + items := collectOpenAIImagePointers([]byte(`{ + "revised_prompt": "cat astronaut", + "parts": [ + {"b64_json":"QUJD"}, + {"download_url":"https://files.example.com/image.png?sig=1"}, + {"asset_pointer":"file-service://file_123"} + ] + }`)) + + require.Len(t, items, 3) + var sawBase64, sawURL, sawPointer bool + for _, item := range items { + if item.B64JSON == "QUJD" { + sawBase64 = true + require.Equal(t, "cat astronaut", item.Prompt) + } + if item.DownloadURL == "https://files.example.com/image.png?sig=1" { + sawURL = true + } + if item.Pointer == "file-service://file_123" { + sawPointer = true + } + } + require.True(t, sawBase64) + require.True(t, sawURL) + require.True(t, sawPointer) +} + +func TestResolveOpenAIImageBytes_PrefersInlineBase64(t *testing.T) { + data, err := resolveOpenAIImageBytes(context.Background(), nil, nil, "", openAIImagePointerInfo{ + B64JSON: "data:image/png;base64,QUJD", + }) + require.NoError(t, err) + require.Equal(t, []byte("ABC"), data) +} diff --git a/backend/internal/service/openai_model_mapping_test.go b/backend/internal/service/openai_model_mapping_test.go index 35e7c250..f25863a8 100644 --- a/backend/internal/service/openai_model_mapping_test.go +++ b/backend/internal/service/openai_model_mapping_test.go @@ -91,6 +91,7 @@ func TestNormalizeCodexModel(t *testing.T) { "gpt-5.3-codex-spark-high": "gpt-5.3-codex-spark", "gpt-5.3-codex-spark-xhigh": "gpt-5.3-codex-spark", "gpt-5.3": "gpt-5.3-codex", + "gpt-image-2": "gpt-image-2", } for input, expected := range cases { diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index 2bf48702..106ec9f7 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -812,6 +812,16 @@ func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing { return openAIGPT54FallbackPricing } + if isOpenAIImageGenerationModel(model) { + for _, candidate := range []string{"gpt-image-2", "gpt-image-1.5", "gpt-image-1"} { + if pricing, ok := s.pricingData[candidate]; ok { + logger.LegacyPrintf("service.pricing", "[Pricing] OpenAI image fallback matched %s -> %s", model, candidate) + return pricing + } + } + return nil + } + // 最终回退到 DefaultTestModel defaultModel := strings.ToLower(openai.DefaultTestModel) if pricing, ok := s.pricingData[defaultModel]; ok { diff --git a/backend/internal/service/pricing_service_test.go b/backend/internal/service/pricing_service_test.go index 13a5c70c..e2bd7cf3 100644 --- a/backend/internal/service/pricing_service_test.go +++ b/backend/internal/service/pricing_service_test.go @@ -128,6 +128,21 @@ func TestGetModelPricing_Gpt54NanoUsesDedicatedStaticFallbackWhenRemoteMissing(t require.Zero(t, got.LongContextInputTokenThreshold) } +func TestGetModelPricing_ImageModelDoesNotFallbackToTextModel(t *testing.T) { + imagePricing := &LiteLLMModelPricing{InputCostPerToken: 3} + textPricing := &LiteLLMModelPricing{InputCostPerToken: 9} + + svc := &PricingService{ + pricingData: map[string]*LiteLLMModelPricing{ + "gpt-image-2": imagePricing, + "gpt-5.4": textPricing, + }, + } + + got := svc.GetModelPricing("gpt-image-3") + require.Same(t, imagePricing, got) +} + func TestParsePricingData_PreservesPriorityAndServiceTierFields(t *testing.T) { raw := map[string]any{ "gpt-5.4": map[string]any{ diff --git a/backend/internal/service/ratelimit_session_window_test.go b/backend/internal/service/ratelimit_session_window_test.go index 77f36ae9..7796a85e 100644 --- a/backend/internal/service/ratelimit_session_window_test.go +++ b/backend/internal/service/ratelimit_session_window_test.go @@ -73,9 +73,6 @@ func (m *sessionWindowMockRepo) GetByCRSAccountID(context.Context, string) (*Acc func (m *sessionWindowMockRepo) FindByExtraField(context.Context, string, any) ([]Account, error) { panic("unexpected") } -func (m *sessionWindowMockRepo) CountByTLSFingerprintProfile(context.Context) (map[int64]int, error) { - panic("unexpected") -} func (m *sessionWindowMockRepo) ListCRSAccountIDs(context.Context) (map[string]int64, error) { panic("unexpected") } diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 0da0eb02..757c4025 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -683,7 +683,6 @@ type PublicSettingsInjectionPayload struct { // Feature flags — MUST match the opt-in/opt-out registry in // frontend/src/utils/featureFlags.ts. Missing a field here is the bug // that hid the "可用渠道" menu on page refresh. - ForceEmailOnThirdPartySignup bool `json:"force_email_on_third_party_signup"` ChannelMonitorEnabled bool `json:"channel_monitor_enabled"` ChannelMonitorDefaultIntervalSeconds int `json:"channel_monitor_default_interval_seconds"` AvailableChannelsEnabled bool `json:"available_channels_enabled"` @@ -736,7 +735,6 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any BalanceLowNotifyThreshold: settings.BalanceLowNotifyThreshold, BalanceLowNotifyRechargeURL: settings.BalanceLowNotifyRechargeURL, - ForceEmailOnThirdPartySignup: settings.ForceEmailOnThirdPartySignup, ChannelMonitorEnabled: settings.ChannelMonitorEnabled, ChannelMonitorDefaultIntervalSeconds: settings.ChannelMonitorDefaultIntervalSeconds, AvailableChannelsEnabled: settings.AvailableChannelsEnabled, diff --git a/backend/internal/service/sticky_session_test.go b/backend/internal/service/sticky_session_test.go index 02369b19..11ace7bd 100644 --- a/backend/internal/service/sticky_session_test.go +++ b/backend/internal/service/sticky_session_test.go @@ -122,8 +122,8 @@ func TestShouldClearStickySession(t *testing.T) { { name: "overloaded account", account: &Account{ - Status: StatusActive, - Schedulable: true, + Status: StatusActive, + Schedulable: true, OverloadUntil: &future, }, requestedModel: "", diff --git a/backend/migrations/130_fix_claude_code_template_userid.sql b/backend/migrations/130_fix_claude_code_template_userid.sql deleted file mode 100644 index 1591c566..00000000 --- a/backend/migrations/130_fix_claude_code_template_userid.sql +++ /dev/null @@ -1,36 +0,0 @@ --- Migration: 114_fix_claude_code_template_userid --- 113 的 seed 使用 legacy 格式的 metadata.user_id,但已部署环境此前是手工建的 --- 「Claude Code 伪装」模板(用新版 JSON-string 格式 user_id),113 的 ON CONFLICT --- DO NOTHING 不会覆盖。本 migration 定向修复这一条历史记录及其下游监控快照。 --- --- 安全性:WHERE 条件同时匹配 (provider, name) + user_id 以 '{' 开头, --- 所以: --- - 用户自己改过 user_id(或者 seed 本来就是 legacy)→ LIKE 不中,保持原状 --- - 用户改过 template name / provider → WHERE 不中,完全跳过 --- 幂等:第二次跑时 user_id 已经是 legacy 格式,LIKE '{%' 不中,UPDATE 0 行。 - -UPDATE channel_monitor_request_templates -SET body_override = jsonb_set( - body_override, - '{metadata,user_id}', - '"user_0000000000000000000000000000000000000000000000000000000000000000_account_00000000-0000-0000-0000-000000000000_session_00000000-0000-0000-0000-000000000000"'::jsonb, - false - ), - updated_at = NOW() -WHERE provider = 'anthropic' - AND name = 'Claude Code 伪装' - AND body_override #>> '{metadata,user_id}' LIKE '{%'; - --- 同步已应用此模板的监控快照(监控采用 snapshot 语义,只更新那些明显还是 seed 原样的)。 -UPDATE channel_monitors m -SET body_override = jsonb_set( - m.body_override, - '{metadata,user_id}', - '"user_0000000000000000000000000000000000000000000000000000000000000000_account_00000000-0000-0000-0000-000000000000_session_00000000-0000-0000-0000-000000000000"'::jsonb, - false - ) -FROM channel_monitor_request_templates t -WHERE m.template_id = t.id - AND t.provider = 'anthropic' - AND t.name = 'Claude Code 伪装' - AND m.body_override #>> '{metadata,user_id}' LIKE '{%'; diff --git a/backend/migrations/131_cleanup_claude_code_mimicry_fields.sql b/backend/migrations/131_cleanup_claude_code_mimicry_fields.sql deleted file mode 100644 index 2aab05df..00000000 --- a/backend/migrations/131_cleanup_claude_code_mimicry_fields.sql +++ /dev/null @@ -1,40 +0,0 @@ --- Migration: 115_cleanup_claude_code_mimicry_fields --- 清理 "Claude Code CLI 模拟套件 (A)" + "Signature Pool (B)" 回滚后遗留的 DB 状态。 --- --- 涉及回滚的功能: --- - 6d0e0562 feat(fingerprint): Claude Code CLI fingerprint mimicry suite --- - cfd95669 feat(tls-fingerprint): show binding count + fix randomized fingerprint visibility --- - 2df77c16/78de54b6/89d14a2 等 Signature Pool 相关 commits --- --- 需要清理的字段: --- 1. accounts.extra->>'tls_fingerprint_randomized' — cfd95669 引入的随机指纹标记 --- 2. accounts.extra->>'metadata' (内含 user_id) — sticky session UUID per Claude OAuth account --- 3. accounts.extra->>'sticky_session_user_id' — sticky session 备用键名(保险) --- --- 需要清理的索引: --- - idx_accounts_tls_fp_profile_id — 来自 migration 108,加速绑定数聚合查询。 --- 回滚后绑定数 UI 已移除,索引不再被任何查询使用,删除以释放空间。 --- --- 注意:上游已存在的 tls_fingerprint_profile_id / enable_tls_fingerprint 字段保留, --- 这些是上游 TLS fingerprint profile 功能本身的一部分,不在回滚范围内。 - --- 1) 删除 cfd95669 引入的索引 -DROP INDEX IF EXISTS idx_accounts_tls_fp_profile_id; - --- 2) 清理 sticky session UUID(仅 Claude/Anthropic OAuth/SetupToken 账号会写入此字段) -UPDATE accounts -SET extra = extra - 'metadata' -WHERE deleted_at IS NULL - AND extra ? 'metadata'; - --- 3) 清理随机指纹标记 -UPDATE accounts -SET extra = extra - 'tls_fingerprint_randomized' -WHERE deleted_at IS NULL - AND extra ? 'tls_fingerprint_randomized'; - --- 4) 清理可能残留的 sticky session 备用字段 -UPDATE accounts -SET extra = extra - 'sticky_session_user_id' -WHERE deleted_at IS NULL - AND extra ? 'sticky_session_user_id'; diff --git a/frontend/public/wechat-qr.jpg b/frontend/public/wechat-qr.jpg deleted file mode 100644 index 659068d8..00000000 Binary files a/frontend/public/wechat-qr.jpg and /dev/null differ diff --git a/frontend/src/__tests__/setup.ts b/frontend/src/__tests__/setup.ts index 0cb49219..decb2a37 100644 --- a/frontend/src/__tests__/setup.ts +++ b/frontend/src/__tests__/setup.ts @@ -36,22 +36,6 @@ class MockResizeObserver { globalThis.ResizeObserver = MockResizeObserver as unknown as typeof ResizeObserver -// Mock matchMedia (jsdom doesn't implement it). -// Default matches=true so desktop viewport queries pass and components that -// only lazy-load on mobile render content immediately in tests. -if (typeof window !== 'undefined' && !window.matchMedia) { - window.matchMedia = (query: string): MediaQueryList => ({ - matches: true, - media: query, - onchange: null, - addListener: vi.fn(), - removeListener: vi.fn(), - addEventListener: vi.fn(), - removeEventListener: vi.fn(), - dispatchEvent: vi.fn() - }) as MediaQueryList -} - // Vue Test Utils 全局配置 config.global.stubs = { // 可以在这里添加全局 stub diff --git a/frontend/src/api/admin/accounts.ts b/frontend/src/api/admin/accounts.ts index 9f476868..a146f1f7 100644 --- a/frontend/src/api/admin/accounts.ts +++ b/frontend/src/api/admin/accounts.ts @@ -17,7 +17,7 @@ import type { AdminDataPayload, AdminDataImportResult, CheckMixedChannelRequest, - CheckMixedChannelResponse, + CheckMixedChannelResponse } from '@/types' /** @@ -663,7 +663,7 @@ export const accountsAPI = { getAntigravityDefaultModelMapping, batchClearError, batchRefresh, - setPrivacy, + setPrivacy } export default accountsAPI diff --git a/frontend/src/components/account/BulkEditAccountModal.vue b/frontend/src/components/account/BulkEditAccountModal.vue index c8380523..13c30cf9 100644 --- a/frontend/src/components/account/BulkEditAccountModal.vue +++ b/frontend/src/components/account/BulkEditAccountModal.vue @@ -698,48 +698,6 @@ - -
-
-
- -

- {{ t('admin.accounts.allowOveragesTooltip') }} -

-
- -
-
- -
-
-
@@ -1009,11 +967,6 @@ const allOpenAIOAuth = computed(() => { ) }) -// 是否全部为 Antigravity 平台(allow_overages 仅在此条件下显示) -const allAntigravity = computed(() => - props.selectedPlatforms.length === 1 && props.selectedPlatforms[0] === 'antigravity' -) - // 是否全部为 Anthropic OAuth/SetupToken(RPM 配置仅在此条件下显示) const allAnthropicOAuthOrSetupToken = computed(() => { return ( @@ -1060,7 +1013,6 @@ const enableGroups = ref(false) const enableOpenAIPassthrough = ref(false) const enableOpenAIWSMode = ref(false) const enableRpmLimit = ref(false) -const enableAllowOverages = ref(false) // State - field values const submitting = ref(false) @@ -1088,7 +1040,6 @@ const bulkBaseRpm = ref(null) const bulkRpmStrategy = ref<'tiered' | 'sticky_exempt'>('tiered') const bulkRpmStickyBuffer = ref(null) const userMsgQueueMode = ref(null) -const allowOverages = ref(false) const umqModeOptions = computed(() => [ { value: '', label: t('admin.accounts.quotaControl.rpmLimit.umqModeOff') }, { value: 'throttle', label: t('admin.accounts.quotaControl.rpmLimit.umqModeThrottle') }, @@ -1330,13 +1281,6 @@ const buildUpdatePayload = (): Record | null => { umqExtra.user_msg_queue_enabled = false // 清理旧字段(JSONB merge) } - // Allow overages (Antigravity only) - if (enableAllowOverages.value) { - if (!updates.extra) updates.extra = {} - const overagesExtra = updates.extra as Record - overagesExtra.allow_overages = allowOverages.value - } - return Object.keys(updates).length > 0 ? updates : null } @@ -1401,7 +1345,6 @@ const handleSubmit = async () => { enableGroups.value || enableOpenAIWSMode.value || enableRpmLimit.value || - enableAllowOverages.value || userMsgQueueMode.value !== null if (!hasAnyFieldEnabled) { @@ -1495,7 +1438,6 @@ watch( enableOpenAIPassthrough.value = false enableOpenAIWSMode.value = false enableRpmLimit.value = false - enableAllowOverages.value = false // Reset all values baseUrl.value = '' @@ -1519,7 +1461,6 @@ watch( bulkRpmStrategy.value = 'tiered' bulkRpmStickyBuffer.value = null userMsgQueueMode.value = null - allowOverages.value = false // Reset mixed channel warning state showMixedChannelWarning.value = false diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index f910ad07..2130c9ab 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -3322,12 +3322,7 @@ watch( if (newVal) { // Load TLS fingerprint profiles adminAPI.tlsFingerprintProfiles.list() - .then(profiles => { - tlsFingerprintProfiles.value = profiles.map(p => ({ - id: p.id, - name: p.name, - })) - }) + .then(profiles => { tlsFingerprintProfiles.value = profiles.map(p => ({ id: p.id, name: p.name })) }) .catch(() => { tlsFingerprintProfiles.value = [] }) // Modal opened - fill related models allowedModels.value = [...getModelsByPlatform(form.platform)] diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index 17df51bb..59ca0b9c 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -2440,10 +2440,7 @@ watch( const loadTLSProfiles = async () => { try { const profiles = await adminAPI.tlsFingerprintProfiles.list() - tlsFingerprintProfiles.value = profiles.map(p => ({ - id: p.id, - name: p.name, - })) + tlsFingerprintProfiles.value = profiles.map(p => ({ id: p.id, name: p.name })) } catch { tlsFingerprintProfiles.value = [] } @@ -3312,5 +3309,4 @@ const handleMixedChannelConfirm = async () => { const handleMixedChannelCancel = () => { clearMixedChannelDialog() } - diff --git a/frontend/src/components/account/__tests__/AccountStatusIndicator.spec.ts b/frontend/src/components/account/__tests__/AccountStatusIndicator.spec.ts index f758e6b0..7cdf7999 100644 --- a/frontend/src/components/account/__tests__/AccountStatusIndicator.spec.ts +++ b/frontend/src/components/account/__tests__/AccountStatusIndicator.spec.ts @@ -122,7 +122,7 @@ describe('AccountStatusIndicator', () => { } }) - expect(wrapper.text()).toContain('admin.accounts.status.creditsExhausted') + expect(wrapper.text()).toContain('account.creditsExhausted') }) it('模型限流 + overages 启用 + AICredits key 生效 → 普通限流样式(积分耗尽,无 ⚡)', () => { @@ -157,6 +157,6 @@ describe('AccountStatusIndicator', () => { expect(wrapper.text()).toContain('CSon45') expect(wrapper.text()).not.toContain('⚡') // AICredits 积分耗尽状态应显示 - expect(wrapper.text()).toContain('admin.accounts.status.creditsExhausted') + expect(wrapper.text()).toContain('account.creditsExhausted') }) }) diff --git a/frontend/src/components/account/__tests__/AccountUsageCell.spec.ts b/frontend/src/components/account/__tests__/AccountUsageCell.spec.ts index 7a739d44..9158da64 100644 --- a/frontend/src/components/account/__tests__/AccountUsageCell.spec.ts +++ b/frontend/src/components/account/__tests__/AccountUsageCell.spec.ts @@ -15,10 +15,6 @@ vi.mock('@/api/admin', () => ({ } })) -vi.mock('@/utils/usageLoadQueue', () => ({ - enqueueUsageRequest: (_account: unknown, fn: () => Promise) => fn() -})) - vi.mock('vue-i18n', async () => { const actual = await vi.importActual('vue-i18n') return { @@ -389,117 +385,6 @@ describe('AccountUsageCell', () => { expect(wrapper.text()).toContain('7d|0|27700') }) - it('OpenAI OAuth 在 usage 请求失败时仍回退显示本地 codex 快照', async () => { - getUsage.mockRejectedValue(new Error('network error')) - const errorSpy = vi.spyOn(console, 'error').mockImplementation(() => {}) - - const wrapper = mount(AccountUsageCell, { - props: { - account: makeAccount({ - id: 2004, - platform: 'openai', - type: 'oauth', - extra: { - codex_usage_updated_at: '2099-03-07T10:00:00Z', - codex_5h_used_percent: 12, - codex_5h_reset_at: '2099-03-07T12:00:00Z', - codex_7d_used_percent: 34, - codex_7d_reset_at: '2099-03-13T12:00:00Z' - } - }) - }, - global: { - stubs: { - UsageProgressBar: { - props: ['label', 'utilization', 'resetsAt', 'windowStats', 'color'], - template: '
{{ label }}|{{ utilization }}|{{ resetsAt }}
' - }, - AccountQuotaInfo: true - } - } - }) - - await flushPromises() - - expect(getUsage).toHaveBeenCalledWith(2004) - expect(wrapper.text()).toContain('5h|12|2099-03-07T12:00:00.000Z') - expect(wrapper.text()).toContain('7d|34|2099-03-13T12:00:00.000Z') - errorSpy.mockRestore() - }) - - it('OpenAI OAuth 已限额时首屏优先等待重新查询的 usage,而不是先显示旧 codex 快照', async () => { - let resolveUsage: ((value: any) => void) | null = null - getUsage.mockReturnValue( - new Promise((resolve) => { - resolveUsage = resolve - }) - ) - - const wrapper = mount(AccountUsageCell, { - props: { - account: makeAccount({ - id: 2005, - platform: 'openai', - type: 'oauth', - rate_limit_reset_at: '2099-03-07T12:00:00Z', - extra: { - codex_5h_used_percent: 0, - codex_5h_reset_at: '2099-03-07T12:00:00Z', - codex_7d_used_percent: 0, - codex_7d_reset_at: '2099-03-13T12:00:00Z' - } - }) - }, - global: { - stubs: { - UsageProgressBar: { - props: ['label', 'utilization', 'resetsAt', 'windowStats', 'color'], - template: '
{{ label }}|{{ utilization }}|{{ windowStats?.tokens }}
' - }, - AccountQuotaInfo: true - } - } - }) - - await Promise.resolve() - - expect(getUsage).toHaveBeenCalledWith(2005) - expect(wrapper.text()).not.toContain('5h|0|') - expect(wrapper.text()).not.toContain('7d|0|') - - resolveUsage?.({ - five_hour: { - utilization: 100, - resets_at: '2026-03-07T12:00:00Z', - remaining_seconds: 3600, - window_stats: { - requests: 211, - tokens: 106540000, - cost: 38.13, - standard_cost: 38.13, - user_cost: 38.13 - } - }, - seven_day: { - utilization: 100, - resets_at: '2026-03-13T12:00:00Z', - remaining_seconds: 3600, - window_stats: { - requests: 211, - tokens: 106540000, - cost: 38.13, - standard_cost: 38.13, - user_cost: 38.13 - } - } - }) - - await flushPromises() - - expect(wrapper.text()).toContain('5h|100|106540000') - expect(wrapper.text()).toContain('7d|100|106540000') - }) - it('OpenAI OAuth 在行数据刷新但仍无 codex 快照时会重新拉取 usage', async () => { getUsage .mockResolvedValueOnce({ diff --git a/frontend/src/components/common/HelpTooltip.vue b/frontend/src/components/common/HelpTooltip.vue index e95052da..d2a2e48f 100644 --- a/frontend/src/components/common/HelpTooltip.vue +++ b/frontend/src/components/common/HelpTooltip.vue @@ -1,23 +1,69 @@ @@ -534,7 +410,6 @@ import { useI18n } from 'vue-i18n' import { useAuthStore, useAppStore } from '@/stores' import LocaleSwitcher from '@/components/common/LocaleSwitcher.vue' import Icon from '@/components/icons/Icon.vue' -import WechatServiceButton from '@/components/common/WechatServiceButton.vue' const { t } = useI18n() @@ -544,6 +419,7 @@ const appStore = useAppStore() // Site settings - directly from appStore (already initialized from injected config) const siteName = computed(() => appStore.cachedPublicSettings?.site_name || appStore.siteName || 'Sub2API') const siteLogo = computed(() => appStore.cachedPublicSettings?.site_logo || appStore.siteLogo || '') +const siteSubtitle = computed(() => appStore.cachedPublicSettings?.site_subtitle || 'AI API Gateway Platform') const docUrl = computed(() => appStore.cachedPublicSettings?.doc_url || appStore.docUrl || '') const homeContent = computed(() => appStore.cachedPublicSettings?.home_content || '') @@ -556,6 +432,9 @@ const isHomeContentUrl = computed(() => { // Theme const isDark = ref(document.documentElement.classList.contains('dark')) +// GitHub URL +const githubUrl = 'https://github.com/Wei-Shaw/sub2api' + // Auth state const isAuthenticated = computed(() => authStore.isAuthenticated) const isAdmin = computed(() => authStore.isAdmin) diff --git a/frontend/src/views/admin/SettingsView.vue b/frontend/src/views/admin/SettingsView.vue index 5d562eda..93beacc5 100644 --- a/frontend/src/views/admin/SettingsView.vue +++ b/frontend/src/views/admin/SettingsView.vue @@ -3751,90 +3751,91 @@
-
-
-

- {{ t("admin.settings.features.channelMonitor.title") }} -

-

- {{ t("admin.settings.features.channelMonitor.description") }} -

-

- - {{ t("admin.settings.features.channelMonitor.configureLink") }} - - -

-
-
-
-
- -

- {{ t("admin.settings.features.channelMonitor.enabledHint") }} -

-
- -
-
- { "admin.settings.payment.findProvider": "查看支持的支付方式", "admin.settings.openaiExperimentalScheduler.title": "OpenAI 实验调度策略", "admin.settings.openaiExperimentalScheduler.description": "默认关闭。开启后仅影响本网关在 OpenAI 账号间的实验性调度选择逻辑,不代表上游 OpenAI 官方能力。", + "admin.settings.site.uploadImage": "上传图片", + "admin.settings.site.remove": "移除", }; return { ...actual, @@ -240,6 +242,37 @@ const SelectStub = defineComponent({ }, }); +const ImageUploadStub = defineComponent({ + props: { + modelValue: { + type: String, + default: "", + }, + uploadLabel: { + type: String, + default: "", + }, + removeLabel: { + type: String, + default: "", + }, + placeholder: { + type: String, + default: "", + }, + }, + setup(props) { + return () => + h("div", { + class: "image-upload-stub", + "data-model-value": props.modelValue, + "data-upload-label": props.uploadLabel, + "data-remove-label": props.removeLabel, + "data-placeholder": props.placeholder, + }); + }, +}); + const baseSettingsResponse = { registration_enabled: true, email_verify_enabled: false, @@ -375,7 +408,7 @@ function mountView() { GroupBadge: true, GroupOptionItem: true, ProxySelector: true, - ImageUpload: true, + ImageUpload: ImageUploadStub, BackupSettings: true, }, }, @@ -582,7 +615,7 @@ describe("admin SettingsView payment visible method controls", () => { GroupBadge: true, GroupOptionItem: true, ProxySelector: true, - ImageUpload: true, + ImageUpload: ImageUploadStub, BackupSettings: true, }, }, @@ -608,6 +641,24 @@ describe("admin SettingsView payment visible method controls", () => { ); expect(wrapper.text()).not.toContain("OpenAI 高级调度器"); }); + + it("passes translated upload and remove labels to the payment help image uploader", async () => { + const wrapper = mountView(); + + await flushPromises(); + await openPaymentTab(wrapper); + + const imageUploads = wrapper.findAll(".image-upload-stub"); + expect(imageUploads.length).toBeGreaterThan(0); + + const paymentHelpImageUpload = imageUploads.find( + (node) => node.attributes("data-placeholder") === "admin.settings.payment.helpImagePlaceholder", + ); + + expect(paymentHelpImageUpload).toBeDefined(); + expect(paymentHelpImageUpload?.attributes("data-upload-label")).toBe("上传图片"); + expect(paymentHelpImageUpload?.attributes("data-remove-label")).toBe("移除"); + }); }); describe("admin SettingsView wechat connect controls", () => { diff --git a/frontend/src/views/admin/ops/components/OpsConcurrencyCard.vue b/frontend/src/views/admin/ops/components/OpsConcurrencyCard.vue index ca640ade..c7370ab5 100644 --- a/frontend/src/views/admin/ops/components/OpsConcurrencyCard.vue +++ b/frontend/src/views/admin/ops/components/OpsConcurrencyCard.vue @@ -122,7 +122,6 @@ const platformRows = computed((): SummaryRow[] => { available_accounts: availableAccounts, rate_limited_accounts: safeNumber(avail.rate_limit_count), - error_accounts: safeNumber(avail.error_count), total_concurrency: totalConcurrency, used_concurrency: usedConcurrency, @@ -162,6 +161,7 @@ const groupRows = computed((): SummaryRow[] => { total_accounts: totalAccounts, available_accounts: availableAccounts, rate_limited_accounts: safeNumber(avail.rate_limit_count), + error_accounts: safeNumber(avail.error_count), total_concurrency: totalConcurrency, used_concurrency: usedConcurrency, @@ -329,7 +329,6 @@ function formatDuration(seconds: number): string { } - watch( () => realtimeEnabled.value, async (enabled) => { diff --git a/frontend/src/views/user/PaymentView.vue b/frontend/src/views/user/PaymentView.vue index 1040d3f6..7cb4343d 100644 --- a/frontend/src/views/user/PaymentView.vue +++ b/frontend/src/views/user/PaymentView.vue @@ -311,6 +311,7 @@ interface CreateOrderOptions { wechatResumeToken?: string paymentType?: string isResume?: boolean + mobileQrFallbackAttempted?: boolean } interface WeixinJSBridgeLike { @@ -666,14 +667,15 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n submitting.value = true errorMessage.value = '' errorHintMessage.value = '' + const requestType = normalizeVisibleMethod(options.paymentType || selectedMethod.value) || options.paymentType || selectedMethod.value try { - const requestType = normalizeVisibleMethod(options.paymentType || selectedMethod.value) || options.paymentType || selectedMethod.value const payload = buildCreateOrderPayload({ amount: orderAmount, paymentType: requestType, orderType, planId, origin: typeof window !== 'undefined' ? window.location.origin : '', + isMobile: isMobileDevice(), isWechatBrowser: typeof window !== 'undefined' && /MicroMessenger/i.test(window.navigator.userAgent), }) if (options.openid) { @@ -747,8 +749,20 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n appStore.showInfo(t('payment.qr.cancelled')) resetPayment() } else if (errMsg && !errMsg.includes('ok')) { - applyScenarioError({ reason: 'WECHAT_JSAPI_FAILED', message: errMsg }, visibleMethod) resetPayment() + const fallbackApplied = await attemptMobileQrFallback( + { reason: 'WECHAT_JSAPI_FAILED', message: errMsg }, + { + orderAmount, + orderType, + planId, + paymentType: visibleMethod, + attempted: options.mobileQrFallbackAttempted === true, + }, + ) + if (!fallbackApplied) { + applyScenarioError({ reason: 'WECHAT_JSAPI_FAILED', message: errMsg }, visibleMethod) + } } else { const resultState = { ...decision.paymentState } resetPayment() @@ -756,7 +770,16 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n } } catch (err: unknown) { resetPayment() - throw err + const fallbackApplied = await attemptMobileQrFallback(err, { + orderAmount, + orderType, + planId, + paymentType: visibleMethod, + attempted: options.mobileQrFallbackAttempted === true, + }) + if (!fallbackApplied) { + throw err + } } return } @@ -776,6 +799,14 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n } else if (apiErr.reason === 'CANCEL_RATE_LIMITED') { errorMessage.value = t('payment.errors.cancelRateLimited') errorHintMessage.value = '' + } else if (await attemptMobileQrFallback(err, { + orderAmount, + orderType, + planId, + paymentType: requestType, + attempted: options.mobileQrFallbackAttempted === true, + })) { + return } else { const handled = applyScenarioError( err, @@ -795,6 +826,101 @@ async function createOrder(orderAmount: number, orderType: OrderType, planId?: n } } +interface MobileQrFallbackContext { + orderAmount: number + orderType: OrderType + planId?: number + paymentType: string + attempted: boolean +} + +function shouldFallbackToDesktopQr(err: unknown, paymentMethod: string, attempted: boolean): boolean { + if (attempted || !isMobileDevice()) { + return false + } + + const normalizedMethod = normalizeVisibleMethod(paymentMethod) || paymentMethod + const reason = typeof err === 'object' && err && 'reason' in err && typeof err.reason === 'string' + ? err.reason + : '' + const message = err instanceof Error + ? err.message + : (typeof err === 'object' && err && 'message' in err && typeof err.message === 'string' + ? err.message + : '') + const normalizedMessage = message.toLowerCase() + + if (normalizedMethod === 'wxpay') { + return reason === 'WECHAT_H5_NOT_AUTHORIZED' + || reason === 'WECHAT_PAYMENT_MP_NOT_CONFIGURED' + || reason === 'WECHAT_JSAPI_FAILED' + || reason === 'PAYMENT_GATEWAY_ERROR' + || reason === 'UNHANDLED_PAYMENT_SCENARIO' + || normalizedMessage.includes('weixinjsbridge is unavailable') + || normalizedMessage.includes('wechat_jsapi_unavailable') + } + + if (normalizedMethod === 'alipay') { + return reason === 'PAYMENT_GATEWAY_ERROR' || reason === 'UNHANDLED_PAYMENT_SCENARIO' + } + + return false +} + +async function attemptMobileQrFallback(err: unknown, context: MobileQrFallbackContext): Promise { + if (!shouldFallbackToDesktopQr(err, context.paymentType, context.attempted)) { + return false + } + + try { + const visibleMethod = normalizeVisibleMethod(context.paymentType) || context.paymentType + const payload = buildCreateOrderPayload({ + amount: context.orderAmount, + paymentType: visibleMethod, + orderType: context.orderType, + planId: context.planId, + origin: typeof window !== 'undefined' ? window.location.origin : '', + isMobile: false, + isWechatBrowser: false, + }) + const result = await paymentStore.createOrder(payload) as CreateOrderResult & { resume_token?: string } + const stripeMethod = visibleMethod === 'wxpay' ? 'wechat_pay' : 'alipay' + const stripeRouteUrl = result.client_secret + ? router.resolve({ + path: '/payment/stripe', + query: { + order_id: String(result.order_id), + client_secret: result.client_secret, + method: stripeMethod, + resume_token: result.resume_token || undefined, + }, + }).href + : '' + const decision = decidePaymentLaunch(result, { + visibleMethod, + orderType: context.orderType, + isMobile: false, + isWechatBrowser: false, + stripePopupUrl: stripeRouteUrl, + stripeRouteUrl, + }) + + if (decision.kind !== 'qr_waiting' || !decision.paymentState.qrCode) { + return false + } + + errorMessage.value = '' + errorHintMessage.value = '' + paymentState.value = decision.paymentState + paymentPhase.value = 'paying' + persistRecoverySnapshot(decision.recovery) + appStore.showWarning(t('payment.errors.mobilePaymentFallbackToQr')) + return true + } catch { + return false + } +} + function applyScenarioError(err: unknown, paymentMethod: string): boolean { const descriptor = describePaymentScenarioError(err, { paymentMethod, diff --git a/frontend/src/views/user/__tests__/PaymentView.spec.ts b/frontend/src/views/user/__tests__/PaymentView.spec.ts index d2683161..b4cd2cdd 100644 --- a/frontend/src/views/user/__tests__/PaymentView.spec.ts +++ b/frontend/src/views/user/__tests__/PaymentView.spec.ts @@ -16,6 +16,7 @@ const refreshUser = vi.hoisted(() => vi.fn()) const fetchActiveSubscriptions = vi.hoisted(() => vi.fn().mockResolvedValue(undefined)) const showError = vi.hoisted(() => vi.fn()) const showInfo = vi.hoisted(() => vi.fn()) +const showWarning = vi.hoisted(() => vi.fn()) const getCheckoutInfo = vi.hoisted(() => vi.fn()) const bridgeInvoke = vi.hoisted(() => vi.fn()) @@ -69,6 +70,7 @@ vi.mock('@/stores', () => ({ useAppStore: () => ({ showError, showInfo, + showWarning, }), })) @@ -193,6 +195,7 @@ describe('PaymentView WeChat JSAPI flow', () => { fetchActiveSubscriptions.mockReset().mockResolvedValue(undefined) showError.mockReset() showInfo.mockReset() + showWarning.mockReset() getCheckoutInfo.mockReset().mockResolvedValue(checkoutInfoFixture()) bridgeInvoke.mockReset() window.localStorage.clear() @@ -364,13 +367,24 @@ describe('PaymentView WeChat JSAPI flow', () => { }) }) - it('shows explicit H5 authorization guidance instead of failing silently', async () => { + it('falls back to QR flow when mobile WeChat payment is unavailable', async () => { routeState.query = { wechat_resume: '1', wechat_resume_token: 'resume-token-h5', payment_type: 'wxpay_direct', } - createOrder.mockRejectedValueOnce({ reason: 'WECHAT_H5_NOT_AUTHORIZED' }) + createOrder + .mockRejectedValueOnce({ reason: 'WECHAT_H5_NOT_AUTHORIZED' }) + .mockResolvedValueOnce({ + order_id: 778, + amount: 88, + pay_amount: 88, + fee_rate: 0, + expires_at: '2099-01-01T00:10:00.000Z', + payment_type: 'wxpay', + qr_code: 'weixin://wxpay/bizpayurl?pr=fallback-native', + out_trade_no: 'sub2_qr_778', + }) shallowMount(PaymentView, { global: { @@ -383,8 +397,18 @@ describe('PaymentView WeChat JSAPI flow', () => { await flushPromises() await flushPromises() - expect(showError).toHaveBeenCalledWith( - 'payment.errors.wechatH5NotAuthorized payment.errors.wechatOpenInWeChatHint', - ) + expect(createOrder).toHaveBeenNthCalledWith(1, expect.objectContaining({ + payment_type: 'wxpay', + is_mobile: true, + wechat_resume_token: 'resume-token-h5', + })) + expect(createOrder).toHaveBeenNthCalledWith(2, expect.objectContaining({ + payment_type: 'wxpay', + is_mobile: false, + payment_source: 'hosted_redirect', + })) + expect(showWarning).toHaveBeenCalledWith('payment.errors.mobilePaymentFallbackToQr') + expect(showError).not.toHaveBeenCalled() + expect(window.localStorage.getItem(PAYMENT_RECOVERY_STORAGE_KEY)).toContain('weixin://wxpay/bizpayurl?pr=fallback-native') }) }) diff --git a/frontend/vitest.config.ts b/frontend/vitest.config.ts index 9635c099..39568250 100644 --- a/frontend/vitest.config.ts +++ b/frontend/vitest.config.ts @@ -13,7 +13,6 @@ export default defineConfig({ test: { globals: true, environment: 'jsdom', - setupFiles: ['src/__tests__/setup.ts'], include: ['src/**/*.{test,spec}.{js,ts,jsx,tsx}'], exclude: ['node_modules', 'dist'], coverage: {