From fecfaae8dc510cf22706f1da45fedae0ac57fb59 Mon Sep 17 00:00:00 2001 From: Payne Fu Date: Wed, 4 Feb 2026 15:56:01 +0800 Subject: [PATCH 01/14] fix: remove unsupported safety_identifier and previous_response_id fields from upstream requests Co-Authored-By: Claude Opus 4.5 --- backend/internal/service/openai_gateway_service.go | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 742946d8..4658c694 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -846,10 +846,12 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } } - // Remove prompt_cache_retention (not supported by upstream OpenAI API) - if _, has := reqBody["prompt_cache_retention"]; has { - delete(reqBody, "prompt_cache_retention") - bodyModified = true + // Remove unsupported fields (not supported by upstream OpenAI API) + for _, unsupportedField := range []string{"prompt_cache_retention", "safety_identifier", "previous_response_id"} { + if _, has := reqBody[unsupportedField]; has { + delete(reqBody, unsupportedField) + bodyModified = true + } } } From 97a5c1ac1d190bc3a8396cba37507999c18ee5e9 Mon Sep 17 00:00:00 2001 From: Lemon Date: Wed, 4 Feb 2026 21:40:25 +0800 Subject: [PATCH 02/14] feat: add support for HTTP/2 Cleartext (h2c) connections --- backend/go.mod | 19 +++----- backend/go.sum | 47 ++++++-------------- backend/internal/config/config.go | 2 + backend/internal/server/http.go | 18 +++++++- backend/internal/server/middleware/logger.go | 8 +++- config.yaml | 3 ++ deploy/.env.example | 4 ++ deploy/config.example.yaml | 3 ++ 8 files changed, 55 insertions(+), 49 deletions(-) diff --git a/backend/go.mod b/backend/go.mod index 9a36a0f1..9234071b 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -25,10 +25,10 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/zeromicro/go-zero v1.9.4 - golang.org/x/crypto v0.46.0 - golang.org/x/net v0.48.0 + golang.org/x/crypto v0.47.0 + golang.org/x/net v0.49.0 golang.org/x/sync v0.19.0 - golang.org/x/term v0.38.0 + golang.org/x/term v0.39.0 gopkg.in/yaml.v3 v3.0.1 modernc.org/sqlite v1.44.3 ) @@ -75,12 +75,10 @@ require ( github.com/goccy/go-json v0.10.2 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/go-querystring v1.1.0 // indirect - github.com/google/subcommands v1.2.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl/v2 v2.18.1 // indirect github.com/icholy/digest v1.1.0 // indirect - github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.18.2 // indirect github.com/klauspost/cpuid/v2 v2.2.4 // indirect @@ -89,7 +87,6 @@ require ( github.com/magiconair/properties v1.8.10 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/mattn/go-runewidth v0.0.15 // indirect github.com/mdelapenya/tlscert v0.2.0 // indirect github.com/mitchellh/go-wordwrap v1.0.1 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect @@ -104,7 +101,6 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/morikuni/aec v1.0.0 // indirect github.com/ncruces/go-strftime v1.0.0 // indirect - github.com/olekukonko/tablewriter v0.0.5 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect github.com/pelletier/go-toml/v2 v2.2.2 // indirect @@ -114,7 +110,6 @@ require ( github.com/quic-go/qpack v0.6.0 // indirect github.com/quic-go/quic-go v0.57.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect - github.com/rivo/uniseg v0.2.0 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect github.com/sagikazarmark/slog-shim v0.1.0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect @@ -122,7 +117,6 @@ require ( github.com/spaolacci/murmur3 v1.1.0 // indirect github.com/spf13/afero v1.11.0 // indirect github.com/spf13/cast v1.6.0 // indirect - github.com/spf13/cobra v1.7.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/testcontainers/testcontainers-go v0.40.0 // indirect @@ -146,10 +140,9 @@ require ( go.uber.org/multierr v1.9.0 // indirect golang.org/x/arch v0.3.0 // indirect golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect - golang.org/x/mod v0.30.0 // indirect - golang.org/x/sys v0.39.0 // indirect - golang.org/x/text v0.32.0 // indirect - golang.org/x/tools v0.39.0 // indirect + golang.org/x/mod v0.31.0 // indirect + golang.org/x/sys v0.40.0 // indirect + golang.org/x/text v0.33.0 // indirect google.golang.org/grpc v1.75.1 // indirect google.golang.org/protobuf v1.36.10 // indirect gopkg.in/ini.v1 v1.67.0 // indirect diff --git a/backend/go.sum b/backend/go.sum index 371623ad..171995c7 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -46,7 +46,6 @@ github.com/containerd/platforms v0.2.1 h1:zvwtM3rz2YHPQsF2CHYM8+KtB5dvhISiXh5ZpS github.com/containerd/platforms v0.2.1/go.mod h1:XHCb+2/hzowdiut9rkudds9bE5yJ7npe7dG/wG+uFPw= github.com/cpuguy83/dockercfg v0.3.2 h1:DlJTyZGBDlXqUZ2Dk2Q3xHs/FtnooJJVaad2S9GKorA= github.com/cpuguy83/dockercfg v0.3.2/go.mod h1:sugsbF4//dDlL/i+S+rtpIWp+5h0BHJHfjj5/jFyUJc= -github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -117,8 +116,6 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17 github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= -github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= -github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4= @@ -138,8 +135,6 @@ github.com/icholy/digest v1.1.0 h1:HfGg9Irj7i+IX1o1QAmPfIBNu/Q5A5Tu3n/MED9k9H4= github.com/icholy/digest v1.1.0/go.mod h1:QNrsSGQ5v7v9cReDI0+eyjsXGUoRSUZQHeQ5C4XLa0Y= github.com/imroc/req/v3 v3.57.0 h1:LMTUjNRUybUkTPn8oJDq8Kg3JRBOBTcnDhKu7mzupKI= github.com/imroc/req/v3 v3.57.0/go.mod h1:JL62ey1nvSLq81HORNcosvlf7SxZStONNqOprg0Pz00= -github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= -github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= @@ -175,9 +170,6 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= -github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= -github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= -github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= @@ -211,8 +203,6 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= -github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= -github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= @@ -240,13 +230,10 @@ github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkr github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= -github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= -github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= -github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ= github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= @@ -265,8 +252,6 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= -github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= -github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= @@ -350,14 +335,14 @@ go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTV golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= -golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= -golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= +golang.org/x/crypto v0.47.0 h1:V6e3FRj+n4dbpw86FJ8Fv7XVOql7TEwpHapKoMJ/GO8= +golang.org/x/crypto v0.47.0/go.mod h1:ff3Y9VzzKbwSSEzWqJsJVBnWmRwRSHt/6Op5n9bQc4A= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= -golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk= -golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc= -golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= -golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= +golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI= +golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg= +golang.org/x/net v0.49.0 h1:eeHFmOGUTtaaPSGNmjBKpbng9MulQsJURQUAfUwY++o= +golang.org/x/net v0.49.0/go.mod h1:/ysNB2EvaqvesRkuLAyjI1ycPZlQHM3q01F02UY/MV8= golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -369,20 +354,16 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= -golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q= -golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg= -golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= -golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= +golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.39.0 h1:RclSuaJf32jOqZz74CkPA9qFuVTX7vhLlpfj/IGWlqY= +golang.org/x/term v0.39.0/go.mod h1:yxzUCTP/U+FzoxfdKmLaA0RV1WgE0VY7hXBwKtY/4ww= +golang.org/x/text v0.33.0 h1:B3njUFyqtHDUI5jMn1YIr5B0IE2U0qck04r6d4KPAxE= +golang.org/x/text v0.33.0/go.mod h1:LuMebE6+rBincTi9+xWTY8TztLzKHc/9C1uBCG27+q8= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= -golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= -golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= -golang.org/x/tools/go/expect v0.1.1-deprecated h1:jpBZDwmgPhXsKZC6WhL20P4b/wmnpsEAGHaNy0n/rJM= -golang.org/x/tools/go/expect v0.1.1-deprecated/go.mod h1:eihoPOH+FgIqa3FpoTwguz/bVUSGBlGQU67vpBeOrBY= -golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated h1:1h2MnaIAIXISqTFKdENegdpAgUXz6NrPEsbIeWaBRvM= -golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated/go.mod h1:RVAQXBGNv1ib0J382/DPCRS/BPnsGebyM1Gj5VSDpG8= +golang.org/x/tools v0.40.0 h1:yLkxfA+Qnul4cs9QA3KnlFu0lVmd8JJfoq+E41uSutA= +golang.org/x/tools v0.40.0/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/genproto v0.0.0-20231106174013-bbf56f31fb17 h1:wpZ8pe2x1Q3f2KyT5f8oP/fa9rHAKgFPr/HZdNuS+PQ= google.golang.org/genproto/googleapis/api v0.0.0-20250929231259-57b25ae835d4 h1:8XJ4pajGwOlasW+L13MnEGA8W4115jJySQtVfS2/IBU= diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 84be445b..0790ed06 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -150,6 +150,7 @@ type ServerConfig struct { ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒) IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒) TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表(CIDR/IP) + EnableH2C bool `mapstructure:"enable_h2c"` // 启用 HTTP/2 Cleartext (h2c) } type CORSConfig struct { @@ -687,6 +688,7 @@ func setDefaults() { viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头 viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时 viper.SetDefault("server.trusted_proxies", []string{}) + viper.SetDefault("server.enable_h2c", false) // 默认关闭 h2c // CORS viper.SetDefault("cors.allowed_origins", []string{}) diff --git a/backend/internal/server/http.go b/backend/internal/server/http.go index 52d5c926..f22445b8 100644 --- a/backend/internal/server/http.go +++ b/backend/internal/server/http.go @@ -14,6 +14,8 @@ import ( "github.com/gin-gonic/gin" "github.com/google/wire" "github.com/redis/go-redis/v9" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" ) // ProviderSet 提供服务器层的依赖 @@ -56,9 +58,23 @@ func ProvideRouter( // ProvideHTTPServer 提供 HTTP 服务器 func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server { + httpHandler := http.Handler(router) + + // 根据配置决定是否启用 H2C + if cfg.Server.EnableH2C { + httpHandler = h2c.NewHandler(router, &http2.Server{ + MaxConcurrentStreams: 250, // 最大并发流数量 + IdleTimeout: 300 * time.Second, + MaxReadFrameSize: 4 << 20, // 4MB + MaxUploadBufferPerConnection: 8 << 20, // 8MB + MaxUploadBufferPerStream: 2 << 20, // 2MB + }) + log.Println("HTTP/2 Cleartext (h2c) enabled") + } + return &http.Server{ Addr: cfg.Server.Address(), - Handler: router, + Handler: httpHandler, // ReadHeaderTimeout: 读取请求头的超时时间,防止慢速请求头攻击 ReadHeaderTimeout: time.Duration(cfg.Server.ReadHeaderTimeout) * time.Second, // IdleTimeout: 空闲连接超时时间,释放不活跃的连接资源 diff --git a/backend/internal/server/middleware/logger.go b/backend/internal/server/middleware/logger.go index a9beeb40..842efda9 100644 --- a/backend/internal/server/middleware/logger.go +++ b/backend/internal/server/middleware/logger.go @@ -34,12 +34,16 @@ func Logger() gin.HandlerFunc { // 客户端IP clientIP := c.ClientIP() - // 日志格式: [时间] 状态码 | 延迟 | IP | 方法 路径 - log.Printf("[GIN] %v | %3d | %13v | %15s | %-7s %s", + // 协议版本 + protocol := c.Request.Proto + + // 日志格式: [时间] 状态码 | 延迟 | IP | 协议 | 方法 路径 + log.Printf("[GIN] %v | %3d | %13v | %15s | %-6s | %-7s %s", endTime.Format("2006/01/02 - 15:04:05"), statusCode, latency, clientIP, + protocol, method, path, ) diff --git a/config.yaml b/config.yaml index 19f77221..e79500b7 100644 --- a/config.yaml +++ b/config.yaml @@ -23,6 +23,9 @@ server: # Trusted proxies for X-Forwarded-For parsing (CIDR/IP). Empty disables trusted proxies. # 信任的代理地址(CIDR/IP 格式),用于解析 X-Forwarded-For 头。留空则禁用代理信任。 trusted_proxies: [] + # Enable HTTP/2 Cleartext (h2c) for client connections + # 启用 HTTP/2 Cleartext (h2c) 客户端连接 + enable_h2c: true # ============================================================================= # Run Mode Configuration diff --git a/deploy/.env.example b/deploy/.env.example index 25096c3d..a0c98bec 100644 --- a/deploy/.env.example +++ b/deploy/.env.example @@ -20,6 +20,10 @@ SERVER_PORT=8080 # Server mode: release or debug SERVER_MODE=release +# Enable HTTP/2 Cleartext (h2c) for client connections +# 启用 HTTP/2 Cleartext (h2c) 客户端连接 +SERVER_ENABLE_H2C=true + # 运行模式: standard (默认) 或 simple (内部自用) # standard: 完整 SaaS 功能,包含计费/余额校验;simple: 隐藏 SaaS 功能并跳过计费/余额校验 RUN_MODE=standard diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 6f5e9744..339b3d89 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -23,6 +23,9 @@ server: # Trusted proxies for X-Forwarded-For parsing (CIDR/IP). Empty disables trusted proxies. # 信任的代理地址(CIDR/IP 格式),用于解析 X-Forwarded-For 头。留空则禁用代理信任。 trusted_proxies: [] + # Enable HTTP/2 Cleartext (h2c) for client connections + # 启用 HTTP/2 Cleartext (h2c) 客户端连接 + enable_h2c: true # ============================================================================= # Run Mode Configuration From 05af95dade5ea1def96d280755646f961a46fe53 Mon Sep 17 00:00:00 2001 From: shaw Date: Thu, 5 Feb 2026 09:53:20 +0800 Subject: [PATCH 03/14] =?UTF-8?q?fix(gateway):=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E5=B7=A5=E5=85=B7=E5=90=8D=E8=BD=AC=E6=8D=A2=E7=A0=B4=E5=9D=8F?= =?UTF-8?q?=20Anthropic=20=E7=89=B9=E6=AE=8A=E5=B7=A5=E5=85=B7=E7=9A=84?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 未知工具名不再进行 PascalCase/snake_case 转换,保持原样透传。 修复 text_editor_20250728 等 Anthropic 特殊工具被错误转换的问题。 --- backend/internal/service/gateway_service.go | 46 ++++----------------- 1 file changed, 9 insertions(+), 37 deletions(-) diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index b2ac1efa..8c88c0a9 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -20,7 +20,6 @@ import ( "strings" "sync/atomic" "time" - "unicode" "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/claude" @@ -620,35 +619,6 @@ func stripToolPrefix(value string) string { return toolPrefixRe.ReplaceAllString(value, "") } -func toPascalCase(value string) string { - if value == "" { - return value - } - normalized := toolNameBoundaryRe.ReplaceAllString(value, " ") - tokens := make([]string, 0) - for _, token := range strings.Fields(normalized) { - expanded := toolNameCamelRe.ReplaceAllString(token, "$1 $2") - parts := strings.Fields(expanded) - if len(parts) > 0 { - tokens = append(tokens, parts...) - } - } - if len(tokens) == 0 { - return value - } - var builder strings.Builder - for _, token := range tokens { - lower := strings.ToLower(token) - if lower == "" { - continue - } - runes := []rune(lower) - runes[0] = unicode.ToUpper(runes[0]) - _, _ = builder.WriteString(string(runes)) - } - return builder.String() -} - func toSnakeCase(value string) string { if value == "" { return value @@ -664,16 +634,15 @@ func normalizeToolNameForClaude(name string, cache map[string]string) string { return name } stripped := stripToolPrefix(name) + // 只对已知的工具名进行映射,未知工具名保持原样 + // 避免破坏 Anthropic 特殊工具(如 text_editor_20250728) mapped, ok := claudeToolNameOverrides[strings.ToLower(stripped)] if !ok { - mapped = toPascalCase(stripped) - } - if mapped != "" && cache != nil && mapped != stripped { - cache[mapped] = stripped - } - if mapped == "" { return stripped } + if cache != nil && mapped != stripped { + cache[mapped] = stripped + } return mapped } @@ -682,15 +651,18 @@ func normalizeToolNameForOpenCode(name string, cache map[string]string) string { return name } stripped := stripToolPrefix(name) + // 优先从请求时建立的映射中查找 if cache != nil { if mapped, ok := cache[stripped]; ok { return mapped } } + // 已知工具名的硬编码映射 if mapped, ok := openCodeToolOverrides[stripped]; ok { return mapped } - return toSnakeCase(stripped) + // 未知工具名保持原样,避免破坏 Anthropic 特殊工具 + return stripped } func normalizeParamNameForOpenCode(name string, cache map[string]string) string { From fa3ea5ee4dec28294c40b162638c1dafb2d2e520 Mon Sep 17 00:00:00 2001 From: JIA-ss <627723154@qq.com> Date: Thu, 5 Feb 2026 11:41:25 +0800 Subject: [PATCH 04/14] feat(gateway): filter /v1/usage stats by API Key instead of UserID Previously the /v1/usage endpoint aggregated usage stats (today/total tokens, cost, RPM/TPM) across all API Keys belonging to the user. This made it impossible to distinguish usage from different API Keys (e.g. balance vs subscription keys). Now the usage stats are filtered by the current request's API Key ID, so each key only sees its own usage data. The balance/remaining fields are unaffected and still reflect the user-level wallet balance. Changes: - Add GetAPIKeyDashboardStats to repository interface and implementation - Add getPerformanceStatsByAPIKey helper (also fixes TPM to include cache_creation_tokens and cache_read_tokens) - Add GetAPIKeyDashboardStats to UsageService - Update Usage handler to call GetAPIKeyDashboardStats(apiKey.ID) Co-Authored-By: Claude Opus 4.5 --- backend/internal/handler/gateway_handler.go | 4 +- backend/internal/repository/usage_log_repo.go | 101 ++++++++++++++++++ backend/internal/server/api_contract_test.go | 4 + .../internal/service/account_usage_service.go | 1 + backend/internal/service/usage_service.go | 9 ++ 5 files changed, 117 insertions(+), 2 deletions(-) diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 7cbf30eb..9aa6b72c 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -616,10 +616,10 @@ func (h *GatewayHandler) Usage(c *gin.Context) { return } - // Best-effort: 获取用量统计,失败不影响基础响应 + // Best-effort: 获取用量统计(按当前 API Key 过滤),失败不影响基础响应 var usageData gin.H if h.usageService != nil { - dashStats, err := h.usageService.GetUserDashboardStats(c.Request.Context(), subject.UserID) + dashStats, err := h.usageService.GetAPIKeyDashboardStats(c.Request.Context(), apiKey.ID) if err == nil && dashStats != nil { usageData = gin.H{ "today": gin.H{ diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index dc8f1460..2db1764f 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -1125,6 +1125,107 @@ func (r *usageLogRepository) GetUserDashboardStats(ctx context.Context, userID i return stats, nil } +// getPerformanceStatsByAPIKey 获取指定 API Key 的 RPM 和 TPM(近5分钟平均值) +func (r *usageLogRepository) getPerformanceStatsByAPIKey(ctx context.Context, apiKeyID int64) (rpm, tpm int64, err error) { + fiveMinutesAgo := time.Now().Add(-5 * time.Minute) + query := ` + SELECT + COUNT(*) as request_count, + COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as token_count + FROM usage_logs + WHERE created_at >= $1 AND api_key_id = $2` + args := []any{fiveMinutesAgo, apiKeyID} + + var requestCount int64 + var tokenCount int64 + if err := scanSingleRow(ctx, r.sql, query, args, &requestCount, &tokenCount); err != nil { + return 0, 0, err + } + return requestCount / 5, tokenCount / 5, nil +} + +// GetAPIKeyDashboardStats 获取指定 API Key 的仪表盘统计(按 api_key_id 过滤) +func (r *usageLogRepository) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*UserDashboardStats, error) { + stats := &UserDashboardStats{} + today := timezone.Today() + + // API Key 维度不需要统计 key 数量,设为 1 + stats.TotalAPIKeys = 1 + stats.ActiveAPIKeys = 1 + + // 累计 Token 统计 + totalStatsQuery := ` + SELECT + COUNT(*) as total_requests, + COALESCE(SUM(input_tokens), 0) as total_input_tokens, + COALESCE(SUM(output_tokens), 0) as total_output_tokens, + COALESCE(SUM(cache_creation_tokens), 0) as total_cache_creation_tokens, + COALESCE(SUM(cache_read_tokens), 0) as total_cache_read_tokens, + COALESCE(SUM(total_cost), 0) as total_cost, + COALESCE(SUM(actual_cost), 0) as total_actual_cost, + COALESCE(AVG(duration_ms), 0) as avg_duration_ms + FROM usage_logs + WHERE api_key_id = $1 + ` + if err := scanSingleRow( + ctx, + r.sql, + totalStatsQuery, + []any{apiKeyID}, + &stats.TotalRequests, + &stats.TotalInputTokens, + &stats.TotalOutputTokens, + &stats.TotalCacheCreationTokens, + &stats.TotalCacheReadTokens, + &stats.TotalCost, + &stats.TotalActualCost, + &stats.AverageDurationMs, + ); err != nil { + return nil, err + } + stats.TotalTokens = stats.TotalInputTokens + stats.TotalOutputTokens + stats.TotalCacheCreationTokens + stats.TotalCacheReadTokens + + // 今日 Token 统计 + todayStatsQuery := ` + SELECT + COUNT(*) as today_requests, + COALESCE(SUM(input_tokens), 0) as today_input_tokens, + COALESCE(SUM(output_tokens), 0) as today_output_tokens, + COALESCE(SUM(cache_creation_tokens), 0) as today_cache_creation_tokens, + COALESCE(SUM(cache_read_tokens), 0) as today_cache_read_tokens, + COALESCE(SUM(total_cost), 0) as today_cost, + COALESCE(SUM(actual_cost), 0) as today_actual_cost + FROM usage_logs + WHERE api_key_id = $1 AND created_at >= $2 + ` + if err := scanSingleRow( + ctx, + r.sql, + todayStatsQuery, + []any{apiKeyID, today}, + &stats.TodayRequests, + &stats.TodayInputTokens, + &stats.TodayOutputTokens, + &stats.TodayCacheCreationTokens, + &stats.TodayCacheReadTokens, + &stats.TodayCost, + &stats.TodayActualCost, + ); err != nil { + return nil, err + } + stats.TodayTokens = stats.TodayInputTokens + stats.TodayOutputTokens + stats.TodayCacheCreationTokens + stats.TodayCacheReadTokens + + // 性能指标:RPM 和 TPM(最近5分钟,按 API Key 过滤) + rpm, tpm, err := r.getPerformanceStatsByAPIKey(ctx, apiKeyID) + if err != nil { + return nil, err + } + stats.Rpm = rpm + stats.Tpm = tpm + + return stats, nil +} + // GetUserUsageTrendByUserID 获取指定用户的使用趋势 func (r *usageLogRepository) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) (results []TrendDataPoint, err error) { dateFormat := "YYYY-MM-DD" diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 44264e72..e197b776 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -1610,6 +1610,10 @@ func (r *stubUsageLogRepo) GetUserDashboardStats(ctx context.Context, userID int return nil, errors.New("not implemented") } +func (r *stubUsageLogRepo) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*usagestats.UserDashboardStats, error) { + return nil, errors.New("not implemented") +} + func (r *stubUsageLogRepo) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) { return nil, errors.New("not implemented") } diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index f3b3e20d..304c5781 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -41,6 +41,7 @@ type UsageLogRepository interface { // User dashboard stats GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error) + GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*usagestats.UserDashboardStats, error) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) GetUserModelStats(ctx context.Context, userID int64, startTime, endTime time.Time) ([]usagestats.ModelStat, error) diff --git a/backend/internal/service/usage_service.go b/backend/internal/service/usage_service.go index aa0a5b87..5594e53f 100644 --- a/backend/internal/service/usage_service.go +++ b/backend/internal/service/usage_service.go @@ -288,6 +288,15 @@ func (s *UsageService) GetUserDashboardStats(ctx context.Context, userID int64) return stats, nil } +// GetAPIKeyDashboardStats returns dashboard summary stats filtered by API Key. +func (s *UsageService) GetAPIKeyDashboardStats(ctx context.Context, apiKeyID int64) (*usagestats.UserDashboardStats, error) { + stats, err := s.usageRepo.GetAPIKeyDashboardStats(ctx, apiKeyID) + if err != nil { + return nil, fmt.Errorf("get api key dashboard stats: %w", err) + } + return stats, nil +} + // GetUserUsageTrendByUserID returns per-user usage trend. func (s *UsageService) GetUserUsageTrendByUserID(ctx context.Context, userID int64, startTime, endTime time.Time, granularity string) ([]usagestats.TrendDataPoint, error) { trend, err := s.usageRepo.GetUserUsageTrendByUserID(ctx, userID, startTime, endTime, granularity) From 49a3c43741be43a84147075f1d9f1e0a7500d84b Mon Sep 17 00:00:00 2001 From: shaw Date: Thu, 5 Feb 2026 12:38:48 +0800 Subject: [PATCH 05/14] =?UTF-8?q?feat(auth):=20=E5=AE=9E=E7=8E=B0=20Refres?= =?UTF-8?q?h=20Token=20=E6=9C=BA=E5=88=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 Access Token + Refresh Token 双令牌认证 - 支持 Token 自动刷新和轮转 - 添加登出和撤销所有会话接口 - 前端实现无感刷新和主动刷新定时器 --- backend/cmd/jwtgen/main.go | 2 +- backend/cmd/server/wire_gen.go | 5 +- backend/internal/config/config.go | 26 ++ backend/internal/handler/auth_handler.go | 157 +++++++-- .../internal/handler/auth_linuxdo_oauth.go | 6 +- .../repository/refresh_token_cache.go | 158 +++++++++ backend/internal/repository/wire.go | 1 + backend/internal/server/routes/auth.go | 8 + backend/internal/service/auth_service.go | 317 +++++++++++++++++- .../service/auth_service_register_test.go | 1 + .../internal/service/refresh_token_cache.go | 73 ++++ frontend/src/api/auth.ts | 116 ++++++- frontend/src/api/client.ts | 137 +++++++- frontend/src/components/layout/AppHeader.vue | 7 +- frontend/src/stores/auth.ts | 163 +++++++-- frontend/src/types/index.ts | 2 + .../src/views/auth/LinuxDoCallbackView.vue | 13 + 17 files changed, 1119 insertions(+), 73 deletions(-) create mode 100644 backend/internal/repository/refresh_token_cache.go create mode 100644 backend/internal/service/refresh_token_cache.go diff --git a/backend/cmd/jwtgen/main.go b/backend/cmd/jwtgen/main.go index 139a3a39..ce4718bf 100644 --- a/backend/cmd/jwtgen/main.go +++ b/backend/cmd/jwtgen/main.go @@ -33,7 +33,7 @@ func main() { }() userRepo := repository.NewUserRepository(client, sqlDB) - authService := service.NewAuthService(userRepo, nil, cfg, nil, nil, nil, nil, nil) + authService := service.NewAuthService(userRepo, nil, nil, cfg, nil, nil, nil, nil, nil) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index ab51540f..47b1e8ac 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -44,9 +44,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { } userRepository := repository.NewUserRepository(client, db) redeemCodeRepository := repository.NewRedeemCodeRepository(client) + redisClient := repository.ProvideRedis(configConfig) + refreshTokenCache := repository.NewRefreshTokenCache(redisClient) settingRepository := repository.NewSettingRepository(client) settingService := service.NewSettingService(settingRepository, configConfig) - redisClient := repository.ProvideRedis(configConfig) emailCache := repository.NewEmailCache(redisClient) emailService := service.NewEmailService(settingRepository, emailCache) turnstileVerifier := repository.NewTurnstileVerifier() @@ -62,7 +63,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) - authService := service.NewAuthService(userRepository, redeemCodeRepository, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) + authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) userService := service.NewUserService(userRepository, apiKeyAuthCacheInvalidator) subscriptionService := service.NewSubscriptionService(groupRepository, userSubscriptionRepository, billingCacheService) redeemCache := repository.NewRedeemCache(redisClient) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 84be445b..25258b23 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -467,6 +467,13 @@ type OpsMetricsCollectorCacheConfig struct { type JWTConfig struct { Secret string `mapstructure:"secret"` ExpireHour int `mapstructure:"expire_hour"` + // AccessTokenExpireMinutes: Access Token有效期(分钟),默认15分钟 + // 短有效期减少被盗用风险,配合Refresh Token实现无感续期 + AccessTokenExpireMinutes int `mapstructure:"access_token_expire_minutes"` + // RefreshTokenExpireDays: Refresh Token有效期(天),默认30天 + RefreshTokenExpireDays int `mapstructure:"refresh_token_expire_days"` + // RefreshWindowMinutes: 刷新窗口(分钟),在Access Token过期前多久开始允许刷新 + RefreshWindowMinutes int `mapstructure:"refresh_window_minutes"` } // TotpConfig TOTP 双因素认证配置 @@ -783,6 +790,9 @@ func setDefaults() { // JWT viper.SetDefault("jwt.secret", "") viper.SetDefault("jwt.expire_hour", 24) + viper.SetDefault("jwt.access_token_expire_minutes", 360) // 6小时Access Token有效期 + viper.SetDefault("jwt.refresh_token_expire_days", 30) // 30天Refresh Token有效期 + viper.SetDefault("jwt.refresh_window_minutes", 2) // 过期前2分钟开始允许刷新 // TOTP viper.SetDefault("totp.encryption_key", "") @@ -912,6 +922,22 @@ func (c *Config) Validate() error { if c.JWT.ExpireHour > 24 { log.Printf("Warning: jwt.expire_hour is %d hours (> 24). Consider shorter expiration for security.", c.JWT.ExpireHour) } + // JWT Refresh Token配置验证 + if c.JWT.AccessTokenExpireMinutes <= 0 { + return fmt.Errorf("jwt.access_token_expire_minutes must be positive") + } + if c.JWT.AccessTokenExpireMinutes > 720 { + log.Printf("Warning: jwt.access_token_expire_minutes is %d (> 720). Consider shorter expiration for security.", c.JWT.AccessTokenExpireMinutes) + } + if c.JWT.RefreshTokenExpireDays <= 0 { + return fmt.Errorf("jwt.refresh_token_expire_days must be positive") + } + if c.JWT.RefreshTokenExpireDays > 90 { + log.Printf("Warning: jwt.refresh_token_expire_days is %d (> 90). Consider shorter expiration for security.", c.JWT.RefreshTokenExpireDays) + } + if c.JWT.RefreshWindowMinutes < 0 { + return fmt.Errorf("jwt.refresh_window_minutes must be non-negative") + } if c.Security.CSP.Enabled && strings.TrimSpace(c.Security.CSP.Policy) == "" { return fmt.Errorf("security.csp.policy is required when CSP is enabled") } diff --git a/backend/internal/handler/auth_handler.go b/backend/internal/handler/auth_handler.go index 75ea9f08..34ed63bc 100644 --- a/backend/internal/handler/auth_handler.go +++ b/backend/internal/handler/auth_handler.go @@ -68,9 +68,39 @@ type LoginRequest struct { // AuthResponse 认证响应格式(匹配前端期望) type AuthResponse struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - User *dto.User `json:"user"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token,omitempty"` // 新增:Refresh Token + ExpiresIn int `json:"expires_in,omitempty"` // 新增:Access Token有效期(秒) + TokenType string `json:"token_type"` + User *dto.User `json:"user"` +} + +// respondWithTokenPair 生成 Token 对并返回认证响应 +// 如果 Token 对生成失败,回退到只返回 Access Token(向后兼容) +func (h *AuthHandler) respondWithTokenPair(c *gin.Context, user *service.User) { + tokenPair, err := h.authService.GenerateTokenPair(c.Request.Context(), user, "") + if err != nil { + slog.Error("failed to generate token pair", "error", err, "user_id", user.ID) + // 回退到只返回Access Token + token, tokenErr := h.authService.GenerateToken(user) + if tokenErr != nil { + response.InternalError(c, "Failed to generate token") + return + } + response.Success(c, AuthResponse{ + AccessToken: token, + TokenType: "Bearer", + User: dto.UserFromService(user), + }) + return + } + response.Success(c, AuthResponse{ + AccessToken: tokenPair.AccessToken, + RefreshToken: tokenPair.RefreshToken, + ExpiresIn: tokenPair.ExpiresIn, + TokenType: "Bearer", + User: dto.UserFromService(user), + }) } // Register handles user registration @@ -90,17 +120,13 @@ func (h *AuthHandler) Register(c *gin.Context) { } } - token, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode) + _, user, err := h.authService.RegisterWithVerification(c.Request.Context(), req.Email, req.Password, req.VerifyCode, req.PromoCode, req.InvitationCode) if err != nil { response.ErrorFrom(c, err) return } - response.Success(c, AuthResponse{ - AccessToken: token, - TokenType: "Bearer", - User: dto.UserFromService(user), - }) + h.respondWithTokenPair(c, user) } // SendVerifyCode 发送邮箱验证码 @@ -150,6 +176,7 @@ func (h *AuthHandler) Login(c *gin.Context) { response.ErrorFrom(c, err) return } + _ = token // token 由 authService.Login 返回但此处由 respondWithTokenPair 重新生成 // Check if TOTP 2FA is enabled for this user if h.totpService != nil && h.settingSvc.IsTotpEnabled(c.Request.Context()) && user.TotpEnabled { @@ -168,11 +195,7 @@ func (h *AuthHandler) Login(c *gin.Context) { return } - response.Success(c, AuthResponse{ - AccessToken: token, - TokenType: "Bearer", - User: dto.UserFromService(user), - }) + h.respondWithTokenPair(c, user) } // TotpLoginResponse represents the response when 2FA is required @@ -238,18 +261,7 @@ func (h *AuthHandler) Login2FA(c *gin.Context) { return } - // Generate the JWT token - token, err := h.authService.GenerateToken(user) - if err != nil { - response.InternalError(c, "Failed to generate token") - return - } - - response.Success(c, AuthResponse{ - AccessToken: token, - TokenType: "Bearer", - User: dto.UserFromService(user), - }) + h.respondWithTokenPair(c, user) } // GetCurrentUser handles getting current authenticated user @@ -491,3 +503,96 @@ func (h *AuthHandler) ResetPassword(c *gin.Context) { Message: "Your password has been reset successfully. You can now log in with your new password.", }) } + +// ==================== Token Refresh Endpoints ==================== + +// RefreshTokenRequest 刷新Token请求 +type RefreshTokenRequest struct { + RefreshToken string `json:"refresh_token" binding:"required"` +} + +// RefreshTokenResponse 刷新Token响应 +type RefreshTokenResponse struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` // Access Token有效期(秒) + TokenType string `json:"token_type"` +} + +// RefreshToken 刷新Token +// POST /api/v1/auth/refresh +func (h *AuthHandler) RefreshToken(c *gin.Context) { + var req RefreshTokenRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + tokenPair, err := h.authService.RefreshTokenPair(c.Request.Context(), req.RefreshToken) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, RefreshTokenResponse{ + AccessToken: tokenPair.AccessToken, + RefreshToken: tokenPair.RefreshToken, + ExpiresIn: tokenPair.ExpiresIn, + TokenType: "Bearer", + }) +} + +// LogoutRequest 登出请求 +type LogoutRequest struct { + RefreshToken string `json:"refresh_token,omitempty"` // 可选:撤销指定的Refresh Token +} + +// LogoutResponse 登出响应 +type LogoutResponse struct { + Message string `json:"message"` +} + +// Logout 用户登出 +// POST /api/v1/auth/logout +func (h *AuthHandler) Logout(c *gin.Context) { + var req LogoutRequest + // 允许空请求体(向后兼容) + _ = c.ShouldBindJSON(&req) + + // 如果提供了Refresh Token,撤销它 + if req.RefreshToken != "" { + if err := h.authService.RevokeRefreshToken(c.Request.Context(), req.RefreshToken); err != nil { + slog.Debug("failed to revoke refresh token", "error", err) + // 不影响登出流程 + } + } + + response.Success(c, LogoutResponse{ + Message: "Logged out successfully", + }) +} + +// RevokeAllSessionsResponse 撤销所有会话响应 +type RevokeAllSessionsResponse struct { + Message string `json:"message"` +} + +// RevokeAllSessions 撤销当前用户的所有会话 +// POST /api/v1/auth/revoke-all-sessions +func (h *AuthHandler) RevokeAllSessions(c *gin.Context) { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "User not authenticated") + return + } + + if err := h.authService.RevokeAllUserSessions(c.Request.Context(), subject.UserID); err != nil { + slog.Error("failed to revoke all sessions", "user_id", subject.UserID, "error", err) + response.InternalError(c, "Failed to revoke sessions") + return + } + + response.Success(c, RevokeAllSessionsResponse{ + Message: "All sessions have been revoked. Please log in again.", + }) +} diff --git a/backend/internal/handler/auth_linuxdo_oauth.go b/backend/internal/handler/auth_linuxdo_oauth.go index a16c4cc7..0ccf47e4 100644 --- a/backend/internal/handler/auth_linuxdo_oauth.go +++ b/backend/internal/handler/auth_linuxdo_oauth.go @@ -211,7 +211,7 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { email = linuxDoSyntheticEmail(subject) } - jwtToken, _, err := h.authService.LoginOrRegisterOAuth(c.Request.Context(), email, username) + tokenPair, _, err := h.authService.LoginOrRegisterOAuthWithTokenPair(c.Request.Context(), email, username) if err != nil { // 避免把内部细节泄露给客户端;给前端保留结构化原因与提示信息即可。 redirectOAuthError(c, frontendCallback, "login_failed", infraerrors.Reason(err), infraerrors.Message(err)) @@ -219,7 +219,9 @@ func (h *AuthHandler) LinuxDoOAuthCallback(c *gin.Context) { } fragment := url.Values{} - fragment.Set("access_token", jwtToken) + fragment.Set("access_token", tokenPair.AccessToken) + fragment.Set("refresh_token", tokenPair.RefreshToken) + fragment.Set("expires_in", fmt.Sprintf("%d", tokenPair.ExpiresIn)) fragment.Set("token_type", "Bearer") fragment.Set("redirect", redirectTo) redirectWithFragment(c, frontendCallback, fragment) diff --git a/backend/internal/repository/refresh_token_cache.go b/backend/internal/repository/refresh_token_cache.go new file mode 100644 index 00000000..b01bd476 --- /dev/null +++ b/backend/internal/repository/refresh_token_cache.go @@ -0,0 +1,158 @@ +package repository + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/redis/go-redis/v9" +) + +const ( + refreshTokenKeyPrefix = "refresh_token:" + userRefreshTokensPrefix = "user_refresh_tokens:" + tokenFamilyPrefix = "token_family:" +) + +// refreshTokenKey generates the Redis key for a refresh token. +func refreshTokenKey(tokenHash string) string { + return refreshTokenKeyPrefix + tokenHash +} + +// userRefreshTokensKey generates the Redis key for user's token set. +func userRefreshTokensKey(userID int64) string { + return fmt.Sprintf("%s%d", userRefreshTokensPrefix, userID) +} + +// tokenFamilyKey generates the Redis key for token family set. +func tokenFamilyKey(familyID string) string { + return tokenFamilyPrefix + familyID +} + +type refreshTokenCache struct { + rdb *redis.Client +} + +// NewRefreshTokenCache creates a new RefreshTokenCache implementation. +func NewRefreshTokenCache(rdb *redis.Client) service.RefreshTokenCache { + return &refreshTokenCache{rdb: rdb} +} + +func (c *refreshTokenCache) StoreRefreshToken(ctx context.Context, tokenHash string, data *service.RefreshTokenData, ttl time.Duration) error { + key := refreshTokenKey(tokenHash) + val, err := json.Marshal(data) + if err != nil { + return fmt.Errorf("marshal refresh token data: %w", err) + } + return c.rdb.Set(ctx, key, val, ttl).Err() +} + +func (c *refreshTokenCache) GetRefreshToken(ctx context.Context, tokenHash string) (*service.RefreshTokenData, error) { + key := refreshTokenKey(tokenHash) + val, err := c.rdb.Get(ctx, key).Result() + if err != nil { + if err == redis.Nil { + return nil, service.ErrRefreshTokenNotFound + } + return nil, err + } + var data service.RefreshTokenData + if err := json.Unmarshal([]byte(val), &data); err != nil { + return nil, fmt.Errorf("unmarshal refresh token data: %w", err) + } + return &data, nil +} + +func (c *refreshTokenCache) DeleteRefreshToken(ctx context.Context, tokenHash string) error { + key := refreshTokenKey(tokenHash) + return c.rdb.Del(ctx, key).Err() +} + +func (c *refreshTokenCache) DeleteUserRefreshTokens(ctx context.Context, userID int64) error { + // Get all token hashes for this user + tokenHashes, err := c.GetUserTokenHashes(ctx, userID) + if err != nil && err != redis.Nil { + return fmt.Errorf("get user token hashes: %w", err) + } + + if len(tokenHashes) == 0 { + return nil + } + + // Build keys to delete + keys := make([]string, 0, len(tokenHashes)+1) + for _, hash := range tokenHashes { + keys = append(keys, refreshTokenKey(hash)) + } + keys = append(keys, userRefreshTokensKey(userID)) + + // Delete all keys in a pipeline + pipe := c.rdb.Pipeline() + for _, key := range keys { + pipe.Del(ctx, key) + } + _, err = pipe.Exec(ctx) + return err +} + +func (c *refreshTokenCache) DeleteTokenFamily(ctx context.Context, familyID string) error { + // Get all token hashes in this family + tokenHashes, err := c.GetFamilyTokenHashes(ctx, familyID) + if err != nil && err != redis.Nil { + return fmt.Errorf("get family token hashes: %w", err) + } + + if len(tokenHashes) == 0 { + return nil + } + + // Build keys to delete + keys := make([]string, 0, len(tokenHashes)+1) + for _, hash := range tokenHashes { + keys = append(keys, refreshTokenKey(hash)) + } + keys = append(keys, tokenFamilyKey(familyID)) + + // Delete all keys in a pipeline + pipe := c.rdb.Pipeline() + for _, key := range keys { + pipe.Del(ctx, key) + } + _, err = pipe.Exec(ctx) + return err +} + +func (c *refreshTokenCache) AddToUserTokenSet(ctx context.Context, userID int64, tokenHash string, ttl time.Duration) error { + key := userRefreshTokensKey(userID) + pipe := c.rdb.Pipeline() + pipe.SAdd(ctx, key, tokenHash) + pipe.Expire(ctx, key, ttl) + _, err := pipe.Exec(ctx) + return err +} + +func (c *refreshTokenCache) AddToFamilyTokenSet(ctx context.Context, familyID string, tokenHash string, ttl time.Duration) error { + key := tokenFamilyKey(familyID) + pipe := c.rdb.Pipeline() + pipe.SAdd(ctx, key, tokenHash) + pipe.Expire(ctx, key, ttl) + _, err := pipe.Exec(ctx) + return err +} + +func (c *refreshTokenCache) GetUserTokenHashes(ctx context.Context, userID int64) ([]string, error) { + key := userRefreshTokensKey(userID) + return c.rdb.SMembers(ctx, key).Result() +} + +func (c *refreshTokenCache) GetFamilyTokenHashes(ctx context.Context, familyID string) ([]string, error) { + key := tokenFamilyKey(familyID) + return c.rdb.SMembers(ctx, key).Result() +} + +func (c *refreshTokenCache) IsTokenInFamily(ctx context.Context, familyID string, tokenHash string) (bool, error) { + key := tokenFamilyKey(familyID) + return c.rdb.SIsMember(ctx, key, tokenHash).Result() +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index e3394361..857ce3e8 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -85,6 +85,7 @@ var ProviderSet = wire.NewSet( NewSchedulerOutboxRepository, NewProxyLatencyCache, NewTotpCache, + NewRefreshTokenCache, // Encryptors NewAESEncryptor, diff --git a/backend/internal/server/routes/auth.go b/backend/internal/server/routes/auth.go index 24f6d549..26d79605 100644 --- a/backend/internal/server/routes/auth.go +++ b/backend/internal/server/routes/auth.go @@ -28,6 +28,12 @@ func RegisterAuthRoutes( auth.POST("/login", h.Auth.Login) auth.POST("/login/2fa", h.Auth.Login2FA) auth.POST("/send-verify-code", h.Auth.SendVerifyCode) + // Token刷新接口添加速率限制:每分钟最多 30 次(Redis 故障时 fail-close) + auth.POST("/refresh", rateLimiter.LimitWithOptions("refresh-token", 30, time.Minute, middleware.RateLimitOptions{ + FailureMode: middleware.RateLimitFailClose, + }), h.Auth.RefreshToken) + // 登出接口(公开,允许未认证用户调用以撤销Refresh Token) + auth.POST("/logout", h.Auth.Logout) // 优惠码验证接口添加速率限制:每分钟最多 10 次(Redis 故障时 fail-close) auth.POST("/validate-promo-code", rateLimiter.LimitWithOptions("validate-promo", 10, time.Minute, middleware.RateLimitOptions{ FailureMode: middleware.RateLimitFailClose, @@ -59,5 +65,7 @@ func RegisterAuthRoutes( authenticated.Use(gin.HandlerFunc(jwtAuth)) { authenticated.GET("/auth/me", h.Auth.GetCurrentUser) + // 撤销所有会话(需要认证) + authenticated.POST("/auth/revoke-all-sessions", h.Auth.RevokeAllSessions) } } diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 25604d2c..fb8aaf9c 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -3,6 +3,7 @@ package service import ( "context" "crypto/rand" + "crypto/sha256" "encoding/hex" "errors" "fmt" @@ -25,8 +26,12 @@ var ( ErrEmailReserved = infraerrors.BadRequest("EMAIL_RESERVED", "email is reserved") ErrInvalidToken = infraerrors.Unauthorized("INVALID_TOKEN", "invalid token") ErrTokenExpired = infraerrors.Unauthorized("TOKEN_EXPIRED", "token has expired") + ErrAccessTokenExpired = infraerrors.Unauthorized("ACCESS_TOKEN_EXPIRED", "access token has expired") ErrTokenTooLarge = infraerrors.BadRequest("TOKEN_TOO_LARGE", "token too large") ErrTokenRevoked = infraerrors.Unauthorized("TOKEN_REVOKED", "token has been revoked") + ErrRefreshTokenInvalid = infraerrors.Unauthorized("REFRESH_TOKEN_INVALID", "invalid refresh token") + ErrRefreshTokenExpired = infraerrors.Unauthorized("REFRESH_TOKEN_EXPIRED", "refresh token has expired") + ErrRefreshTokenReused = infraerrors.Unauthorized("REFRESH_TOKEN_REUSED", "refresh token has been reused") ErrEmailVerifyRequired = infraerrors.BadRequest("EMAIL_VERIFY_REQUIRED", "email verification is required") ErrRegDisabled = infraerrors.Forbidden("REGISTRATION_DISABLED", "registration is currently disabled") ErrServiceUnavailable = infraerrors.ServiceUnavailable("SERVICE_UNAVAILABLE", "service temporarily unavailable") @@ -37,6 +42,9 @@ var ( // maxTokenLength 限制 token 大小,避免超长 header 触发解析时的异常内存分配。 const maxTokenLength = 8192 +// refreshTokenPrefix is the prefix for refresh tokens to distinguish them from access tokens. +const refreshTokenPrefix = "rt_" + // JWTClaims JWT载荷数据 type JWTClaims struct { UserID int64 `json:"user_id"` @@ -50,6 +58,7 @@ type JWTClaims struct { type AuthService struct { userRepo UserRepository redeemRepo RedeemCodeRepository + refreshTokenCache RefreshTokenCache cfg *config.Config settingService *SettingService emailService *EmailService @@ -62,6 +71,7 @@ type AuthService struct { func NewAuthService( userRepo UserRepository, redeemRepo RedeemCodeRepository, + refreshTokenCache RefreshTokenCache, cfg *config.Config, settingService *SettingService, emailService *EmailService, @@ -72,6 +82,7 @@ func NewAuthService( return &AuthService{ userRepo: userRepo, redeemRepo: redeemRepo, + refreshTokenCache: refreshTokenCache, cfg: cfg, settingService: settingService, emailService: emailService, @@ -481,6 +492,100 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username return token, user, nil } +// LoginOrRegisterOAuthWithTokenPair 用于第三方 OAuth/SSO 登录,返回完整的 TokenPair +// 与 LoginOrRegisterOAuth 功能相同,但返回 TokenPair 而非单个 token +func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, email, username string) (*TokenPair, *User, error) { + // 检查 refreshTokenCache 是否可用 + if s.refreshTokenCache == nil { + return nil, nil, errors.New("refresh token cache not configured") + } + + email = strings.TrimSpace(email) + if email == "" || len(email) > 255 { + return nil, nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email") + } + if _, err := mail.ParseAddress(email); err != nil { + return nil, nil, infraerrors.BadRequest("INVALID_EMAIL", "invalid email") + } + + username = strings.TrimSpace(username) + if len([]rune(username)) > 100 { + username = string([]rune(username)[:100]) + } + + user, err := s.userRepo.GetByEmail(ctx, email) + if err != nil { + if errors.Is(err, ErrUserNotFound) { + // OAuth 首次登录视为注册 + if s.settingService == nil || !s.settingService.IsRegistrationEnabled(ctx) { + return nil, nil, ErrRegDisabled + } + + randomPassword, err := randomHexString(32) + if err != nil { + log.Printf("[Auth] Failed to generate random password for oauth signup: %v", err) + return nil, nil, ErrServiceUnavailable + } + hashedPassword, err := s.HashPassword(randomPassword) + if err != nil { + return nil, nil, fmt.Errorf("hash password: %w", err) + } + + defaultBalance := s.cfg.Default.UserBalance + defaultConcurrency := s.cfg.Default.UserConcurrency + if s.settingService != nil { + defaultBalance = s.settingService.GetDefaultBalance(ctx) + defaultConcurrency = s.settingService.GetDefaultConcurrency(ctx) + } + + newUser := &User{ + Email: email, + Username: username, + PasswordHash: hashedPassword, + Role: RoleUser, + Balance: defaultBalance, + Concurrency: defaultConcurrency, + Status: StatusActive, + } + + if err := s.userRepo.Create(ctx, newUser); err != nil { + if errors.Is(err, ErrEmailExists) { + user, err = s.userRepo.GetByEmail(ctx, email) + if err != nil { + log.Printf("[Auth] Database error getting user after conflict: %v", err) + return nil, nil, ErrServiceUnavailable + } + } else { + log.Printf("[Auth] Database error creating oauth user: %v", err) + return nil, nil, ErrServiceUnavailable + } + } else { + user = newUser + } + } else { + log.Printf("[Auth] Database error during oauth login: %v", err) + return nil, nil, ErrServiceUnavailable + } + } + + if !user.IsActive() { + return nil, nil, ErrUserNotActive + } + + if user.Username == "" && username != "" { + user.Username = username + if err := s.userRepo.Update(ctx, user); err != nil { + log.Printf("[Auth] Failed to update username after oauth login: %v", err) + } + } + + tokenPair, err := s.GenerateTokenPair(ctx, user, "") + if err != nil { + return nil, nil, fmt.Errorf("generate token pair: %w", err) + } + return tokenPair, user, nil +} + // ValidateToken 验证JWT token并返回用户声明 func (s *AuthService) ValidateToken(tokenString string) (*JWTClaims, error) { // 先做长度校验,尽早拒绝异常超长 token,降低 DoS 风险。 @@ -539,10 +644,17 @@ func isReservedEmail(email string) bool { return strings.HasSuffix(normalized, LinuxDoConnectSyntheticEmailDomain) } -// GenerateToken 生成JWT token +// GenerateToken 生成JWT access token +// 使用新的access_token_expire_minutes配置项(如果配置了),否则回退到expire_hour func (s *AuthService) GenerateToken(user *User) (string, error) { now := time.Now() - expiresAt := now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour) + var expiresAt time.Time + if s.cfg.JWT.AccessTokenExpireMinutes > 0 { + expiresAt = now.Add(time.Duration(s.cfg.JWT.AccessTokenExpireMinutes) * time.Minute) + } else { + // 向后兼容:使用旧的expire_hour配置 + expiresAt = now.Add(time.Duration(s.cfg.JWT.ExpireHour) * time.Hour) + } claims := &JWTClaims{ UserID: user.ID, @@ -565,6 +677,15 @@ func (s *AuthService) GenerateToken(user *User) (string, error) { return tokenString, nil } +// GetAccessTokenExpiresIn 返回Access Token的有效期(秒) +// 用于前端设置刷新定时器 +func (s *AuthService) GetAccessTokenExpiresIn() int { + if s.cfg.JWT.AccessTokenExpireMinutes > 0 { + return s.cfg.JWT.AccessTokenExpireMinutes * 60 + } + return s.cfg.JWT.ExpireHour * 3600 +} + // HashPassword 使用bcrypt加密密码 func (s *AuthService) HashPassword(password string) (string, error) { hashedBytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) @@ -755,6 +876,198 @@ func (s *AuthService) ResetPassword(ctx context.Context, email, token, newPasswo return ErrServiceUnavailable } + // Also revoke all refresh tokens for this user + if err := s.RevokeAllUserSessions(ctx, user.ID); err != nil { + log.Printf("[Auth] Failed to revoke refresh tokens for user %d: %v", user.ID, err) + // Don't return error - password was already changed successfully + } + log.Printf("[Auth] Password reset successful for user: %s", email) return nil } + +// ==================== Refresh Token Methods ==================== + +// TokenPair 包含Access Token和Refresh Token +type TokenPair struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` // Access Token有效期(秒) +} + +// GenerateTokenPair 生成Access Token和Refresh Token对 +// familyID: 可选的Token家族ID,用于Token轮转时保持家族关系 +func (s *AuthService) GenerateTokenPair(ctx context.Context, user *User, familyID string) (*TokenPair, error) { + // 检查 refreshTokenCache 是否可用 + if s.refreshTokenCache == nil { + return nil, errors.New("refresh token cache not configured") + } + + // 生成Access Token + accessToken, err := s.GenerateToken(user) + if err != nil { + return nil, fmt.Errorf("generate access token: %w", err) + } + + // 生成Refresh Token + refreshToken, err := s.generateRefreshToken(ctx, user, familyID) + if err != nil { + return nil, fmt.Errorf("generate refresh token: %w", err) + } + + return &TokenPair{ + AccessToken: accessToken, + RefreshToken: refreshToken, + ExpiresIn: s.GetAccessTokenExpiresIn(), + }, nil +} + +// generateRefreshToken 生成并存储Refresh Token +func (s *AuthService) generateRefreshToken(ctx context.Context, user *User, familyID string) (string, error) { + // 生成随机Token + tokenBytes := make([]byte, 32) + if _, err := rand.Read(tokenBytes); err != nil { + return "", fmt.Errorf("generate random bytes: %w", err) + } + rawToken := refreshTokenPrefix + hex.EncodeToString(tokenBytes) + + // 计算Token哈希(存储哈希而非原始Token) + tokenHash := hashToken(rawToken) + + // 如果没有提供familyID,生成新的 + if familyID == "" { + familyBytes := make([]byte, 16) + if _, err := rand.Read(familyBytes); err != nil { + return "", fmt.Errorf("generate family id: %w", err) + } + familyID = hex.EncodeToString(familyBytes) + } + + now := time.Now() + ttl := time.Duration(s.cfg.JWT.RefreshTokenExpireDays) * 24 * time.Hour + + data := &RefreshTokenData{ + UserID: user.ID, + TokenVersion: user.TokenVersion, + FamilyID: familyID, + CreatedAt: now, + ExpiresAt: now.Add(ttl), + } + + // 存储Token数据 + if err := s.refreshTokenCache.StoreRefreshToken(ctx, tokenHash, data, ttl); err != nil { + return "", fmt.Errorf("store refresh token: %w", err) + } + + // 添加到用户Token集合 + if err := s.refreshTokenCache.AddToUserTokenSet(ctx, user.ID, tokenHash, ttl); err != nil { + log.Printf("[Auth] Failed to add token to user set: %v", err) + // 不影响主流程 + } + + // 添加到家族Token集合 + if err := s.refreshTokenCache.AddToFamilyTokenSet(ctx, familyID, tokenHash, ttl); err != nil { + log.Printf("[Auth] Failed to add token to family set: %v", err) + // 不影响主流程 + } + + return rawToken, nil +} + +// RefreshTokenPair 使用Refresh Token刷新Token对 +// 实现Token轮转:每次刷新都会生成新的Refresh Token,旧Token立即失效 +func (s *AuthService) RefreshTokenPair(ctx context.Context, refreshToken string) (*TokenPair, error) { + // 检查 refreshTokenCache 是否可用 + if s.refreshTokenCache == nil { + return nil, ErrRefreshTokenInvalid + } + + // 验证Token格式 + if !strings.HasPrefix(refreshToken, refreshTokenPrefix) { + return nil, ErrRefreshTokenInvalid + } + + tokenHash := hashToken(refreshToken) + + // 获取Token数据 + data, err := s.refreshTokenCache.GetRefreshToken(ctx, tokenHash) + if err != nil { + if errors.Is(err, ErrRefreshTokenNotFound) { + // Token不存在,可能是已被使用(Token轮转)或已过期 + log.Printf("[Auth] Refresh token not found, possible reuse attack") + return nil, ErrRefreshTokenInvalid + } + log.Printf("[Auth] Error getting refresh token: %v", err) + return nil, ErrServiceUnavailable + } + + // 检查Token是否过期 + if time.Now().After(data.ExpiresAt) { + // 删除过期Token + _ = s.refreshTokenCache.DeleteRefreshToken(ctx, tokenHash) + return nil, ErrRefreshTokenExpired + } + + // 获取用户信息 + user, err := s.userRepo.GetByID(ctx, data.UserID) + if err != nil { + if errors.Is(err, ErrUserNotFound) { + // 用户已删除,撤销整个Token家族 + _ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID) + return nil, ErrRefreshTokenInvalid + } + log.Printf("[Auth] Database error getting user for token refresh: %v", err) + return nil, ErrServiceUnavailable + } + + // 检查用户状态 + if !user.IsActive() { + // 用户被禁用,撤销整个Token家族 + _ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID) + return nil, ErrUserNotActive + } + + // 检查TokenVersion(密码更改后所有Token失效) + if data.TokenVersion != user.TokenVersion { + // TokenVersion不匹配,撤销整个Token家族 + _ = s.refreshTokenCache.DeleteTokenFamily(ctx, data.FamilyID) + return nil, ErrTokenRevoked + } + + // Token轮转:立即使旧Token失效 + if err := s.refreshTokenCache.DeleteRefreshToken(ctx, tokenHash); err != nil { + log.Printf("[Auth] Failed to delete old refresh token: %v", err) + // 继续处理,不影响主流程 + } + + // 生成新的Token对,保持同一个家族ID + return s.GenerateTokenPair(ctx, user, data.FamilyID) +} + +// RevokeRefreshToken 撤销单个Refresh Token +func (s *AuthService) RevokeRefreshToken(ctx context.Context, refreshToken string) error { + if s.refreshTokenCache == nil { + return nil // No-op if cache not configured + } + if !strings.HasPrefix(refreshToken, refreshTokenPrefix) { + return ErrRefreshTokenInvalid + } + + tokenHash := hashToken(refreshToken) + return s.refreshTokenCache.DeleteRefreshToken(ctx, tokenHash) +} + +// RevokeAllUserSessions 撤销用户的所有会话(所有Refresh Token) +// 用于密码更改或用户主动登出所有设备 +func (s *AuthService) RevokeAllUserSessions(ctx context.Context, userID int64) error { + if s.refreshTokenCache == nil { + return nil // No-op if cache not configured + } + return s.refreshTokenCache.DeleteUserRefreshTokens(ctx, userID) +} + +// hashToken 计算Token的SHA256哈希 +func hashToken(token string) string { + hash := sha256.Sum256([]byte(token)) + return hex.EncodeToString(hash[:]) +} diff --git a/backend/internal/service/auth_service_register_test.go b/backend/internal/service/auth_service_register_test.go index aa3c769e..f1685be5 100644 --- a/backend/internal/service/auth_service_register_test.go +++ b/backend/internal/service/auth_service_register_test.go @@ -116,6 +116,7 @@ func newAuthService(repo *userRepoStub, settings map[string]string, emailCache E return NewAuthService( repo, nil, // redeemRepo + nil, // refreshTokenCache cfg, settingService, emailService, diff --git a/backend/internal/service/refresh_token_cache.go b/backend/internal/service/refresh_token_cache.go new file mode 100644 index 00000000..91b3924f --- /dev/null +++ b/backend/internal/service/refresh_token_cache.go @@ -0,0 +1,73 @@ +package service + +import ( + "context" + "errors" + "time" +) + +// ErrRefreshTokenNotFound is returned when a refresh token is not found in cache. +// This is used to abstract away the underlying cache implementation (e.g., redis.Nil). +var ErrRefreshTokenNotFound = errors.New("refresh token not found") + +// RefreshTokenData 存储在Redis中的Refresh Token数据 +type RefreshTokenData struct { + UserID int64 `json:"user_id"` + TokenVersion int64 `json:"token_version"` // 用于检测密码更改后的Token失效 + FamilyID string `json:"family_id"` // Token家族ID,用于防重放攻击 + CreatedAt time.Time `json:"created_at"` + ExpiresAt time.Time `json:"expires_at"` +} + +// RefreshTokenCache 管理Refresh Token的Redis缓存 +// 用于JWT Token刷新机制,支持Token轮转和防重放攻击 +// +// Key 格式: +// - refresh_token:{token_hash} -> RefreshTokenData (JSON) +// - user_refresh_tokens:{user_id} -> Set +// - token_family:{family_id} -> Set +type RefreshTokenCache interface { + // StoreRefreshToken 存储Refresh Token + // tokenHash: Token的SHA256哈希值(不存储原始Token) + // data: Token关联的数据 + // ttl: Token过期时间 + StoreRefreshToken(ctx context.Context, tokenHash string, data *RefreshTokenData, ttl time.Duration) error + + // GetRefreshToken 获取Refresh Token数据 + // 返回 (data, nil) 如果Token存在 + // 返回 (nil, ErrRefreshTokenNotFound) 如果Token不存在 + // 返回 (nil, err) 如果发生其他错误 + GetRefreshToken(ctx context.Context, tokenHash string) (*RefreshTokenData, error) + + // DeleteRefreshToken 删除单个Refresh Token + // 用于Token轮转时使旧Token失效 + DeleteRefreshToken(ctx context.Context, tokenHash string) error + + // DeleteUserRefreshTokens 删除用户的所有Refresh Token + // 用于密码更改或用户主动登出所有设备 + DeleteUserRefreshTokens(ctx context.Context, userID int64) error + + // DeleteTokenFamily 删除整个Token家族 + // 用于检测到Token重放攻击时,撤销整个会话链 + DeleteTokenFamily(ctx context.Context, familyID string) error + + // AddToUserTokenSet 将Token添加到用户的Token集合 + // 用于跟踪用户的所有活跃Refresh Token + AddToUserTokenSet(ctx context.Context, userID int64, tokenHash string, ttl time.Duration) error + + // AddToFamilyTokenSet 将Token添加到家族Token集合 + // 用于跟踪同一登录会话的所有Token + AddToFamilyTokenSet(ctx context.Context, familyID string, tokenHash string, ttl time.Duration) error + + // GetUserTokenHashes 获取用户的所有Token哈希 + // 用于批量删除用户Token + GetUserTokenHashes(ctx context.Context, userID int64) ([]string, error) + + // GetFamilyTokenHashes 获取家族的所有Token哈希 + // 用于批量删除家族Token + GetFamilyTokenHashes(ctx context.Context, familyID string) ([]string, error) + + // IsTokenInFamily 检查Token是否属于指定家族 + // 用于验证Token家族关系 + IsTokenInFamily(ctx context.Context, familyID string, tokenHash string) (bool, error) +} diff --git a/frontend/src/api/auth.ts b/frontend/src/api/auth.ts index 40c9c5a4..e196e234 100644 --- a/frontend/src/api/auth.ts +++ b/frontend/src/api/auth.ts @@ -35,6 +35,22 @@ export function setAuthToken(token: string): void { localStorage.setItem('auth_token', token) } +/** + * Store refresh token in localStorage + */ +export function setRefreshToken(token: string): void { + localStorage.setItem('refresh_token', token) +} + +/** + * Store token expiration timestamp in localStorage + * Converts expires_in (seconds) to absolute timestamp (milliseconds) + */ +export function setTokenExpiresAt(expiresIn: number): void { + const expiresAt = Date.now() + expiresIn * 1000 + localStorage.setItem('token_expires_at', String(expiresAt)) +} + /** * Get authentication token from localStorage */ @@ -42,12 +58,29 @@ export function getAuthToken(): string | null { return localStorage.getItem('auth_token') } +/** + * Get refresh token from localStorage + */ +export function getRefreshToken(): string | null { + return localStorage.getItem('refresh_token') +} + +/** + * Get token expiration timestamp from localStorage + */ +export function getTokenExpiresAt(): number | null { + const value = localStorage.getItem('token_expires_at') + return value ? parseInt(value, 10) : null +} + /** * Clear authentication token from localStorage */ export function clearAuthToken(): void { localStorage.removeItem('auth_token') + localStorage.removeItem('refresh_token') localStorage.removeItem('auth_user') + localStorage.removeItem('token_expires_at') } /** @@ -61,6 +94,12 @@ export async function login(credentials: LoginRequest): Promise { // Only store token if 2FA is not required if (!isTotp2FARequired(data)) { setAuthToken(data.access_token) + if (data.refresh_token) { + setRefreshToken(data.refresh_token) + } + if (data.expires_in) { + setTokenExpiresAt(data.expires_in) + } localStorage.setItem('auth_user', JSON.stringify(data.user)) } @@ -77,6 +116,12 @@ export async function login2FA(request: TotpLogin2FARequest): Promise // Store token and user data setAuthToken(data.access_token) + if (data.refresh_token) { + setRefreshToken(data.refresh_token) + } + if (data.expires_in) { + setTokenExpiresAt(data.expires_in) + } localStorage.setItem('auth_user', JSON.stringify(data.user)) return data @@ -108,11 +159,62 @@ export async function getCurrentUser() { /** * User logout * Clears authentication token and user data from localStorage + * Optionally revokes the refresh token on the server */ -export function logout(): void { +export async function logout(): Promise { + const refreshToken = getRefreshToken() + + // Try to revoke the refresh token on the server + if (refreshToken) { + try { + await apiClient.post('/auth/logout', { refresh_token: refreshToken }) + } catch { + // Ignore errors - we still want to clear local state + } + } + clearAuthToken() - // Optionally redirect to login page - // window.location.href = '/login'; +} + +/** + * Refresh token response + */ +export interface RefreshTokenResponse { + access_token: string + refresh_token: string + expires_in: number + token_type: string +} + +/** + * Refresh the access token using the refresh token + * @returns New token pair + */ +export async function refreshToken(): Promise { + const currentRefreshToken = getRefreshToken() + if (!currentRefreshToken) { + throw new Error('No refresh token available') + } + + const { data } = await apiClient.post('/auth/refresh', { + refresh_token: currentRefreshToken + }) + + // Update tokens in localStorage + setAuthToken(data.access_token) + setRefreshToken(data.refresh_token) + setTokenExpiresAt(data.expires_in) + + return data +} + +/** + * Revoke all sessions for the current user + * @returns Response with message + */ +export async function revokeAllSessions(): Promise<{ message: string }> { + const { data } = await apiClient.post<{ message: string }>('/auth/revoke-all-sessions') + return data } /** @@ -242,14 +344,20 @@ export const authAPI = { logout, isAuthenticated, setAuthToken, + setRefreshToken, + setTokenExpiresAt, getAuthToken, + getRefreshToken, + getTokenExpiresAt, clearAuthToken, getPublicSettings, sendVerifyCode, validatePromoCode, validateInvitationCode, forgotPassword, - resetPassword + resetPassword, + refreshToken, + revokeAllSessions } export default authAPI diff --git a/frontend/src/api/client.ts b/frontend/src/api/client.ts index 3827498b..22db5a44 100644 --- a/frontend/src/api/client.ts +++ b/frontend/src/api/client.ts @@ -1,9 +1,9 @@ /** * Axios HTTP Client Configuration - * Base client with interceptors for authentication and error handling + * Base client with interceptors for authentication, token refresh, and error handling */ -import axios, { AxiosInstance, AxiosError, InternalAxiosRequestConfig } from 'axios' +import axios, { AxiosInstance, AxiosError, InternalAxiosRequestConfig, AxiosResponse } from 'axios' import type { ApiResponse } from '@/types' import { getLocale } from '@/i18n' @@ -19,6 +19,28 @@ export const apiClient: AxiosInstance = axios.create({ } }) +// ==================== Token Refresh State ==================== + +// Track if a token refresh is in progress to prevent multiple simultaneous refresh requests +let isRefreshing = false +// Queue of requests waiting for token refresh +let refreshSubscribers: Array<(token: string) => void> = [] + +/** + * Subscribe to token refresh completion + */ +function subscribeTokenRefresh(callback: (token: string) => void): void { + refreshSubscribers.push(callback) +} + +/** + * Notify all subscribers that token has been refreshed + */ +function onTokenRefreshed(token: string): void { + refreshSubscribers.forEach((callback) => callback(token)) + refreshSubscribers = [] +} + // ==================== Request Interceptor ==================== // Get user's timezone @@ -61,7 +83,7 @@ apiClient.interceptors.request.use( // ==================== Response Interceptor ==================== apiClient.interceptors.response.use( - (response) => { + (response: AxiosResponse) => { // Unwrap standard API response format { code, message, data } const apiResponse = response.data as ApiResponse if (apiResponse && typeof apiResponse === 'object' && 'code' in apiResponse) { @@ -79,13 +101,15 @@ apiClient.interceptors.response.use( } return response }, - (error: AxiosError>) => { + async (error: AxiosError>) => { // Request cancellation: keep the original axios cancellation error so callers can ignore it. // Otherwise we'd misclassify it as a generic "network error". if (error.code === 'ERR_CANCELED' || axios.isCancel(error)) { return Promise.reject(error) } + const originalRequest = error.config as InternalAxiosRequestConfig & { _retry?: boolean } + // Handle common errors if (error.response) { const { status, data } = error.response @@ -120,23 +144,116 @@ apiClient.interceptors.response.use( }) } - // 401: Unauthorized - clear token and redirect to login - if (status === 401) { - const hasToken = !!localStorage.getItem('auth_token') - const url = error.config?.url || '' + // 401: Try to refresh the token if we have a refresh token + // This handles TOKEN_EXPIRED, INVALID_TOKEN, TOKEN_REVOKED, etc. + if (status === 401 && !originalRequest._retry) { + const refreshToken = localStorage.getItem('refresh_token') const isAuthEndpoint = url.includes('/auth/login') || url.includes('/auth/register') || url.includes('/auth/refresh') + + // If we have a refresh token and this is not an auth endpoint, try to refresh + if (refreshToken && !isAuthEndpoint) { + if (isRefreshing) { + // Wait for the ongoing refresh to complete + return new Promise((resolve, reject) => { + subscribeTokenRefresh((newToken: string) => { + if (newToken) { + // Mark as retried to prevent infinite loop if retry also returns 401 + originalRequest._retry = true + if (originalRequest.headers) { + originalRequest.headers.Authorization = `Bearer ${newToken}` + } + resolve(apiClient(originalRequest)) + } else { + // Refresh failed, reject with original error + reject({ + status, + code: apiData.code, + message: apiData.message || apiData.detail || error.message + }) + } + }) + }) + } + + originalRequest._retry = true + isRefreshing = true + + try { + // Call refresh endpoint directly to avoid circular dependency + const refreshResponse = await axios.post( + `${API_BASE_URL}/auth/refresh`, + { refresh_token: refreshToken }, + { headers: { 'Content-Type': 'application/json' } } + ) + + const refreshData = refreshResponse.data as ApiResponse<{ + access_token: string + refresh_token: string + expires_in: number + }> + + if (refreshData.code === 0 && refreshData.data) { + const { access_token, refresh_token: newRefreshToken, expires_in } = refreshData.data + + // Update tokens in localStorage (convert expires_in to timestamp) + localStorage.setItem('auth_token', access_token) + localStorage.setItem('refresh_token', newRefreshToken) + localStorage.setItem('token_expires_at', String(Date.now() + expires_in * 1000)) + + // Notify subscribers with new token + onTokenRefreshed(access_token) + + // Retry the original request with new token + if (originalRequest.headers) { + originalRequest.headers.Authorization = `Bearer ${access_token}` + } + + isRefreshing = false + return apiClient(originalRequest) + } + + // Refresh response was not successful, fall through to clear auth + throw new Error('Token refresh failed') + } catch (refreshError) { + // Refresh failed - notify subscribers with empty token + onTokenRefreshed('') + isRefreshing = false + + // Clear tokens and redirect to login + localStorage.removeItem('auth_token') + localStorage.removeItem('refresh_token') + localStorage.removeItem('auth_user') + localStorage.removeItem('token_expires_at') + sessionStorage.setItem('auth_expired', '1') + + if (!window.location.pathname.includes('/login')) { + window.location.href = '/login' + } + + return Promise.reject({ + status: 401, + code: 'TOKEN_REFRESH_FAILED', + message: 'Session expired. Please log in again.' + }) + } + } + + // No refresh token or is auth endpoint - clear auth and redirect + const hasToken = !!localStorage.getItem('auth_token') const headers = error.config?.headers as Record | undefined const authHeader = headers?.Authorization ?? headers?.authorization const sentAuth = typeof authHeader === 'string' ? authHeader.trim() !== '' : Array.isArray(authHeader) - ? authHeader.length > 0 - : !!authHeader + ? authHeader.length > 0 + : !!authHeader localStorage.removeItem('auth_token') + localStorage.removeItem('refresh_token') localStorage.removeItem('auth_user') + localStorage.removeItem('token_expires_at') if ((hasToken || sentAuth) && !isAuthEndpoint) { sessionStorage.setItem('auth_expired', '1') } diff --git a/frontend/src/components/layout/AppHeader.vue b/frontend/src/components/layout/AppHeader.vue index 6b5849c0..a6b4030f 100644 --- a/frontend/src/components/layout/AppHeader.vue +++ b/frontend/src/components/layout/AppHeader.vue @@ -283,7 +283,12 @@ function closeDropdown() { async function handleLogout() { closeDropdown() - authStore.logout() + try { + await authStore.logout() + } catch (error) { + // Ignore logout errors - still redirect to login + console.error('Logout error:', error) + } await router.push('/login') } diff --git a/frontend/src/stores/auth.ts b/frontend/src/stores/auth.ts index e4612f5e..22cad50a 100644 --- a/frontend/src/stores/auth.ts +++ b/frontend/src/stores/auth.ts @@ -1,6 +1,6 @@ /** * Authentication Store - * Manages user authentication state, login/logout, and token persistence + * Manages user authentication state, login/logout, token refresh, and token persistence */ import { defineStore } from 'pinia' @@ -10,15 +10,21 @@ import type { User, LoginRequest, RegisterRequest, AuthResponse } from '@/types' const AUTH_TOKEN_KEY = 'auth_token' const AUTH_USER_KEY = 'auth_user' -const AUTO_REFRESH_INTERVAL = 60 * 1000 // 60 seconds +const REFRESH_TOKEN_KEY = 'refresh_token' +const TOKEN_EXPIRES_AT_KEY = 'token_expires_at' // 存储过期时间戳而非有效期 +const AUTO_REFRESH_INTERVAL = 60 * 1000 // 60 seconds for user data refresh +const TOKEN_REFRESH_BUFFER = 120 * 1000 // 120 seconds before expiry to refresh token export const useAuthStore = defineStore('auth', () => { // ==================== State ==================== const user = ref(null) const token = ref(null) + const refreshTokenValue = ref(null) + const tokenExpiresAt = ref(null) // 过期时间戳(毫秒) const runMode = ref<'standard' | 'simple'>('standard') let refreshIntervalId: ReturnType | null = null + let tokenRefreshTimeoutId: ReturnType | null = null // ==================== Computed ==================== @@ -42,19 +48,29 @@ export const useAuthStore = defineStore('auth', () => { function checkAuth(): void { const savedToken = localStorage.getItem(AUTH_TOKEN_KEY) const savedUser = localStorage.getItem(AUTH_USER_KEY) + const savedRefreshToken = localStorage.getItem(REFRESH_TOKEN_KEY) + const savedExpiresAt = localStorage.getItem(TOKEN_EXPIRES_AT_KEY) if (savedToken && savedUser) { try { token.value = savedToken user.value = JSON.parse(savedUser) + refreshTokenValue.value = savedRefreshToken + tokenExpiresAt.value = savedExpiresAt ? parseInt(savedExpiresAt, 10) : null // Immediately refresh user data from backend (async, don't block) refreshUser().catch((error) => { console.error('Failed to refresh user on init:', error) }) - // Start auto-refresh interval + // Start auto-refresh interval for user data startAutoRefresh() + + // Start proactive token refresh if we have refresh token and expiry info + // Note: use !== null to handle case when tokenExpiresAt.value is 0 (expired) + if (savedRefreshToken && tokenExpiresAt.value !== null) { + scheduleTokenRefreshAt(tokenExpiresAt.value) + } } catch (error) { console.error('Failed to parse saved user data:', error) clearAuth() @@ -89,6 +105,76 @@ export const useAuthStore = defineStore('auth', () => { } } + /** + * Schedule proactive token refresh before expiry (based on expiry timestamp) + * @param expiresAtMs - Token expiry timestamp in milliseconds + */ + function scheduleTokenRefreshAt(expiresAtMs: number): void { + // Clear any existing timeout + if (tokenRefreshTimeoutId) { + clearTimeout(tokenRefreshTimeoutId) + tokenRefreshTimeoutId = null + } + + // Calculate remaining time until refresh (buffer time before expiry) + const now = Date.now() + const refreshInMs = Math.max(0, expiresAtMs - now - TOKEN_REFRESH_BUFFER) + + if (refreshInMs <= 0) { + // Token is about to expire or already expired, refresh immediately + performTokenRefresh() + return + } + + tokenRefreshTimeoutId = setTimeout(() => { + performTokenRefresh() + }, refreshInMs) + } + + /** + * Schedule proactive token refresh before expiry (based on expires_in seconds) + * @param expiresInSeconds - Token expiry time in seconds from now + */ + function scheduleTokenRefresh(expiresInSeconds: number): void { + const expiresAtMs = Date.now() + expiresInSeconds * 1000 + tokenExpiresAt.value = expiresAtMs + localStorage.setItem(TOKEN_EXPIRES_AT_KEY, String(expiresAtMs)) + scheduleTokenRefreshAt(expiresAtMs) + } + + /** + * Perform the actual token refresh + */ + async function performTokenRefresh(): Promise { + if (!refreshTokenValue.value) { + return + } + + try { + const response = await authAPI.refreshToken() + + // Update state + token.value = response.access_token + refreshTokenValue.value = response.refresh_token + + // Schedule next refresh (this also updates tokenExpiresAt and localStorage) + scheduleTokenRefresh(response.expires_in) + } catch (error) { + console.error('Token refresh failed:', error) + // Don't clear auth here - the interceptor will handle 401 errors + } + } + + /** + * Stop token refresh timeout + */ + function stopTokenRefresh(): void { + if (tokenRefreshTimeoutId) { + clearTimeout(tokenRefreshTimeoutId) + tokenRefreshTimeoutId = null + } + } + /** * User login * @param credentials - Login credentials (email and password) @@ -141,6 +227,12 @@ export const useAuthStore = defineStore('auth', () => { // Store token and user token.value = response.access_token + // Store refresh token if present + if (response.refresh_token) { + refreshTokenValue.value = response.refresh_token + localStorage.setItem(REFRESH_TOKEN_KEY, response.refresh_token) + } + // Extract run_mode if present if (response.user.run_mode) { runMode.value = response.user.run_mode @@ -152,8 +244,14 @@ export const useAuthStore = defineStore('auth', () => { localStorage.setItem(AUTH_TOKEN_KEY, response.access_token) localStorage.setItem(AUTH_USER_KEY, JSON.stringify(userData)) - // Start auto-refresh interval + // Start auto-refresh interval for user data startAutoRefresh() + + // Start proactive token refresh if we have refresh token and expiry info + // scheduleTokenRefresh will also store the expiry timestamp + if (response.refresh_token && response.expires_in) { + scheduleTokenRefresh(response.expires_in) + } } /** @@ -166,24 +264,10 @@ export const useAuthStore = defineStore('auth', () => { try { const response = await authAPI.register(userData) - // Store token and user - token.value = response.access_token + // Use the common helper to set auth state + setAuthFromResponse(response) - // Extract run_mode if present - if (response.user.run_mode) { - runMode.value = response.user.run_mode - } - const { run_mode: _run_mode, ...userDataWithoutRunMode } = response.user - user.value = userDataWithoutRunMode - - // Persist to localStorage - localStorage.setItem(AUTH_TOKEN_KEY, response.access_token) - localStorage.setItem(AUTH_USER_KEY, JSON.stringify(userDataWithoutRunMode)) - - // Start auto-refresh interval - startAutoRefresh() - - return userDataWithoutRunMode + return user.value! } catch (error) { // Clear any partial state on error clearAuth() @@ -193,18 +277,41 @@ export const useAuthStore = defineStore('auth', () => { /** * 直接设置 token(用于 OAuth/SSO 回调),并加载当前用户信息。 + * 会自动读取 localStorage 中已设置的 refresh_token 和 token_expires_in * @param newToken - 后端签发的 JWT access token */ async function setToken(newToken: string): Promise { // Clear any previous state first (avoid mixing sessions) - clearAuth() + // Note: Don't clear localStorage here as OAuth callback may have set refresh_token + stopAutoRefresh() + stopTokenRefresh() + token.value = null + user.value = null token.value = newToken localStorage.setItem(AUTH_TOKEN_KEY, newToken) + // Read refresh token and expires_at from localStorage if set by OAuth callback + const savedRefreshToken = localStorage.getItem(REFRESH_TOKEN_KEY) + const savedExpiresAt = localStorage.getItem(TOKEN_EXPIRES_AT_KEY) + + if (savedRefreshToken) { + refreshTokenValue.value = savedRefreshToken + } + if (savedExpiresAt) { + tokenExpiresAt.value = parseInt(savedExpiresAt, 10) + } + try { const userData = await refreshUser() startAutoRefresh() + + // Start proactive token refresh if we have refresh token and expiry info + // Note: use !== null to handle case when tokenExpiresAt.value is 0 (expired) + if (savedRefreshToken && tokenExpiresAt.value !== null) { + scheduleTokenRefreshAt(tokenExpiresAt.value) + } + return userData } catch (error) { clearAuth() @@ -216,9 +323,9 @@ export const useAuthStore = defineStore('auth', () => { * User logout * Clears all authentication state and persisted data */ - function logout(): void { - // Call API logout (client-side cleanup) - authAPI.logout() + async function logout(): Promise { + // Call API logout (revokes refresh token on server) + await authAPI.logout() // Clear state clearAuth() @@ -263,11 +370,17 @@ export const useAuthStore = defineStore('auth', () => { function clearAuth(): void { // Stop auto-refresh stopAutoRefresh() + // Stop token refresh + stopTokenRefresh() token.value = null + refreshTokenValue.value = null + tokenExpiresAt.value = null user.value = null localStorage.removeItem(AUTH_TOKEN_KEY) localStorage.removeItem(AUTH_USER_KEY) + localStorage.removeItem(REFRESH_TOKEN_KEY) + localStorage.removeItem(TOKEN_EXPIRES_AT_KEY) } // ==================== Return Store API ==================== diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index 12449d3c..eb53de44 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -92,6 +92,8 @@ export interface PublicSettings { export interface AuthResponse { access_token: string + refresh_token?: string // New: Refresh Token for token renewal + expires_in?: number // New: Access Token expiry time in seconds token_type: string user: User & { run_mode?: 'standard' | 'simple' } } diff --git a/frontend/src/views/auth/LinuxDoCallbackView.vue b/frontend/src/views/auth/LinuxDoCallbackView.vue index c6f93e6b..4dbca1df 100644 --- a/frontend/src/views/auth/LinuxDoCallbackView.vue +++ b/frontend/src/views/auth/LinuxDoCallbackView.vue @@ -71,6 +71,8 @@ onMounted(async () => { const params = parseFragmentParams() const token = params.get('access_token') || '' + const refreshToken = params.get('refresh_token') || '' + const expiresInStr = params.get('expires_in') || '' const redirect = sanitizeRedirectPath( params.get('redirect') || (route.query.redirect as string | undefined) || '/dashboard' ) @@ -92,6 +94,17 @@ onMounted(async () => { } try { + // Store refresh token and expires_at (convert to timestamp) if provided + if (refreshToken) { + localStorage.setItem('refresh_token', refreshToken) + } + if (expiresInStr) { + const expiresIn = parseInt(expiresInStr, 10) + if (!isNaN(expiresIn)) { + localStorage.setItem('token_expires_at', String(Date.now() + expiresIn * 1000)) + } + } + await authStore.setToken(token) appStore.showSuccess(t('auth.loginSuccess')) await router.replace(redirect) From 39a0359dd5a26b5e79f94b6c83610448b3003a78 Mon Sep 17 00:00:00 2001 From: Lemon Date: Thu, 5 Feb 2026 12:48:05 +0800 Subject: [PATCH 06/14] feat: enhance HTTP/2 Cleartext (h2c) configuration options --- backend/internal/config/config.go | 34 +++++++++++++++++++++++-------- backend/internal/server/http.go | 30 ++++++++++++++++++++------- config.yaml | 29 +++++++++++++++++++++++--- deploy/.env.example | 23 ++++++++++++++++++++- deploy/config.example.yaml | 29 +++++++++++++++++++++++--- 5 files changed, 123 insertions(+), 22 deletions(-) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 0790ed06..f0eb1ceb 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -144,13 +144,24 @@ type PricingConfig struct { } type ServerConfig struct { - Host string `mapstructure:"host"` - Port int `mapstructure:"port"` - Mode string `mapstructure:"mode"` // debug/release - ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒) - IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒) - TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表(CIDR/IP) - EnableH2C bool `mapstructure:"enable_h2c"` // 启用 HTTP/2 Cleartext (h2c) + Host string `mapstructure:"host"` + Port int `mapstructure:"port"` + Mode string `mapstructure:"mode"` // debug/release + ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒) + IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒) + TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表(CIDR/IP) + MaxRequestBodySize int64 `mapstructure:"max_request_body_size"` // 全局最大请求体限制 + H2C H2CConfig `mapstructure:"h2c"` // HTTP/2 Cleartext 配置 +} + +// H2CConfig HTTP/2 Cleartext 配置 +type H2CConfig struct { + Enabled bool `mapstructure:"enabled"` // 是否启用 H2C + MaxConcurrentStreams uint32 `mapstructure:"max_concurrent_streams"` // 最大并发流数量 + IdleTimeout int `mapstructure:"idle_timeout"` // 空闲超时(秒) + MaxReadFrameSize int `mapstructure:"max_read_frame_size"` // 最大帧大小(字节) + MaxUploadBufferPerConnection int `mapstructure:"max_upload_buffer_per_connection"` // 每个连接的上传缓冲区(字节) + MaxUploadBufferPerStream int `mapstructure:"max_upload_buffer_per_stream"` // 每个流的上传缓冲区(字节) } type CORSConfig struct { @@ -688,7 +699,14 @@ func setDefaults() { viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头 viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时 viper.SetDefault("server.trusted_proxies", []string{}) - viper.SetDefault("server.enable_h2c", false) // 默认关闭 h2c + viper.SetDefault("server.max_request_body_size", int64(100*1024*1024)) + // H2C 默认配置 + viper.SetDefault("server.h2c.enabled", false) + viper.SetDefault("server.h2c.max_concurrent_streams", uint32(50)) // 50 个并发流 + viper.SetDefault("server.h2c.idle_timeout", 75) // 75 秒 + viper.SetDefault("server.h2c.max_read_frame_size", 1<<20) // 1MB(够用) + viper.SetDefault("server.h2c.max_upload_buffer_per_connection", 2<<20) // 2MB + viper.SetDefault("server.h2c.max_upload_buffer_per_stream", 512<<10) // 512KB // CORS viper.SetDefault("cors.allowed_origins", []string{}) diff --git a/backend/internal/server/http.go b/backend/internal/server/http.go index f22445b8..d2d8ed40 100644 --- a/backend/internal/server/http.go +++ b/backend/internal/server/http.go @@ -60,16 +60,32 @@ func ProvideRouter( func ProvideHTTPServer(cfg *config.Config, router *gin.Engine) *http.Server { httpHandler := http.Handler(router) + globalMaxSize := cfg.Server.MaxRequestBodySize + if globalMaxSize <= 0 { + globalMaxSize = cfg.Gateway.MaxBodySize + } + if globalMaxSize > 0 { + httpHandler = http.MaxBytesHandler(httpHandler, globalMaxSize) + log.Printf("Global max request body size: %d bytes (%.2f MB)", globalMaxSize, float64(globalMaxSize)/(1<<20)) + } + // 根据配置决定是否启用 H2C - if cfg.Server.EnableH2C { + if cfg.Server.H2C.Enabled { + h2cConfig := cfg.Server.H2C httpHandler = h2c.NewHandler(router, &http2.Server{ - MaxConcurrentStreams: 250, // 最大并发流数量 - IdleTimeout: 300 * time.Second, - MaxReadFrameSize: 4 << 20, // 4MB - MaxUploadBufferPerConnection: 8 << 20, // 8MB - MaxUploadBufferPerStream: 2 << 20, // 2MB + MaxConcurrentStreams: h2cConfig.MaxConcurrentStreams, + IdleTimeout: time.Duration(h2cConfig.IdleTimeout) * time.Second, + MaxReadFrameSize: uint32(h2cConfig.MaxReadFrameSize), + MaxUploadBufferPerConnection: int32(h2cConfig.MaxUploadBufferPerConnection), + MaxUploadBufferPerStream: int32(h2cConfig.MaxUploadBufferPerStream), }) - log.Println("HTTP/2 Cleartext (h2c) enabled") + log.Printf("HTTP/2 Cleartext (h2c) enabled: max_concurrent_streams=%d, idle_timeout=%ds, max_read_frame_size=%d, max_upload_buffer_per_connection=%d, max_upload_buffer_per_stream=%d", + h2cConfig.MaxConcurrentStreams, + h2cConfig.IdleTimeout, + h2cConfig.MaxReadFrameSize, + h2cConfig.MaxUploadBufferPerConnection, + h2cConfig.MaxUploadBufferPerStream, + ) } return &http.Server{ diff --git a/config.yaml b/config.yaml index e79500b7..1cbd8c11 100644 --- a/config.yaml +++ b/config.yaml @@ -23,9 +23,32 @@ server: # Trusted proxies for X-Forwarded-For parsing (CIDR/IP). Empty disables trusted proxies. # 信任的代理地址(CIDR/IP 格式),用于解析 X-Forwarded-For 头。留空则禁用代理信任。 trusted_proxies: [] - # Enable HTTP/2 Cleartext (h2c) for client connections - # 启用 HTTP/2 Cleartext (h2c) 客户端连接 - enable_h2c: true + # Global max request body size in bytes (default: 100MB) + # 全局最大请求体大小(字节,默认 100MB) + # Applies to all requests, especially important for h2c first request memory protection + # 适用于所有请求,对 h2c 第一请求的内存保护尤为重要 + max_request_body_size: 104857600 + # HTTP/2 Cleartext (h2c) configuration + # HTTP/2 Cleartext (h2c) 配置 + h2c: + # Enable HTTP/2 Cleartext for client connections + # 启用 HTTP/2 Cleartext 客户端连接 + enabled: true + # Max concurrent streams per connection + # 每个连接的最大并发流数量 + max_concurrent_streams: 50 + # Idle timeout for connections (seconds) + # 连接空闲超时时间(秒) + idle_timeout: 75 + # Max frame size in bytes (default: 1MB) + # 最大帧大小(字节,默认 1MB) + max_read_frame_size: 1048576 + # Max upload buffer per connection in bytes (default: 2MB) + # 每个连接的最大上传缓冲区(字节,默认 2MB) + max_upload_buffer_per_connection: 2097152 + # Max upload buffer per stream in bytes (default: 512KB) + # 每个流的最大上传缓冲区(字节,默认 512KB) + max_upload_buffer_per_stream: 524288 # ============================================================================= # Run Mode Configuration diff --git a/deploy/.env.example b/deploy/.env.example index a0c98bec..c5e850ae 100644 --- a/deploy/.env.example +++ b/deploy/.env.example @@ -20,9 +20,30 @@ SERVER_PORT=8080 # Server mode: release or debug SERVER_MODE=release +# Global max request body size in bytes (default: 100MB) +# 全局最大请求体大小(字节,默认 100MB) +# Applies to all requests, especially important for h2c first request memory protection +# 适用于所有请求,对 h2c 第一请求的内存保护尤为重要 +SERVER_MAX_REQUEST_BODY_SIZE=104857600 + # Enable HTTP/2 Cleartext (h2c) for client connections # 启用 HTTP/2 Cleartext (h2c) 客户端连接 -SERVER_ENABLE_H2C=true +SERVER_H2C_ENABLED=true +# H2C max concurrent streams (default: 50) +# H2C 最大并发流数量(默认 50) +SERVER_H2C_MAX_CONCURRENT_STREAMS=50 +# H2C idle timeout in seconds (default: 75) +# H2C 空闲超时时间(秒,默认 75) +SERVER_H2C_IDLE_TIMEOUT=75 +# H2C max read frame size in bytes (default: 1048576 = 1MB) +# H2C 最大帧大小(字节,默认 1048576 = 1MB) +SERVER_H2C_MAX_READ_FRAME_SIZE=1048576 +# H2C max upload buffer per connection in bytes (default: 2097152 = 2MB) +# H2C 每个连接的最大上传缓冲区(字节,默认 2097152 = 2MB) +SERVER_H2C_MAX_UPLOAD_BUFFER_PER_CONNECTION=2097152 +# H2C max upload buffer per stream in bytes (default: 524288 = 512KB) +# H2C 每个流的最大上传缓冲区(字节,默认 524288 = 512KB) +SERVER_H2C_MAX_UPLOAD_BUFFER_PER_STREAM=524288 # 运行模式: standard (默认) 或 simple (内部自用) # standard: 完整 SaaS 功能,包含计费/余额校验;simple: 隐藏 SaaS 功能并跳过计费/余额校验 diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 339b3d89..d9f5f2ab 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -23,9 +23,32 @@ server: # Trusted proxies for X-Forwarded-For parsing (CIDR/IP). Empty disables trusted proxies. # 信任的代理地址(CIDR/IP 格式),用于解析 X-Forwarded-For 头。留空则禁用代理信任。 trusted_proxies: [] - # Enable HTTP/2 Cleartext (h2c) for client connections - # 启用 HTTP/2 Cleartext (h2c) 客户端连接 - enable_h2c: true + # Global max request body size in bytes (default: 100MB) + # 全局最大请求体大小(字节,默认 100MB) + # Applies to all requests, especially important for h2c first request memory protection + # 适用于所有请求,对 h2c 第一请求的内存保护尤为重要 + max_request_body_size: 104857600 + # HTTP/2 Cleartext (h2c) configuration + # HTTP/2 Cleartext (h2c) 配置 + h2c: + # Enable HTTP/2 Cleartext for client connections + # 启用 HTTP/2 Cleartext 客户端连接 + enabled: true + # Max concurrent streams per connection + # 每个连接的最大并发流数量 + max_concurrent_streams: 50 + # Idle timeout for connections (seconds) + # 连接空闲超时时间(秒) + idle_timeout: 75 + # Max frame size in bytes (default: 1MB) + # 最大帧大小(字节,默认 1MB) + max_read_frame_size: 1048576 + # Max upload buffer per connection in bytes (default: 2MB) + # 每个连接的最大上传缓冲区(字节,默认 2MB) + max_upload_buffer_per_connection: 2097152 + # Max upload buffer per stream in bytes (default: 512KB) + # 每个流的最大上传缓冲区(字节,默认 512KB) + max_upload_buffer_per_stream: 524288 # ============================================================================= # Run Mode Configuration From 6d0152c8e2e32f00e88de16fefa456cf4cadb758 Mon Sep 17 00:00:00 2001 From: shaw Date: Thu, 5 Feb 2026 13:39:31 +0800 Subject: [PATCH 07/14] =?UTF-8?q?chore:=20=E7=A7=BB=E9=99=A4=E5=A4=9A?= =?UTF-8?q?=E4=BD=99=E7=9A=84=E6=96=87=E6=A1=A3/=E9=85=8D=E7=BD=AE?= =?UTF-8?q?=E7=A4=BA=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Linux DO Connect.md | 368 ----------------------------- config.yaml | 556 -------------------------------------------- 2 files changed, 924 deletions(-) delete mode 100644 Linux DO Connect.md delete mode 100644 config.yaml diff --git a/Linux DO Connect.md b/Linux DO Connect.md deleted file mode 100644 index 7ca1260f..00000000 --- a/Linux DO Connect.md +++ /dev/null @@ -1,368 +0,0 @@ -# Linux DO Connect - -OAuth(Open Authorization)是一个开放的网络授权标准,目前最新版本为 OAuth 2.0。我们日常使用的第三方登录(如 Google 账号登录)就采用了该标准。OAuth 允许用户授权第三方应用访问存储在其他服务提供商(如 Google)上的信息,无需在不同平台上重复填写注册信息。用户授权后,平台可以直接访问用户的账户信息进行身份验证,而用户无需向第三方应用提供密码。 - -目前系统已实现完整的 OAuth2 授权码(code)方式鉴权,但界面等配套功能还在持续完善中。让我们一起打造一个更完善的共享方案。 - -## 基本介绍 - -这是一套标准的 OAuth2 鉴权系统,可以让开发者共享论坛的用户基本信息。 - -- 可获取字段: - -| 参数 | 说明 | -| ----------------- | ------------------------------- | -| `id` | 用户唯一标识(不可变) | -| `username` | 论坛用户名 | -| `name` | 论坛用户昵称(可变) | -| `avatar_template` | 用户头像模板URL(支持多种尺寸) | -| `active` | 账号活跃状态 | -| `trust_level` | 信任等级(0-4) | -| `silenced` | 禁言状态 | -| `external_ids` | 外部ID关联信息 | -| `api_key` | API访问密钥 | - -通过这些信息,公益网站/接口可以实现: - -1. 基于 `id` 的服务频率限制 -2. 基于 `trust_level` 的服务额度分配 -3. 基于用户信息的滥用举报机制 - -## 相关端点 - -- Authorize 端点: `https://connect.linux.do/oauth2/authorize` -- Token 端点:`https://connect.linux.do/oauth2/token` -- 用户信息 端点:`https://connect.linux.do/api/user` - -## 申请使用 - -- 访问 [Connect.Linux.Do](https://connect.linux.do/) 申请接入你的应用。 - -![linuxdoconnect_1](https://wiki.linux.do/_next/image?url=%2Flinuxdoconnect_1.png&w=1080&q=75) - -- 点击 **`我的应用接入`** - **`申请新接入`**,填写相关信息。其中 **`回调地址`** 是你的应用接收用户信息的地址。 - -![linuxdoconnect_2](https://wiki.linux.do/_next/image?url=%2Flinuxdoconnect_2.png&w=1080&q=75) - -- 申请成功后,你将获得 **`Client Id`** 和 **`Client Secret`**,这是你应用的唯一身份凭证。 - -![linuxdoconnect_3](https://wiki.linux.do/_next/image?url=%2Flinuxdoconnect_3.png&w=1080&q=75) - -## 接入 Linux Do - -JavaScript -```JavaScript -// 安装第三方请求库(或使用原生的 Fetch API),本例中使用 axios -// npm install axios - -// 通过 OAuth2 获取 Linux Do 用户信息的参考流程 -const axios = require('axios'); -const readline = require('readline'); - -// 配置信息(建议通过环境变量配置,避免使用硬编码) -const CLIENT_ID = '你的 Client ID'; -const CLIENT_SECRET = '你的 Client Secret'; -const REDIRECT_URI = '你的回调地址'; -const AUTH_URL = 'https://connect.linux.do/oauth2/authorize'; -const TOKEN_URL = 'https://connect.linux.do/oauth2/token'; -const USER_INFO_URL = 'https://connect.linux.do/api/user'; - -// 第一步:生成授权 URL -function getAuthUrl() { - const params = new URLSearchParams({ - client_id: CLIENT_ID, - redirect_uri: REDIRECT_URI, - response_type: 'code', - scope: 'user' - }); - - return `${AUTH_URL}?${params.toString()}`; -} - -// 第二步:获取 code 参数 -function getCode() { - return new Promise((resolve) => { - // 本例中使用终端输入来模拟流程,仅供本地测试 - // 请在实际应用中替换为真实的处理逻辑 - const rl = readline.createInterface({ input: process.stdin, output: process.stdout }); - rl.question('从回调 URL 中提取出 code,粘贴到此处并按回车:', (answer) => { - rl.close(); - resolve(answer.trim()); - }); - }); -} - -// 第三步:使用 code 参数获取访问令牌 -async function getAccessToken(code) { - try { - const form = new URLSearchParams({ - client_id: CLIENT_ID, - client_secret: CLIENT_SECRET, - code: code, - redirect_uri: REDIRECT_URI, - grant_type: 'authorization_code' - }).toString(); - - const response = await axios.post(TOKEN_URL, form, { - // 提醒:需正确配置请求头,否则无法正常获取访问令牌 - headers: { - 'Content-Type': 'application/x-www-form-urlencoded', - 'Accept': 'application/json' - } - }); - - return response.data; - } catch (error) { - console.error(`获取访问令牌失败:${error.response ? JSON.stringify(error.response.data) : error.message}`); - throw error; - } -} - -// 第四步:使用访问令牌获取用户信息 -async function getUserInfo(accessToken) { - try { - const response = await axios.get(USER_INFO_URL, { - headers: { - Authorization: `Bearer ${accessToken}` - } - }); - - return response.data; - } catch (error) { - console.error(`获取用户信息失败:${error.response ? JSON.stringify(error.response.data) : error.message}`); - throw error; - } -} - -// 主流程 -async function main() { - // 1. 生成授权 URL,前端引导用户访问授权页 - const authUrl = getAuthUrl(); - console.log(`请访问此 URL 授权:${authUrl} -`); - - // 2. 用户授权后,从回调 URL 获取 code 参数 - const code = await getCode(); - - try { - // 3. 使用 code 参数获取访问令牌 - const tokenData = await getAccessToken(code); - const accessToken = tokenData.access_token; - - // 4. 使用访问令牌获取用户信息 - if (accessToken) { - const userInfo = await getUserInfo(accessToken); - console.log(` -获取用户信息成功:${JSON.stringify(userInfo, null, 2)}`); - } else { - console.log(` -获取访问令牌失败:${JSON.stringify(tokenData)}`); - } - } catch (error) { - console.error('发生错误:', error); - } -} -``` -Python -```python -# 安装第三方请求库,本例中使用 requests -# pip install requests - -# 通过 OAuth2 获取 Linux Do 用户信息的参考流程 -import requests -import json - -# 配置信息(建议通过环境变量配置,避免使用硬编码) -CLIENT_ID = '你的 Client ID' -CLIENT_SECRET = '你的 Client Secret' -REDIRECT_URI = '你的回调地址' -AUTH_URL = 'https://connect.linux.do/oauth2/authorize' -TOKEN_URL = 'https://connect.linux.do/oauth2/token' -USER_INFO_URL = 'https://connect.linux.do/api/user' - -# 第一步:生成授权 URL -def get_auth_url(): - params = { - 'client_id': CLIENT_ID, - 'redirect_uri': REDIRECT_URI, - 'response_type': 'code', - 'scope': 'user' - } - auth_url = f"{AUTH_URL}?{'&'.join(f'{k}={v}' for k, v in params.items())}" - return auth_url - -# 第二步:获取 code 参数 -def get_code(): - # 本例中使用终端输入来模拟流程,仅供本地测试 - # 请在实际应用中替换为真实的处理逻辑 - return input('从回调 URL 中提取出 code,粘贴到此处并按回车:').strip() - -# 第三步:使用 code 参数获取访问令牌 -def get_access_token(code): - try: - data = { - 'client_id': CLIENT_ID, - 'client_secret': CLIENT_SECRET, - 'code': code, - 'redirect_uri': REDIRECT_URI, - 'grant_type': 'authorization_code' - } - # 提醒:需正确配置请求头,否则无法正常获取访问令牌 - headers = { - 'Content-Type': 'application/x-www-form-urlencoded', - 'Accept': 'application/json' - } - response = requests.post(TOKEN_URL, data=data, headers=headers) - response.raise_for_status() - return response.json() - except requests.exceptions.RequestException as e: - print(f"获取访问令牌失败:{e}") - return None - -# 第四步:使用访问令牌获取用户信息 -def get_user_info(access_token): - try: - headers = { - 'Authorization': f'Bearer {access_token}' - } - response = requests.get(USER_INFO_URL, headers=headers) - response.raise_for_status() - return response.json() - except requests.exceptions.RequestException as e: - print(f"获取用户信息失败:{e}") - return None - -# 主流程 -if __name__ == '__main__': - # 1. 生成授权 URL,前端引导用户访问授权页 - auth_url = get_auth_url() - print(f'请访问此 URL 授权:{auth_url} -') - - # 2. 用户授权后,从回调 URL 获取 code 参数 - code = get_code() - - # 3. 使用 code 参数获取访问令牌 - token_data = get_access_token(code) - if token_data: - access_token = token_data.get('access_token') - - # 4. 使用访问令牌获取用户信息 - if access_token: - user_info = get_user_info(access_token) - if user_info: - print(f" -获取用户信息成功:{json.dumps(user_info, indent=2)}") - else: - print(" -获取用户信息失败") - else: - print(f" -获取访问令牌失败:{json.dumps(token_data, indent=2)}") - else: - print(" -获取访问令牌失败") -``` -PHP -```php -// 通过 OAuth2 获取 Linux Do 用户信息的参考流程 - -// 配置信息 -$CLIENT_ID = '你的 Client ID'; -$CLIENT_SECRET = '你的 Client Secret'; -$REDIRECT_URI = '你的回调地址'; -$AUTH_URL = 'https://connect.linux.do/oauth2/authorize'; -$TOKEN_URL = 'https://connect.linux.do/oauth2/token'; -$USER_INFO_URL = 'https://connect.linux.do/api/user'; - -// 生成授权 URL -function getAuthUrl($clientId, $redirectUri) { - global $AUTH_URL; - return $AUTH_URL . '?' . http_build_query([ - 'client_id' => $clientId, - 'redirect_uri' => $redirectUri, - 'response_type' => 'code', - 'scope' => 'user' - ]); -} - -// 使用 code 参数获取用户信息(合并获取令牌和获取用户信息的步骤) -function getUserInfoWithCode($code, $clientId, $clientSecret, $redirectUri) { - global $TOKEN_URL, $USER_INFO_URL; - - // 1. 获取访问令牌 - $ch = curl_init($TOKEN_URL); - curl_setopt($ch, CURLOPT_RETURNTRANSFER, true); - curl_setopt($ch, CURLOPT_POST, true); - curl_setopt($ch, CURLOPT_POSTFIELDS, http_build_query([ - 'client_id' => $clientId, - 'client_secret' => $clientSecret, - 'code' => $code, - 'redirect_uri' => $redirectUri, - 'grant_type' => 'authorization_code' - ])); - curl_setopt($ch, CURLOPT_HTTPHEADER, [ - 'Content-Type: application/x-www-form-urlencoded', - 'Accept: application/json' - ]); - - $tokenResponse = curl_exec($ch); - curl_close($ch); - - $tokenData = json_decode($tokenResponse, true); - if (!isset($tokenData['access_token'])) { - return ['error' => '获取访问令牌失败', 'details' => $tokenData]; - } - - // 2. 获取用户信息 - $ch = curl_init($USER_INFO_URL); - curl_setopt($ch, CURLOPT_RETURNTRANSFER, true); - curl_setopt($ch, CURLOPT_HTTPHEADER, [ - 'Authorization: Bearer ' . $tokenData['access_token'] - ]); - - $userResponse = curl_exec($ch); - curl_close($ch); - - return json_decode($userResponse, true); -} - -// 主流程 -// 1. 生成授权 URL -$authUrl = getAuthUrl($CLIENT_ID, $REDIRECT_URI); -echo "使用 Linux Do 登录"; - -// 2. 处理回调并获取用户信息 -if (isset($_GET['code'])) { - $userInfo = getUserInfoWithCode( - $_GET['code'], - $CLIENT_ID, - $CLIENT_SECRET, - $REDIRECT_URI - ); - - if (isset($userInfo['error'])) { - echo '错误: ' . $userInfo['error']; - } else { - echo '欢迎, ' . $userInfo['name'] . '!'; - // 处理用户登录逻辑... - } -} -``` - -## 使用说明 - -### 授权流程 - -1. 用户点击应用中的’使用 Linux Do 登录’按钮 -2. 系统将用户重定向至 Linux Do 的授权页面 -3. 用户完成授权后,系统自动重定向回应用并携带授权码 -4. 应用使用授权码获取访问令牌 -5. 使用访问令牌获取用户信息 - -### 安全建议 - -- 切勿在前端代码中暴露 Client Secret -- 对所有用户输入数据进行严格验证 -- 确保使用 HTTPS 协议传输数据 -- 定期更新并妥善保管 Client Secret \ No newline at end of file diff --git a/config.yaml b/config.yaml deleted file mode 100644 index 1cbd8c11..00000000 --- a/config.yaml +++ /dev/null @@ -1,556 +0,0 @@ -# Sub2API Configuration File -# Sub2API 配置文件 -# -# Copy this file to /etc/sub2api/config.yaml and modify as needed -# 复制此文件到 /etc/sub2api/config.yaml 并根据需要修改 -# -# Documentation / 文档: https://github.com/Wei-Shaw/sub2api - -# ============================================================================= -# Server Configuration -# 服务器配置 -# ============================================================================= -server: - # Bind address (0.0.0.0 for all interfaces) - # 绑定地址(0.0.0.0 表示监听所有网络接口) - host: "0.0.0.0" - # Port to listen on - # 监听端口 - port: 8080 - # Mode: "debug" for development, "release" for production - # 运行模式:"debug" 用于开发,"release" 用于生产环境 - mode: "release" - # Trusted proxies for X-Forwarded-For parsing (CIDR/IP). Empty disables trusted proxies. - # 信任的代理地址(CIDR/IP 格式),用于解析 X-Forwarded-For 头。留空则禁用代理信任。 - trusted_proxies: [] - # Global max request body size in bytes (default: 100MB) - # 全局最大请求体大小(字节,默认 100MB) - # Applies to all requests, especially important for h2c first request memory protection - # 适用于所有请求,对 h2c 第一请求的内存保护尤为重要 - max_request_body_size: 104857600 - # HTTP/2 Cleartext (h2c) configuration - # HTTP/2 Cleartext (h2c) 配置 - h2c: - # Enable HTTP/2 Cleartext for client connections - # 启用 HTTP/2 Cleartext 客户端连接 - enabled: true - # Max concurrent streams per connection - # 每个连接的最大并发流数量 - max_concurrent_streams: 50 - # Idle timeout for connections (seconds) - # 连接空闲超时时间(秒) - idle_timeout: 75 - # Max frame size in bytes (default: 1MB) - # 最大帧大小(字节,默认 1MB) - max_read_frame_size: 1048576 - # Max upload buffer per connection in bytes (default: 2MB) - # 每个连接的最大上传缓冲区(字节,默认 2MB) - max_upload_buffer_per_connection: 2097152 - # Max upload buffer per stream in bytes (default: 512KB) - # 每个流的最大上传缓冲区(字节,默认 512KB) - max_upload_buffer_per_stream: 524288 - -# ============================================================================= -# Run Mode Configuration -# 运行模式配置 -# ============================================================================= -# Run mode: "standard" (default) or "simple" (for internal use) -# 运行模式:"standard"(默认)或 "simple"(内部使用) -# - standard: Full SaaS features with billing/balance checks -# - standard: 完整 SaaS 功能,包含计费和余额校验 -# - simple: Hides SaaS features and skips billing/balance checks -# - simple: 隐藏 SaaS 功能,跳过计费和余额校验 -run_mode: "standard" - -# ============================================================================= -# CORS Configuration -# 跨域资源共享 (CORS) 配置 -# ============================================================================= -cors: - # Allowed origins list. Leave empty to disable cross-origin requests. - # 允许的来源列表。留空则禁用跨域请求。 - allowed_origins: [] - # Allow credentials (cookies/authorization headers). Cannot be used with "*". - # 允许携带凭证(cookies/授权头)。不能与 "*" 通配符同时使用。 - allow_credentials: true - -# ============================================================================= -# Security Configuration -# 安全配置 -# ============================================================================= -security: - url_allowlist: - # Enable URL allowlist validation (disable to skip all URL checks) - # 启用 URL 白名单验证(禁用则跳过所有 URL 检查) - enabled: false - # Allowed upstream hosts for API proxying - # 允许代理的上游 API 主机列表 - upstream_hosts: - - "api.openai.com" - - "api.anthropic.com" - - "api.kimi.com" - - "open.bigmodel.cn" - - "api.minimaxi.com" - - "generativelanguage.googleapis.com" - - "cloudcode-pa.googleapis.com" - - "*.openai.azure.com" - # Allowed hosts for pricing data download - # 允许下载定价数据的主机列表 - pricing_hosts: - - "raw.githubusercontent.com" - # Allowed hosts for CRS sync (required when using CRS sync) - # 允许 CRS 同步的主机列表(使用 CRS 同步功能时必须配置) - crs_hosts: [] - # Allow localhost/private IPs for upstream/pricing/CRS (use only in trusted networks) - # 允许本地/私有 IP 地址用于上游/定价/CRS(仅在可信网络中使用) - allow_private_hosts: true - # Allow http:// URLs when allowlist is disabled (default: false, require https) - # 白名单禁用时是否允许 http:// URL(默认: false,要求 https) - allow_insecure_http: true - response_headers: - # Enable configurable response header filtering (disable to use default allowlist) - # 启用可配置的响应头过滤(禁用则使用默认白名单) - enabled: false - # Extra allowed response headers from upstream - # 额外允许的上游响应头 - additional_allowed: [] - # Force-remove response headers from upstream - # 强制移除的上游响应头 - force_remove: [] - csp: - # Enable Content-Security-Policy header - # 启用内容安全策略 (CSP) 响应头 - enabled: true - # Default CSP policy (override if you host assets on other domains) - # 默认 CSP 策略(如果静态资源托管在其他域名,请自行覆盖) - policy: "default-src 'self'; script-src 'self'; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-ancestors 'none'; base-uri 'self'; form-action 'self'" - proxy_probe: - # Allow skipping TLS verification for proxy probe (debug only) - # 允许代理探测时跳过 TLS 证书验证(仅用于调试) - insecure_skip_verify: false - -# ============================================================================= -# Gateway Configuration -# 网关配置 -# ============================================================================= -gateway: - # Timeout for waiting upstream response headers (seconds) - # 等待上游响应头超时时间(秒) - response_header_timeout: 600 - # Max request body size in bytes (default: 100MB) - # 请求体最大字节数(默认 100MB) - max_body_size: 104857600 - # Connection pool isolation strategy: - # 连接池隔离策略: - # - proxy: Isolate by proxy, same proxy shares connection pool (suitable for few proxies, many accounts) - # - proxy: 按代理隔离,同一代理共享连接池(适合代理少、账户多) - # - account: Isolate by account, same account shares connection pool (suitable for few accounts, strict isolation) - # - account: 按账户隔离,同一账户共享连接池(适合账户少、需严格隔离) - # - account_proxy: Isolate by account+proxy combination (default, finest granularity) - # - account_proxy: 按账户+代理组合隔离(默认,最细粒度) - connection_pool_isolation: "account_proxy" - # HTTP upstream connection pool settings (HTTP/2 + multi-proxy scenario defaults) - # HTTP 上游连接池配置(HTTP/2 + 多代理场景默认值) - # Max idle connections across all hosts - # 所有主机的最大空闲连接数 - max_idle_conns: 240 - # Max idle connections per host - # 每个主机的最大空闲连接数 - max_idle_conns_per_host: 120 - # Max connections per host - # 每个主机的最大连接数 - max_conns_per_host: 240 - # Idle connection timeout (seconds) - # 空闲连接超时时间(秒) - idle_conn_timeout_seconds: 90 - # Upstream client cache settings - # 上游连接池客户端缓存配置 - # max_upstream_clients: Max cached clients, evicts least recently used when exceeded - # max_upstream_clients: 最大缓存客户端数量,超出后淘汰最久未使用的 - max_upstream_clients: 5000 - # client_idle_ttl_seconds: Client idle reclaim threshold (seconds), reclaimed when idle and no active requests - # client_idle_ttl_seconds: 客户端空闲回收阈值(秒),超时且无活跃请求时回收 - client_idle_ttl_seconds: 900 - # Concurrency slot expiration time (minutes) - # 并发槽位过期时间(分钟) - concurrency_slot_ttl_minutes: 30 - # Stream data interval timeout (seconds), 0=disable - # 流数据间隔超时(秒),0=禁用 - stream_data_interval_timeout: 180 - # Stream keepalive interval (seconds), 0=disable - # 流式 keepalive 间隔(秒),0=禁用 - stream_keepalive_interval: 10 - # SSE max line size in bytes (default: 40MB) - # SSE 单行最大字节数(默认 40MB) - max_line_size: 41943040 - # Log upstream error response body summary (safe/truncated; does not log request content) - # 记录上游错误响应体摘要(安全/截断;不记录请求内容) - log_upstream_error_body: true - # Max bytes to log from upstream error body - # 记录上游错误响应体的最大字节数 - log_upstream_error_body_max_bytes: 2048 - # Auto inject anthropic-beta header for API-key accounts when needed (default: off) - # 需要时自动为 API-key 账户注入 anthropic-beta 头(默认:关闭) - inject_beta_for_apikey: false - # Allow failover on selected 400 errors (default: off) - # 允许在特定 400 错误时进行故障转移(默认:关闭) - failover_on_400: false - -# ============================================================================= -# API Key Auth Cache Configuration -# API Key 认证缓存配置 -# ============================================================================= -api_key_auth_cache: - # L1 cache size (entries), in-process LRU/TTL cache - # L1 缓存容量(条目数),进程内 LRU/TTL 缓存 - l1_size: 65535 - # L1 cache TTL (seconds) - # L1 缓存 TTL(秒) - l1_ttl_seconds: 15 - # L2 cache TTL (seconds), stored in Redis - # L2 缓存 TTL(秒),Redis 中存储 - l2_ttl_seconds: 300 - # Negative cache TTL (seconds) - # 负缓存 TTL(秒) - negative_ttl_seconds: 30 - # TTL jitter percent (0-100) - # TTL 抖动百分比(0-100) - jitter_percent: 10 - # Enable singleflight for cache misses - # 缓存未命中时启用 singleflight 合并回源 - singleflight: true - -# ============================================================================= -# Dashboard Cache Configuration -# 仪表盘缓存配置 -# ============================================================================= -dashboard_cache: - # Enable dashboard cache - # 启用仪表盘缓存 - enabled: true - # Redis key prefix for multi-environment isolation - # Redis key 前缀,用于多环境隔离 - key_prefix: "sub2api:" - # Fresh TTL (seconds); within this window cached stats are considered fresh - # 新鲜阈值(秒);命中后处于该窗口视为新鲜数据 - stats_fresh_ttl_seconds: 15 - # Cache TTL (seconds) stored in Redis - # Redis 缓存 TTL(秒) - stats_ttl_seconds: 30 - # Async refresh timeout (seconds) - # 异步刷新超时(秒) - stats_refresh_timeout_seconds: 30 - -# ============================================================================= -# Dashboard Aggregation Configuration -# 仪表盘预聚合配置(重启生效) -# ============================================================================= -dashboard_aggregation: - # Enable aggregation job - # 启用聚合作业 - enabled: true - # Refresh interval (seconds) - # 刷新间隔(秒) - interval_seconds: 60 - # Lookback window (seconds) for late-arriving data - # 回看窗口(秒),处理迟到数据 - lookback_seconds: 120 - # Allow manual backfill - # 允许手动回填 - backfill_enabled: false - # Backfill max range (days) - # 回填最大跨度(天) - backfill_max_days: 31 - # Recompute recent N days on startup - # 启动时重算最近 N 天 - recompute_days: 2 - # Retention windows (days) - # 保留窗口(天) - retention: - # Raw usage_logs retention - # 原始 usage_logs 保留天数 - usage_logs_days: 90 - # Hourly aggregation retention - # 小时聚合保留天数 - hourly_days: 180 - # Daily aggregation retention - # 日聚合保留天数 - daily_days: 730 - -# ============================================================================= -# Usage Cleanup Task Configuration -# 使用记录清理任务配置(重启生效) -# ============================================================================= -usage_cleanup: - # Enable cleanup task worker - # 启用清理任务执行器 - enabled: true - # Max date range (days) per task - # 单次任务最大时间跨度(天) - max_range_days: 31 - # Batch delete size - # 单批删除数量 - batch_size: 5000 - # Worker interval (seconds) - # 执行器轮询间隔(秒) - worker_interval_seconds: 10 - # Task execution timeout (seconds) - # 单次任务最大执行时长(秒) - task_timeout_seconds: 1800 - -# ============================================================================= -# Concurrency Wait Configuration -# 并发等待配置 -# ============================================================================= -concurrency: - # SSE ping interval during concurrency wait (seconds) - # 并发等待期间的 SSE ping 间隔(秒) - ping_interval: 10 - -# ============================================================================= -# Database Configuration (PostgreSQL) -# 数据库配置 (PostgreSQL) -# ============================================================================= -database: - # Database host address - # 数据库主机地址 - host: "localhost" - # Database port - # 数据库端口 - port: 5432 - # Database username - # 数据库用户名 - user: "postgres" - # Database password - # 数据库密码 - password: "your_secure_password_here" - # Database name - # 数据库名称 - dbname: "sub2api" - # SSL mode: disable, require, verify-ca, verify-full - # SSL 模式:disable(禁用), require(要求), verify-ca(验证CA), verify-full(完全验证) - sslmode: "disable" - -# ============================================================================= -# Redis Configuration -# Redis 配置 -# ============================================================================= -redis: - # Redis host address - # Redis 主机地址 - host: "localhost" - # Redis port - # Redis 端口 - port: 6379 - # Redis password (leave empty if no password is set) - # Redis 密码(如果未设置密码则留空) - password: "" - # Database number (0-15) - # 数据库编号(0-15) - db: 0 - # Enable TLS/SSL connection - # 是否启用 TLS/SSL 连接 - enable_tls: false - -# ============================================================================= -# Ops Monitoring (Optional) -# 运维监控 (可选) -# ============================================================================= -ops: - # Hard switch: disable all ops background jobs and APIs when false - # 硬开关:为 false 时禁用所有 Ops 后台任务与接口 - enabled: true - - # Prefer pre-aggregated tables (ops_metrics_hourly/ops_metrics_daily) for long-window dashboard queries. - # 优先使用预聚合表(用于长时间窗口查询性能) - use_preaggregated_tables: false - - # Data cleanup configuration - # 数据清理配置(vNext 默认统一保留 30 天) - cleanup: - enabled: true - # Cron expression (minute hour dom month dow), e.g. "0 2 * * *" = daily at 2 AM - # Cron 表达式(分 时 日 月 周),例如 "0 2 * * *" = 每天凌晨 2 点 - schedule: "0 2 * * *" - error_log_retention_days: 30 - minute_metrics_retention_days: 30 - hourly_metrics_retention_days: 30 - - # Pre-aggregation configuration - # 预聚合任务配置 - aggregation: - enabled: true - - # OpsMetricsCollector Redis cache (reduces duplicate expensive window aggregation in multi-replica deployments) - # 指标采集 Redis 缓存(多副本部署时减少重复计算) - metrics_collector_cache: - enabled: true - ttl: 65s - -# ============================================================================= -# JWT Configuration -# JWT 配置 -# ============================================================================= -jwt: - # IMPORTANT: Change this to a random string in production! - # 重要:生产环境中请更改为随机字符串! - # Generate with / 生成命令: openssl rand -hex 32 - secret: "change-this-to-a-secure-random-string" - # Token expiration time in hours (max 24) - # 令牌过期时间(小时,最大 24) - expire_hour: 24 - -# ============================================================================= -# Default Settings -# 默认设置 -# ============================================================================= -default: - # Initial admin account (created on first run) - # 初始管理员账户(首次运行时创建) - admin_email: "admin@example.com" - admin_password: "admin123" - - # Default settings for new users - # 新用户默认设置 - # Max concurrent requests per user - # 每用户最大并发请求数 - user_concurrency: 5 - # Initial balance for new users - # 新用户初始余额 - user_balance: 0 - - # API key settings - # API 密钥设置 - # Prefix for generated API keys - # 生成的 API 密钥前缀 - api_key_prefix: "sk-" - - # Rate multiplier (affects billing calculation) - # 费率倍数(影响计费计算) - rate_multiplier: 1.0 - -# ============================================================================= -# Rate Limiting -# 速率限制 -# ============================================================================= -rate_limit: - # Cooldown time (in minutes) when upstream returns 529 (overloaded) - # 上游返回 529(过载)时的冷却时间(分钟) - overload_cooldown_minutes: 10 - -# ============================================================================= -# Pricing Data Source (Optional) -# 定价数据源(可选) -# ============================================================================= -pricing: - # URL to fetch model pricing data (default: LiteLLM) - # 获取模型定价数据的 URL(默认:LiteLLM) - remote_url: "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json" - # Hash verification URL (optional) - # 哈希校验 URL(可选) - hash_url: "" - # Local data directory for caching - # 本地数据缓存目录 - data_dir: "./data" - # Fallback pricing file - # 备用定价文件 - fallback_file: "./resources/model-pricing/model_prices_and_context_window.json" - # Update interval in hours - # 更新间隔(小时) - update_interval_hours: 24 - # Hash check interval in minutes - # 哈希检查间隔(分钟) - hash_check_interval_minutes: 10 - -# ============================================================================= -# Billing Configuration -# 计费配置 -# ============================================================================= -billing: - circuit_breaker: - # Enable circuit breaker for billing service - # 启用计费服务熔断器 - enabled: true - # Number of failures before opening circuit - # 触发熔断的失败次数阈值 - failure_threshold: 5 - # Time to wait before attempting reset (seconds) - # 熔断后重试等待时间(秒) - reset_timeout_seconds: 30 - # Number of requests to allow in half-open state - # 半开状态允许通过的请求数 - half_open_requests: 3 - -# ============================================================================= -# Turnstile Configuration -# Turnstile 人机验证配置 -# ============================================================================= -turnstile: - # Require Turnstile in release mode (when enabled, login/register will fail if not configured) - # 在 release 模式下要求 Turnstile 验证(启用后,若未配置则登录/注册会失败) - required: false - -# ============================================================================= -# Gemini OAuth (Required for Gemini accounts) -# Gemini OAuth 配置(Gemini 账户必需) -# ============================================================================= -# Sub2API supports TWO Gemini OAuth modes: -# Sub2API 支持两种 Gemini OAuth 模式: -# -# 1. Code Assist OAuth (requires GCP project_id) -# 1. Code Assist OAuth(需要 GCP project_id) -# - Uses: cloudcode-pa.googleapis.com (Code Assist API) -# - 使用:cloudcode-pa.googleapis.com(Code Assist API) -# -# 2. AI Studio OAuth (no project_id needed) -# 2. AI Studio OAuth(不需要 project_id) -# - Uses: generativelanguage.googleapis.com (AI Studio API) -# - 使用:generativelanguage.googleapis.com(AI Studio API) -# -# Default: Uses Gemini CLI's public OAuth credentials (same as Google's official CLI tool) -# 默认:使用 Gemini CLI 的公开 OAuth 凭证(与 Google 官方 CLI 工具相同) -gemini: - oauth: - # Gemini CLI public OAuth credentials (works for both Code Assist and AI Studio) - # Gemini CLI 公开 OAuth 凭证(适用于 Code Assist 和 AI Studio) - client_id: "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com" - client_secret: "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" - # Optional scopes (space-separated). Leave empty to auto-select based on oauth_type. - # 可选的权限范围(空格分隔)。留空则根据 oauth_type 自动选择。 - scopes: "" - quota: - # Optional: local quota simulation for Gemini Code Assist (local billing). - # 可选:Gemini Code Assist 本地配额模拟(本地计费)。 - # These values are used for UI progress + precheck scheduling, not official Google quotas. - # 这些值用于 UI 进度显示和预检调度,并非 Google 官方配额。 - tiers: - LEGACY: - # Pro model requests per day - # Pro 模型每日请求数 - pro_rpd: 50 - # Flash model requests per day - # Flash 模型每日请求数 - flash_rpd: 1500 - # Cooldown time (minutes) after hitting quota - # 达到配额后的冷却时间(分钟) - cooldown_minutes: 30 - PRO: - # Pro model requests per day - # Pro 模型每日请求数 - pro_rpd: 1500 - # Flash model requests per day - # Flash 模型每日请求数 - flash_rpd: 4000 - # Cooldown time (minutes) after hitting quota - # 达到配额后的冷却时间(分钟) - cooldown_minutes: 5 - ULTRA: - # Pro model requests per day - # Pro 模型每日请求数 - pro_rpd: 2000 - # Flash model requests per day (0 = unlimited) - # Flash 模型每日请求数(0 = 无限制) - flash_rpd: 0 - # Cooldown time (minutes) after hitting quota - # 达到配额后的冷却时间(分钟) - cooldown_minutes: 5 From 979114db45fe817a2741d63f971c2287533b0bc0 Mon Sep 17 00:00:00 2001 From: IanShaw027 <131567472+IanShaw027@users.noreply.github.com> Date: Thu, 5 Feb 2026 13:57:02 +0800 Subject: [PATCH 08/14] =?UTF-8?q?fix(gemini):=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E5=B7=B2=E6=B3=A8=E5=86=8C=E7=94=A8=E6=88=B7=20OAuth=20?= =?UTF-8?q?=E6=8E=88=E6=9D=83=E6=97=B6=E9=94=99=E8=AF=AF=E8=B0=83=E7=94=A8?= =?UTF-8?q?=20onboardUser=20=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 问题:Google One Ultra 等已注册用户在 OAuth 授权时,如果 LoadCodeAssist 返回了 currentTier/paidTier 但没有返回 cloudaicompanionProject,之前的 逻辑会继续调用 onboardUser,导致 INVALID_ARGUMENT 错误。 修复:对齐 Gemini CLI 的处理逻辑: - 当检测到用户已注册(有 currentTier/paidTier)时,不再调用 onboardUser - 先尝试从 Cloud Resource Manager 获取可用项目 - 如果仍无法获取,返回友好的错误提示,引导用户手动填写 Project ID 这个修复解决了 Google One 订阅用户无法正常授权的问题。 --- .../internal/service/gemini_oauth_service.go | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go index bc84baeb..fd2932e6 100644 --- a/backend/internal/service/gemini_oauth_service.go +++ b/backend/internal/service/gemini_oauth_service.go @@ -944,6 +944,32 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr return strings.TrimSpace(loadResp.CloudAICompanionProject), tierID, nil } + // 关键逻辑:对齐 Gemini CLI 对“已注册用户”的处理方式。 + // 当 LoadCodeAssist 返回了 currentTier / paidTier(表示账号已注册)但没有返回 cloudaicompanionProject 时: + // - 不要再调用 onboardUser(通常不会再分配 project_id,且可能触发 INVALID_ARGUMENT) + // - 先尝试从 Cloud Resource Manager 获取可用项目;仍失败则提示用户手动填写 project_id + if loadResp != nil { + registeredTierID := strings.TrimSpace(loadResp.GetTier()) + if registeredTierID != "" { + // 已注册但未返回 cloudaicompanionProject,这在 Google One 用户中较常见:需要用户自行提供 project_id。 + log.Printf("[GeminiOAuth] User has tier (%s) but no cloudaicompanionProject, trying Cloud Resource Manager...", registeredTierID) + + // Try to get project from Cloud Resource Manager + fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) + if fbErr == nil && strings.TrimSpace(fallback) != "" { + log.Printf("[GeminiOAuth] Found project from Cloud Resource Manager: %s", fallback) + return strings.TrimSpace(fallback), tierID, nil + } + + // No project found - user must provide project_id manually + log.Printf("[GeminiOAuth] No project found from Cloud Resource Manager, user must provide project_id manually") + return "", tierID, fmt.Errorf("user is registered (tier: %s) but no project_id available. Please provide Project ID manually in the authorization form, or create a project at https://console.cloud.google.com", registeredTierID) + } + } + + // 未检测到 currentTier/paidTier,视为新用户,继续调用 onboardUser + log.Printf("[GeminiOAuth] No currentTier/paidTier found, proceeding with onboardUser (tierID: %s)", tierID) + req := &geminicli.OnboardUserRequest{ TierID: tierID, Metadata: geminicli.LoadCodeAssistMetadata{ From 2b192f7dcab70187999c0e04743ff33e024e37f7 Mon Sep 17 00:00:00 2001 From: shaw Date: Thu, 5 Feb 2026 16:00:34 +0800 Subject: [PATCH 09/14] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E7=94=A8?= =?UTF-8?q?=E6=88=B7=E4=B8=93=E5=B1=9E=E5=88=86=E7=BB=84=E5=80=8D=E7=8E=87?= =?UTF-8?q?=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/cmd/server/wire_gen.go | 7 +- backend/go.sum | 8 + .../internal/handler/admin/user_handler.go | 4 + backend/internal/handler/api_key_handler.go | 18 + backend/internal/handler/dto/mappers.go | 5 +- backend/internal/handler/dto/types.go | 3 + .../repository/user_group_rate_repo.go | 113 ++++++ backend/internal/repository/wire.go | 1 + backend/internal/server/api_contract_test.go | 4 +- .../middleware/api_key_auth_google_test.go | 2 + .../server/middleware/api_key_auth_test.go | 8 +- backend/internal/server/routes/user.go | 1 + backend/internal/service/admin_service.go | 41 ++- backend/internal/service/api_key_service.go | 46 ++- .../service/api_key_service_cache_test.go | 18 +- backend/internal/service/gateway_service.go | 21 +- backend/internal/service/user.go | 26 +- backend/internal/service/user_group_rate.go | 25 ++ .../047_add_user_group_rate_multipliers.sql | 19 + frontend/src/api/groups.ts | 12 +- .../admin/user/UserAllowedGroupsModal.vue | 337 ++++++++++++++++-- frontend/src/components/common/GroupBadge.vue | 29 +- .../src/components/common/GroupOptionItem.vue | 5 +- frontend/src/i18n/locales/en.ts | 10 + frontend/src/i18n/locales/zh.ts | 10 + frontend/src/types/index.ts | 5 + frontend/src/views/user/KeysView.vue | 16 + 27 files changed, 705 insertions(+), 89 deletions(-) create mode 100644 backend/internal/repository/user_group_rate_repo.go create mode 100644 backend/internal/service/user_group_rate.go create mode 100644 backend/migrations/047_add_user_group_rate_multipliers.sql diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 47b1e8ac..3ca86f91 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -59,8 +59,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig) apiKeyRepository := repository.NewAPIKeyRepository(client) groupRepository := repository.NewGroupRepository(client, db) + userGroupRateRepository := repository.NewUserGroupRateRepository(db) apiKeyCache := repository.NewAPIKeyCache(redisClient) - apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig) + apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) @@ -100,7 +101,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { proxyRepository := repository.NewProxyRepository(client, db) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) - adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator) + adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator) adminUserHandler := admin.NewUserHandler(adminService) groupHandler := admin.NewGroupHandler(adminService) claudeOAuthClient := repository.NewClaudeOAuthClient() @@ -153,7 +154,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { identityService := service.NewIdentityService(identityCache) deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService) - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache) + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache) openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) diff --git a/backend/go.sum b/backend/go.sum index 171995c7..3000eb38 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -170,6 +170,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= +github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= @@ -203,6 +205,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= +github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= @@ -230,6 +234,8 @@ github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkr github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= @@ -252,6 +258,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= +github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index ac76689d..1c772e7d 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -45,6 +45,9 @@ type UpdateUserRequest struct { Concurrency *int `json:"concurrency"` Status string `json:"status" binding:"omitempty,oneof=active disabled"` AllowedGroups *[]int64 `json:"allowed_groups"` + // GroupRates 用户专属分组倍率配置 + // map[groupID]*rate,nil 表示删除该分组的专属倍率 + GroupRates map[int64]*float64 `json:"group_rates"` } // UpdateBalanceRequest represents balance update request @@ -183,6 +186,7 @@ func (h *UserHandler) Update(c *gin.Context) { Concurrency: req.Concurrency, Status: req.Status, AllowedGroups: req.AllowedGroups, + GroupRates: req.GroupRates, }) if err != nil { response.ErrorFrom(c, err) diff --git a/backend/internal/handler/api_key_handler.go b/backend/internal/handler/api_key_handler.go index 9717194b..f1a18ad2 100644 --- a/backend/internal/handler/api_key_handler.go +++ b/backend/internal/handler/api_key_handler.go @@ -243,3 +243,21 @@ func (h *APIKeyHandler) GetAvailableGroups(c *gin.Context) { } response.Success(c, out) } + +// GetUserGroupRates 获取当前用户的专属分组倍率配置 +// GET /api/v1/groups/rates +func (h *APIKeyHandler) GetUserGroupRates(c *gin.Context) { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "User not authenticated") + return + } + + rates, err := h.apiKeyService.GetUserGroupRates(c.Request.Context(), subject.UserID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, rates) +} diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 4f8d1eeb..da0e9fc6 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -58,8 +58,9 @@ func UserFromServiceAdmin(u *service.User) *AdminUser { return nil } return &AdminUser{ - User: *base, - Notes: u.Notes, + User: *base, + Notes: u.Notes, + GroupRates: u.GroupRates, } } diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 8e6faf02..71bb1ed4 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -29,6 +29,9 @@ type AdminUser struct { User Notes string `json:"notes"` + // GroupRates 用户专属分组倍率配置 + // map[groupID]rateMultiplier + GroupRates map[int64]float64 `json:"group_rates,omitempty"` } type APIKey struct { diff --git a/backend/internal/repository/user_group_rate_repo.go b/backend/internal/repository/user_group_rate_repo.go new file mode 100644 index 00000000..eb65403b --- /dev/null +++ b/backend/internal/repository/user_group_rate_repo.go @@ -0,0 +1,113 @@ +package repository + +import ( + "context" + "database/sql" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +type userGroupRateRepository struct { + sql sqlExecutor +} + +// NewUserGroupRateRepository 创建用户专属分组倍率仓储 +func NewUserGroupRateRepository(sqlDB *sql.DB) service.UserGroupRateRepository { + return &userGroupRateRepository{sql: sqlDB} +} + +// GetByUserID 获取用户的所有专属分组倍率 +func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) { + query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1` + rows, err := r.sql.QueryContext(ctx, query, userID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + result := make(map[int64]float64) + for rows.Next() { + var groupID int64 + var rate float64 + if err := rows.Scan(&groupID, &rate); err != nil { + return nil, err + } + result[groupID] = rate + } + if err := rows.Err(); err != nil { + return nil, err + } + return result, nil +} + +// GetByUserAndGroup 获取用户在特定分组的专属倍率 +func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) { + query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2` + var rate float64 + err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rate) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, err + } + return &rate, nil +} + +// SyncUserGroupRates 同步用户的分组专属倍率 +func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error { + if len(rates) == 0 { + // 如果传入空 map,删除该用户的所有专属倍率 + _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID) + return err + } + + // 分离需要删除和需要 upsert 的记录 + var toDelete []int64 + toUpsert := make(map[int64]float64) + for groupID, rate := range rates { + if rate == nil { + toDelete = append(toDelete, groupID) + } else { + toUpsert[groupID] = *rate + } + } + + // 删除指定的记录 + for _, groupID := range toDelete { + _, err := r.sql.ExecContext(ctx, + `DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`, + userID, groupID) + if err != nil { + return err + } + } + + // Upsert 记录 + now := time.Now() + for groupID, rate := range toUpsert { + _, err := r.sql.ExecContext(ctx, ` + INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at) + VALUES ($1, $2, $3, $4, $4) + ON CONFLICT (user_id, group_id) DO UPDATE SET rate_multiplier = $3, updated_at = $4 + `, userID, groupID, rate, now) + if err != nil { + return err + } + } + + return nil +} + +// DeleteByGroupID 删除指定分组的所有用户专属倍率 +func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error { + _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID) + return err +} + +// DeleteByUserID 删除指定用户的所有专属倍率 +func (r *userGroupRateRepository) DeleteByUserID(ctx context.Context, userID int64) error { + _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID) + return err +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 857ce3e8..5437de35 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -66,6 +66,7 @@ var ProviderSet = wire.NewSet( NewUserSubscriptionRepository, NewUserAttributeDefinitionRepository, NewUserAttributeValueRepository, + NewUserGroupRateRepository, // Cache implementations NewGatewayCache, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index e197b776..f5f8cda7 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -593,7 +593,7 @@ func newContractDeps(t *testing.T) *contractDeps { } userService := service.NewUserService(userRepo, nil) - apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg) + apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg) usageRepo := newStubUsageLogRepo() usageService := service.NewUsageService(usageRepo, userRepo, nil, nil) @@ -607,7 +607,7 @@ func newContractDeps(t *testing.T) *contractDeps { settingRepo := newStubSettingRepo() settingService := service.NewSettingService(settingRepo, cfg) - adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil) + adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) diff --git a/backend/internal/server/middleware/api_key_auth_google_test.go b/backend/internal/server/middleware/api_key_auth_google_test.go index c14582bd..38b93cb2 100644 --- a/backend/internal/server/middleware/api_key_auth_google_test.go +++ b/backend/internal/server/middleware/api_key_auth_google_test.go @@ -93,6 +93,7 @@ func newTestAPIKeyService(repo service.APIKeyRepository) *service.APIKeyService nil, // userRepo (unused in GetByKey) nil, // groupRepo nil, // userSubRepo + nil, // userGroupRateRepo nil, // cache &config.Config{}, ) @@ -187,6 +188,7 @@ func TestApiKeyAuthWithSubscriptionGoogleSetsGroupContext(t *testing.T) { nil, nil, nil, + nil, &config.Config{RunMode: config.RunModeSimple}, ) diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go index a03f6168..9d514818 100644 --- a/backend/internal/server/middleware/api_key_auth_test.go +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -59,7 +59,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) { cfg := &config.Config{RunMode: config.RunModeSimple} - apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil) router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) @@ -73,7 +73,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { t.Run("standard_mode_enforces_quota_check", func(t *testing.T) { cfg := &config.Config{RunMode: config.RunModeStandard} - apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) now := time.Now() sub := &service.UserSubscription{ @@ -150,7 +150,7 @@ func TestAPIKeyAuthSetsGroupContext(t *testing.T) { } cfg := &config.Config{RunMode: config.RunModeSimple} - apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) router := gin.New() router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg))) router.GET("/t", func(c *gin.Context) { @@ -208,7 +208,7 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) { } cfg := &config.Config{RunMode: config.RunModeSimple} - apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) router := gin.New() router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg))) diff --git a/backend/internal/server/routes/user.go b/backend/internal/server/routes/user.go index 5581e1e1..d0ed2489 100644 --- a/backend/internal/server/routes/user.go +++ b/backend/internal/server/routes/user.go @@ -49,6 +49,7 @@ func RegisterUserRoutes( groups := authenticated.Group("/groups") { groups.GET("/available", h.APIKey.GetAvailableGroups) + groups.GET("/rates", h.APIKey.GetUserGroupRates) } // 使用记录 diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index c512f235..f215f82e 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -93,6 +93,9 @@ type UpdateUserInput struct { Concurrency *int // 使用指针区分"未提供"和"设置为0" Status string AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组" + // GroupRates 用户专属分组倍率配置 + // map[groupID]*rate,nil 表示删除该分组的专属倍率 + GroupRates map[int64]*float64 } type CreateGroupInput struct { @@ -293,6 +296,7 @@ type adminServiceImpl struct { proxyRepo ProxyRepository apiKeyRepo APIKeyRepository redeemCodeRepo RedeemCodeRepository + userGroupRateRepo UserGroupRateRepository billingCacheService *BillingCacheService proxyProber ProxyExitInfoProber proxyLatencyCache ProxyLatencyCache @@ -307,6 +311,7 @@ func NewAdminService( proxyRepo ProxyRepository, apiKeyRepo APIKeyRepository, redeemCodeRepo RedeemCodeRepository, + userGroupRateRepo UserGroupRateRepository, billingCacheService *BillingCacheService, proxyProber ProxyExitInfoProber, proxyLatencyCache ProxyLatencyCache, @@ -319,6 +324,7 @@ func NewAdminService( proxyRepo: proxyRepo, apiKeyRepo: apiKeyRepo, redeemCodeRepo: redeemCodeRepo, + userGroupRateRepo: userGroupRateRepo, billingCacheService: billingCacheService, proxyProber: proxyProber, proxyLatencyCache: proxyLatencyCache, @@ -333,11 +339,35 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi if err != nil { return nil, 0, err } + // 批量加载用户专属分组倍率 + if s.userGroupRateRepo != nil && len(users) > 0 { + for i := range users { + rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID) + if err != nil { + log.Printf("failed to load user group rates: user_id=%d err=%v", users[i].ID, err) + continue + } + users[i].GroupRates = rates + } + } return users, result.Total, nil } func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) { - return s.userRepo.GetByID(ctx, id) + user, err := s.userRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + // 加载用户专属分组倍率 + if s.userGroupRateRepo != nil { + rates, err := s.userGroupRateRepo.GetByUserID(ctx, id) + if err != nil { + log.Printf("failed to load user group rates: user_id=%d err=%v", id, err) + } else { + user.GroupRates = rates + } + } + return user, nil } func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) { @@ -406,6 +436,14 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda if err := s.userRepo.Update(ctx, user); err != nil { return nil, err } + + // 同步用户专属分组倍率 + if input.GroupRates != nil && s.userGroupRateRepo != nil { + if err := s.userGroupRateRepo.SyncUserGroupRates(ctx, user.ID, input.GroupRates); err != nil { + log.Printf("failed to sync user group rates: user_id=%d err=%v", user.ID, err) + } + } + if s.authCacheInvalidator != nil { if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole { s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID) @@ -941,6 +979,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { if err != nil { return err } + // 注意:user_group_rate_multipliers 表通过外键 ON DELETE CASCADE 自动清理 // 事务成功后,异步失效受影响用户的订阅缓存 if len(affectedUserIDs) > 0 && s.billingCacheService != nil { diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index b27682f3..cb1dd60a 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -115,15 +115,16 @@ type UpdateAPIKeyRequest struct { // APIKeyService API Key服务 type APIKeyService struct { - apiKeyRepo APIKeyRepository - userRepo UserRepository - groupRepo GroupRepository - userSubRepo UserSubscriptionRepository - cache APIKeyCache - cfg *config.Config - authCacheL1 *ristretto.Cache - authCfg apiKeyAuthCacheConfig - authGroup singleflight.Group + apiKeyRepo APIKeyRepository + userRepo UserRepository + groupRepo GroupRepository + userSubRepo UserSubscriptionRepository + userGroupRateRepo UserGroupRateRepository + cache APIKeyCache + cfg *config.Config + authCacheL1 *ristretto.Cache + authCfg apiKeyAuthCacheConfig + authGroup singleflight.Group } // NewAPIKeyService 创建API Key服务实例 @@ -132,16 +133,18 @@ func NewAPIKeyService( userRepo UserRepository, groupRepo GroupRepository, userSubRepo UserSubscriptionRepository, + userGroupRateRepo UserGroupRateRepository, cache APIKeyCache, cfg *config.Config, ) *APIKeyService { svc := &APIKeyService{ - apiKeyRepo: apiKeyRepo, - userRepo: userRepo, - groupRepo: groupRepo, - userSubRepo: userSubRepo, - cache: cache, - cfg: cfg, + apiKeyRepo: apiKeyRepo, + userRepo: userRepo, + groupRepo: groupRepo, + userSubRepo: userSubRepo, + userGroupRateRepo: userGroupRateRepo, + cache: cache, + cfg: cfg, } svc.initAuthCache(cfg) return svc @@ -627,6 +630,19 @@ func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword return keys, nil } +// GetUserGroupRates 获取用户的专属分组倍率配置 +// 返回 map[groupID]rateMultiplier +func (s *APIKeyService) GetUserGroupRates(ctx context.Context, userID int64) (map[int64]float64, error) { + if s.userGroupRateRepo == nil { + return nil, nil + } + rates, err := s.userGroupRateRepo.GetByUserID(ctx, userID) + if err != nil { + return nil, fmt.Errorf("get user group rates: %w", err) + } + return rates, nil +} + // CheckAPIKeyQuotaAndExpiry checks if the API key is valid for use (not expired, quota not exhausted) // Returns nil if valid, error if invalid func (s *APIKeyService) CheckAPIKeyQuotaAndExpiry(apiKey *APIKey) error { diff --git a/backend/internal/service/api_key_service_cache_test.go b/backend/internal/service/api_key_service_cache_test.go index 1099b1d2..14ecbf39 100644 --- a/backend/internal/service/api_key_service_cache_test.go +++ b/backend/internal/service/api_key_service_cache_test.go @@ -167,7 +167,7 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) { NegativeTTLSeconds: 30, }, } - svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg) groupID := int64(9) cacheEntry := &APIKeyAuthCacheEntry{ @@ -223,7 +223,7 @@ func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) { NegativeTTLSeconds: 30, }, } - svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg) cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { return &APIKeyAuthCacheEntry{NotFound: true}, nil } @@ -256,7 +256,7 @@ func TestAPIKeyService_GetByKey_CacheMissStoresL2(t *testing.T) { NegativeTTLSeconds: 30, }, } - svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg) cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { return nil, redis.Nil } @@ -293,7 +293,7 @@ func TestAPIKeyService_GetByKey_UsesL1Cache(t *testing.T) { L1TTLSeconds: 60, }, } - svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg) require.NotNil(t, svc.authCacheL1) _, err := svc.GetByKey(context.Background(), "k-l1") @@ -320,7 +320,7 @@ func TestAPIKeyService_InvalidateAuthCacheByUserID(t *testing.T) { NegativeTTLSeconds: 30, }, } - svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg) svc.InvalidateAuthCacheByUserID(context.Background(), 7) require.Len(t, cache.deleteAuthKeys, 2) @@ -338,7 +338,7 @@ func TestAPIKeyService_InvalidateAuthCacheByGroupID(t *testing.T) { L2TTLSeconds: 60, }, } - svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg) svc.InvalidateAuthCacheByGroupID(context.Background(), 9) require.Len(t, cache.deleteAuthKeys, 2) @@ -356,7 +356,7 @@ func TestAPIKeyService_InvalidateAuthCacheByKey(t *testing.T) { L2TTLSeconds: 60, }, } - svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg) svc.InvalidateAuthCacheByKey(context.Background(), "k1") require.Len(t, cache.deleteAuthKeys, 1) @@ -375,7 +375,7 @@ func TestAPIKeyService_GetByKey_CachesNegativeOnRepoMiss(t *testing.T) { NegativeTTLSeconds: 30, }, } - svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg) cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { return nil, redis.Nil } @@ -411,7 +411,7 @@ func TestAPIKeyService_GetByKey_SingleflightCollapses(t *testing.T) { Singleflight: true, }, } - svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg) start := make(chan struct{}) wg := sync.WaitGroup{} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 8c88c0a9..9036955a 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -384,6 +384,7 @@ type GatewayService struct { usageLogRepo UsageLogRepository userRepo UserRepository userSubRepo UserSubscriptionRepository + userGroupRateRepo UserGroupRateRepository cache GatewayCache cfg *config.Config schedulerSnapshot *SchedulerSnapshotService @@ -405,6 +406,7 @@ func NewGatewayService( usageLogRepo UsageLogRepository, userRepo UserRepository, userSubRepo UserSubscriptionRepository, + userGroupRateRepo UserGroupRateRepository, cache GatewayCache, cfg *config.Config, schedulerSnapshot *SchedulerSnapshotService, @@ -424,6 +426,7 @@ func NewGatewayService( usageLogRepo: usageLogRepo, userRepo: userRepo, userSubRepo: userSubRepo, + userGroupRateRepo: userGroupRateRepo, cache: cache, cfg: cfg, schedulerSnapshot: schedulerSnapshot, @@ -4609,10 +4612,17 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu account := input.Account subscription := input.Subscription - // 获取费率倍数 + // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) multiplier := s.cfg.Default.RateMultiplier if apiKey.GroupID != nil && apiKey.Group != nil { multiplier = apiKey.Group.RateMultiplier + + // 检查用户专属倍率 + if s.userGroupRateRepo != nil { + if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil { + multiplier = *userRate + } + } } var cost *CostBreakdown @@ -4773,10 +4783,17 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * account := input.Account subscription := input.Subscription - // 获取费率倍数 + // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) multiplier := s.cfg.Default.RateMultiplier if apiKey.GroupID != nil && apiKey.Group != nil { multiplier = apiKey.Group.RateMultiplier + + // 检查用户专属倍率 + if s.userGroupRateRepo != nil { + if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil { + multiplier = *userRate + } + } } var cost *CostBreakdown diff --git a/backend/internal/service/user.go b/backend/internal/service/user.go index 0f589eb3..e56d83bf 100644 --- a/backend/internal/service/user.go +++ b/backend/internal/service/user.go @@ -21,6 +21,10 @@ type User struct { CreatedAt time.Time UpdatedAt time.Time + // GroupRates 用户专属分组倍率配置 + // map[groupID]rateMultiplier + GroupRates map[int64]float64 + // TOTP 双因素认证字段 TotpSecretEncrypted *string // AES-256-GCM 加密的 TOTP 密钥 TotpEnabled bool // 是否启用 TOTP @@ -40,18 +44,20 @@ func (u *User) IsActive() bool { // CanBindGroup checks whether a user can bind to a given group. // For standard groups: -// - If AllowedGroups is non-empty, only allow binding to IDs in that list. -// - If AllowedGroups is empty (nil or length 0), allow binding to any non-exclusive group. +// - Public groups (non-exclusive): all users can bind +// - Exclusive groups: only users with the group in AllowedGroups can bind func (u *User) CanBindGroup(groupID int64, isExclusive bool) bool { - if len(u.AllowedGroups) > 0 { - for _, id := range u.AllowedGroups { - if id == groupID { - return true - } - } - return false + // 公开分组(非专属):所有用户都可以绑定 + if !isExclusive { + return true } - return !isExclusive + // 专属分组:需要在 AllowedGroups 中 + for _, id := range u.AllowedGroups { + if id == groupID { + return true + } + } + return false } func (u *User) SetPassword(password string) error { diff --git a/backend/internal/service/user_group_rate.go b/backend/internal/service/user_group_rate.go new file mode 100644 index 00000000..9eb5f067 --- /dev/null +++ b/backend/internal/service/user_group_rate.go @@ -0,0 +1,25 @@ +package service + +import "context" + +// UserGroupRateRepository 用户专属分组倍率仓储接口 +// 允许管理员为特定用户设置分组的专属计费倍率,覆盖分组默认倍率 +type UserGroupRateRepository interface { + // GetByUserID 获取用户的所有专属分组倍率 + // 返回 map[groupID]rateMultiplier + GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) + + // GetByUserAndGroup 获取用户在特定分组的专属倍率 + // 如果未设置专属倍率,返回 nil + GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) + + // SyncUserGroupRates 同步用户的分组专属倍率 + // rates: map[groupID]*rateMultiplier,nil 表示删除该分组的专属倍率 + SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error + + // DeleteByGroupID 删除指定分组的所有用户专属倍率(分组删除时调用) + DeleteByGroupID(ctx context.Context, groupID int64) error + + // DeleteByUserID 删除指定用户的所有专属倍率(用户删除时调用) + DeleteByUserID(ctx context.Context, userID int64) error +} diff --git a/backend/migrations/047_add_user_group_rate_multipliers.sql b/backend/migrations/047_add_user_group_rate_multipliers.sql new file mode 100644 index 00000000..a37d5bcd --- /dev/null +++ b/backend/migrations/047_add_user_group_rate_multipliers.sql @@ -0,0 +1,19 @@ +-- 用户专属分组倍率表 +-- 允许管理员为特定用户设置分组的专属计费倍率,覆盖分组默认倍率 +CREATE TABLE IF NOT EXISTS user_group_rate_multipliers ( + user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + group_id BIGINT NOT NULL REFERENCES groups(id) ON DELETE CASCADE, + rate_multiplier DECIMAL(10,4) NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + PRIMARY KEY (user_id, group_id) +); + +-- 按 group_id 查询索引(删除分组时清理关联记录) +CREATE INDEX IF NOT EXISTS idx_user_group_rate_multipliers_group_id + ON user_group_rate_multipliers(group_id); + +COMMENT ON TABLE user_group_rate_multipliers IS '用户专属分组倍率配置'; +COMMENT ON COLUMN user_group_rate_multipliers.user_id IS '用户ID'; +COMMENT ON COLUMN user_group_rate_multipliers.group_id IS '分组ID'; +COMMENT ON COLUMN user_group_rate_multipliers.rate_multiplier IS '专属计费倍率(覆盖分组默认倍率)'; diff --git a/frontend/src/api/groups.ts b/frontend/src/api/groups.ts index 0f366d51..0963a7a6 100644 --- a/frontend/src/api/groups.ts +++ b/frontend/src/api/groups.ts @@ -18,8 +18,18 @@ export async function getAvailable(): Promise { return data } +/** + * Get current user's custom group rate multipliers + * @returns Map of group_id to custom rate_multiplier + */ +export async function getUserGroupRates(): Promise> { + const { data } = await apiClient.get | null>('/groups/rates') + return data || {} +} + export const userGroupsAPI = { - getAvailable + getAvailable, + getUserGroupRates } export default userGroupsAPI diff --git a/frontend/src/components/admin/user/UserAllowedGroupsModal.vue b/frontend/src/components/admin/user/UserAllowedGroupsModal.vue index 825d2be5..bccc22c7 100644 --- a/frontend/src/components/admin/user/UserAllowedGroupsModal.vue +++ b/frontend/src/components/admin/user/UserAllowedGroupsModal.vue @@ -1,59 +1,328 @@ + + diff --git a/frontend/src/components/common/GroupBadge.vue b/frontend/src/components/common/GroupBadge.vue index 239d0452..83f4b8aa 100644 --- a/frontend/src/components/common/GroupBadge.vue +++ b/frontend/src/components/common/GroupBadge.vue @@ -11,7 +11,14 @@ {{ name }} - {{ labelText }} + + @@ -27,6 +34,7 @@ interface Props { platform?: GroupPlatform subscriptionType?: SubscriptionType rateMultiplier?: number + userRateMultiplier?: number | null // 用户专属倍率 showRate?: boolean daysRemaining?: number | null // 剩余天数(订阅类型时使用) } @@ -34,20 +42,31 @@ interface Props { const props = withDefaults(defineProps(), { subscriptionType: 'standard', showRate: true, - daysRemaining: null + daysRemaining: null, + userRateMultiplier: null }) const { t } = useI18n() const isSubscription = computed(() => props.subscriptionType === 'subscription') +// 是否有专属倍率(且与默认倍率不同) +const hasCustomRate = computed(() => { + return ( + props.userRateMultiplier !== null && + props.userRateMultiplier !== undefined && + props.rateMultiplier !== undefined && + props.userRateMultiplier !== props.rateMultiplier + ) +}) + // 是否显示右侧标签 const showLabel = computed(() => { if (!props.showRate) return false // 订阅类型:显示天数或"订阅" if (isSubscription.value) return true - // 标准类型:显示倍率 - return props.rateMultiplier !== undefined + // 标准类型:显示倍率(包括专属倍率) + return props.rateMultiplier !== undefined || hasCustomRate.value }) // Label text @@ -71,7 +90,7 @@ const labelClass = computed(() => { const base = 'px-1.5 py-0.5 rounded text-[10px] font-semibold' if (!isSubscription.value) { - // Standard: subtle background + // Standard: subtle background (不再为专属倍率使用不同的背景色) return `${base} bg-black/10 dark:bg-white/10` } diff --git a/frontend/src/components/common/GroupOptionItem.vue b/frontend/src/components/common/GroupOptionItem.vue index 3283c330..44750350 100644 --- a/frontend/src/components/common/GroupOptionItem.vue +++ b/frontend/src/components/common/GroupOptionItem.vue @@ -9,6 +9,7 @@ :platform="platform" :subscription-type="subscriptionType" :rate-multiplier="rateMultiplier" + :user-rate-multiplier="userRateMultiplier" /> (), { subscriptionType: 'standard', selected: false, - showCheckmark: true + showCheckmark: true, + userRateMultiplier: null }) diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index fb255c1a..a4571b10 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -849,6 +849,16 @@ export default { allowedGroupsUpdated: 'Allowed groups updated successfully', failedToLoadGroups: 'Failed to load groups', failedToUpdateAllowedGroups: 'Failed to update allowed groups', + // User Group Configuration + groupConfig: 'User Group Configuration', + groupConfigHint: 'Configure custom rate multipliers for user {email} (overrides group defaults)', + exclusiveGroups: 'Exclusive Groups', + publicGroups: 'Public Groups (Default Available)', + defaultRate: 'Default Rate', + customRate: 'Custom Rate', + useDefaultRate: 'Use Default', + customRatePlaceholder: 'Leave empty for default', + groupConfigUpdated: 'Group configuration updated successfully', deposit: 'Deposit', withdraw: 'Withdraw', depositAmount: 'Deposit Amount', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index e964aae2..8c6b1d91 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -910,6 +910,16 @@ export default { allowedGroupsUpdated: '允许分组更新成功', failedToLoadGroups: '加载分组列表失败', failedToUpdateAllowedGroups: '更新允许分组失败', + // 用户分组配置 + groupConfig: '用户分组配置', + groupConfigHint: '为用户 {email} 配置专属分组倍率(覆盖分组默认倍率)', + exclusiveGroups: '专属分组', + publicGroups: '公开分组(默认可用)', + defaultRate: '默认倍率', + customRate: '专属倍率', + useDefaultRate: '使用默认', + customRatePlaceholder: '留空使用默认', + groupConfigUpdated: '分组配置更新成功', deposit: '充值', withdraw: '退款', depositAmount: '充值金额', diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index eb53de44..a87ae4ca 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -41,6 +41,8 @@ export interface User { export interface AdminUser extends User { // 管理员备注(普通用户接口不返回) notes: string + // 用户专属分组倍率配置 (group_id -> rate_multiplier) + group_rates?: Record } export interface LoginRequest { @@ -966,6 +968,9 @@ export interface UpdateUserRequest { concurrency?: number status?: 'active' | 'disabled' allowed_groups?: number[] | null + // 用户专属分组倍率配置 (group_id -> rate_multiplier | null) + // null 表示删除该分组的专属倍率 + group_rates?: Record } export interface ChangePasswordRequest { diff --git a/frontend/src/views/user/KeysView.vue b/frontend/src/views/user/KeysView.vue index 51b015fa..80a64f2e 100644 --- a/frontend/src/views/user/KeysView.vue +++ b/frontend/src/views/user/KeysView.vue @@ -73,6 +73,7 @@ :platform="row.group.platform" :subscription-type="row.group.subscription_type" :rate-multiplier="row.group.rate_multiplier" + :user-rate-multiplier="userGroupRates[row.group.id]" /> {{ t('keys.noGroup') @@ -272,6 +273,7 @@ :platform="(option as unknown as GroupOption).platform" :subscription-type="(option as unknown as GroupOption).subscriptionType" :rate-multiplier="(option as unknown as GroupOption).rate" + :user-rate-multiplier="(option as unknown as GroupOption).userRate" /> {{ t('keys.selectGroup') }} @@ -281,6 +283,7 @@ :platform="(option as unknown as GroupOption).platform" :subscription-type="(option as unknown as GroupOption).subscriptionType" :rate-multiplier="(option as unknown as GroupOption).rate" + :user-rate-multiplier="(option as unknown as GroupOption).userRate" :description="(option as unknown as GroupOption).description" :selected="selected" /> @@ -667,6 +670,7 @@ :platform="option.platform" :subscription-type="option.subscriptionType" :rate-multiplier="option.rate" + :user-rate-multiplier="option.userRate" :description="option.description" :selected=" selectedKeyForGroup?.group_id === option.value || @@ -718,6 +722,7 @@ interface GroupOption { label: string description: string | null rate: number + userRate: number | null subscriptionType: SubscriptionType platform: GroupPlatform } @@ -742,6 +747,7 @@ const groups = ref([]) const loading = ref(false) const submitting = ref(false) const usageStats = ref>({}) +const userGroupRates = ref>({}) const pagination = ref({ page: 1, @@ -825,6 +831,7 @@ const groupOptions = computed(() => label: group.name, description: group.description, rate: group.rate_multiplier, + userRate: userGroupRates.value[group.id] ?? null, subscriptionType: group.subscription_type, platform: group.platform })) @@ -899,6 +906,14 @@ const loadGroups = async () => { } } +const loadUserGroupRates = async () => { + try { + userGroupRates.value = await userGroupsAPI.getUserGroupRates() + } catch (error) { + console.error('Failed to load user group rates:', error) + } +} + const loadPublicSettings = async () => { try { publicSettings.value = await authAPI.getPublicSettings() @@ -1268,6 +1283,7 @@ const closeCcsClientSelect = () => { onMounted(() => { loadApiKeys() loadGroups() + loadUserGroupRates() loadPublicSettings() document.addEventListener('click', closeGroupSelector) }) From 1d8b686446cc5374ddb1b192116a9afd8395b197 Mon Sep 17 00:00:00 2001 From: shaw Date: Thu, 5 Feb 2026 16:17:11 +0800 Subject: [PATCH 10/14] =?UTF-8?q?chore:=20=E7=A7=BB=E9=99=A4=E6=97=A0?= =?UTF-8?q?=E5=85=B3=E7=9A=84md=E6=96=87=E6=A1=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- PR_DESCRIPTION.md | 164 ---------------------------------------------- 1 file changed, 164 deletions(-) delete mode 100644 PR_DESCRIPTION.md diff --git a/PR_DESCRIPTION.md b/PR_DESCRIPTION.md deleted file mode 100644 index b240f45c..00000000 --- a/PR_DESCRIPTION.md +++ /dev/null @@ -1,164 +0,0 @@ -## 概述 - -全面增强运维监控系统(Ops)的错误日志管理和告警静默功能,优化前端 UI 组件代码质量和用户体验。本次更新重构了核心服务层和数据访问层,提升系统可维护性和运维效率。 - -## 主要改动 - -### 1. 错误日志查询优化 - -**功能特性:** -- 新增 GetErrorLogByID 接口,支持按 ID 精确查询错误详情 -- 优化错误日志过滤逻辑,支持多维度筛选(平台、阶段、来源、所有者等) -- 改进查询参数处理,简化代码结构 -- 增强错误分类和标准化处理 -- 支持错误解决状态追踪(resolved 字段) - -**技术实现:** -- `ops_handler.go` - 新增单条错误日志查询接口 -- `ops_repo.go` - 优化数据查询和过滤条件构建 -- `ops_models.go` - 扩展错误日志数据模型 -- 前端 API 接口同步更新 - -### 2. 告警静默功能 - -**功能特性:** -- 支持按规则、平台、分组、区域等维度静默告警 -- 可设置静默时长和原因说明 -- 静默记录可追溯,记录创建人和创建时间 -- 自动过期机制,避免永久静默 - -**技术实现:** -- `037_ops_alert_silences.sql` - 新增告警静默表 -- `ops_alerts.go` - 告警静默逻辑实现 -- `ops_alerts_handler.go` - 告警静默 API 接口 -- `OpsAlertEventsCard.vue` - 前端告警静默操作界面 - -**数据库结构:** - -| 字段 | 类型 | 说明 | -|------|------|------| -| rule_id | BIGINT | 告警规则 ID | -| platform | VARCHAR(64) | 平台标识 | -| group_id | BIGINT | 分组 ID(可选) | -| region | VARCHAR(64) | 区域(可选) | -| until | TIMESTAMPTZ | 静默截止时间 | -| reason | TEXT | 静默原因 | -| created_by | BIGINT | 创建人 ID | - -### 3. 错误分类标准化 - -**功能特性:** -- 统一错误阶段分类(request|auth|routing|upstream|network|internal) -- 规范错误归属分类(client|provider|platform) -- 标准化错误来源分类(client_request|upstream_http|gateway) -- 自动迁移历史数据到新分类体系 - -**技术实现:** -- `038_ops_errors_resolution_retry_results_and_standardize_classification.sql` - 分类标准化迁移 -- 自动映射历史遗留分类到新标准 -- 自动解决已恢复的上游错误(客户端状态码 < 400) - -### 4. Gateway 服务集成 - -**功能特性:** -- 完善各 Gateway 服务的 Ops 集成 -- 统一错误日志记录接口 -- 增强上游错误追踪能力 - -**涉及服务:** -- `antigravity_gateway_service.go` - Antigravity 网关集成 -- `gateway_service.go` - 通用网关集成 -- `gemini_messages_compat_service.go` - Gemini 兼容层集成 -- `openai_gateway_service.go` - OpenAI 网关集成 - -### 5. 前端 UI 优化 - -**代码重构:** -- 大幅简化错误详情模态框代码(从 828 行优化到 450 行) -- 优化错误日志表格组件,提升可读性 -- 清理未使用的 i18n 翻译,减少冗余 -- 统一组件代码风格和格式 -- 优化骨架屏组件,更好匹配实际看板布局 - -**布局改进:** -- 修复模态框内容溢出和滚动问题 -- 优化表格布局,使用 flex 布局确保正确显示 -- 改进看板头部布局和交互 -- 提升响应式体验 -- 骨架屏支持全屏模式适配 - -**交互优化:** -- 优化告警事件卡片功能和展示 -- 改进错误详情展示逻辑 -- 增强请求详情模态框 -- 完善运行时设置卡片 -- 改进加载动画效果 - -### 6. 国际化完善 - -**文案补充:** -- 补充错误日志相关的英文翻译 -- 添加告警静默功能的中英文文案 -- 完善提示文本和错误信息 -- 统一术语翻译标准 - -## 文件变更 - -**后端(26 个文件):** -- `backend/internal/handler/admin/ops_alerts_handler.go` - 告警接口增强 -- `backend/internal/handler/admin/ops_handler.go` - 错误日志接口优化 -- `backend/internal/handler/ops_error_logger.go` - 错误记录器增强 -- `backend/internal/repository/ops_repo.go` - 数据访问层重构 -- `backend/internal/repository/ops_repo_alerts.go` - 告警数据访问增强 -- `backend/internal/service/ops_*.go` - 核心服务层重构(10 个文件) -- `backend/internal/service/*_gateway_service.go` - Gateway 集成(4 个文件) -- `backend/internal/server/routes/admin.go` - 路由配置更新 -- `backend/migrations/*.sql` - 数据库迁移(2 个文件) -- 测试文件更新(5 个文件) - -**前端(13 个文件):** -- `frontend/src/views/admin/ops/OpsDashboard.vue` - 看板主页优化 -- `frontend/src/views/admin/ops/components/*.vue` - 组件重构(10 个文件) -- `frontend/src/api/admin/ops.ts` - API 接口扩展 -- `frontend/src/i18n/locales/*.ts` - 国际化文本(2 个文件) - -## 代码统计 - -- 44 个文件修改 -- 3733 行新增 -- 995 行删除 -- 净增加 2738 行 - -## 核心改进 - -**可维护性提升:** -- 重构核心服务层,职责更清晰 -- 简化前端组件代码,降低复杂度 -- 统一代码风格和命名规范 -- 清理冗余代码和未使用的翻译 -- 标准化错误分类体系 - -**功能完善:** -- 告警静默功能,减少告警噪音 -- 错误日志查询优化,提升运维效率 -- Gateway 服务集成完善,统一监控能力 -- 错误解决状态追踪,便于问题管理 - -**用户体验优化:** -- 修复多个 UI 布局问题 -- 优化交互流程 -- 完善国际化支持 -- 提升响应式体验 -- 改进加载状态展示 - -## 测试验证 - -- ✅ 错误日志查询和过滤功能 -- ✅ 告警静默创建和自动过期 -- ✅ 错误分类标准化迁移 -- ✅ Gateway 服务错误日志记录 -- ✅ 前端组件布局和交互 -- ✅ 骨架屏全屏模式适配 -- ✅ 国际化文本完整性 -- ✅ API 接口功能正确性 -- ✅ 数据库迁移执行成功 From d2527e36eb8a6b57bbf1f9c1629da206bc067900 Mon Sep 17 00:00:00 2001 From: ianshaw Date: Thu, 5 Feb 2026 20:13:06 +0800 Subject: [PATCH 11/14] =?UTF-8?q?feat(gemini):=20=E5=A2=9E=E5=BC=BA=20API?= =?UTF-8?q?=20=E6=8E=88=E6=9D=83=E9=94=99=E8=AF=AF=E5=A4=84=E7=90=86?= =?UTF-8?q?=EF=BC=8C=E8=87=AA=E5=8A=A8=E6=8F=90=E5=8F=96=E5=B9=B6=E6=98=BE?= =?UTF-8?q?=E7=A4=BA=E6=BF=80=E6=B4=BB=20URL?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 当 Gemini for Google Cloud API 未启用时(SERVICE_DISABLED 错误), 系统现在会: - 自动检测 403 PERMISSION_DENIED 错误 - 从错误响应中提取 API 激活 URL - 向用户显示清晰的错误消息和可点击的激活链接 - 提供操作指引(启用后等待几分钟) 新增文件: - internal/pkg/googleapi/error.go: Google API 错误解析器 - internal/pkg/googleapi/error_test.go: 完整的测试覆盖 - GEMINI_API_ERROR_HANDLING.md: 实现文档 修改文件: - internal/repository/geminicli_codeassist_client.go: 在 LoadCodeAssist 和 OnboardUser 中增强错误处理 这大大改善了用户体验,用户不再需要手动从错误日志中查找激活 URL。 --- backend/internal/pkg/googleapi/error.go | 109 +++++++++++++ backend/internal/pkg/googleapi/error_test.go | 143 ++++++++++++++++++ .../repository/geminicli_codeassist_client.go | 35 ++++- 3 files changed, 281 insertions(+), 6 deletions(-) create mode 100644 backend/internal/pkg/googleapi/error.go create mode 100644 backend/internal/pkg/googleapi/error_test.go diff --git a/backend/internal/pkg/googleapi/error.go b/backend/internal/pkg/googleapi/error.go new file mode 100644 index 00000000..b6374e02 --- /dev/null +++ b/backend/internal/pkg/googleapi/error.go @@ -0,0 +1,109 @@ +// Package googleapi provides helpers for Google-style API responses. +package googleapi + +import ( + "encoding/json" + "fmt" + "strings" +) + +// ErrorResponse represents a Google API error response +type ErrorResponse struct { + Error ErrorDetail `json:"error"` +} + +// ErrorDetail contains the error details from Google API +type ErrorDetail struct { + Code int `json:"code"` + Message string `json:"message"` + Status string `json:"status"` + Details []json.RawMessage `json:"details,omitempty"` +} + +// ErrorDetailInfo contains additional error information +type ErrorDetailInfo struct { + Type string `json:"@type"` + Reason string `json:"reason,omitempty"` + Domain string `json:"domain,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +// ErrorHelp contains help links +type ErrorHelp struct { + Type string `json:"@type"` + Links []HelpLink `json:"links,omitempty"` +} + +// HelpLink represents a help link +type HelpLink struct { + Description string `json:"description"` + URL string `json:"url"` +} + +// ParseError parses a Google API error response and extracts key information +func ParseError(body string) (*ErrorResponse, error) { + var errResp ErrorResponse + if err := json.Unmarshal([]byte(body), &errResp); err != nil { + return nil, fmt.Errorf("failed to parse error response: %w", err) + } + return &errResp, nil +} + +// ExtractActivationURL extracts the API activation URL from error details +func ExtractActivationURL(body string) string { + var errResp ErrorResponse + if err := json.Unmarshal([]byte(body), &errResp); err != nil { + return "" + } + + // Check error details for activation URL + for _, detailRaw := range errResp.Error.Details { + // Parse as ErrorDetailInfo + var info ErrorDetailInfo + if err := json.Unmarshal(detailRaw, &info); err == nil { + if info.Metadata != nil { + if activationURL, ok := info.Metadata["activationUrl"]; ok && activationURL != "" { + return activationURL + } + } + } + + // Parse as ErrorHelp + var help ErrorHelp + if err := json.Unmarshal(detailRaw, &help); err == nil { + for _, link := range help.Links { + if strings.Contains(link.Description, "activation") || + strings.Contains(link.Description, "API activation") || + strings.Contains(link.URL, "/apis/api/") { + return link.URL + } + } + } + } + + return "" +} + +// IsServiceDisabledError checks if the error is a SERVICE_DISABLED error +func IsServiceDisabledError(body string) bool { + var errResp ErrorResponse + if err := json.Unmarshal([]byte(body), &errResp); err != nil { + return false + } + + // Check if it's a 403 PERMISSION_DENIED with SERVICE_DISABLED reason + if errResp.Error.Code != 403 || errResp.Error.Status != "PERMISSION_DENIED" { + return false + } + + for _, detailRaw := range errResp.Error.Details { + var info ErrorDetailInfo + if err := json.Unmarshal(detailRaw, &info); err == nil { + if info.Reason == "SERVICE_DISABLED" { + return true + } + } + } + + return false +} diff --git a/backend/internal/pkg/googleapi/error_test.go b/backend/internal/pkg/googleapi/error_test.go new file mode 100644 index 00000000..992dcf85 --- /dev/null +++ b/backend/internal/pkg/googleapi/error_test.go @@ -0,0 +1,143 @@ +package googleapi + +import ( + "testing" +) + +func TestExtractActivationURL(t *testing.T) { + // Test case from the user's error message + errorBody := `{ + "error": { + "code": 403, + "message": "Gemini for Google Cloud API has not been used in project project-6eca5881-ab73-4736-843 before or it is disabled. Enable it by visiting https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843 then retry. If you enabled this API recently, wait a few minutes for the action to propagate to our systems and retry.", + "status": "PERMISSION_DENIED", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "SERVICE_DISABLED", + "domain": "googleapis.com", + "metadata": { + "service": "cloudaicompanion.googleapis.com", + "activationUrl": "https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843", + "consumer": "projects/project-6eca5881-ab73-4736-843", + "serviceTitle": "Gemini for Google Cloud API", + "containerInfo": "project-6eca5881-ab73-4736-843" + } + }, + { + "@type": "type.googleapis.com/google.rpc.LocalizedMessage", + "locale": "en-US", + "message": "Gemini for Google Cloud API has not been used in project project-6eca5881-ab73-4736-843 before or it is disabled. Enable it by visiting https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843 then retry. If you enabled this API recently, wait a few minutes for the action to propagate to our systems and retry." + }, + { + "@type": "type.googleapis.com/google.rpc.Help", + "links": [ + { + "description": "Google developers console API activation", + "url": "https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843" + } + ] + } + ] + } + }` + + activationURL := ExtractActivationURL(errorBody) + expectedURL := "https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843" + + if activationURL != expectedURL { + t.Errorf("Expected activation URL %s, got %s", expectedURL, activationURL) + } +} + +func TestIsServiceDisabledError(t *testing.T) { + tests := []struct { + name string + body string + expected bool + }{ + { + name: "SERVICE_DISABLED error", + body: `{ + "error": { + "code": 403, + "status": "PERMISSION_DENIED", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "SERVICE_DISABLED" + } + ] + } + }`, + expected: true, + }, + { + name: "Other 403 error", + body: `{ + "error": { + "code": 403, + "status": "PERMISSION_DENIED", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "OTHER_REASON" + } + ] + } + }`, + expected: false, + }, + { + name: "404 error", + body: `{ + "error": { + "code": 404, + "status": "NOT_FOUND" + } + }`, + expected: false, + }, + { + name: "Invalid JSON", + body: `invalid json`, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := IsServiceDisabledError(tt.body) + if result != tt.expected { + t.Errorf("Expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestParseError(t *testing.T) { + errorBody := `{ + "error": { + "code": 403, + "message": "API not enabled", + "status": "PERMISSION_DENIED" + } + }` + + errResp, err := ParseError(errorBody) + if err != nil { + t.Fatalf("Failed to parse error: %v", err) + } + + if errResp.Error.Code != 403 { + t.Errorf("Expected code 403, got %d", errResp.Error.Code) + } + + if errResp.Error.Status != "PERMISSION_DENIED" { + t.Errorf("Expected status PERMISSION_DENIED, got %s", errResp.Error.Status) + } + + if errResp.Error.Message != "API not enabled" { + t.Errorf("Expected message 'API not enabled', got %s", errResp.Error.Message) + } +} diff --git a/backend/internal/repository/geminicli_codeassist_client.go b/backend/internal/repository/geminicli_codeassist_client.go index d7f54e85..b63be1ad 100644 --- a/backend/internal/repository/geminicli_codeassist_client.go +++ b/backend/internal/repository/geminicli_codeassist_client.go @@ -6,6 +6,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" + "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/imroc/req/v3" @@ -38,9 +39,20 @@ func (c *geminiCliCodeAssistClient) LoadCodeAssist(ctx context.Context, accessTo return nil, fmt.Errorf("request failed: %w", err) } if !resp.IsSuccessState() { - body := geminicli.SanitizeBodyForLogs(resp.String()) - fmt.Printf("[CodeAssist] LoadCodeAssist failed: status %d, body: %s\n", resp.StatusCode, body) - return nil, fmt.Errorf("loadCodeAssist failed: status %d, body: %s", resp.StatusCode, body) + body := resp.String() + sanitizedBody := geminicli.SanitizeBodyForLogs(body) + fmt.Printf("[CodeAssist] LoadCodeAssist failed: status %d, body: %s\n", resp.StatusCode, sanitizedBody) + + // Check if this is a SERVICE_DISABLED error and extract activation URL + if googleapi.IsServiceDisabledError(body) { + activationURL := googleapi.ExtractActivationURL(body) + if activationURL != "" { + return nil, fmt.Errorf("Gemini for Google Cloud API is not enabled for this project. Please enable it by visiting: %s\n\nAfter enabling the API, wait a few minutes for the changes to propagate, then try again", activationURL) + } + return nil, fmt.Errorf("Gemini for Google Cloud API is not enabled for this project. Please enable it in the Google Cloud Console at: https://console.cloud.google.com/apis/library/cloudaicompanion.googleapis.com") + } + + return nil, fmt.Errorf("loadCodeAssist failed: status %d, body: %s", resp.StatusCode, sanitizedBody) } fmt.Printf("[CodeAssist] LoadCodeAssist success: status %d, response: %+v\n", resp.StatusCode, out) return &out, nil @@ -67,9 +79,20 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken return nil, fmt.Errorf("request failed: %w", err) } if !resp.IsSuccessState() { - body := geminicli.SanitizeBodyForLogs(resp.String()) - fmt.Printf("[CodeAssist] OnboardUser failed: status %d, body: %s\n", resp.StatusCode, body) - return nil, fmt.Errorf("onboardUser failed: status %d, body: %s", resp.StatusCode, body) + body := resp.String() + sanitizedBody := geminicli.SanitizeBodyForLogs(body) + fmt.Printf("[CodeAssist] OnboardUser failed: status %d, body: %s\n", resp.StatusCode, sanitizedBody) + + // Check if this is a SERVICE_DISABLED error and extract activation URL + if googleapi.IsServiceDisabledError(body) { + activationURL := googleapi.ExtractActivationURL(body) + if activationURL != "" { + return nil, fmt.Errorf("Gemini for Google Cloud API is not enabled for this project. Please enable it by visiting: %s\n\nAfter enabling the API, wait a few minutes for the changes to propagate, then try again", activationURL) + } + return nil, fmt.Errorf("Gemini for Google Cloud API is not enabled for this project. Please enable it in the Google Cloud Console at: https://console.cloud.google.com/apis/library/cloudaicompanion.googleapis.com") + } + + return nil, fmt.Errorf("onboardUser failed: status %d, body: %s", resp.StatusCode, sanitizedBody) } fmt.Printf("[CodeAssist] OnboardUser success: status %d, response: %+v\n", resp.StatusCode, out) return &out, nil From 7b46bbb6286afd1f07d9b9fe39e6051ce0e5f0e3 Mon Sep 17 00:00:00 2001 From: ianshaw Date: Thu, 5 Feb 2026 20:47:15 +0800 Subject: [PATCH 12/14] =?UTF-8?q?fix(lint):=20=E4=BF=AE=E5=A4=8D=E9=94=99?= =?UTF-8?q?=E8=AF=AF=E6=B6=88=E6=81=AF=E5=A4=A7=E5=86=99=E9=97=AE=E9=A2=98?= =?UTF-8?q?=E4=BB=A5=E7=AC=A6=E5=90=88=20Go=20=E6=83=AF=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../internal/repository/geminicli_codeassist_client.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/backend/internal/repository/geminicli_codeassist_client.go b/backend/internal/repository/geminicli_codeassist_client.go index b63be1ad..4f63280d 100644 --- a/backend/internal/repository/geminicli_codeassist_client.go +++ b/backend/internal/repository/geminicli_codeassist_client.go @@ -47,9 +47,9 @@ func (c *geminiCliCodeAssistClient) LoadCodeAssist(ctx context.Context, accessTo if googleapi.IsServiceDisabledError(body) { activationURL := googleapi.ExtractActivationURL(body) if activationURL != "" { - return nil, fmt.Errorf("Gemini for Google Cloud API is not enabled for this project. Please enable it by visiting: %s\n\nAfter enabling the API, wait a few minutes for the changes to propagate, then try again", activationURL) + return nil, fmt.Errorf("gemini API not enabled for this project, please enable it by visiting: %s\n\nAfter enabling the API, wait a few minutes for the changes to propagate, then try again", activationURL) } - return nil, fmt.Errorf("Gemini for Google Cloud API is not enabled for this project. Please enable it in the Google Cloud Console at: https://console.cloud.google.com/apis/library/cloudaicompanion.googleapis.com") + return nil, fmt.Errorf("gemini API not enabled for this project, please enable it in the Google Cloud Console at: https://console.cloud.google.com/apis/library/cloudaicompanion.googleapis.com") } return nil, fmt.Errorf("loadCodeAssist failed: status %d, body: %s", resp.StatusCode, sanitizedBody) @@ -87,9 +87,9 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken if googleapi.IsServiceDisabledError(body) { activationURL := googleapi.ExtractActivationURL(body) if activationURL != "" { - return nil, fmt.Errorf("Gemini for Google Cloud API is not enabled for this project. Please enable it by visiting: %s\n\nAfter enabling the API, wait a few minutes for the changes to propagate, then try again", activationURL) + return nil, fmt.Errorf("gemini API not enabled for this project, please enable it by visiting: %s\n\nAfter enabling the API, wait a few minutes for the changes to propagate, then try again", activationURL) } - return nil, fmt.Errorf("Gemini for Google Cloud API is not enabled for this project. Please enable it in the Google Cloud Console at: https://console.cloud.google.com/apis/library/cloudaicompanion.googleapis.com") + return nil, fmt.Errorf("gemini API not enabled for this project, please enable it in the Google Cloud Console at: https://console.cloud.google.com/apis/library/cloudaicompanion.googleapis.com") } return nil, fmt.Errorf("onboardUser failed: status %d, body: %s", resp.StatusCode, sanitizedBody) From 39e05a2dad412b161d9feec39cd7e9bbb17e3213 Mon Sep 17 00:00:00 2001 From: shaw Date: Thu, 5 Feb 2026 21:52:54 +0800 Subject: [PATCH 13/14] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9E=E5=85=A8?= =?UTF-8?q?=E5=B1=80=E9=94=99=E8=AF=AF=E9=80=8F=E4=BC=A0=E8=A7=84=E5=88=99?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 支持管理员配置上游错误如何返回给客户端: - 新增 ErrorPassthroughRule 数据模型和 Ent Schema - 实现规则的 CRUD API(/admin/error-passthrough-rules) - 支持按错误码、关键词匹配,支持 any/all 匹配模式 - 支持按平台过滤(anthropic/openai/gemini/antigravity) - 支持透传或自定义响应状态码和错误消息 - 实现两级缓存(Redis + 本地内存)和多实例同步 - 集成到 gateway_handler 的错误处理流程 - 新增前端管理界面组件 - 新增单元测试覆盖核心匹配逻辑 优化: - 移除 refreshLocalCache 中的冗余排序(数据库已排序) - 后端 Validate() 增加匹配条件非空校验 --- backend/cmd/server/wire_gen.go | 10 +- backend/ent/client.go | 171 +- backend/ent/ent.go | 2 + backend/ent/errorpassthroughrule.go | 269 ++++ .../errorpassthroughrule.go | 161 ++ backend/ent/errorpassthroughrule/where.go | 635 ++++++++ backend/ent/errorpassthroughrule_create.go | 1382 +++++++++++++++++ backend/ent/errorpassthroughrule_delete.go | 88 ++ backend/ent/errorpassthroughrule_query.go | 564 +++++++ backend/ent/errorpassthroughrule_update.go | 823 ++++++++++ backend/ent/hook/hook.go | 12 + backend/ent/intercept/intercept.go | 30 + backend/ent/migrate/schema.go | 40 + backend/ent/mutation.go | 1268 +++++++++++++++ backend/ent/predicate/predicate.go | 3 + backend/ent/runtime/runtime.go | 56 + backend/ent/schema/error_passthrough_rule.go | 121 ++ backend/ent/tx.go | 3 + .../admin/error_passthrough_handler.go | 273 ++++ backend/internal/handler/gateway_handler.go | 59 +- .../internal/handler/gemini_v1beta_handler.go | 41 +- backend/internal/handler/handler.go | 1 + .../handler/openai_gateway_handler.go | 68 +- backend/internal/handler/wire.go | 3 + .../internal/model/error_passthrough_rule.go | 74 + .../repository/error_passthrough_cache.go | 128 ++ .../repository/error_passthrough_repo.go | 178 +++ backend/internal/repository/wire.go | 2 + backend/internal/server/routes/admin.go | 14 + .../service/antigravity_gateway_service.go | 7 +- .../service/error_passthrough_service.go | 300 ++++ .../service/error_passthrough_service_test.go | 755 +++++++++ backend/internal/service/gateway_service.go | 19 +- .../service/gemini_messages_compat_service.go | 8 +- .../service/openai_gateway_service.go | 4 +- backend/internal/service/wire.go | 1 + .../048_add_error_passthrough_rules.sql | 24 + frontend/src/api/admin/errorPassthrough.ts | 134 ++ frontend/src/api/admin/index.ts | 8 +- .../admin/ErrorPassthroughRulesModal.vue | 623 ++++++++ frontend/src/i18n/locales/en.ts | 74 + frontend/src/i18n/locales/zh.ts | 74 + frontend/src/views/admin/AccountsView.vue | 13 + 43 files changed, 8456 insertions(+), 67 deletions(-) create mode 100644 backend/ent/errorpassthroughrule.go create mode 100644 backend/ent/errorpassthroughrule/errorpassthroughrule.go create mode 100644 backend/ent/errorpassthroughrule/where.go create mode 100644 backend/ent/errorpassthroughrule_create.go create mode 100644 backend/ent/errorpassthroughrule_delete.go create mode 100644 backend/ent/errorpassthroughrule_query.go create mode 100644 backend/ent/errorpassthroughrule_update.go create mode 100644 backend/ent/schema/error_passthrough_rule.go create mode 100644 backend/internal/handler/admin/error_passthrough_handler.go create mode 100644 backend/internal/model/error_passthrough_rule.go create mode 100644 backend/internal/repository/error_passthrough_cache.go create mode 100644 backend/internal/repository/error_passthrough_repo.go create mode 100644 backend/internal/service/error_passthrough_service.go create mode 100644 backend/internal/service/error_passthrough_service_test.go create mode 100644 backend/migrations/048_add_error_passthrough_rules.sql create mode 100644 frontend/src/api/admin/errorPassthrough.ts create mode 100644 frontend/src/components/admin/ErrorPassthroughRulesModal.vue diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 3ca86f91..8184bc1c 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -174,9 +174,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { userAttributeValueRepository := repository.NewUserAttributeValueRepository(client) userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository) userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService) - adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler) - gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, configConfig) - openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, configConfig) + errorPassthroughRepository := repository.NewErrorPassthroughRepository(client) + errorPassthroughCache := repository.NewErrorPassthroughCache(redisClient) + errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache) + errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService) + adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler) + gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, errorPassthroughService, configConfig) + openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, errorPassthroughService, configConfig) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) totpHandler := handler.NewTotpHandler(totpService) handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler) diff --git a/backend/ent/client.go b/backend/ent/client.go index a17721da..a791c081 100644 --- a/backend/ent/client.go +++ b/backend/ent/client.go @@ -20,6 +20,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" @@ -52,6 +53,8 @@ type Client struct { Announcement *AnnouncementClient // AnnouncementRead is the client for interacting with the AnnouncementRead builders. AnnouncementRead *AnnouncementReadClient + // ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders. + ErrorPassthroughRule *ErrorPassthroughRuleClient // Group is the client for interacting with the Group builders. Group *GroupClient // PromoCode is the client for interacting with the PromoCode builders. @@ -94,6 +97,7 @@ func (c *Client) init() { c.AccountGroup = NewAccountGroupClient(c.config) c.Announcement = NewAnnouncementClient(c.config) c.AnnouncementRead = NewAnnouncementReadClient(c.config) + c.ErrorPassthroughRule = NewErrorPassthroughRuleClient(c.config) c.Group = NewGroupClient(c.config) c.PromoCode = NewPromoCodeClient(c.config) c.PromoCodeUsage = NewPromoCodeUsageClient(c.config) @@ -204,6 +208,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) { AccountGroup: NewAccountGroupClient(cfg), Announcement: NewAnnouncementClient(cfg), AnnouncementRead: NewAnnouncementReadClient(cfg), + ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), Group: NewGroupClient(cfg), PromoCode: NewPromoCodeClient(cfg), PromoCodeUsage: NewPromoCodeUsageClient(cfg), @@ -241,6 +246,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) AccountGroup: NewAccountGroupClient(cfg), Announcement: NewAnnouncementClient(cfg), AnnouncementRead: NewAnnouncementReadClient(cfg), + ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), Group: NewGroupClient(cfg), PromoCode: NewPromoCodeClient(cfg), PromoCodeUsage: NewPromoCodeUsageClient(cfg), @@ -284,9 +290,10 @@ func (c *Client) Close() error { func (c *Client) Use(hooks ...Hook) { for _, n := range []interface{ Use(...Hook) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, - c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.Setting, - c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup, - c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, + c.ErrorPassthroughRule, c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy, + c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User, + c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, + c.UserSubscription, } { n.Use(hooks...) } @@ -297,9 +304,10 @@ func (c *Client) Use(hooks ...Hook) { func (c *Client) Intercept(interceptors ...Interceptor) { for _, n := range []interface{ Intercept(...Interceptor) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, - c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.Setting, - c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup, - c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, + c.ErrorPassthroughRule, c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy, + c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User, + c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, + c.UserSubscription, } { n.Intercept(interceptors...) } @@ -318,6 +326,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { return c.Announcement.mutate(ctx, m) case *AnnouncementReadMutation: return c.AnnouncementRead.mutate(ctx, m) + case *ErrorPassthroughRuleMutation: + return c.ErrorPassthroughRule.mutate(ctx, m) case *GroupMutation: return c.Group.mutate(ctx, m) case *PromoCodeMutation: @@ -1161,6 +1171,139 @@ func (c *AnnouncementReadClient) mutate(ctx context.Context, m *AnnouncementRead } } +// ErrorPassthroughRuleClient is a client for the ErrorPassthroughRule schema. +type ErrorPassthroughRuleClient struct { + config +} + +// NewErrorPassthroughRuleClient returns a client for the ErrorPassthroughRule from the given config. +func NewErrorPassthroughRuleClient(c config) *ErrorPassthroughRuleClient { + return &ErrorPassthroughRuleClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `errorpassthroughrule.Hooks(f(g(h())))`. +func (c *ErrorPassthroughRuleClient) Use(hooks ...Hook) { + c.hooks.ErrorPassthroughRule = append(c.hooks.ErrorPassthroughRule, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `errorpassthroughrule.Intercept(f(g(h())))`. +func (c *ErrorPassthroughRuleClient) Intercept(interceptors ...Interceptor) { + c.inters.ErrorPassthroughRule = append(c.inters.ErrorPassthroughRule, interceptors...) +} + +// Create returns a builder for creating a ErrorPassthroughRule entity. +func (c *ErrorPassthroughRuleClient) Create() *ErrorPassthroughRuleCreate { + mutation := newErrorPassthroughRuleMutation(c.config, OpCreate) + return &ErrorPassthroughRuleCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of ErrorPassthroughRule entities. +func (c *ErrorPassthroughRuleClient) CreateBulk(builders ...*ErrorPassthroughRuleCreate) *ErrorPassthroughRuleCreateBulk { + return &ErrorPassthroughRuleCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *ErrorPassthroughRuleClient) MapCreateBulk(slice any, setFunc func(*ErrorPassthroughRuleCreate, int)) *ErrorPassthroughRuleCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &ErrorPassthroughRuleCreateBulk{err: fmt.Errorf("calling to ErrorPassthroughRuleClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*ErrorPassthroughRuleCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &ErrorPassthroughRuleCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for ErrorPassthroughRule. +func (c *ErrorPassthroughRuleClient) Update() *ErrorPassthroughRuleUpdate { + mutation := newErrorPassthroughRuleMutation(c.config, OpUpdate) + return &ErrorPassthroughRuleUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *ErrorPassthroughRuleClient) UpdateOne(_m *ErrorPassthroughRule) *ErrorPassthroughRuleUpdateOne { + mutation := newErrorPassthroughRuleMutation(c.config, OpUpdateOne, withErrorPassthroughRule(_m)) + return &ErrorPassthroughRuleUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *ErrorPassthroughRuleClient) UpdateOneID(id int64) *ErrorPassthroughRuleUpdateOne { + mutation := newErrorPassthroughRuleMutation(c.config, OpUpdateOne, withErrorPassthroughRuleID(id)) + return &ErrorPassthroughRuleUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for ErrorPassthroughRule. +func (c *ErrorPassthroughRuleClient) Delete() *ErrorPassthroughRuleDelete { + mutation := newErrorPassthroughRuleMutation(c.config, OpDelete) + return &ErrorPassthroughRuleDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *ErrorPassthroughRuleClient) DeleteOne(_m *ErrorPassthroughRule) *ErrorPassthroughRuleDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *ErrorPassthroughRuleClient) DeleteOneID(id int64) *ErrorPassthroughRuleDeleteOne { + builder := c.Delete().Where(errorpassthroughrule.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &ErrorPassthroughRuleDeleteOne{builder} +} + +// Query returns a query builder for ErrorPassthroughRule. +func (c *ErrorPassthroughRuleClient) Query() *ErrorPassthroughRuleQuery { + return &ErrorPassthroughRuleQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeErrorPassthroughRule}, + inters: c.Interceptors(), + } +} + +// Get returns a ErrorPassthroughRule entity by its id. +func (c *ErrorPassthroughRuleClient) Get(ctx context.Context, id int64) (*ErrorPassthroughRule, error) { + return c.Query().Where(errorpassthroughrule.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *ErrorPassthroughRuleClient) GetX(ctx context.Context, id int64) *ErrorPassthroughRule { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// Hooks returns the client hooks. +func (c *ErrorPassthroughRuleClient) Hooks() []Hook { + return c.hooks.ErrorPassthroughRule +} + +// Interceptors returns the client interceptors. +func (c *ErrorPassthroughRuleClient) Interceptors() []Interceptor { + return c.inters.ErrorPassthroughRule +} + +func (c *ErrorPassthroughRuleClient) mutate(ctx context.Context, m *ErrorPassthroughRuleMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&ErrorPassthroughRuleCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&ErrorPassthroughRuleUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&ErrorPassthroughRuleUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&ErrorPassthroughRuleDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown ErrorPassthroughRule mutation op: %q", m.Op()) + } +} + // GroupClient is a client for the Group schema. type GroupClient struct { config @@ -3462,16 +3605,16 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription // hooks and interceptors per client, for fast access. type ( hooks struct { - APIKey, Account, AccountGroup, Announcement, AnnouncementRead, Group, PromoCode, - PromoCodeUsage, Proxy, RedeemCode, Setting, UsageCleanupTask, UsageLog, User, - UserAllowedGroup, UserAttributeDefinition, UserAttributeValue, - UserSubscription []ent.Hook + APIKey, Account, AccountGroup, Announcement, AnnouncementRead, + ErrorPassthroughRule, Group, PromoCode, PromoCodeUsage, Proxy, RedeemCode, + Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup, + UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook } inters struct { - APIKey, Account, AccountGroup, Announcement, AnnouncementRead, Group, PromoCode, - PromoCodeUsage, Proxy, RedeemCode, Setting, UsageCleanupTask, UsageLog, User, - UserAllowedGroup, UserAttributeDefinition, UserAttributeValue, - UserSubscription []ent.Interceptor + APIKey, Account, AccountGroup, Announcement, AnnouncementRead, + ErrorPassthroughRule, Group, PromoCode, PromoCodeUsage, Proxy, RedeemCode, + Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup, + UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor } ) diff --git a/backend/ent/ent.go b/backend/ent/ent.go index 05e30ba7..5767a167 100644 --- a/backend/ent/ent.go +++ b/backend/ent/ent.go @@ -17,6 +17,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" @@ -95,6 +96,7 @@ func checkColumn(t, c string) error { accountgroup.Table: accountgroup.ValidColumn, announcement.Table: announcement.ValidColumn, announcementread.Table: announcementread.ValidColumn, + errorpassthroughrule.Table: errorpassthroughrule.ValidColumn, group.Table: group.ValidColumn, promocode.Table: promocode.ValidColumn, promocodeusage.Table: promocodeusage.ValidColumn, diff --git a/backend/ent/errorpassthroughrule.go b/backend/ent/errorpassthroughrule.go new file mode 100644 index 00000000..1932f626 --- /dev/null +++ b/backend/ent/errorpassthroughrule.go @@ -0,0 +1,269 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" +) + +// ErrorPassthroughRule is the model entity for the ErrorPassthroughRule schema. +type ErrorPassthroughRule struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // Name holds the value of the "name" field. + Name string `json:"name,omitempty"` + // Enabled holds the value of the "enabled" field. + Enabled bool `json:"enabled,omitempty"` + // Priority holds the value of the "priority" field. + Priority int `json:"priority,omitempty"` + // ErrorCodes holds the value of the "error_codes" field. + ErrorCodes []int `json:"error_codes,omitempty"` + // Keywords holds the value of the "keywords" field. + Keywords []string `json:"keywords,omitempty"` + // MatchMode holds the value of the "match_mode" field. + MatchMode string `json:"match_mode,omitempty"` + // Platforms holds the value of the "platforms" field. + Platforms []string `json:"platforms,omitempty"` + // PassthroughCode holds the value of the "passthrough_code" field. + PassthroughCode bool `json:"passthrough_code,omitempty"` + // ResponseCode holds the value of the "response_code" field. + ResponseCode *int `json:"response_code,omitempty"` + // PassthroughBody holds the value of the "passthrough_body" field. + PassthroughBody bool `json:"passthrough_body,omitempty"` + // CustomMessage holds the value of the "custom_message" field. + CustomMessage *string `json:"custom_message,omitempty"` + // Description holds the value of the "description" field. + Description *string `json:"description,omitempty"` + selectValues sql.SelectValues +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*ErrorPassthroughRule) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case errorpassthroughrule.FieldErrorCodes, errorpassthroughrule.FieldKeywords, errorpassthroughrule.FieldPlatforms: + values[i] = new([]byte) + case errorpassthroughrule.FieldEnabled, errorpassthroughrule.FieldPassthroughCode, errorpassthroughrule.FieldPassthroughBody: + values[i] = new(sql.NullBool) + case errorpassthroughrule.FieldID, errorpassthroughrule.FieldPriority, errorpassthroughrule.FieldResponseCode: + values[i] = new(sql.NullInt64) + case errorpassthroughrule.FieldName, errorpassthroughrule.FieldMatchMode, errorpassthroughrule.FieldCustomMessage, errorpassthroughrule.FieldDescription: + values[i] = new(sql.NullString) + case errorpassthroughrule.FieldCreatedAt, errorpassthroughrule.FieldUpdatedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the ErrorPassthroughRule fields. +func (_m *ErrorPassthroughRule) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case errorpassthroughrule.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case errorpassthroughrule.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case errorpassthroughrule.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case errorpassthroughrule.FieldName: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[i]) + } else if value.Valid { + _m.Name = value.String + } + case errorpassthroughrule.FieldEnabled: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field enabled", values[i]) + } else if value.Valid { + _m.Enabled = value.Bool + } + case errorpassthroughrule.FieldPriority: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field priority", values[i]) + } else if value.Valid { + _m.Priority = int(value.Int64) + } + case errorpassthroughrule.FieldErrorCodes: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field error_codes", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.ErrorCodes); err != nil { + return fmt.Errorf("unmarshal field error_codes: %w", err) + } + } + case errorpassthroughrule.FieldKeywords: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field keywords", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.Keywords); err != nil { + return fmt.Errorf("unmarshal field keywords: %w", err) + } + } + case errorpassthroughrule.FieldMatchMode: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field match_mode", values[i]) + } else if value.Valid { + _m.MatchMode = value.String + } + case errorpassthroughrule.FieldPlatforms: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field platforms", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.Platforms); err != nil { + return fmt.Errorf("unmarshal field platforms: %w", err) + } + } + case errorpassthroughrule.FieldPassthroughCode: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field passthrough_code", values[i]) + } else if value.Valid { + _m.PassthroughCode = value.Bool + } + case errorpassthroughrule.FieldResponseCode: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field response_code", values[i]) + } else if value.Valid { + _m.ResponseCode = new(int) + *_m.ResponseCode = int(value.Int64) + } + case errorpassthroughrule.FieldPassthroughBody: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field passthrough_body", values[i]) + } else if value.Valid { + _m.PassthroughBody = value.Bool + } + case errorpassthroughrule.FieldCustomMessage: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field custom_message", values[i]) + } else if value.Valid { + _m.CustomMessage = new(string) + *_m.CustomMessage = value.String + } + case errorpassthroughrule.FieldDescription: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field description", values[i]) + } else if value.Valid { + _m.Description = new(string) + *_m.Description = value.String + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the ErrorPassthroughRule. +// This includes values selected through modifiers, order, etc. +func (_m *ErrorPassthroughRule) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// Update returns a builder for updating this ErrorPassthroughRule. +// Note that you need to call ErrorPassthroughRule.Unwrap() before calling this method if this ErrorPassthroughRule +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *ErrorPassthroughRule) Update() *ErrorPassthroughRuleUpdateOne { + return NewErrorPassthroughRuleClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the ErrorPassthroughRule entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *ErrorPassthroughRule) Unwrap() *ErrorPassthroughRule { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: ErrorPassthroughRule is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *ErrorPassthroughRule) String() string { + var builder strings.Builder + builder.WriteString("ErrorPassthroughRule(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("name=") + builder.WriteString(_m.Name) + builder.WriteString(", ") + builder.WriteString("enabled=") + builder.WriteString(fmt.Sprintf("%v", _m.Enabled)) + builder.WriteString(", ") + builder.WriteString("priority=") + builder.WriteString(fmt.Sprintf("%v", _m.Priority)) + builder.WriteString(", ") + builder.WriteString("error_codes=") + builder.WriteString(fmt.Sprintf("%v", _m.ErrorCodes)) + builder.WriteString(", ") + builder.WriteString("keywords=") + builder.WriteString(fmt.Sprintf("%v", _m.Keywords)) + builder.WriteString(", ") + builder.WriteString("match_mode=") + builder.WriteString(_m.MatchMode) + builder.WriteString(", ") + builder.WriteString("platforms=") + builder.WriteString(fmt.Sprintf("%v", _m.Platforms)) + builder.WriteString(", ") + builder.WriteString("passthrough_code=") + builder.WriteString(fmt.Sprintf("%v", _m.PassthroughCode)) + builder.WriteString(", ") + if v := _m.ResponseCode; v != nil { + builder.WriteString("response_code=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + builder.WriteString("passthrough_body=") + builder.WriteString(fmt.Sprintf("%v", _m.PassthroughBody)) + builder.WriteString(", ") + if v := _m.CustomMessage; v != nil { + builder.WriteString("custom_message=") + builder.WriteString(*v) + } + builder.WriteString(", ") + if v := _m.Description; v != nil { + builder.WriteString("description=") + builder.WriteString(*v) + } + builder.WriteByte(')') + return builder.String() +} + +// ErrorPassthroughRules is a parsable slice of ErrorPassthroughRule. +type ErrorPassthroughRules []*ErrorPassthroughRule diff --git a/backend/ent/errorpassthroughrule/errorpassthroughrule.go b/backend/ent/errorpassthroughrule/errorpassthroughrule.go new file mode 100644 index 00000000..d7be4f03 --- /dev/null +++ b/backend/ent/errorpassthroughrule/errorpassthroughrule.go @@ -0,0 +1,161 @@ +// Code generated by ent, DO NOT EDIT. + +package errorpassthroughrule + +import ( + "time" + + "entgo.io/ent/dialect/sql" +) + +const ( + // Label holds the string label denoting the errorpassthroughrule type in the database. + Label = "error_passthrough_rule" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldName holds the string denoting the name field in the database. + FieldName = "name" + // FieldEnabled holds the string denoting the enabled field in the database. + FieldEnabled = "enabled" + // FieldPriority holds the string denoting the priority field in the database. + FieldPriority = "priority" + // FieldErrorCodes holds the string denoting the error_codes field in the database. + FieldErrorCodes = "error_codes" + // FieldKeywords holds the string denoting the keywords field in the database. + FieldKeywords = "keywords" + // FieldMatchMode holds the string denoting the match_mode field in the database. + FieldMatchMode = "match_mode" + // FieldPlatforms holds the string denoting the platforms field in the database. + FieldPlatforms = "platforms" + // FieldPassthroughCode holds the string denoting the passthrough_code field in the database. + FieldPassthroughCode = "passthrough_code" + // FieldResponseCode holds the string denoting the response_code field in the database. + FieldResponseCode = "response_code" + // FieldPassthroughBody holds the string denoting the passthrough_body field in the database. + FieldPassthroughBody = "passthrough_body" + // FieldCustomMessage holds the string denoting the custom_message field in the database. + FieldCustomMessage = "custom_message" + // FieldDescription holds the string denoting the description field in the database. + FieldDescription = "description" + // Table holds the table name of the errorpassthroughrule in the database. + Table = "error_passthrough_rules" +) + +// Columns holds all SQL columns for errorpassthroughrule fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldName, + FieldEnabled, + FieldPriority, + FieldErrorCodes, + FieldKeywords, + FieldMatchMode, + FieldPlatforms, + FieldPassthroughCode, + FieldResponseCode, + FieldPassthroughBody, + FieldCustomMessage, + FieldDescription, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // NameValidator is a validator for the "name" field. It is called by the builders before save. + NameValidator func(string) error + // DefaultEnabled holds the default value on creation for the "enabled" field. + DefaultEnabled bool + // DefaultPriority holds the default value on creation for the "priority" field. + DefaultPriority int + // DefaultMatchMode holds the default value on creation for the "match_mode" field. + DefaultMatchMode string + // MatchModeValidator is a validator for the "match_mode" field. It is called by the builders before save. + MatchModeValidator func(string) error + // DefaultPassthroughCode holds the default value on creation for the "passthrough_code" field. + DefaultPassthroughCode bool + // DefaultPassthroughBody holds the default value on creation for the "passthrough_body" field. + DefaultPassthroughBody bool +) + +// OrderOption defines the ordering options for the ErrorPassthroughRule queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByName orders the results by the name field. +func ByName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldName, opts...).ToFunc() +} + +// ByEnabled orders the results by the enabled field. +func ByEnabled(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldEnabled, opts...).ToFunc() +} + +// ByPriority orders the results by the priority field. +func ByPriority(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPriority, opts...).ToFunc() +} + +// ByMatchMode orders the results by the match_mode field. +func ByMatchMode(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldMatchMode, opts...).ToFunc() +} + +// ByPassthroughCode orders the results by the passthrough_code field. +func ByPassthroughCode(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPassthroughCode, opts...).ToFunc() +} + +// ByResponseCode orders the results by the response_code field. +func ByResponseCode(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldResponseCode, opts...).ToFunc() +} + +// ByPassthroughBody orders the results by the passthrough_body field. +func ByPassthroughBody(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPassthroughBody, opts...).ToFunc() +} + +// ByCustomMessage orders the results by the custom_message field. +func ByCustomMessage(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCustomMessage, opts...).ToFunc() +} + +// ByDescription orders the results by the description field. +func ByDescription(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDescription, opts...).ToFunc() +} diff --git a/backend/ent/errorpassthroughrule/where.go b/backend/ent/errorpassthroughrule/where.go new file mode 100644 index 00000000..56839d52 --- /dev/null +++ b/backend/ent/errorpassthroughrule/where.go @@ -0,0 +1,635 @@ +// Code generated by ent, DO NOT EDIT. + +package errorpassthroughrule + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// Name applies equality check predicate on the "name" field. It's identical to NameEQ. +func Name(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldName, v)) +} + +// Enabled applies equality check predicate on the "enabled" field. It's identical to EnabledEQ. +func Enabled(v bool) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldEnabled, v)) +} + +// Priority applies equality check predicate on the "priority" field. It's identical to PriorityEQ. +func Priority(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldPriority, v)) +} + +// MatchMode applies equality check predicate on the "match_mode" field. It's identical to MatchModeEQ. +func MatchMode(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldMatchMode, v)) +} + +// PassthroughCode applies equality check predicate on the "passthrough_code" field. It's identical to PassthroughCodeEQ. +func PassthroughCode(v bool) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldPassthroughCode, v)) +} + +// ResponseCode applies equality check predicate on the "response_code" field. It's identical to ResponseCodeEQ. +func ResponseCode(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldResponseCode, v)) +} + +// PassthroughBody applies equality check predicate on the "passthrough_body" field. It's identical to PassthroughBodyEQ. +func PassthroughBody(v bool) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldPassthroughBody, v)) +} + +// CustomMessage applies equality check predicate on the "custom_message" field. It's identical to CustomMessageEQ. +func CustomMessage(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldCustomMessage, v)) +} + +// Description applies equality check predicate on the "description" field. It's identical to DescriptionEQ. +func Description(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldDescription, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// NameEQ applies the EQ predicate on the "name" field. +func NameEQ(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldName, v)) +} + +// NameNEQ applies the NEQ predicate on the "name" field. +func NameNEQ(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldName, v)) +} + +// NameIn applies the In predicate on the "name" field. +func NameIn(vs ...string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIn(FieldName, vs...)) +} + +// NameNotIn applies the NotIn predicate on the "name" field. +func NameNotIn(vs ...string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldName, vs...)) +} + +// NameGT applies the GT predicate on the "name" field. +func NameGT(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGT(FieldName, v)) +} + +// NameGTE applies the GTE predicate on the "name" field. +func NameGTE(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldName, v)) +} + +// NameLT applies the LT predicate on the "name" field. +func NameLT(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLT(FieldName, v)) +} + +// NameLTE applies the LTE predicate on the "name" field. +func NameLTE(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldName, v)) +} + +// NameContains applies the Contains predicate on the "name" field. +func NameContains(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldContains(FieldName, v)) +} + +// NameHasPrefix applies the HasPrefix predicate on the "name" field. +func NameHasPrefix(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldHasPrefix(FieldName, v)) +} + +// NameHasSuffix applies the HasSuffix predicate on the "name" field. +func NameHasSuffix(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldHasSuffix(FieldName, v)) +} + +// NameEqualFold applies the EqualFold predicate on the "name" field. +func NameEqualFold(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEqualFold(FieldName, v)) +} + +// NameContainsFold applies the ContainsFold predicate on the "name" field. +func NameContainsFold(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldContainsFold(FieldName, v)) +} + +// EnabledEQ applies the EQ predicate on the "enabled" field. +func EnabledEQ(v bool) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldEnabled, v)) +} + +// EnabledNEQ applies the NEQ predicate on the "enabled" field. +func EnabledNEQ(v bool) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldEnabled, v)) +} + +// PriorityEQ applies the EQ predicate on the "priority" field. +func PriorityEQ(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldPriority, v)) +} + +// PriorityNEQ applies the NEQ predicate on the "priority" field. +func PriorityNEQ(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldPriority, v)) +} + +// PriorityIn applies the In predicate on the "priority" field. +func PriorityIn(vs ...int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIn(FieldPriority, vs...)) +} + +// PriorityNotIn applies the NotIn predicate on the "priority" field. +func PriorityNotIn(vs ...int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldPriority, vs...)) +} + +// PriorityGT applies the GT predicate on the "priority" field. +func PriorityGT(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGT(FieldPriority, v)) +} + +// PriorityGTE applies the GTE predicate on the "priority" field. +func PriorityGTE(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldPriority, v)) +} + +// PriorityLT applies the LT predicate on the "priority" field. +func PriorityLT(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLT(FieldPriority, v)) +} + +// PriorityLTE applies the LTE predicate on the "priority" field. +func PriorityLTE(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldPriority, v)) +} + +// ErrorCodesIsNil applies the IsNil predicate on the "error_codes" field. +func ErrorCodesIsNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIsNull(FieldErrorCodes)) +} + +// ErrorCodesNotNil applies the NotNil predicate on the "error_codes" field. +func ErrorCodesNotNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotNull(FieldErrorCodes)) +} + +// KeywordsIsNil applies the IsNil predicate on the "keywords" field. +func KeywordsIsNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIsNull(FieldKeywords)) +} + +// KeywordsNotNil applies the NotNil predicate on the "keywords" field. +func KeywordsNotNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotNull(FieldKeywords)) +} + +// MatchModeEQ applies the EQ predicate on the "match_mode" field. +func MatchModeEQ(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldMatchMode, v)) +} + +// MatchModeNEQ applies the NEQ predicate on the "match_mode" field. +func MatchModeNEQ(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldMatchMode, v)) +} + +// MatchModeIn applies the In predicate on the "match_mode" field. +func MatchModeIn(vs ...string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIn(FieldMatchMode, vs...)) +} + +// MatchModeNotIn applies the NotIn predicate on the "match_mode" field. +func MatchModeNotIn(vs ...string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldMatchMode, vs...)) +} + +// MatchModeGT applies the GT predicate on the "match_mode" field. +func MatchModeGT(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGT(FieldMatchMode, v)) +} + +// MatchModeGTE applies the GTE predicate on the "match_mode" field. +func MatchModeGTE(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldMatchMode, v)) +} + +// MatchModeLT applies the LT predicate on the "match_mode" field. +func MatchModeLT(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLT(FieldMatchMode, v)) +} + +// MatchModeLTE applies the LTE predicate on the "match_mode" field. +func MatchModeLTE(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldMatchMode, v)) +} + +// MatchModeContains applies the Contains predicate on the "match_mode" field. +func MatchModeContains(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldContains(FieldMatchMode, v)) +} + +// MatchModeHasPrefix applies the HasPrefix predicate on the "match_mode" field. +func MatchModeHasPrefix(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldHasPrefix(FieldMatchMode, v)) +} + +// MatchModeHasSuffix applies the HasSuffix predicate on the "match_mode" field. +func MatchModeHasSuffix(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldHasSuffix(FieldMatchMode, v)) +} + +// MatchModeEqualFold applies the EqualFold predicate on the "match_mode" field. +func MatchModeEqualFold(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEqualFold(FieldMatchMode, v)) +} + +// MatchModeContainsFold applies the ContainsFold predicate on the "match_mode" field. +func MatchModeContainsFold(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldContainsFold(FieldMatchMode, v)) +} + +// PlatformsIsNil applies the IsNil predicate on the "platforms" field. +func PlatformsIsNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIsNull(FieldPlatforms)) +} + +// PlatformsNotNil applies the NotNil predicate on the "platforms" field. +func PlatformsNotNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotNull(FieldPlatforms)) +} + +// PassthroughCodeEQ applies the EQ predicate on the "passthrough_code" field. +func PassthroughCodeEQ(v bool) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldPassthroughCode, v)) +} + +// PassthroughCodeNEQ applies the NEQ predicate on the "passthrough_code" field. +func PassthroughCodeNEQ(v bool) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldPassthroughCode, v)) +} + +// ResponseCodeEQ applies the EQ predicate on the "response_code" field. +func ResponseCodeEQ(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldResponseCode, v)) +} + +// ResponseCodeNEQ applies the NEQ predicate on the "response_code" field. +func ResponseCodeNEQ(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldResponseCode, v)) +} + +// ResponseCodeIn applies the In predicate on the "response_code" field. +func ResponseCodeIn(vs ...int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIn(FieldResponseCode, vs...)) +} + +// ResponseCodeNotIn applies the NotIn predicate on the "response_code" field. +func ResponseCodeNotIn(vs ...int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldResponseCode, vs...)) +} + +// ResponseCodeGT applies the GT predicate on the "response_code" field. +func ResponseCodeGT(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGT(FieldResponseCode, v)) +} + +// ResponseCodeGTE applies the GTE predicate on the "response_code" field. +func ResponseCodeGTE(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldResponseCode, v)) +} + +// ResponseCodeLT applies the LT predicate on the "response_code" field. +func ResponseCodeLT(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLT(FieldResponseCode, v)) +} + +// ResponseCodeLTE applies the LTE predicate on the "response_code" field. +func ResponseCodeLTE(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldResponseCode, v)) +} + +// ResponseCodeIsNil applies the IsNil predicate on the "response_code" field. +func ResponseCodeIsNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIsNull(FieldResponseCode)) +} + +// ResponseCodeNotNil applies the NotNil predicate on the "response_code" field. +func ResponseCodeNotNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotNull(FieldResponseCode)) +} + +// PassthroughBodyEQ applies the EQ predicate on the "passthrough_body" field. +func PassthroughBodyEQ(v bool) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldPassthroughBody, v)) +} + +// PassthroughBodyNEQ applies the NEQ predicate on the "passthrough_body" field. +func PassthroughBodyNEQ(v bool) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldPassthroughBody, v)) +} + +// CustomMessageEQ applies the EQ predicate on the "custom_message" field. +func CustomMessageEQ(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldCustomMessage, v)) +} + +// CustomMessageNEQ applies the NEQ predicate on the "custom_message" field. +func CustomMessageNEQ(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldCustomMessage, v)) +} + +// CustomMessageIn applies the In predicate on the "custom_message" field. +func CustomMessageIn(vs ...string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIn(FieldCustomMessage, vs...)) +} + +// CustomMessageNotIn applies the NotIn predicate on the "custom_message" field. +func CustomMessageNotIn(vs ...string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldCustomMessage, vs...)) +} + +// CustomMessageGT applies the GT predicate on the "custom_message" field. +func CustomMessageGT(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGT(FieldCustomMessage, v)) +} + +// CustomMessageGTE applies the GTE predicate on the "custom_message" field. +func CustomMessageGTE(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldCustomMessage, v)) +} + +// CustomMessageLT applies the LT predicate on the "custom_message" field. +func CustomMessageLT(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLT(FieldCustomMessage, v)) +} + +// CustomMessageLTE applies the LTE predicate on the "custom_message" field. +func CustomMessageLTE(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldCustomMessage, v)) +} + +// CustomMessageContains applies the Contains predicate on the "custom_message" field. +func CustomMessageContains(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldContains(FieldCustomMessage, v)) +} + +// CustomMessageHasPrefix applies the HasPrefix predicate on the "custom_message" field. +func CustomMessageHasPrefix(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldHasPrefix(FieldCustomMessage, v)) +} + +// CustomMessageHasSuffix applies the HasSuffix predicate on the "custom_message" field. +func CustomMessageHasSuffix(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldHasSuffix(FieldCustomMessage, v)) +} + +// CustomMessageIsNil applies the IsNil predicate on the "custom_message" field. +func CustomMessageIsNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIsNull(FieldCustomMessage)) +} + +// CustomMessageNotNil applies the NotNil predicate on the "custom_message" field. +func CustomMessageNotNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotNull(FieldCustomMessage)) +} + +// CustomMessageEqualFold applies the EqualFold predicate on the "custom_message" field. +func CustomMessageEqualFold(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEqualFold(FieldCustomMessage, v)) +} + +// CustomMessageContainsFold applies the ContainsFold predicate on the "custom_message" field. +func CustomMessageContainsFold(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldContainsFold(FieldCustomMessage, v)) +} + +// DescriptionEQ applies the EQ predicate on the "description" field. +func DescriptionEQ(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldDescription, v)) +} + +// DescriptionNEQ applies the NEQ predicate on the "description" field. +func DescriptionNEQ(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldDescription, v)) +} + +// DescriptionIn applies the In predicate on the "description" field. +func DescriptionIn(vs ...string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIn(FieldDescription, vs...)) +} + +// DescriptionNotIn applies the NotIn predicate on the "description" field. +func DescriptionNotIn(vs ...string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldDescription, vs...)) +} + +// DescriptionGT applies the GT predicate on the "description" field. +func DescriptionGT(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGT(FieldDescription, v)) +} + +// DescriptionGTE applies the GTE predicate on the "description" field. +func DescriptionGTE(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldDescription, v)) +} + +// DescriptionLT applies the LT predicate on the "description" field. +func DescriptionLT(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLT(FieldDescription, v)) +} + +// DescriptionLTE applies the LTE predicate on the "description" field. +func DescriptionLTE(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldDescription, v)) +} + +// DescriptionContains applies the Contains predicate on the "description" field. +func DescriptionContains(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldContains(FieldDescription, v)) +} + +// DescriptionHasPrefix applies the HasPrefix predicate on the "description" field. +func DescriptionHasPrefix(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldHasPrefix(FieldDescription, v)) +} + +// DescriptionHasSuffix applies the HasSuffix predicate on the "description" field. +func DescriptionHasSuffix(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldHasSuffix(FieldDescription, v)) +} + +// DescriptionIsNil applies the IsNil predicate on the "description" field. +func DescriptionIsNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIsNull(FieldDescription)) +} + +// DescriptionNotNil applies the NotNil predicate on the "description" field. +func DescriptionNotNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotNull(FieldDescription)) +} + +// DescriptionEqualFold applies the EqualFold predicate on the "description" field. +func DescriptionEqualFold(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEqualFold(FieldDescription, v)) +} + +// DescriptionContainsFold applies the ContainsFold predicate on the "description" field. +func DescriptionContainsFold(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldContainsFold(FieldDescription, v)) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.ErrorPassthroughRule) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.ErrorPassthroughRule) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.ErrorPassthroughRule) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.NotPredicates(p)) +} diff --git a/backend/ent/errorpassthroughrule_create.go b/backend/ent/errorpassthroughrule_create.go new file mode 100644 index 00000000..4dc08dce --- /dev/null +++ b/backend/ent/errorpassthroughrule_create.go @@ -0,0 +1,1382 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" +) + +// ErrorPassthroughRuleCreate is the builder for creating a ErrorPassthroughRule entity. +type ErrorPassthroughRuleCreate struct { + config + mutation *ErrorPassthroughRuleMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *ErrorPassthroughRuleCreate) SetCreatedAt(v time.Time) *ErrorPassthroughRuleCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *ErrorPassthroughRuleCreate) SetNillableCreatedAt(v *time.Time) *ErrorPassthroughRuleCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *ErrorPassthroughRuleCreate) SetUpdatedAt(v time.Time) *ErrorPassthroughRuleCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *ErrorPassthroughRuleCreate) SetNillableUpdatedAt(v *time.Time) *ErrorPassthroughRuleCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetName sets the "name" field. +func (_c *ErrorPassthroughRuleCreate) SetName(v string) *ErrorPassthroughRuleCreate { + _c.mutation.SetName(v) + return _c +} + +// SetEnabled sets the "enabled" field. +func (_c *ErrorPassthroughRuleCreate) SetEnabled(v bool) *ErrorPassthroughRuleCreate { + _c.mutation.SetEnabled(v) + return _c +} + +// SetNillableEnabled sets the "enabled" field if the given value is not nil. +func (_c *ErrorPassthroughRuleCreate) SetNillableEnabled(v *bool) *ErrorPassthroughRuleCreate { + if v != nil { + _c.SetEnabled(*v) + } + return _c +} + +// SetPriority sets the "priority" field. +func (_c *ErrorPassthroughRuleCreate) SetPriority(v int) *ErrorPassthroughRuleCreate { + _c.mutation.SetPriority(v) + return _c +} + +// SetNillablePriority sets the "priority" field if the given value is not nil. +func (_c *ErrorPassthroughRuleCreate) SetNillablePriority(v *int) *ErrorPassthroughRuleCreate { + if v != nil { + _c.SetPriority(*v) + } + return _c +} + +// SetErrorCodes sets the "error_codes" field. +func (_c *ErrorPassthroughRuleCreate) SetErrorCodes(v []int) *ErrorPassthroughRuleCreate { + _c.mutation.SetErrorCodes(v) + return _c +} + +// SetKeywords sets the "keywords" field. +func (_c *ErrorPassthroughRuleCreate) SetKeywords(v []string) *ErrorPassthroughRuleCreate { + _c.mutation.SetKeywords(v) + return _c +} + +// SetMatchMode sets the "match_mode" field. +func (_c *ErrorPassthroughRuleCreate) SetMatchMode(v string) *ErrorPassthroughRuleCreate { + _c.mutation.SetMatchMode(v) + return _c +} + +// SetNillableMatchMode sets the "match_mode" field if the given value is not nil. +func (_c *ErrorPassthroughRuleCreate) SetNillableMatchMode(v *string) *ErrorPassthroughRuleCreate { + if v != nil { + _c.SetMatchMode(*v) + } + return _c +} + +// SetPlatforms sets the "platforms" field. +func (_c *ErrorPassthroughRuleCreate) SetPlatforms(v []string) *ErrorPassthroughRuleCreate { + _c.mutation.SetPlatforms(v) + return _c +} + +// SetPassthroughCode sets the "passthrough_code" field. +func (_c *ErrorPassthroughRuleCreate) SetPassthroughCode(v bool) *ErrorPassthroughRuleCreate { + _c.mutation.SetPassthroughCode(v) + return _c +} + +// SetNillablePassthroughCode sets the "passthrough_code" field if the given value is not nil. +func (_c *ErrorPassthroughRuleCreate) SetNillablePassthroughCode(v *bool) *ErrorPassthroughRuleCreate { + if v != nil { + _c.SetPassthroughCode(*v) + } + return _c +} + +// SetResponseCode sets the "response_code" field. +func (_c *ErrorPassthroughRuleCreate) SetResponseCode(v int) *ErrorPassthroughRuleCreate { + _c.mutation.SetResponseCode(v) + return _c +} + +// SetNillableResponseCode sets the "response_code" field if the given value is not nil. +func (_c *ErrorPassthroughRuleCreate) SetNillableResponseCode(v *int) *ErrorPassthroughRuleCreate { + if v != nil { + _c.SetResponseCode(*v) + } + return _c +} + +// SetPassthroughBody sets the "passthrough_body" field. +func (_c *ErrorPassthroughRuleCreate) SetPassthroughBody(v bool) *ErrorPassthroughRuleCreate { + _c.mutation.SetPassthroughBody(v) + return _c +} + +// SetNillablePassthroughBody sets the "passthrough_body" field if the given value is not nil. +func (_c *ErrorPassthroughRuleCreate) SetNillablePassthroughBody(v *bool) *ErrorPassthroughRuleCreate { + if v != nil { + _c.SetPassthroughBody(*v) + } + return _c +} + +// SetCustomMessage sets the "custom_message" field. +func (_c *ErrorPassthroughRuleCreate) SetCustomMessage(v string) *ErrorPassthroughRuleCreate { + _c.mutation.SetCustomMessage(v) + return _c +} + +// SetNillableCustomMessage sets the "custom_message" field if the given value is not nil. +func (_c *ErrorPassthroughRuleCreate) SetNillableCustomMessage(v *string) *ErrorPassthroughRuleCreate { + if v != nil { + _c.SetCustomMessage(*v) + } + return _c +} + +// SetDescription sets the "description" field. +func (_c *ErrorPassthroughRuleCreate) SetDescription(v string) *ErrorPassthroughRuleCreate { + _c.mutation.SetDescription(v) + return _c +} + +// SetNillableDescription sets the "description" field if the given value is not nil. +func (_c *ErrorPassthroughRuleCreate) SetNillableDescription(v *string) *ErrorPassthroughRuleCreate { + if v != nil { + _c.SetDescription(*v) + } + return _c +} + +// Mutation returns the ErrorPassthroughRuleMutation object of the builder. +func (_c *ErrorPassthroughRuleCreate) Mutation() *ErrorPassthroughRuleMutation { + return _c.mutation +} + +// Save creates the ErrorPassthroughRule in the database. +func (_c *ErrorPassthroughRuleCreate) Save(ctx context.Context) (*ErrorPassthroughRule, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *ErrorPassthroughRuleCreate) SaveX(ctx context.Context) *ErrorPassthroughRule { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *ErrorPassthroughRuleCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *ErrorPassthroughRuleCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *ErrorPassthroughRuleCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := errorpassthroughrule.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := errorpassthroughrule.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } + if _, ok := _c.mutation.Enabled(); !ok { + v := errorpassthroughrule.DefaultEnabled + _c.mutation.SetEnabled(v) + } + if _, ok := _c.mutation.Priority(); !ok { + v := errorpassthroughrule.DefaultPriority + _c.mutation.SetPriority(v) + } + if _, ok := _c.mutation.MatchMode(); !ok { + v := errorpassthroughrule.DefaultMatchMode + _c.mutation.SetMatchMode(v) + } + if _, ok := _c.mutation.PassthroughCode(); !ok { + v := errorpassthroughrule.DefaultPassthroughCode + _c.mutation.SetPassthroughCode(v) + } + if _, ok := _c.mutation.PassthroughBody(); !ok { + v := errorpassthroughrule.DefaultPassthroughBody + _c.mutation.SetPassthroughBody(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *ErrorPassthroughRuleCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "ErrorPassthroughRule.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "ErrorPassthroughRule.updated_at"`)} + } + if _, ok := _c.mutation.Name(); !ok { + return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "ErrorPassthroughRule.name"`)} + } + if v, ok := _c.mutation.Name(); ok { + if err := errorpassthroughrule.NameValidator(v); err != nil { + return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ErrorPassthroughRule.name": %w`, err)} + } + } + if _, ok := _c.mutation.Enabled(); !ok { + return &ValidationError{Name: "enabled", err: errors.New(`ent: missing required field "ErrorPassthroughRule.enabled"`)} + } + if _, ok := _c.mutation.Priority(); !ok { + return &ValidationError{Name: "priority", err: errors.New(`ent: missing required field "ErrorPassthroughRule.priority"`)} + } + if _, ok := _c.mutation.MatchMode(); !ok { + return &ValidationError{Name: "match_mode", err: errors.New(`ent: missing required field "ErrorPassthroughRule.match_mode"`)} + } + if v, ok := _c.mutation.MatchMode(); ok { + if err := errorpassthroughrule.MatchModeValidator(v); err != nil { + return &ValidationError{Name: "match_mode", err: fmt.Errorf(`ent: validator failed for field "ErrorPassthroughRule.match_mode": %w`, err)} + } + } + if _, ok := _c.mutation.PassthroughCode(); !ok { + return &ValidationError{Name: "passthrough_code", err: errors.New(`ent: missing required field "ErrorPassthroughRule.passthrough_code"`)} + } + if _, ok := _c.mutation.PassthroughBody(); !ok { + return &ValidationError{Name: "passthrough_body", err: errors.New(`ent: missing required field "ErrorPassthroughRule.passthrough_body"`)} + } + return nil +} + +func (_c *ErrorPassthroughRuleCreate) sqlSave(ctx context.Context) (*ErrorPassthroughRule, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *ErrorPassthroughRuleCreate) createSpec() (*ErrorPassthroughRule, *sqlgraph.CreateSpec) { + var ( + _node = &ErrorPassthroughRule{config: _c.config} + _spec = sqlgraph.NewCreateSpec(errorpassthroughrule.Table, sqlgraph.NewFieldSpec(errorpassthroughrule.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(errorpassthroughrule.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(errorpassthroughrule.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.Name(); ok { + _spec.SetField(errorpassthroughrule.FieldName, field.TypeString, value) + _node.Name = value + } + if value, ok := _c.mutation.Enabled(); ok { + _spec.SetField(errorpassthroughrule.FieldEnabled, field.TypeBool, value) + _node.Enabled = value + } + if value, ok := _c.mutation.Priority(); ok { + _spec.SetField(errorpassthroughrule.FieldPriority, field.TypeInt, value) + _node.Priority = value + } + if value, ok := _c.mutation.ErrorCodes(); ok { + _spec.SetField(errorpassthroughrule.FieldErrorCodes, field.TypeJSON, value) + _node.ErrorCodes = value + } + if value, ok := _c.mutation.Keywords(); ok { + _spec.SetField(errorpassthroughrule.FieldKeywords, field.TypeJSON, value) + _node.Keywords = value + } + if value, ok := _c.mutation.MatchMode(); ok { + _spec.SetField(errorpassthroughrule.FieldMatchMode, field.TypeString, value) + _node.MatchMode = value + } + if value, ok := _c.mutation.Platforms(); ok { + _spec.SetField(errorpassthroughrule.FieldPlatforms, field.TypeJSON, value) + _node.Platforms = value + } + if value, ok := _c.mutation.PassthroughCode(); ok { + _spec.SetField(errorpassthroughrule.FieldPassthroughCode, field.TypeBool, value) + _node.PassthroughCode = value + } + if value, ok := _c.mutation.ResponseCode(); ok { + _spec.SetField(errorpassthroughrule.FieldResponseCode, field.TypeInt, value) + _node.ResponseCode = &value + } + if value, ok := _c.mutation.PassthroughBody(); ok { + _spec.SetField(errorpassthroughrule.FieldPassthroughBody, field.TypeBool, value) + _node.PassthroughBody = value + } + if value, ok := _c.mutation.CustomMessage(); ok { + _spec.SetField(errorpassthroughrule.FieldCustomMessage, field.TypeString, value) + _node.CustomMessage = &value + } + if value, ok := _c.mutation.Description(); ok { + _spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value) + _node.Description = &value + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.ErrorPassthroughRule.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.ErrorPassthroughRuleUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *ErrorPassthroughRuleCreate) OnConflict(opts ...sql.ConflictOption) *ErrorPassthroughRuleUpsertOne { + _c.conflict = opts + return &ErrorPassthroughRuleUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.ErrorPassthroughRule.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *ErrorPassthroughRuleCreate) OnConflictColumns(columns ...string) *ErrorPassthroughRuleUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &ErrorPassthroughRuleUpsertOne{ + create: _c, + } +} + +type ( + // ErrorPassthroughRuleUpsertOne is the builder for "upsert"-ing + // one ErrorPassthroughRule node. + ErrorPassthroughRuleUpsertOne struct { + create *ErrorPassthroughRuleCreate + } + + // ErrorPassthroughRuleUpsert is the "OnConflict" setter. + ErrorPassthroughRuleUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *ErrorPassthroughRuleUpsert) SetUpdatedAt(v time.Time) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdateUpdatedAt() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldUpdatedAt) + return u +} + +// SetName sets the "name" field. +func (u *ErrorPassthroughRuleUpsert) SetName(v string) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldName, v) + return u +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdateName() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldName) + return u +} + +// SetEnabled sets the "enabled" field. +func (u *ErrorPassthroughRuleUpsert) SetEnabled(v bool) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldEnabled, v) + return u +} + +// UpdateEnabled sets the "enabled" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdateEnabled() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldEnabled) + return u +} + +// SetPriority sets the "priority" field. +func (u *ErrorPassthroughRuleUpsert) SetPriority(v int) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldPriority, v) + return u +} + +// UpdatePriority sets the "priority" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdatePriority() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldPriority) + return u +} + +// AddPriority adds v to the "priority" field. +func (u *ErrorPassthroughRuleUpsert) AddPriority(v int) *ErrorPassthroughRuleUpsert { + u.Add(errorpassthroughrule.FieldPriority, v) + return u +} + +// SetErrorCodes sets the "error_codes" field. +func (u *ErrorPassthroughRuleUpsert) SetErrorCodes(v []int) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldErrorCodes, v) + return u +} + +// UpdateErrorCodes sets the "error_codes" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdateErrorCodes() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldErrorCodes) + return u +} + +// ClearErrorCodes clears the value of the "error_codes" field. +func (u *ErrorPassthroughRuleUpsert) ClearErrorCodes() *ErrorPassthroughRuleUpsert { + u.SetNull(errorpassthroughrule.FieldErrorCodes) + return u +} + +// SetKeywords sets the "keywords" field. +func (u *ErrorPassthroughRuleUpsert) SetKeywords(v []string) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldKeywords, v) + return u +} + +// UpdateKeywords sets the "keywords" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdateKeywords() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldKeywords) + return u +} + +// ClearKeywords clears the value of the "keywords" field. +func (u *ErrorPassthroughRuleUpsert) ClearKeywords() *ErrorPassthroughRuleUpsert { + u.SetNull(errorpassthroughrule.FieldKeywords) + return u +} + +// SetMatchMode sets the "match_mode" field. +func (u *ErrorPassthroughRuleUpsert) SetMatchMode(v string) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldMatchMode, v) + return u +} + +// UpdateMatchMode sets the "match_mode" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdateMatchMode() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldMatchMode) + return u +} + +// SetPlatforms sets the "platforms" field. +func (u *ErrorPassthroughRuleUpsert) SetPlatforms(v []string) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldPlatforms, v) + return u +} + +// UpdatePlatforms sets the "platforms" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdatePlatforms() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldPlatforms) + return u +} + +// ClearPlatforms clears the value of the "platforms" field. +func (u *ErrorPassthroughRuleUpsert) ClearPlatforms() *ErrorPassthroughRuleUpsert { + u.SetNull(errorpassthroughrule.FieldPlatforms) + return u +} + +// SetPassthroughCode sets the "passthrough_code" field. +func (u *ErrorPassthroughRuleUpsert) SetPassthroughCode(v bool) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldPassthroughCode, v) + return u +} + +// UpdatePassthroughCode sets the "passthrough_code" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdatePassthroughCode() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldPassthroughCode) + return u +} + +// SetResponseCode sets the "response_code" field. +func (u *ErrorPassthroughRuleUpsert) SetResponseCode(v int) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldResponseCode, v) + return u +} + +// UpdateResponseCode sets the "response_code" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdateResponseCode() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldResponseCode) + return u +} + +// AddResponseCode adds v to the "response_code" field. +func (u *ErrorPassthroughRuleUpsert) AddResponseCode(v int) *ErrorPassthroughRuleUpsert { + u.Add(errorpassthroughrule.FieldResponseCode, v) + return u +} + +// ClearResponseCode clears the value of the "response_code" field. +func (u *ErrorPassthroughRuleUpsert) ClearResponseCode() *ErrorPassthroughRuleUpsert { + u.SetNull(errorpassthroughrule.FieldResponseCode) + return u +} + +// SetPassthroughBody sets the "passthrough_body" field. +func (u *ErrorPassthroughRuleUpsert) SetPassthroughBody(v bool) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldPassthroughBody, v) + return u +} + +// UpdatePassthroughBody sets the "passthrough_body" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdatePassthroughBody() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldPassthroughBody) + return u +} + +// SetCustomMessage sets the "custom_message" field. +func (u *ErrorPassthroughRuleUpsert) SetCustomMessage(v string) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldCustomMessage, v) + return u +} + +// UpdateCustomMessage sets the "custom_message" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdateCustomMessage() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldCustomMessage) + return u +} + +// ClearCustomMessage clears the value of the "custom_message" field. +func (u *ErrorPassthroughRuleUpsert) ClearCustomMessage() *ErrorPassthroughRuleUpsert { + u.SetNull(errorpassthroughrule.FieldCustomMessage) + return u +} + +// SetDescription sets the "description" field. +func (u *ErrorPassthroughRuleUpsert) SetDescription(v string) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldDescription, v) + return u +} + +// UpdateDescription sets the "description" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdateDescription() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldDescription) + return u +} + +// ClearDescription clears the value of the "description" field. +func (u *ErrorPassthroughRuleUpsert) ClearDescription() *ErrorPassthroughRuleUpsert { + u.SetNull(errorpassthroughrule.FieldDescription) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.ErrorPassthroughRule.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *ErrorPassthroughRuleUpsertOne) UpdateNewValues() *ErrorPassthroughRuleUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(errorpassthroughrule.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.ErrorPassthroughRule.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *ErrorPassthroughRuleUpsertOne) Ignore() *ErrorPassthroughRuleUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *ErrorPassthroughRuleUpsertOne) DoNothing() *ErrorPassthroughRuleUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the ErrorPassthroughRuleCreate.OnConflict +// documentation for more info. +func (u *ErrorPassthroughRuleUpsertOne) Update(set func(*ErrorPassthroughRuleUpsert)) *ErrorPassthroughRuleUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&ErrorPassthroughRuleUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *ErrorPassthroughRuleUpsertOne) SetUpdatedAt(v time.Time) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdateUpdatedAt() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetName sets the "name" field. +func (u *ErrorPassthroughRuleUpsertOne) SetName(v string) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetName(v) + }) +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdateName() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateName() + }) +} + +// SetEnabled sets the "enabled" field. +func (u *ErrorPassthroughRuleUpsertOne) SetEnabled(v bool) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetEnabled(v) + }) +} + +// UpdateEnabled sets the "enabled" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdateEnabled() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateEnabled() + }) +} + +// SetPriority sets the "priority" field. +func (u *ErrorPassthroughRuleUpsertOne) SetPriority(v int) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetPriority(v) + }) +} + +// AddPriority adds v to the "priority" field. +func (u *ErrorPassthroughRuleUpsertOne) AddPriority(v int) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.AddPriority(v) + }) +} + +// UpdatePriority sets the "priority" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdatePriority() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdatePriority() + }) +} + +// SetErrorCodes sets the "error_codes" field. +func (u *ErrorPassthroughRuleUpsertOne) SetErrorCodes(v []int) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetErrorCodes(v) + }) +} + +// UpdateErrorCodes sets the "error_codes" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdateErrorCodes() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateErrorCodes() + }) +} + +// ClearErrorCodes clears the value of the "error_codes" field. +func (u *ErrorPassthroughRuleUpsertOne) ClearErrorCodes() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearErrorCodes() + }) +} + +// SetKeywords sets the "keywords" field. +func (u *ErrorPassthroughRuleUpsertOne) SetKeywords(v []string) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetKeywords(v) + }) +} + +// UpdateKeywords sets the "keywords" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdateKeywords() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateKeywords() + }) +} + +// ClearKeywords clears the value of the "keywords" field. +func (u *ErrorPassthroughRuleUpsertOne) ClearKeywords() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearKeywords() + }) +} + +// SetMatchMode sets the "match_mode" field. +func (u *ErrorPassthroughRuleUpsertOne) SetMatchMode(v string) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetMatchMode(v) + }) +} + +// UpdateMatchMode sets the "match_mode" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdateMatchMode() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateMatchMode() + }) +} + +// SetPlatforms sets the "platforms" field. +func (u *ErrorPassthroughRuleUpsertOne) SetPlatforms(v []string) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetPlatforms(v) + }) +} + +// UpdatePlatforms sets the "platforms" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdatePlatforms() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdatePlatforms() + }) +} + +// ClearPlatforms clears the value of the "platforms" field. +func (u *ErrorPassthroughRuleUpsertOne) ClearPlatforms() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearPlatforms() + }) +} + +// SetPassthroughCode sets the "passthrough_code" field. +func (u *ErrorPassthroughRuleUpsertOne) SetPassthroughCode(v bool) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetPassthroughCode(v) + }) +} + +// UpdatePassthroughCode sets the "passthrough_code" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdatePassthroughCode() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdatePassthroughCode() + }) +} + +// SetResponseCode sets the "response_code" field. +func (u *ErrorPassthroughRuleUpsertOne) SetResponseCode(v int) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetResponseCode(v) + }) +} + +// AddResponseCode adds v to the "response_code" field. +func (u *ErrorPassthroughRuleUpsertOne) AddResponseCode(v int) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.AddResponseCode(v) + }) +} + +// UpdateResponseCode sets the "response_code" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdateResponseCode() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateResponseCode() + }) +} + +// ClearResponseCode clears the value of the "response_code" field. +func (u *ErrorPassthroughRuleUpsertOne) ClearResponseCode() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearResponseCode() + }) +} + +// SetPassthroughBody sets the "passthrough_body" field. +func (u *ErrorPassthroughRuleUpsertOne) SetPassthroughBody(v bool) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetPassthroughBody(v) + }) +} + +// UpdatePassthroughBody sets the "passthrough_body" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdatePassthroughBody() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdatePassthroughBody() + }) +} + +// SetCustomMessage sets the "custom_message" field. +func (u *ErrorPassthroughRuleUpsertOne) SetCustomMessage(v string) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetCustomMessage(v) + }) +} + +// UpdateCustomMessage sets the "custom_message" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdateCustomMessage() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateCustomMessage() + }) +} + +// ClearCustomMessage clears the value of the "custom_message" field. +func (u *ErrorPassthroughRuleUpsertOne) ClearCustomMessage() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearCustomMessage() + }) +} + +// SetDescription sets the "description" field. +func (u *ErrorPassthroughRuleUpsertOne) SetDescription(v string) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetDescription(v) + }) +} + +// UpdateDescription sets the "description" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdateDescription() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateDescription() + }) +} + +// ClearDescription clears the value of the "description" field. +func (u *ErrorPassthroughRuleUpsertOne) ClearDescription() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearDescription() + }) +} + +// Exec executes the query. +func (u *ErrorPassthroughRuleUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for ErrorPassthroughRuleCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *ErrorPassthroughRuleUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *ErrorPassthroughRuleUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *ErrorPassthroughRuleUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// ErrorPassthroughRuleCreateBulk is the builder for creating many ErrorPassthroughRule entities in bulk. +type ErrorPassthroughRuleCreateBulk struct { + config + err error + builders []*ErrorPassthroughRuleCreate + conflict []sql.ConflictOption +} + +// Save creates the ErrorPassthroughRule entities in the database. +func (_c *ErrorPassthroughRuleCreateBulk) Save(ctx context.Context) ([]*ErrorPassthroughRule, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*ErrorPassthroughRule, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*ErrorPassthroughRuleMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *ErrorPassthroughRuleCreateBulk) SaveX(ctx context.Context) []*ErrorPassthroughRule { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *ErrorPassthroughRuleCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *ErrorPassthroughRuleCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.ErrorPassthroughRule.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.ErrorPassthroughRuleUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *ErrorPassthroughRuleCreateBulk) OnConflict(opts ...sql.ConflictOption) *ErrorPassthroughRuleUpsertBulk { + _c.conflict = opts + return &ErrorPassthroughRuleUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.ErrorPassthroughRule.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *ErrorPassthroughRuleCreateBulk) OnConflictColumns(columns ...string) *ErrorPassthroughRuleUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &ErrorPassthroughRuleUpsertBulk{ + create: _c, + } +} + +// ErrorPassthroughRuleUpsertBulk is the builder for "upsert"-ing +// a bulk of ErrorPassthroughRule nodes. +type ErrorPassthroughRuleUpsertBulk struct { + create *ErrorPassthroughRuleCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.ErrorPassthroughRule.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *ErrorPassthroughRuleUpsertBulk) UpdateNewValues() *ErrorPassthroughRuleUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(errorpassthroughrule.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.ErrorPassthroughRule.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *ErrorPassthroughRuleUpsertBulk) Ignore() *ErrorPassthroughRuleUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *ErrorPassthroughRuleUpsertBulk) DoNothing() *ErrorPassthroughRuleUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the ErrorPassthroughRuleCreateBulk.OnConflict +// documentation for more info. +func (u *ErrorPassthroughRuleUpsertBulk) Update(set func(*ErrorPassthroughRuleUpsert)) *ErrorPassthroughRuleUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&ErrorPassthroughRuleUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetUpdatedAt(v time.Time) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdateUpdatedAt() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetName sets the "name" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetName(v string) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetName(v) + }) +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdateName() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateName() + }) +} + +// SetEnabled sets the "enabled" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetEnabled(v bool) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetEnabled(v) + }) +} + +// UpdateEnabled sets the "enabled" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdateEnabled() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateEnabled() + }) +} + +// SetPriority sets the "priority" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetPriority(v int) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetPriority(v) + }) +} + +// AddPriority adds v to the "priority" field. +func (u *ErrorPassthroughRuleUpsertBulk) AddPriority(v int) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.AddPriority(v) + }) +} + +// UpdatePriority sets the "priority" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdatePriority() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdatePriority() + }) +} + +// SetErrorCodes sets the "error_codes" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetErrorCodes(v []int) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetErrorCodes(v) + }) +} + +// UpdateErrorCodes sets the "error_codes" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdateErrorCodes() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateErrorCodes() + }) +} + +// ClearErrorCodes clears the value of the "error_codes" field. +func (u *ErrorPassthroughRuleUpsertBulk) ClearErrorCodes() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearErrorCodes() + }) +} + +// SetKeywords sets the "keywords" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetKeywords(v []string) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetKeywords(v) + }) +} + +// UpdateKeywords sets the "keywords" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdateKeywords() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateKeywords() + }) +} + +// ClearKeywords clears the value of the "keywords" field. +func (u *ErrorPassthroughRuleUpsertBulk) ClearKeywords() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearKeywords() + }) +} + +// SetMatchMode sets the "match_mode" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetMatchMode(v string) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetMatchMode(v) + }) +} + +// UpdateMatchMode sets the "match_mode" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdateMatchMode() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateMatchMode() + }) +} + +// SetPlatforms sets the "platforms" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetPlatforms(v []string) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetPlatforms(v) + }) +} + +// UpdatePlatforms sets the "platforms" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdatePlatforms() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdatePlatforms() + }) +} + +// ClearPlatforms clears the value of the "platforms" field. +func (u *ErrorPassthroughRuleUpsertBulk) ClearPlatforms() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearPlatforms() + }) +} + +// SetPassthroughCode sets the "passthrough_code" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetPassthroughCode(v bool) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetPassthroughCode(v) + }) +} + +// UpdatePassthroughCode sets the "passthrough_code" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdatePassthroughCode() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdatePassthroughCode() + }) +} + +// SetResponseCode sets the "response_code" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetResponseCode(v int) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetResponseCode(v) + }) +} + +// AddResponseCode adds v to the "response_code" field. +func (u *ErrorPassthroughRuleUpsertBulk) AddResponseCode(v int) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.AddResponseCode(v) + }) +} + +// UpdateResponseCode sets the "response_code" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdateResponseCode() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateResponseCode() + }) +} + +// ClearResponseCode clears the value of the "response_code" field. +func (u *ErrorPassthroughRuleUpsertBulk) ClearResponseCode() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearResponseCode() + }) +} + +// SetPassthroughBody sets the "passthrough_body" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetPassthroughBody(v bool) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetPassthroughBody(v) + }) +} + +// UpdatePassthroughBody sets the "passthrough_body" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdatePassthroughBody() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdatePassthroughBody() + }) +} + +// SetCustomMessage sets the "custom_message" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetCustomMessage(v string) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetCustomMessage(v) + }) +} + +// UpdateCustomMessage sets the "custom_message" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdateCustomMessage() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateCustomMessage() + }) +} + +// ClearCustomMessage clears the value of the "custom_message" field. +func (u *ErrorPassthroughRuleUpsertBulk) ClearCustomMessage() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearCustomMessage() + }) +} + +// SetDescription sets the "description" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetDescription(v string) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetDescription(v) + }) +} + +// UpdateDescription sets the "description" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdateDescription() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateDescription() + }) +} + +// ClearDescription clears the value of the "description" field. +func (u *ErrorPassthroughRuleUpsertBulk) ClearDescription() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearDescription() + }) +} + +// Exec executes the query. +func (u *ErrorPassthroughRuleUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the ErrorPassthroughRuleCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for ErrorPassthroughRuleCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *ErrorPassthroughRuleUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/errorpassthroughrule_delete.go b/backend/ent/errorpassthroughrule_delete.go new file mode 100644 index 00000000..943c7e2b --- /dev/null +++ b/backend/ent/errorpassthroughrule_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ErrorPassthroughRuleDelete is the builder for deleting a ErrorPassthroughRule entity. +type ErrorPassthroughRuleDelete struct { + config + hooks []Hook + mutation *ErrorPassthroughRuleMutation +} + +// Where appends a list predicates to the ErrorPassthroughRuleDelete builder. +func (_d *ErrorPassthroughRuleDelete) Where(ps ...predicate.ErrorPassthroughRule) *ErrorPassthroughRuleDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *ErrorPassthroughRuleDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *ErrorPassthroughRuleDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *ErrorPassthroughRuleDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(errorpassthroughrule.Table, sqlgraph.NewFieldSpec(errorpassthroughrule.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// ErrorPassthroughRuleDeleteOne is the builder for deleting a single ErrorPassthroughRule entity. +type ErrorPassthroughRuleDeleteOne struct { + _d *ErrorPassthroughRuleDelete +} + +// Where appends a list predicates to the ErrorPassthroughRuleDelete builder. +func (_d *ErrorPassthroughRuleDeleteOne) Where(ps ...predicate.ErrorPassthroughRule) *ErrorPassthroughRuleDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *ErrorPassthroughRuleDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{errorpassthroughrule.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *ErrorPassthroughRuleDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/errorpassthroughrule_query.go b/backend/ent/errorpassthroughrule_query.go new file mode 100644 index 00000000..bfab5bd8 --- /dev/null +++ b/backend/ent/errorpassthroughrule_query.go @@ -0,0 +1,564 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ErrorPassthroughRuleQuery is the builder for querying ErrorPassthroughRule entities. +type ErrorPassthroughRuleQuery struct { + config + ctx *QueryContext + order []errorpassthroughrule.OrderOption + inters []Interceptor + predicates []predicate.ErrorPassthroughRule + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the ErrorPassthroughRuleQuery builder. +func (_q *ErrorPassthroughRuleQuery) Where(ps ...predicate.ErrorPassthroughRule) *ErrorPassthroughRuleQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *ErrorPassthroughRuleQuery) Limit(limit int) *ErrorPassthroughRuleQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *ErrorPassthroughRuleQuery) Offset(offset int) *ErrorPassthroughRuleQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *ErrorPassthroughRuleQuery) Unique(unique bool) *ErrorPassthroughRuleQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *ErrorPassthroughRuleQuery) Order(o ...errorpassthroughrule.OrderOption) *ErrorPassthroughRuleQuery { + _q.order = append(_q.order, o...) + return _q +} + +// First returns the first ErrorPassthroughRule entity from the query. +// Returns a *NotFoundError when no ErrorPassthroughRule was found. +func (_q *ErrorPassthroughRuleQuery) First(ctx context.Context) (*ErrorPassthroughRule, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{errorpassthroughrule.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *ErrorPassthroughRuleQuery) FirstX(ctx context.Context) *ErrorPassthroughRule { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first ErrorPassthroughRule ID from the query. +// Returns a *NotFoundError when no ErrorPassthroughRule ID was found. +func (_q *ErrorPassthroughRuleQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{errorpassthroughrule.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *ErrorPassthroughRuleQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single ErrorPassthroughRule entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one ErrorPassthroughRule entity is found. +// Returns a *NotFoundError when no ErrorPassthroughRule entities are found. +func (_q *ErrorPassthroughRuleQuery) Only(ctx context.Context) (*ErrorPassthroughRule, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{errorpassthroughrule.Label} + default: + return nil, &NotSingularError{errorpassthroughrule.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *ErrorPassthroughRuleQuery) OnlyX(ctx context.Context) *ErrorPassthroughRule { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only ErrorPassthroughRule ID in the query. +// Returns a *NotSingularError when more than one ErrorPassthroughRule ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *ErrorPassthroughRuleQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{errorpassthroughrule.Label} + default: + err = &NotSingularError{errorpassthroughrule.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *ErrorPassthroughRuleQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of ErrorPassthroughRules. +func (_q *ErrorPassthroughRuleQuery) All(ctx context.Context) ([]*ErrorPassthroughRule, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*ErrorPassthroughRule, *ErrorPassthroughRuleQuery]() + return withInterceptors[[]*ErrorPassthroughRule](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *ErrorPassthroughRuleQuery) AllX(ctx context.Context) []*ErrorPassthroughRule { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of ErrorPassthroughRule IDs. +func (_q *ErrorPassthroughRuleQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(errorpassthroughrule.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *ErrorPassthroughRuleQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *ErrorPassthroughRuleQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*ErrorPassthroughRuleQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *ErrorPassthroughRuleQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *ErrorPassthroughRuleQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *ErrorPassthroughRuleQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the ErrorPassthroughRuleQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *ErrorPassthroughRuleQuery) Clone() *ErrorPassthroughRuleQuery { + if _q == nil { + return nil + } + return &ErrorPassthroughRuleQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]errorpassthroughrule.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.ErrorPassthroughRule{}, _q.predicates...), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.ErrorPassthroughRule.Query(). +// GroupBy(errorpassthroughrule.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *ErrorPassthroughRuleQuery) GroupBy(field string, fields ...string) *ErrorPassthroughRuleGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &ErrorPassthroughRuleGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = errorpassthroughrule.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.ErrorPassthroughRule.Query(). +// Select(errorpassthroughrule.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *ErrorPassthroughRuleQuery) Select(fields ...string) *ErrorPassthroughRuleSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &ErrorPassthroughRuleSelect{ErrorPassthroughRuleQuery: _q} + sbuild.label = errorpassthroughrule.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a ErrorPassthroughRuleSelect configured with the given aggregations. +func (_q *ErrorPassthroughRuleQuery) Aggregate(fns ...AggregateFunc) *ErrorPassthroughRuleSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *ErrorPassthroughRuleQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !errorpassthroughrule.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *ErrorPassthroughRuleQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*ErrorPassthroughRule, error) { + var ( + nodes = []*ErrorPassthroughRule{} + _spec = _q.querySpec() + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*ErrorPassthroughRule).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &ErrorPassthroughRule{config: _q.config} + nodes = append(nodes, node) + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + return nodes, nil +} + +func (_q *ErrorPassthroughRuleQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *ErrorPassthroughRuleQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(errorpassthroughrule.Table, errorpassthroughrule.Columns, sqlgraph.NewFieldSpec(errorpassthroughrule.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, errorpassthroughrule.FieldID) + for i := range fields { + if fields[i] != errorpassthroughrule.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *ErrorPassthroughRuleQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(errorpassthroughrule.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = errorpassthroughrule.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *ErrorPassthroughRuleQuery) ForUpdate(opts ...sql.LockOption) *ErrorPassthroughRuleQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *ErrorPassthroughRuleQuery) ForShare(opts ...sql.LockOption) *ErrorPassthroughRuleQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// ErrorPassthroughRuleGroupBy is the group-by builder for ErrorPassthroughRule entities. +type ErrorPassthroughRuleGroupBy struct { + selector + build *ErrorPassthroughRuleQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *ErrorPassthroughRuleGroupBy) Aggregate(fns ...AggregateFunc) *ErrorPassthroughRuleGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *ErrorPassthroughRuleGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*ErrorPassthroughRuleQuery, *ErrorPassthroughRuleGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *ErrorPassthroughRuleGroupBy) sqlScan(ctx context.Context, root *ErrorPassthroughRuleQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// ErrorPassthroughRuleSelect is the builder for selecting fields of ErrorPassthroughRule entities. +type ErrorPassthroughRuleSelect struct { + *ErrorPassthroughRuleQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *ErrorPassthroughRuleSelect) Aggregate(fns ...AggregateFunc) *ErrorPassthroughRuleSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *ErrorPassthroughRuleSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*ErrorPassthroughRuleQuery, *ErrorPassthroughRuleSelect](ctx, _s.ErrorPassthroughRuleQuery, _s, _s.inters, v) +} + +func (_s *ErrorPassthroughRuleSelect) sqlScan(ctx context.Context, root *ErrorPassthroughRuleQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/errorpassthroughrule_update.go b/backend/ent/errorpassthroughrule_update.go new file mode 100644 index 00000000..9d52aa49 --- /dev/null +++ b/backend/ent/errorpassthroughrule_update.go @@ -0,0 +1,823 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/dialect/sql/sqljson" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ErrorPassthroughRuleUpdate is the builder for updating ErrorPassthroughRule entities. +type ErrorPassthroughRuleUpdate struct { + config + hooks []Hook + mutation *ErrorPassthroughRuleMutation +} + +// Where appends a list predicates to the ErrorPassthroughRuleUpdate builder. +func (_u *ErrorPassthroughRuleUpdate) Where(ps ...predicate.ErrorPassthroughRule) *ErrorPassthroughRuleUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *ErrorPassthroughRuleUpdate) SetUpdatedAt(v time.Time) *ErrorPassthroughRuleUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetName sets the "name" field. +func (_u *ErrorPassthroughRuleUpdate) SetName(v string) *ErrorPassthroughRuleUpdate { + _u.mutation.SetName(v) + return _u +} + +// SetNillableName sets the "name" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdate) SetNillableName(v *string) *ErrorPassthroughRuleUpdate { + if v != nil { + _u.SetName(*v) + } + return _u +} + +// SetEnabled sets the "enabled" field. +func (_u *ErrorPassthroughRuleUpdate) SetEnabled(v bool) *ErrorPassthroughRuleUpdate { + _u.mutation.SetEnabled(v) + return _u +} + +// SetNillableEnabled sets the "enabled" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdate) SetNillableEnabled(v *bool) *ErrorPassthroughRuleUpdate { + if v != nil { + _u.SetEnabled(*v) + } + return _u +} + +// SetPriority sets the "priority" field. +func (_u *ErrorPassthroughRuleUpdate) SetPriority(v int) *ErrorPassthroughRuleUpdate { + _u.mutation.ResetPriority() + _u.mutation.SetPriority(v) + return _u +} + +// SetNillablePriority sets the "priority" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdate) SetNillablePriority(v *int) *ErrorPassthroughRuleUpdate { + if v != nil { + _u.SetPriority(*v) + } + return _u +} + +// AddPriority adds value to the "priority" field. +func (_u *ErrorPassthroughRuleUpdate) AddPriority(v int) *ErrorPassthroughRuleUpdate { + _u.mutation.AddPriority(v) + return _u +} + +// SetErrorCodes sets the "error_codes" field. +func (_u *ErrorPassthroughRuleUpdate) SetErrorCodes(v []int) *ErrorPassthroughRuleUpdate { + _u.mutation.SetErrorCodes(v) + return _u +} + +// AppendErrorCodes appends value to the "error_codes" field. +func (_u *ErrorPassthroughRuleUpdate) AppendErrorCodes(v []int) *ErrorPassthroughRuleUpdate { + _u.mutation.AppendErrorCodes(v) + return _u +} + +// ClearErrorCodes clears the value of the "error_codes" field. +func (_u *ErrorPassthroughRuleUpdate) ClearErrorCodes() *ErrorPassthroughRuleUpdate { + _u.mutation.ClearErrorCodes() + return _u +} + +// SetKeywords sets the "keywords" field. +func (_u *ErrorPassthroughRuleUpdate) SetKeywords(v []string) *ErrorPassthroughRuleUpdate { + _u.mutation.SetKeywords(v) + return _u +} + +// AppendKeywords appends value to the "keywords" field. +func (_u *ErrorPassthroughRuleUpdate) AppendKeywords(v []string) *ErrorPassthroughRuleUpdate { + _u.mutation.AppendKeywords(v) + return _u +} + +// ClearKeywords clears the value of the "keywords" field. +func (_u *ErrorPassthroughRuleUpdate) ClearKeywords() *ErrorPassthroughRuleUpdate { + _u.mutation.ClearKeywords() + return _u +} + +// SetMatchMode sets the "match_mode" field. +func (_u *ErrorPassthroughRuleUpdate) SetMatchMode(v string) *ErrorPassthroughRuleUpdate { + _u.mutation.SetMatchMode(v) + return _u +} + +// SetNillableMatchMode sets the "match_mode" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdate) SetNillableMatchMode(v *string) *ErrorPassthroughRuleUpdate { + if v != nil { + _u.SetMatchMode(*v) + } + return _u +} + +// SetPlatforms sets the "platforms" field. +func (_u *ErrorPassthroughRuleUpdate) SetPlatforms(v []string) *ErrorPassthroughRuleUpdate { + _u.mutation.SetPlatforms(v) + return _u +} + +// AppendPlatforms appends value to the "platforms" field. +func (_u *ErrorPassthroughRuleUpdate) AppendPlatforms(v []string) *ErrorPassthroughRuleUpdate { + _u.mutation.AppendPlatforms(v) + return _u +} + +// ClearPlatforms clears the value of the "platforms" field. +func (_u *ErrorPassthroughRuleUpdate) ClearPlatforms() *ErrorPassthroughRuleUpdate { + _u.mutation.ClearPlatforms() + return _u +} + +// SetPassthroughCode sets the "passthrough_code" field. +func (_u *ErrorPassthroughRuleUpdate) SetPassthroughCode(v bool) *ErrorPassthroughRuleUpdate { + _u.mutation.SetPassthroughCode(v) + return _u +} + +// SetNillablePassthroughCode sets the "passthrough_code" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdate) SetNillablePassthroughCode(v *bool) *ErrorPassthroughRuleUpdate { + if v != nil { + _u.SetPassthroughCode(*v) + } + return _u +} + +// SetResponseCode sets the "response_code" field. +func (_u *ErrorPassthroughRuleUpdate) SetResponseCode(v int) *ErrorPassthroughRuleUpdate { + _u.mutation.ResetResponseCode() + _u.mutation.SetResponseCode(v) + return _u +} + +// SetNillableResponseCode sets the "response_code" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdate) SetNillableResponseCode(v *int) *ErrorPassthroughRuleUpdate { + if v != nil { + _u.SetResponseCode(*v) + } + return _u +} + +// AddResponseCode adds value to the "response_code" field. +func (_u *ErrorPassthroughRuleUpdate) AddResponseCode(v int) *ErrorPassthroughRuleUpdate { + _u.mutation.AddResponseCode(v) + return _u +} + +// ClearResponseCode clears the value of the "response_code" field. +func (_u *ErrorPassthroughRuleUpdate) ClearResponseCode() *ErrorPassthroughRuleUpdate { + _u.mutation.ClearResponseCode() + return _u +} + +// SetPassthroughBody sets the "passthrough_body" field. +func (_u *ErrorPassthroughRuleUpdate) SetPassthroughBody(v bool) *ErrorPassthroughRuleUpdate { + _u.mutation.SetPassthroughBody(v) + return _u +} + +// SetNillablePassthroughBody sets the "passthrough_body" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdate) SetNillablePassthroughBody(v *bool) *ErrorPassthroughRuleUpdate { + if v != nil { + _u.SetPassthroughBody(*v) + } + return _u +} + +// SetCustomMessage sets the "custom_message" field. +func (_u *ErrorPassthroughRuleUpdate) SetCustomMessage(v string) *ErrorPassthroughRuleUpdate { + _u.mutation.SetCustomMessage(v) + return _u +} + +// SetNillableCustomMessage sets the "custom_message" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdate) SetNillableCustomMessage(v *string) *ErrorPassthroughRuleUpdate { + if v != nil { + _u.SetCustomMessage(*v) + } + return _u +} + +// ClearCustomMessage clears the value of the "custom_message" field. +func (_u *ErrorPassthroughRuleUpdate) ClearCustomMessage() *ErrorPassthroughRuleUpdate { + _u.mutation.ClearCustomMessage() + return _u +} + +// SetDescription sets the "description" field. +func (_u *ErrorPassthroughRuleUpdate) SetDescription(v string) *ErrorPassthroughRuleUpdate { + _u.mutation.SetDescription(v) + return _u +} + +// SetNillableDescription sets the "description" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdate) SetNillableDescription(v *string) *ErrorPassthroughRuleUpdate { + if v != nil { + _u.SetDescription(*v) + } + return _u +} + +// ClearDescription clears the value of the "description" field. +func (_u *ErrorPassthroughRuleUpdate) ClearDescription() *ErrorPassthroughRuleUpdate { + _u.mutation.ClearDescription() + return _u +} + +// Mutation returns the ErrorPassthroughRuleMutation object of the builder. +func (_u *ErrorPassthroughRuleUpdate) Mutation() *ErrorPassthroughRuleMutation { + return _u.mutation +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *ErrorPassthroughRuleUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *ErrorPassthroughRuleUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *ErrorPassthroughRuleUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *ErrorPassthroughRuleUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *ErrorPassthroughRuleUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := errorpassthroughrule.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *ErrorPassthroughRuleUpdate) check() error { + if v, ok := _u.mutation.Name(); ok { + if err := errorpassthroughrule.NameValidator(v); err != nil { + return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ErrorPassthroughRule.name": %w`, err)} + } + } + if v, ok := _u.mutation.MatchMode(); ok { + if err := errorpassthroughrule.MatchModeValidator(v); err != nil { + return &ValidationError{Name: "match_mode", err: fmt.Errorf(`ent: validator failed for field "ErrorPassthroughRule.match_mode": %w`, err)} + } + } + return nil +} + +func (_u *ErrorPassthroughRuleUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(errorpassthroughrule.Table, errorpassthroughrule.Columns, sqlgraph.NewFieldSpec(errorpassthroughrule.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(errorpassthroughrule.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.Name(); ok { + _spec.SetField(errorpassthroughrule.FieldName, field.TypeString, value) + } + if value, ok := _u.mutation.Enabled(); ok { + _spec.SetField(errorpassthroughrule.FieldEnabled, field.TypeBool, value) + } + if value, ok := _u.mutation.Priority(); ok { + _spec.SetField(errorpassthroughrule.FieldPriority, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedPriority(); ok { + _spec.AddField(errorpassthroughrule.FieldPriority, field.TypeInt, value) + } + if value, ok := _u.mutation.ErrorCodes(); ok { + _spec.SetField(errorpassthroughrule.FieldErrorCodes, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedErrorCodes(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, errorpassthroughrule.FieldErrorCodes, value) + }) + } + if _u.mutation.ErrorCodesCleared() { + _spec.ClearField(errorpassthroughrule.FieldErrorCodes, field.TypeJSON) + } + if value, ok := _u.mutation.Keywords(); ok { + _spec.SetField(errorpassthroughrule.FieldKeywords, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedKeywords(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, errorpassthroughrule.FieldKeywords, value) + }) + } + if _u.mutation.KeywordsCleared() { + _spec.ClearField(errorpassthroughrule.FieldKeywords, field.TypeJSON) + } + if value, ok := _u.mutation.MatchMode(); ok { + _spec.SetField(errorpassthroughrule.FieldMatchMode, field.TypeString, value) + } + if value, ok := _u.mutation.Platforms(); ok { + _spec.SetField(errorpassthroughrule.FieldPlatforms, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedPlatforms(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, errorpassthroughrule.FieldPlatforms, value) + }) + } + if _u.mutation.PlatformsCleared() { + _spec.ClearField(errorpassthroughrule.FieldPlatforms, field.TypeJSON) + } + if value, ok := _u.mutation.PassthroughCode(); ok { + _spec.SetField(errorpassthroughrule.FieldPassthroughCode, field.TypeBool, value) + } + if value, ok := _u.mutation.ResponseCode(); ok { + _spec.SetField(errorpassthroughrule.FieldResponseCode, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedResponseCode(); ok { + _spec.AddField(errorpassthroughrule.FieldResponseCode, field.TypeInt, value) + } + if _u.mutation.ResponseCodeCleared() { + _spec.ClearField(errorpassthroughrule.FieldResponseCode, field.TypeInt) + } + if value, ok := _u.mutation.PassthroughBody(); ok { + _spec.SetField(errorpassthroughrule.FieldPassthroughBody, field.TypeBool, value) + } + if value, ok := _u.mutation.CustomMessage(); ok { + _spec.SetField(errorpassthroughrule.FieldCustomMessage, field.TypeString, value) + } + if _u.mutation.CustomMessageCleared() { + _spec.ClearField(errorpassthroughrule.FieldCustomMessage, field.TypeString) + } + if value, ok := _u.mutation.Description(); ok { + _spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value) + } + if _u.mutation.DescriptionCleared() { + _spec.ClearField(errorpassthroughrule.FieldDescription, field.TypeString) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{errorpassthroughrule.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// ErrorPassthroughRuleUpdateOne is the builder for updating a single ErrorPassthroughRule entity. +type ErrorPassthroughRuleUpdateOne struct { + config + fields []string + hooks []Hook + mutation *ErrorPassthroughRuleMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetUpdatedAt(v time.Time) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetName sets the "name" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetName(v string) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetName(v) + return _u +} + +// SetNillableName sets the "name" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdateOne) SetNillableName(v *string) *ErrorPassthroughRuleUpdateOne { + if v != nil { + _u.SetName(*v) + } + return _u +} + +// SetEnabled sets the "enabled" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetEnabled(v bool) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetEnabled(v) + return _u +} + +// SetNillableEnabled sets the "enabled" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdateOne) SetNillableEnabled(v *bool) *ErrorPassthroughRuleUpdateOne { + if v != nil { + _u.SetEnabled(*v) + } + return _u +} + +// SetPriority sets the "priority" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetPriority(v int) *ErrorPassthroughRuleUpdateOne { + _u.mutation.ResetPriority() + _u.mutation.SetPriority(v) + return _u +} + +// SetNillablePriority sets the "priority" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdateOne) SetNillablePriority(v *int) *ErrorPassthroughRuleUpdateOne { + if v != nil { + _u.SetPriority(*v) + } + return _u +} + +// AddPriority adds value to the "priority" field. +func (_u *ErrorPassthroughRuleUpdateOne) AddPriority(v int) *ErrorPassthroughRuleUpdateOne { + _u.mutation.AddPriority(v) + return _u +} + +// SetErrorCodes sets the "error_codes" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetErrorCodes(v []int) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetErrorCodes(v) + return _u +} + +// AppendErrorCodes appends value to the "error_codes" field. +func (_u *ErrorPassthroughRuleUpdateOne) AppendErrorCodes(v []int) *ErrorPassthroughRuleUpdateOne { + _u.mutation.AppendErrorCodes(v) + return _u +} + +// ClearErrorCodes clears the value of the "error_codes" field. +func (_u *ErrorPassthroughRuleUpdateOne) ClearErrorCodes() *ErrorPassthroughRuleUpdateOne { + _u.mutation.ClearErrorCodes() + return _u +} + +// SetKeywords sets the "keywords" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetKeywords(v []string) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetKeywords(v) + return _u +} + +// AppendKeywords appends value to the "keywords" field. +func (_u *ErrorPassthroughRuleUpdateOne) AppendKeywords(v []string) *ErrorPassthroughRuleUpdateOne { + _u.mutation.AppendKeywords(v) + return _u +} + +// ClearKeywords clears the value of the "keywords" field. +func (_u *ErrorPassthroughRuleUpdateOne) ClearKeywords() *ErrorPassthroughRuleUpdateOne { + _u.mutation.ClearKeywords() + return _u +} + +// SetMatchMode sets the "match_mode" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetMatchMode(v string) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetMatchMode(v) + return _u +} + +// SetNillableMatchMode sets the "match_mode" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdateOne) SetNillableMatchMode(v *string) *ErrorPassthroughRuleUpdateOne { + if v != nil { + _u.SetMatchMode(*v) + } + return _u +} + +// SetPlatforms sets the "platforms" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetPlatforms(v []string) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetPlatforms(v) + return _u +} + +// AppendPlatforms appends value to the "platforms" field. +func (_u *ErrorPassthroughRuleUpdateOne) AppendPlatforms(v []string) *ErrorPassthroughRuleUpdateOne { + _u.mutation.AppendPlatforms(v) + return _u +} + +// ClearPlatforms clears the value of the "platforms" field. +func (_u *ErrorPassthroughRuleUpdateOne) ClearPlatforms() *ErrorPassthroughRuleUpdateOne { + _u.mutation.ClearPlatforms() + return _u +} + +// SetPassthroughCode sets the "passthrough_code" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetPassthroughCode(v bool) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetPassthroughCode(v) + return _u +} + +// SetNillablePassthroughCode sets the "passthrough_code" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdateOne) SetNillablePassthroughCode(v *bool) *ErrorPassthroughRuleUpdateOne { + if v != nil { + _u.SetPassthroughCode(*v) + } + return _u +} + +// SetResponseCode sets the "response_code" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetResponseCode(v int) *ErrorPassthroughRuleUpdateOne { + _u.mutation.ResetResponseCode() + _u.mutation.SetResponseCode(v) + return _u +} + +// SetNillableResponseCode sets the "response_code" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdateOne) SetNillableResponseCode(v *int) *ErrorPassthroughRuleUpdateOne { + if v != nil { + _u.SetResponseCode(*v) + } + return _u +} + +// AddResponseCode adds value to the "response_code" field. +func (_u *ErrorPassthroughRuleUpdateOne) AddResponseCode(v int) *ErrorPassthroughRuleUpdateOne { + _u.mutation.AddResponseCode(v) + return _u +} + +// ClearResponseCode clears the value of the "response_code" field. +func (_u *ErrorPassthroughRuleUpdateOne) ClearResponseCode() *ErrorPassthroughRuleUpdateOne { + _u.mutation.ClearResponseCode() + return _u +} + +// SetPassthroughBody sets the "passthrough_body" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetPassthroughBody(v bool) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetPassthroughBody(v) + return _u +} + +// SetNillablePassthroughBody sets the "passthrough_body" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdateOne) SetNillablePassthroughBody(v *bool) *ErrorPassthroughRuleUpdateOne { + if v != nil { + _u.SetPassthroughBody(*v) + } + return _u +} + +// SetCustomMessage sets the "custom_message" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetCustomMessage(v string) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetCustomMessage(v) + return _u +} + +// SetNillableCustomMessage sets the "custom_message" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdateOne) SetNillableCustomMessage(v *string) *ErrorPassthroughRuleUpdateOne { + if v != nil { + _u.SetCustomMessage(*v) + } + return _u +} + +// ClearCustomMessage clears the value of the "custom_message" field. +func (_u *ErrorPassthroughRuleUpdateOne) ClearCustomMessage() *ErrorPassthroughRuleUpdateOne { + _u.mutation.ClearCustomMessage() + return _u +} + +// SetDescription sets the "description" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetDescription(v string) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetDescription(v) + return _u +} + +// SetNillableDescription sets the "description" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdateOne) SetNillableDescription(v *string) *ErrorPassthroughRuleUpdateOne { + if v != nil { + _u.SetDescription(*v) + } + return _u +} + +// ClearDescription clears the value of the "description" field. +func (_u *ErrorPassthroughRuleUpdateOne) ClearDescription() *ErrorPassthroughRuleUpdateOne { + _u.mutation.ClearDescription() + return _u +} + +// Mutation returns the ErrorPassthroughRuleMutation object of the builder. +func (_u *ErrorPassthroughRuleUpdateOne) Mutation() *ErrorPassthroughRuleMutation { + return _u.mutation +} + +// Where appends a list predicates to the ErrorPassthroughRuleUpdate builder. +func (_u *ErrorPassthroughRuleUpdateOne) Where(ps ...predicate.ErrorPassthroughRule) *ErrorPassthroughRuleUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *ErrorPassthroughRuleUpdateOne) Select(field string, fields ...string) *ErrorPassthroughRuleUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated ErrorPassthroughRule entity. +func (_u *ErrorPassthroughRuleUpdateOne) Save(ctx context.Context) (*ErrorPassthroughRule, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *ErrorPassthroughRuleUpdateOne) SaveX(ctx context.Context) *ErrorPassthroughRule { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *ErrorPassthroughRuleUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *ErrorPassthroughRuleUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *ErrorPassthroughRuleUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := errorpassthroughrule.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *ErrorPassthroughRuleUpdateOne) check() error { + if v, ok := _u.mutation.Name(); ok { + if err := errorpassthroughrule.NameValidator(v); err != nil { + return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ErrorPassthroughRule.name": %w`, err)} + } + } + if v, ok := _u.mutation.MatchMode(); ok { + if err := errorpassthroughrule.MatchModeValidator(v); err != nil { + return &ValidationError{Name: "match_mode", err: fmt.Errorf(`ent: validator failed for field "ErrorPassthroughRule.match_mode": %w`, err)} + } + } + return nil +} + +func (_u *ErrorPassthroughRuleUpdateOne) sqlSave(ctx context.Context) (_node *ErrorPassthroughRule, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(errorpassthroughrule.Table, errorpassthroughrule.Columns, sqlgraph.NewFieldSpec(errorpassthroughrule.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "ErrorPassthroughRule.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, errorpassthroughrule.FieldID) + for _, f := range fields { + if !errorpassthroughrule.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != errorpassthroughrule.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(errorpassthroughrule.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.Name(); ok { + _spec.SetField(errorpassthroughrule.FieldName, field.TypeString, value) + } + if value, ok := _u.mutation.Enabled(); ok { + _spec.SetField(errorpassthroughrule.FieldEnabled, field.TypeBool, value) + } + if value, ok := _u.mutation.Priority(); ok { + _spec.SetField(errorpassthroughrule.FieldPriority, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedPriority(); ok { + _spec.AddField(errorpassthroughrule.FieldPriority, field.TypeInt, value) + } + if value, ok := _u.mutation.ErrorCodes(); ok { + _spec.SetField(errorpassthroughrule.FieldErrorCodes, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedErrorCodes(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, errorpassthroughrule.FieldErrorCodes, value) + }) + } + if _u.mutation.ErrorCodesCleared() { + _spec.ClearField(errorpassthroughrule.FieldErrorCodes, field.TypeJSON) + } + if value, ok := _u.mutation.Keywords(); ok { + _spec.SetField(errorpassthroughrule.FieldKeywords, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedKeywords(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, errorpassthroughrule.FieldKeywords, value) + }) + } + if _u.mutation.KeywordsCleared() { + _spec.ClearField(errorpassthroughrule.FieldKeywords, field.TypeJSON) + } + if value, ok := _u.mutation.MatchMode(); ok { + _spec.SetField(errorpassthroughrule.FieldMatchMode, field.TypeString, value) + } + if value, ok := _u.mutation.Platforms(); ok { + _spec.SetField(errorpassthroughrule.FieldPlatforms, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedPlatforms(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, errorpassthroughrule.FieldPlatforms, value) + }) + } + if _u.mutation.PlatformsCleared() { + _spec.ClearField(errorpassthroughrule.FieldPlatforms, field.TypeJSON) + } + if value, ok := _u.mutation.PassthroughCode(); ok { + _spec.SetField(errorpassthroughrule.FieldPassthroughCode, field.TypeBool, value) + } + if value, ok := _u.mutation.ResponseCode(); ok { + _spec.SetField(errorpassthroughrule.FieldResponseCode, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedResponseCode(); ok { + _spec.AddField(errorpassthroughrule.FieldResponseCode, field.TypeInt, value) + } + if _u.mutation.ResponseCodeCleared() { + _spec.ClearField(errorpassthroughrule.FieldResponseCode, field.TypeInt) + } + if value, ok := _u.mutation.PassthroughBody(); ok { + _spec.SetField(errorpassthroughrule.FieldPassthroughBody, field.TypeBool, value) + } + if value, ok := _u.mutation.CustomMessage(); ok { + _spec.SetField(errorpassthroughrule.FieldCustomMessage, field.TypeString, value) + } + if _u.mutation.CustomMessageCleared() { + _spec.ClearField(errorpassthroughrule.FieldCustomMessage, field.TypeString) + } + if value, ok := _u.mutation.Description(); ok { + _spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value) + } + if _u.mutation.DescriptionCleared() { + _spec.ClearField(errorpassthroughrule.FieldDescription, field.TypeString) + } + _node = &ErrorPassthroughRule{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{errorpassthroughrule.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/hook/hook.go b/backend/ent/hook/hook.go index 1e653c77..1b15685c 100644 --- a/backend/ent/hook/hook.go +++ b/backend/ent/hook/hook.go @@ -69,6 +69,18 @@ func (f AnnouncementReadFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.V return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AnnouncementReadMutation", m) } +// The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary +// function as ErrorPassthroughRule mutator. +type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f ErrorPassthroughRuleFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.ErrorPassthroughRuleMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.ErrorPassthroughRuleMutation", m) +} + // The GroupFunc type is an adapter to allow the use of ordinary // function as Group mutator. type GroupFunc func(context.Context, *ent.GroupMutation) (ent.Value, error) diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go index a37be48f..8ee42db3 100644 --- a/backend/ent/intercept/intercept.go +++ b/backend/ent/intercept/intercept.go @@ -13,6 +13,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocode" @@ -220,6 +221,33 @@ func (f TraverseAnnouncementRead) Traverse(ctx context.Context, q ent.Query) err return fmt.Errorf("unexpected query type %T. expect *ent.AnnouncementReadQuery", q) } +// The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary function as a Querier. +type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f ErrorPassthroughRuleFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.ErrorPassthroughRuleQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.ErrorPassthroughRuleQuery", q) +} + +// The TraverseErrorPassthroughRule type is an adapter to allow the use of ordinary function as Traverser. +type TraverseErrorPassthroughRule func(context.Context, *ent.ErrorPassthroughRuleQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseErrorPassthroughRule) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseErrorPassthroughRule) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.ErrorPassthroughRuleQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.ErrorPassthroughRuleQuery", q) +} + // The GroupFunc type is an adapter to allow the use of ordinary function as a Querier. type GroupFunc func(context.Context, *ent.GroupQuery) (ent.Value, error) @@ -584,6 +612,8 @@ func NewQuery(q ent.Query) (Query, error) { return &query[*ent.AnnouncementQuery, predicate.Announcement, announcement.OrderOption]{typ: ent.TypeAnnouncement, tq: q}, nil case *ent.AnnouncementReadQuery: return &query[*ent.AnnouncementReadQuery, predicate.AnnouncementRead, announcementread.OrderOption]{typ: ent.TypeAnnouncementRead, tq: q}, nil + case *ent.ErrorPassthroughRuleQuery: + return &query[*ent.ErrorPassthroughRuleQuery, predicate.ErrorPassthroughRule, errorpassthroughrule.OrderOption]{typ: ent.TypeErrorPassthroughRule, tq: q}, nil case *ent.GroupQuery: return &query[*ent.GroupQuery, predicate.Group, group.OrderOption]{typ: ent.TypeGroup, tq: q}, nil case *ent.PromoCodeQuery: diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index dc91f6a5..f9e90d73 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -309,6 +309,42 @@ var ( }, }, } + // ErrorPassthroughRulesColumns holds the columns for the "error_passthrough_rules" table. + ErrorPassthroughRulesColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "name", Type: field.TypeString, Size: 100}, + {Name: "enabled", Type: field.TypeBool, Default: true}, + {Name: "priority", Type: field.TypeInt, Default: 0}, + {Name: "error_codes", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "keywords", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "match_mode", Type: field.TypeString, Size: 10, Default: "any"}, + {Name: "platforms", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "passthrough_code", Type: field.TypeBool, Default: true}, + {Name: "response_code", Type: field.TypeInt, Nullable: true}, + {Name: "passthrough_body", Type: field.TypeBool, Default: true}, + {Name: "custom_message", Type: field.TypeString, Nullable: true, Size: 2147483647}, + {Name: "description", Type: field.TypeString, Nullable: true, Size: 2147483647}, + } + // ErrorPassthroughRulesTable holds the schema information for the "error_passthrough_rules" table. + ErrorPassthroughRulesTable = &schema.Table{ + Name: "error_passthrough_rules", + Columns: ErrorPassthroughRulesColumns, + PrimaryKey: []*schema.Column{ErrorPassthroughRulesColumns[0]}, + Indexes: []*schema.Index{ + { + Name: "errorpassthroughrule_enabled", + Unique: false, + Columns: []*schema.Column{ErrorPassthroughRulesColumns[4]}, + }, + { + Name: "errorpassthroughrule_priority", + Unique: false, + Columns: []*schema.Column{ErrorPassthroughRulesColumns[5]}, + }, + }, + } // GroupsColumns holds the columns for the "groups" table. GroupsColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt64, Increment: true}, @@ -950,6 +986,7 @@ var ( AccountGroupsTable, AnnouncementsTable, AnnouncementReadsTable, + ErrorPassthroughRulesTable, GroupsTable, PromoCodesTable, PromoCodeUsagesTable, @@ -989,6 +1026,9 @@ func init() { AnnouncementReadsTable.Annotation = &entsql.Annotation{ Table: "announcement_reads", } + ErrorPassthroughRulesTable.Annotation = &entsql.Annotation{ + Table: "error_passthrough_rules", + } GroupsTable.Annotation = &entsql.Annotation{ Table: "groups", } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 77d208e1..5c182dea 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -17,6 +17,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocode" @@ -48,6 +49,7 @@ const ( TypeAccountGroup = "AccountGroup" TypeAnnouncement = "Announcement" TypeAnnouncementRead = "AnnouncementRead" + TypeErrorPassthroughRule = "ErrorPassthroughRule" TypeGroup = "Group" TypePromoCode = "PromoCode" TypePromoCodeUsage = "PromoCodeUsage" @@ -5750,6 +5752,1272 @@ func (m *AnnouncementReadMutation) ResetEdge(name string) error { return fmt.Errorf("unknown AnnouncementRead edge %s", name) } +// ErrorPassthroughRuleMutation represents an operation that mutates the ErrorPassthroughRule nodes in the graph. +type ErrorPassthroughRuleMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + name *string + enabled *bool + priority *int + addpriority *int + error_codes *[]int + appenderror_codes []int + keywords *[]string + appendkeywords []string + match_mode *string + platforms *[]string + appendplatforms []string + passthrough_code *bool + response_code *int + addresponse_code *int + passthrough_body *bool + custom_message *string + description *string + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*ErrorPassthroughRule, error) + predicates []predicate.ErrorPassthroughRule +} + +var _ ent.Mutation = (*ErrorPassthroughRuleMutation)(nil) + +// errorpassthroughruleOption allows management of the mutation configuration using functional options. +type errorpassthroughruleOption func(*ErrorPassthroughRuleMutation) + +// newErrorPassthroughRuleMutation creates new mutation for the ErrorPassthroughRule entity. +func newErrorPassthroughRuleMutation(c config, op Op, opts ...errorpassthroughruleOption) *ErrorPassthroughRuleMutation { + m := &ErrorPassthroughRuleMutation{ + config: c, + op: op, + typ: TypeErrorPassthroughRule, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withErrorPassthroughRuleID sets the ID field of the mutation. +func withErrorPassthroughRuleID(id int64) errorpassthroughruleOption { + return func(m *ErrorPassthroughRuleMutation) { + var ( + err error + once sync.Once + value *ErrorPassthroughRule + ) + m.oldValue = func(ctx context.Context) (*ErrorPassthroughRule, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().ErrorPassthroughRule.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withErrorPassthroughRule sets the old ErrorPassthroughRule of the mutation. +func withErrorPassthroughRule(node *ErrorPassthroughRule) errorpassthroughruleOption { + return func(m *ErrorPassthroughRuleMutation) { + m.oldValue = func(context.Context) (*ErrorPassthroughRule, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m ErrorPassthroughRuleMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m ErrorPassthroughRuleMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *ErrorPassthroughRuleMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *ErrorPassthroughRuleMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().ErrorPassthroughRule.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *ErrorPassthroughRuleMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *ErrorPassthroughRuleMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *ErrorPassthroughRuleMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *ErrorPassthroughRuleMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *ErrorPassthroughRuleMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *ErrorPassthroughRuleMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetName sets the "name" field. +func (m *ErrorPassthroughRuleMutation) SetName(s string) { + m.name = &s +} + +// Name returns the value of the "name" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Name() (r string, exists bool) { + v := m.name + if v == nil { + return + } + return *v, true +} + +// OldName returns the old "name" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldName(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldName is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldName requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldName: %w", err) + } + return oldValue.Name, nil +} + +// ResetName resets all changes to the "name" field. +func (m *ErrorPassthroughRuleMutation) ResetName() { + m.name = nil +} + +// SetEnabled sets the "enabled" field. +func (m *ErrorPassthroughRuleMutation) SetEnabled(b bool) { + m.enabled = &b +} + +// Enabled returns the value of the "enabled" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Enabled() (r bool, exists bool) { + v := m.enabled + if v == nil { + return + } + return *v, true +} + +// OldEnabled returns the old "enabled" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldEnabled(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldEnabled is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldEnabled requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldEnabled: %w", err) + } + return oldValue.Enabled, nil +} + +// ResetEnabled resets all changes to the "enabled" field. +func (m *ErrorPassthroughRuleMutation) ResetEnabled() { + m.enabled = nil +} + +// SetPriority sets the "priority" field. +func (m *ErrorPassthroughRuleMutation) SetPriority(i int) { + m.priority = &i + m.addpriority = nil +} + +// Priority returns the value of the "priority" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Priority() (r int, exists bool) { + v := m.priority + if v == nil { + return + } + return *v, true +} + +// OldPriority returns the old "priority" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldPriority(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPriority is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPriority requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPriority: %w", err) + } + return oldValue.Priority, nil +} + +// AddPriority adds i to the "priority" field. +func (m *ErrorPassthroughRuleMutation) AddPriority(i int) { + if m.addpriority != nil { + *m.addpriority += i + } else { + m.addpriority = &i + } +} + +// AddedPriority returns the value that was added to the "priority" field in this mutation. +func (m *ErrorPassthroughRuleMutation) AddedPriority() (r int, exists bool) { + v := m.addpriority + if v == nil { + return + } + return *v, true +} + +// ResetPriority resets all changes to the "priority" field. +func (m *ErrorPassthroughRuleMutation) ResetPriority() { + m.priority = nil + m.addpriority = nil +} + +// SetErrorCodes sets the "error_codes" field. +func (m *ErrorPassthroughRuleMutation) SetErrorCodes(i []int) { + m.error_codes = &i + m.appenderror_codes = nil +} + +// ErrorCodes returns the value of the "error_codes" field in the mutation. +func (m *ErrorPassthroughRuleMutation) ErrorCodes() (r []int, exists bool) { + v := m.error_codes + if v == nil { + return + } + return *v, true +} + +// OldErrorCodes returns the old "error_codes" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldErrorCodes(ctx context.Context) (v []int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldErrorCodes is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldErrorCodes requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldErrorCodes: %w", err) + } + return oldValue.ErrorCodes, nil +} + +// AppendErrorCodes adds i to the "error_codes" field. +func (m *ErrorPassthroughRuleMutation) AppendErrorCodes(i []int) { + m.appenderror_codes = append(m.appenderror_codes, i...) +} + +// AppendedErrorCodes returns the list of values that were appended to the "error_codes" field in this mutation. +func (m *ErrorPassthroughRuleMutation) AppendedErrorCodes() ([]int, bool) { + if len(m.appenderror_codes) == 0 { + return nil, false + } + return m.appenderror_codes, true +} + +// ClearErrorCodes clears the value of the "error_codes" field. +func (m *ErrorPassthroughRuleMutation) ClearErrorCodes() { + m.error_codes = nil + m.appenderror_codes = nil + m.clearedFields[errorpassthroughrule.FieldErrorCodes] = struct{}{} +} + +// ErrorCodesCleared returns if the "error_codes" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) ErrorCodesCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldErrorCodes] + return ok +} + +// ResetErrorCodes resets all changes to the "error_codes" field. +func (m *ErrorPassthroughRuleMutation) ResetErrorCodes() { + m.error_codes = nil + m.appenderror_codes = nil + delete(m.clearedFields, errorpassthroughrule.FieldErrorCodes) +} + +// SetKeywords sets the "keywords" field. +func (m *ErrorPassthroughRuleMutation) SetKeywords(s []string) { + m.keywords = &s + m.appendkeywords = nil +} + +// Keywords returns the value of the "keywords" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Keywords() (r []string, exists bool) { + v := m.keywords + if v == nil { + return + } + return *v, true +} + +// OldKeywords returns the old "keywords" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldKeywords(ctx context.Context) (v []string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldKeywords is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldKeywords requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldKeywords: %w", err) + } + return oldValue.Keywords, nil +} + +// AppendKeywords adds s to the "keywords" field. +func (m *ErrorPassthroughRuleMutation) AppendKeywords(s []string) { + m.appendkeywords = append(m.appendkeywords, s...) +} + +// AppendedKeywords returns the list of values that were appended to the "keywords" field in this mutation. +func (m *ErrorPassthroughRuleMutation) AppendedKeywords() ([]string, bool) { + if len(m.appendkeywords) == 0 { + return nil, false + } + return m.appendkeywords, true +} + +// ClearKeywords clears the value of the "keywords" field. +func (m *ErrorPassthroughRuleMutation) ClearKeywords() { + m.keywords = nil + m.appendkeywords = nil + m.clearedFields[errorpassthroughrule.FieldKeywords] = struct{}{} +} + +// KeywordsCleared returns if the "keywords" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) KeywordsCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldKeywords] + return ok +} + +// ResetKeywords resets all changes to the "keywords" field. +func (m *ErrorPassthroughRuleMutation) ResetKeywords() { + m.keywords = nil + m.appendkeywords = nil + delete(m.clearedFields, errorpassthroughrule.FieldKeywords) +} + +// SetMatchMode sets the "match_mode" field. +func (m *ErrorPassthroughRuleMutation) SetMatchMode(s string) { + m.match_mode = &s +} + +// MatchMode returns the value of the "match_mode" field in the mutation. +func (m *ErrorPassthroughRuleMutation) MatchMode() (r string, exists bool) { + v := m.match_mode + if v == nil { + return + } + return *v, true +} + +// OldMatchMode returns the old "match_mode" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldMatchMode(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMatchMode is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMatchMode requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMatchMode: %w", err) + } + return oldValue.MatchMode, nil +} + +// ResetMatchMode resets all changes to the "match_mode" field. +func (m *ErrorPassthroughRuleMutation) ResetMatchMode() { + m.match_mode = nil +} + +// SetPlatforms sets the "platforms" field. +func (m *ErrorPassthroughRuleMutation) SetPlatforms(s []string) { + m.platforms = &s + m.appendplatforms = nil +} + +// Platforms returns the value of the "platforms" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Platforms() (r []string, exists bool) { + v := m.platforms + if v == nil { + return + } + return *v, true +} + +// OldPlatforms returns the old "platforms" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldPlatforms(ctx context.Context) (v []string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPlatforms is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPlatforms requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPlatforms: %w", err) + } + return oldValue.Platforms, nil +} + +// AppendPlatforms adds s to the "platforms" field. +func (m *ErrorPassthroughRuleMutation) AppendPlatforms(s []string) { + m.appendplatforms = append(m.appendplatforms, s...) +} + +// AppendedPlatforms returns the list of values that were appended to the "platforms" field in this mutation. +func (m *ErrorPassthroughRuleMutation) AppendedPlatforms() ([]string, bool) { + if len(m.appendplatforms) == 0 { + return nil, false + } + return m.appendplatforms, true +} + +// ClearPlatforms clears the value of the "platforms" field. +func (m *ErrorPassthroughRuleMutation) ClearPlatforms() { + m.platforms = nil + m.appendplatforms = nil + m.clearedFields[errorpassthroughrule.FieldPlatforms] = struct{}{} +} + +// PlatformsCleared returns if the "platforms" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) PlatformsCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldPlatforms] + return ok +} + +// ResetPlatforms resets all changes to the "platforms" field. +func (m *ErrorPassthroughRuleMutation) ResetPlatforms() { + m.platforms = nil + m.appendplatforms = nil + delete(m.clearedFields, errorpassthroughrule.FieldPlatforms) +} + +// SetPassthroughCode sets the "passthrough_code" field. +func (m *ErrorPassthroughRuleMutation) SetPassthroughCode(b bool) { + m.passthrough_code = &b +} + +// PassthroughCode returns the value of the "passthrough_code" field in the mutation. +func (m *ErrorPassthroughRuleMutation) PassthroughCode() (r bool, exists bool) { + v := m.passthrough_code + if v == nil { + return + } + return *v, true +} + +// OldPassthroughCode returns the old "passthrough_code" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldPassthroughCode(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPassthroughCode is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPassthroughCode requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPassthroughCode: %w", err) + } + return oldValue.PassthroughCode, nil +} + +// ResetPassthroughCode resets all changes to the "passthrough_code" field. +func (m *ErrorPassthroughRuleMutation) ResetPassthroughCode() { + m.passthrough_code = nil +} + +// SetResponseCode sets the "response_code" field. +func (m *ErrorPassthroughRuleMutation) SetResponseCode(i int) { + m.response_code = &i + m.addresponse_code = nil +} + +// ResponseCode returns the value of the "response_code" field in the mutation. +func (m *ErrorPassthroughRuleMutation) ResponseCode() (r int, exists bool) { + v := m.response_code + if v == nil { + return + } + return *v, true +} + +// OldResponseCode returns the old "response_code" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldResponseCode(ctx context.Context) (v *int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldResponseCode is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldResponseCode requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldResponseCode: %w", err) + } + return oldValue.ResponseCode, nil +} + +// AddResponseCode adds i to the "response_code" field. +func (m *ErrorPassthroughRuleMutation) AddResponseCode(i int) { + if m.addresponse_code != nil { + *m.addresponse_code += i + } else { + m.addresponse_code = &i + } +} + +// AddedResponseCode returns the value that was added to the "response_code" field in this mutation. +func (m *ErrorPassthroughRuleMutation) AddedResponseCode() (r int, exists bool) { + v := m.addresponse_code + if v == nil { + return + } + return *v, true +} + +// ClearResponseCode clears the value of the "response_code" field. +func (m *ErrorPassthroughRuleMutation) ClearResponseCode() { + m.response_code = nil + m.addresponse_code = nil + m.clearedFields[errorpassthroughrule.FieldResponseCode] = struct{}{} +} + +// ResponseCodeCleared returns if the "response_code" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) ResponseCodeCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldResponseCode] + return ok +} + +// ResetResponseCode resets all changes to the "response_code" field. +func (m *ErrorPassthroughRuleMutation) ResetResponseCode() { + m.response_code = nil + m.addresponse_code = nil + delete(m.clearedFields, errorpassthroughrule.FieldResponseCode) +} + +// SetPassthroughBody sets the "passthrough_body" field. +func (m *ErrorPassthroughRuleMutation) SetPassthroughBody(b bool) { + m.passthrough_body = &b +} + +// PassthroughBody returns the value of the "passthrough_body" field in the mutation. +func (m *ErrorPassthroughRuleMutation) PassthroughBody() (r bool, exists bool) { + v := m.passthrough_body + if v == nil { + return + } + return *v, true +} + +// OldPassthroughBody returns the old "passthrough_body" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldPassthroughBody(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPassthroughBody is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPassthroughBody requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPassthroughBody: %w", err) + } + return oldValue.PassthroughBody, nil +} + +// ResetPassthroughBody resets all changes to the "passthrough_body" field. +func (m *ErrorPassthroughRuleMutation) ResetPassthroughBody() { + m.passthrough_body = nil +} + +// SetCustomMessage sets the "custom_message" field. +func (m *ErrorPassthroughRuleMutation) SetCustomMessage(s string) { + m.custom_message = &s +} + +// CustomMessage returns the value of the "custom_message" field in the mutation. +func (m *ErrorPassthroughRuleMutation) CustomMessage() (r string, exists bool) { + v := m.custom_message + if v == nil { + return + } + return *v, true +} + +// OldCustomMessage returns the old "custom_message" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldCustomMessage(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCustomMessage is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCustomMessage requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCustomMessage: %w", err) + } + return oldValue.CustomMessage, nil +} + +// ClearCustomMessage clears the value of the "custom_message" field. +func (m *ErrorPassthroughRuleMutation) ClearCustomMessage() { + m.custom_message = nil + m.clearedFields[errorpassthroughrule.FieldCustomMessage] = struct{}{} +} + +// CustomMessageCleared returns if the "custom_message" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) CustomMessageCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldCustomMessage] + return ok +} + +// ResetCustomMessage resets all changes to the "custom_message" field. +func (m *ErrorPassthroughRuleMutation) ResetCustomMessage() { + m.custom_message = nil + delete(m.clearedFields, errorpassthroughrule.FieldCustomMessage) +} + +// SetDescription sets the "description" field. +func (m *ErrorPassthroughRuleMutation) SetDescription(s string) { + m.description = &s +} + +// Description returns the value of the "description" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Description() (r string, exists bool) { + v := m.description + if v == nil { + return + } + return *v, true +} + +// OldDescription returns the old "description" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldDescription(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDescription is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDescription requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDescription: %w", err) + } + return oldValue.Description, nil +} + +// ClearDescription clears the value of the "description" field. +func (m *ErrorPassthroughRuleMutation) ClearDescription() { + m.description = nil + m.clearedFields[errorpassthroughrule.FieldDescription] = struct{}{} +} + +// DescriptionCleared returns if the "description" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) DescriptionCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldDescription] + return ok +} + +// ResetDescription resets all changes to the "description" field. +func (m *ErrorPassthroughRuleMutation) ResetDescription() { + m.description = nil + delete(m.clearedFields, errorpassthroughrule.FieldDescription) +} + +// Where appends a list predicates to the ErrorPassthroughRuleMutation builder. +func (m *ErrorPassthroughRuleMutation) Where(ps ...predicate.ErrorPassthroughRule) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the ErrorPassthroughRuleMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *ErrorPassthroughRuleMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.ErrorPassthroughRule, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *ErrorPassthroughRuleMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *ErrorPassthroughRuleMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (ErrorPassthroughRule). +func (m *ErrorPassthroughRuleMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *ErrorPassthroughRuleMutation) Fields() []string { + fields := make([]string, 0, 14) + if m.created_at != nil { + fields = append(fields, errorpassthroughrule.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, errorpassthroughrule.FieldUpdatedAt) + } + if m.name != nil { + fields = append(fields, errorpassthroughrule.FieldName) + } + if m.enabled != nil { + fields = append(fields, errorpassthroughrule.FieldEnabled) + } + if m.priority != nil { + fields = append(fields, errorpassthroughrule.FieldPriority) + } + if m.error_codes != nil { + fields = append(fields, errorpassthroughrule.FieldErrorCodes) + } + if m.keywords != nil { + fields = append(fields, errorpassthroughrule.FieldKeywords) + } + if m.match_mode != nil { + fields = append(fields, errorpassthroughrule.FieldMatchMode) + } + if m.platforms != nil { + fields = append(fields, errorpassthroughrule.FieldPlatforms) + } + if m.passthrough_code != nil { + fields = append(fields, errorpassthroughrule.FieldPassthroughCode) + } + if m.response_code != nil { + fields = append(fields, errorpassthroughrule.FieldResponseCode) + } + if m.passthrough_body != nil { + fields = append(fields, errorpassthroughrule.FieldPassthroughBody) + } + if m.custom_message != nil { + fields = append(fields, errorpassthroughrule.FieldCustomMessage) + } + if m.description != nil { + fields = append(fields, errorpassthroughrule.FieldDescription) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *ErrorPassthroughRuleMutation) Field(name string) (ent.Value, bool) { + switch name { + case errorpassthroughrule.FieldCreatedAt: + return m.CreatedAt() + case errorpassthroughrule.FieldUpdatedAt: + return m.UpdatedAt() + case errorpassthroughrule.FieldName: + return m.Name() + case errorpassthroughrule.FieldEnabled: + return m.Enabled() + case errorpassthroughrule.FieldPriority: + return m.Priority() + case errorpassthroughrule.FieldErrorCodes: + return m.ErrorCodes() + case errorpassthroughrule.FieldKeywords: + return m.Keywords() + case errorpassthroughrule.FieldMatchMode: + return m.MatchMode() + case errorpassthroughrule.FieldPlatforms: + return m.Platforms() + case errorpassthroughrule.FieldPassthroughCode: + return m.PassthroughCode() + case errorpassthroughrule.FieldResponseCode: + return m.ResponseCode() + case errorpassthroughrule.FieldPassthroughBody: + return m.PassthroughBody() + case errorpassthroughrule.FieldCustomMessage: + return m.CustomMessage() + case errorpassthroughrule.FieldDescription: + return m.Description() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *ErrorPassthroughRuleMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case errorpassthroughrule.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case errorpassthroughrule.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case errorpassthroughrule.FieldName: + return m.OldName(ctx) + case errorpassthroughrule.FieldEnabled: + return m.OldEnabled(ctx) + case errorpassthroughrule.FieldPriority: + return m.OldPriority(ctx) + case errorpassthroughrule.FieldErrorCodes: + return m.OldErrorCodes(ctx) + case errorpassthroughrule.FieldKeywords: + return m.OldKeywords(ctx) + case errorpassthroughrule.FieldMatchMode: + return m.OldMatchMode(ctx) + case errorpassthroughrule.FieldPlatforms: + return m.OldPlatforms(ctx) + case errorpassthroughrule.FieldPassthroughCode: + return m.OldPassthroughCode(ctx) + case errorpassthroughrule.FieldResponseCode: + return m.OldResponseCode(ctx) + case errorpassthroughrule.FieldPassthroughBody: + return m.OldPassthroughBody(ctx) + case errorpassthroughrule.FieldCustomMessage: + return m.OldCustomMessage(ctx) + case errorpassthroughrule.FieldDescription: + return m.OldDescription(ctx) + } + return nil, fmt.Errorf("unknown ErrorPassthroughRule field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *ErrorPassthroughRuleMutation) SetField(name string, value ent.Value) error { + switch name { + case errorpassthroughrule.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case errorpassthroughrule.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case errorpassthroughrule.FieldName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetName(v) + return nil + case errorpassthroughrule.FieldEnabled: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetEnabled(v) + return nil + case errorpassthroughrule.FieldPriority: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPriority(v) + return nil + case errorpassthroughrule.FieldErrorCodes: + v, ok := value.([]int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetErrorCodes(v) + return nil + case errorpassthroughrule.FieldKeywords: + v, ok := value.([]string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetKeywords(v) + return nil + case errorpassthroughrule.FieldMatchMode: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMatchMode(v) + return nil + case errorpassthroughrule.FieldPlatforms: + v, ok := value.([]string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPlatforms(v) + return nil + case errorpassthroughrule.FieldPassthroughCode: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPassthroughCode(v) + return nil + case errorpassthroughrule.FieldResponseCode: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetResponseCode(v) + return nil + case errorpassthroughrule.FieldPassthroughBody: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPassthroughBody(v) + return nil + case errorpassthroughrule.FieldCustomMessage: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCustomMessage(v) + return nil + case errorpassthroughrule.FieldDescription: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDescription(v) + return nil + } + return fmt.Errorf("unknown ErrorPassthroughRule field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *ErrorPassthroughRuleMutation) AddedFields() []string { + var fields []string + if m.addpriority != nil { + fields = append(fields, errorpassthroughrule.FieldPriority) + } + if m.addresponse_code != nil { + fields = append(fields, errorpassthroughrule.FieldResponseCode) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *ErrorPassthroughRuleMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case errorpassthroughrule.FieldPriority: + return m.AddedPriority() + case errorpassthroughrule.FieldResponseCode: + return m.AddedResponseCode() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *ErrorPassthroughRuleMutation) AddField(name string, value ent.Value) error { + switch name { + case errorpassthroughrule.FieldPriority: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddPriority(v) + return nil + case errorpassthroughrule.FieldResponseCode: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddResponseCode(v) + return nil + } + return fmt.Errorf("unknown ErrorPassthroughRule numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *ErrorPassthroughRuleMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(errorpassthroughrule.FieldErrorCodes) { + fields = append(fields, errorpassthroughrule.FieldErrorCodes) + } + if m.FieldCleared(errorpassthroughrule.FieldKeywords) { + fields = append(fields, errorpassthroughrule.FieldKeywords) + } + if m.FieldCleared(errorpassthroughrule.FieldPlatforms) { + fields = append(fields, errorpassthroughrule.FieldPlatforms) + } + if m.FieldCleared(errorpassthroughrule.FieldResponseCode) { + fields = append(fields, errorpassthroughrule.FieldResponseCode) + } + if m.FieldCleared(errorpassthroughrule.FieldCustomMessage) { + fields = append(fields, errorpassthroughrule.FieldCustomMessage) + } + if m.FieldCleared(errorpassthroughrule.FieldDescription) { + fields = append(fields, errorpassthroughrule.FieldDescription) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *ErrorPassthroughRuleMutation) ClearField(name string) error { + switch name { + case errorpassthroughrule.FieldErrorCodes: + m.ClearErrorCodes() + return nil + case errorpassthroughrule.FieldKeywords: + m.ClearKeywords() + return nil + case errorpassthroughrule.FieldPlatforms: + m.ClearPlatforms() + return nil + case errorpassthroughrule.FieldResponseCode: + m.ClearResponseCode() + return nil + case errorpassthroughrule.FieldCustomMessage: + m.ClearCustomMessage() + return nil + case errorpassthroughrule.FieldDescription: + m.ClearDescription() + return nil + } + return fmt.Errorf("unknown ErrorPassthroughRule nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *ErrorPassthroughRuleMutation) ResetField(name string) error { + switch name { + case errorpassthroughrule.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case errorpassthroughrule.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case errorpassthroughrule.FieldName: + m.ResetName() + return nil + case errorpassthroughrule.FieldEnabled: + m.ResetEnabled() + return nil + case errorpassthroughrule.FieldPriority: + m.ResetPriority() + return nil + case errorpassthroughrule.FieldErrorCodes: + m.ResetErrorCodes() + return nil + case errorpassthroughrule.FieldKeywords: + m.ResetKeywords() + return nil + case errorpassthroughrule.FieldMatchMode: + m.ResetMatchMode() + return nil + case errorpassthroughrule.FieldPlatforms: + m.ResetPlatforms() + return nil + case errorpassthroughrule.FieldPassthroughCode: + m.ResetPassthroughCode() + return nil + case errorpassthroughrule.FieldResponseCode: + m.ResetResponseCode() + return nil + case errorpassthroughrule.FieldPassthroughBody: + m.ResetPassthroughBody() + return nil + case errorpassthroughrule.FieldCustomMessage: + m.ResetCustomMessage() + return nil + case errorpassthroughrule.FieldDescription: + m.ResetDescription() + return nil + } + return fmt.Errorf("unknown ErrorPassthroughRule field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *ErrorPassthroughRuleMutation) AddedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *ErrorPassthroughRuleMutation) AddedIDs(name string) []ent.Value { + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *ErrorPassthroughRuleMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *ErrorPassthroughRuleMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) EdgeCleared(name string) bool { + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *ErrorPassthroughRuleMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown ErrorPassthroughRule unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *ErrorPassthroughRuleMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown ErrorPassthroughRule edge %s", name) +} + // GroupMutation represents an operation that mutates the Group nodes in the graph. type GroupMutation struct { config diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go index 613c5913..c12955ef 100644 --- a/backend/ent/predicate/predicate.go +++ b/backend/ent/predicate/predicate.go @@ -21,6 +21,9 @@ type Announcement func(*sql.Selector) // AnnouncementRead is the predicate function for announcementread builders. type AnnouncementRead func(*sql.Selector) +// ErrorPassthroughRule is the predicate function for errorpassthroughrule builders. +type ErrorPassthroughRule func(*sql.Selector) + // Group is the predicate function for group builders. type Group func(*sql.Selector) diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index f1fea8cc..4b3c1a4f 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -10,6 +10,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" @@ -270,6 +271,61 @@ func init() { announcementreadDescCreatedAt := announcementreadFields[3].Descriptor() // announcementread.DefaultCreatedAt holds the default value on creation for the created_at field. announcementread.DefaultCreatedAt = announcementreadDescCreatedAt.Default.(func() time.Time) + errorpassthroughruleMixin := schema.ErrorPassthroughRule{}.Mixin() + errorpassthroughruleMixinFields0 := errorpassthroughruleMixin[0].Fields() + _ = errorpassthroughruleMixinFields0 + errorpassthroughruleFields := schema.ErrorPassthroughRule{}.Fields() + _ = errorpassthroughruleFields + // errorpassthroughruleDescCreatedAt is the schema descriptor for created_at field. + errorpassthroughruleDescCreatedAt := errorpassthroughruleMixinFields0[0].Descriptor() + // errorpassthroughrule.DefaultCreatedAt holds the default value on creation for the created_at field. + errorpassthroughrule.DefaultCreatedAt = errorpassthroughruleDescCreatedAt.Default.(func() time.Time) + // errorpassthroughruleDescUpdatedAt is the schema descriptor for updated_at field. + errorpassthroughruleDescUpdatedAt := errorpassthroughruleMixinFields0[1].Descriptor() + // errorpassthroughrule.DefaultUpdatedAt holds the default value on creation for the updated_at field. + errorpassthroughrule.DefaultUpdatedAt = errorpassthroughruleDescUpdatedAt.Default.(func() time.Time) + // errorpassthroughrule.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + errorpassthroughrule.UpdateDefaultUpdatedAt = errorpassthroughruleDescUpdatedAt.UpdateDefault.(func() time.Time) + // errorpassthroughruleDescName is the schema descriptor for name field. + errorpassthroughruleDescName := errorpassthroughruleFields[0].Descriptor() + // errorpassthroughrule.NameValidator is a validator for the "name" field. It is called by the builders before save. + errorpassthroughrule.NameValidator = func() func(string) error { + validators := errorpassthroughruleDescName.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(name string) error { + for _, fn := range fns { + if err := fn(name); err != nil { + return err + } + } + return nil + } + }() + // errorpassthroughruleDescEnabled is the schema descriptor for enabled field. + errorpassthroughruleDescEnabled := errorpassthroughruleFields[1].Descriptor() + // errorpassthroughrule.DefaultEnabled holds the default value on creation for the enabled field. + errorpassthroughrule.DefaultEnabled = errorpassthroughruleDescEnabled.Default.(bool) + // errorpassthroughruleDescPriority is the schema descriptor for priority field. + errorpassthroughruleDescPriority := errorpassthroughruleFields[2].Descriptor() + // errorpassthroughrule.DefaultPriority holds the default value on creation for the priority field. + errorpassthroughrule.DefaultPriority = errorpassthroughruleDescPriority.Default.(int) + // errorpassthroughruleDescMatchMode is the schema descriptor for match_mode field. + errorpassthroughruleDescMatchMode := errorpassthroughruleFields[5].Descriptor() + // errorpassthroughrule.DefaultMatchMode holds the default value on creation for the match_mode field. + errorpassthroughrule.DefaultMatchMode = errorpassthroughruleDescMatchMode.Default.(string) + // errorpassthroughrule.MatchModeValidator is a validator for the "match_mode" field. It is called by the builders before save. + errorpassthroughrule.MatchModeValidator = errorpassthroughruleDescMatchMode.Validators[0].(func(string) error) + // errorpassthroughruleDescPassthroughCode is the schema descriptor for passthrough_code field. + errorpassthroughruleDescPassthroughCode := errorpassthroughruleFields[7].Descriptor() + // errorpassthroughrule.DefaultPassthroughCode holds the default value on creation for the passthrough_code field. + errorpassthroughrule.DefaultPassthroughCode = errorpassthroughruleDescPassthroughCode.Default.(bool) + // errorpassthroughruleDescPassthroughBody is the schema descriptor for passthrough_body field. + errorpassthroughruleDescPassthroughBody := errorpassthroughruleFields[9].Descriptor() + // errorpassthroughrule.DefaultPassthroughBody holds the default value on creation for the passthrough_body field. + errorpassthroughrule.DefaultPassthroughBody = errorpassthroughruleDescPassthroughBody.Default.(bool) groupMixin := schema.Group{}.Mixin() groupMixinHooks1 := groupMixin[1].Hooks() group.Hooks[0] = groupMixinHooks1[0] diff --git a/backend/ent/schema/error_passthrough_rule.go b/backend/ent/schema/error_passthrough_rule.go new file mode 100644 index 00000000..4a861f38 --- /dev/null +++ b/backend/ent/schema/error_passthrough_rule.go @@ -0,0 +1,121 @@ +// Package schema 定义 Ent ORM 的数据库 schema。 +package schema + +import ( + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +// ErrorPassthroughRule 定义全局错误透传规则的 schema。 +// +// 错误透传规则用于控制上游错误如何返回给客户端: +// - 匹配条件:错误码 + 关键词组合 +// - 响应行为:透传原始信息 或 自定义错误信息 +// - 响应状态码:可指定返回给客户端的状态码 +// - 平台范围:规则适用的平台(Anthropic、OpenAI、Gemini、Antigravity) +type ErrorPassthroughRule struct { + ent.Schema +} + +// Annotations 返回 schema 的注解配置。 +func (ErrorPassthroughRule) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "error_passthrough_rules"}, + } +} + +// Mixin 返回该 schema 使用的混入组件。 +func (ErrorPassthroughRule) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +// Fields 定义错误透传规则实体的所有字段。 +func (ErrorPassthroughRule) Fields() []ent.Field { + return []ent.Field{ + // name: 规则名称,用于在界面中标识规则 + field.String("name"). + MaxLen(100). + NotEmpty(), + + // enabled: 是否启用该规则 + field.Bool("enabled"). + Default(true), + + // priority: 规则优先级,数值越小优先级越高 + // 匹配时按优先级顺序检查,命中第一个匹配的规则 + field.Int("priority"). + Default(0), + + // error_codes: 匹配的错误码列表(OR关系) + // 例如:[422, 400] 表示匹配 422 或 400 错误码 + field.JSON("error_codes", []int{}). + Optional(). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + + // keywords: 匹配的关键词列表(OR关系) + // 例如:["context limit", "model not supported"] + // 关键词匹配不区分大小写 + field.JSON("keywords", []string{}). + Optional(). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + + // match_mode: 匹配模式 + // - "any": 错误码匹配 OR 关键词匹配(任一条件满足即可) + // - "all": 错误码匹配 AND 关键词匹配(所有条件都必须满足) + field.String("match_mode"). + MaxLen(10). + Default("any"), + + // platforms: 适用平台列表 + // 例如:["anthropic", "openai", "gemini", "antigravity"] + // 空列表表示适用于所有平台 + field.JSON("platforms", []string{}). + Optional(). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + + // passthrough_code: 是否透传上游原始状态码 + // true: 使用上游返回的状态码 + // false: 使用 response_code 指定的状态码 + field.Bool("passthrough_code"). + Default(true), + + // response_code: 自定义响应状态码 + // 当 passthrough_code=false 时使用此状态码 + field.Int("response_code"). + Optional(). + Nillable(), + + // passthrough_body: 是否透传上游原始错误信息 + // true: 使用上游返回的错误信息 + // false: 使用 custom_message 指定的错误信息 + field.Bool("passthrough_body"). + Default(true), + + // custom_message: 自定义错误信息 + // 当 passthrough_body=false 时使用此错误信息 + field.Text("custom_message"). + Optional(). + Nillable(), + + // description: 规则描述,用于说明规则的用途 + field.Text("description"). + Optional(). + Nillable(), + } +} + +// Indexes 定义数据库索引,优化查询性能。 +func (ErrorPassthroughRule) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("enabled"), // 筛选启用的规则 + index.Fields("priority"), // 按优先级排序 + } +} diff --git a/backend/ent/tx.go b/backend/ent/tx.go index 702bdf90..45d83428 100644 --- a/backend/ent/tx.go +++ b/backend/ent/tx.go @@ -24,6 +24,8 @@ type Tx struct { Announcement *AnnouncementClient // AnnouncementRead is the client for interacting with the AnnouncementRead builders. AnnouncementRead *AnnouncementReadClient + // ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders. + ErrorPassthroughRule *ErrorPassthroughRuleClient // Group is the client for interacting with the Group builders. Group *GroupClient // PromoCode is the client for interacting with the PromoCode builders. @@ -186,6 +188,7 @@ func (tx *Tx) init() { tx.AccountGroup = NewAccountGroupClient(tx.config) tx.Announcement = NewAnnouncementClient(tx.config) tx.AnnouncementRead = NewAnnouncementReadClient(tx.config) + tx.ErrorPassthroughRule = NewErrorPassthroughRuleClient(tx.config) tx.Group = NewGroupClient(tx.config) tx.PromoCode = NewPromoCodeClient(tx.config) tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config) diff --git a/backend/internal/handler/admin/error_passthrough_handler.go b/backend/internal/handler/admin/error_passthrough_handler.go new file mode 100644 index 00000000..c32db561 --- /dev/null +++ b/backend/internal/handler/admin/error_passthrough_handler.go @@ -0,0 +1,273 @@ +package admin + +import ( + "strconv" + + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +// ErrorPassthroughHandler 处理错误透传规则的 HTTP 请求 +type ErrorPassthroughHandler struct { + service *service.ErrorPassthroughService +} + +// NewErrorPassthroughHandler 创建错误透传规则处理器 +func NewErrorPassthroughHandler(service *service.ErrorPassthroughService) *ErrorPassthroughHandler { + return &ErrorPassthroughHandler{service: service} +} + +// CreateErrorPassthroughRuleRequest 创建规则请求 +type CreateErrorPassthroughRuleRequest struct { + Name string `json:"name" binding:"required"` + Enabled *bool `json:"enabled"` + Priority int `json:"priority"` + ErrorCodes []int `json:"error_codes"` + Keywords []string `json:"keywords"` + MatchMode string `json:"match_mode"` + Platforms []string `json:"platforms"` + PassthroughCode *bool `json:"passthrough_code"` + ResponseCode *int `json:"response_code"` + PassthroughBody *bool `json:"passthrough_body"` + CustomMessage *string `json:"custom_message"` + Description *string `json:"description"` +} + +// UpdateErrorPassthroughRuleRequest 更新规则请求(部分更新,所有字段可选) +type UpdateErrorPassthroughRuleRequest struct { + Name *string `json:"name"` + Enabled *bool `json:"enabled"` + Priority *int `json:"priority"` + ErrorCodes []int `json:"error_codes"` + Keywords []string `json:"keywords"` + MatchMode *string `json:"match_mode"` + Platforms []string `json:"platforms"` + PassthroughCode *bool `json:"passthrough_code"` + ResponseCode *int `json:"response_code"` + PassthroughBody *bool `json:"passthrough_body"` + CustomMessage *string `json:"custom_message"` + Description *string `json:"description"` +} + +// List 获取所有规则 +// GET /api/v1/admin/error-passthrough-rules +func (h *ErrorPassthroughHandler) List(c *gin.Context) { + rules, err := h.service.List(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, rules) +} + +// GetByID 根据 ID 获取规则 +// GET /api/v1/admin/error-passthrough-rules/:id +func (h *ErrorPassthroughHandler) GetByID(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid rule ID") + return + } + + rule, err := h.service.GetByID(c.Request.Context(), id) + if err != nil { + response.ErrorFrom(c, err) + return + } + if rule == nil { + response.NotFound(c, "Rule not found") + return + } + + response.Success(c, rule) +} + +// Create 创建规则 +// POST /api/v1/admin/error-passthrough-rules +func (h *ErrorPassthroughHandler) Create(c *gin.Context) { + var req CreateErrorPassthroughRuleRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + rule := &model.ErrorPassthroughRule{ + Name: req.Name, + Priority: req.Priority, + ErrorCodes: req.ErrorCodes, + Keywords: req.Keywords, + Platforms: req.Platforms, + } + + // 设置默认值 + if req.Enabled != nil { + rule.Enabled = *req.Enabled + } else { + rule.Enabled = true + } + if req.MatchMode != "" { + rule.MatchMode = req.MatchMode + } else { + rule.MatchMode = model.MatchModeAny + } + if req.PassthroughCode != nil { + rule.PassthroughCode = *req.PassthroughCode + } else { + rule.PassthroughCode = true + } + if req.PassthroughBody != nil { + rule.PassthroughBody = *req.PassthroughBody + } else { + rule.PassthroughBody = true + } + rule.ResponseCode = req.ResponseCode + rule.CustomMessage = req.CustomMessage + rule.Description = req.Description + + // 确保切片不为 nil + if rule.ErrorCodes == nil { + rule.ErrorCodes = []int{} + } + if rule.Keywords == nil { + rule.Keywords = []string{} + } + if rule.Platforms == nil { + rule.Platforms = []string{} + } + + created, err := h.service.Create(c.Request.Context(), rule) + if err != nil { + if _, ok := err.(*model.ValidationError); ok { + response.BadRequest(c, err.Error()) + return + } + response.ErrorFrom(c, err) + return + } + + response.Success(c, created) +} + +// Update 更新规则(支持部分更新) +// PUT /api/v1/admin/error-passthrough-rules/:id +func (h *ErrorPassthroughHandler) Update(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid rule ID") + return + } + + var req UpdateErrorPassthroughRuleRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + // 先获取现有规则 + existing, err := h.service.GetByID(c.Request.Context(), id) + if err != nil { + response.ErrorFrom(c, err) + return + } + if existing == nil { + response.NotFound(c, "Rule not found") + return + } + + // 部分更新:只更新请求中提供的字段 + rule := &model.ErrorPassthroughRule{ + ID: id, + Name: existing.Name, + Enabled: existing.Enabled, + Priority: existing.Priority, + ErrorCodes: existing.ErrorCodes, + Keywords: existing.Keywords, + MatchMode: existing.MatchMode, + Platforms: existing.Platforms, + PassthroughCode: existing.PassthroughCode, + ResponseCode: existing.ResponseCode, + PassthroughBody: existing.PassthroughBody, + CustomMessage: existing.CustomMessage, + Description: existing.Description, + } + + // 应用请求中提供的更新 + if req.Name != nil { + rule.Name = *req.Name + } + if req.Enabled != nil { + rule.Enabled = *req.Enabled + } + if req.Priority != nil { + rule.Priority = *req.Priority + } + if req.ErrorCodes != nil { + rule.ErrorCodes = req.ErrorCodes + } + if req.Keywords != nil { + rule.Keywords = req.Keywords + } + if req.MatchMode != nil { + rule.MatchMode = *req.MatchMode + } + if req.Platforms != nil { + rule.Platforms = req.Platforms + } + if req.PassthroughCode != nil { + rule.PassthroughCode = *req.PassthroughCode + } + if req.ResponseCode != nil { + rule.ResponseCode = req.ResponseCode + } + if req.PassthroughBody != nil { + rule.PassthroughBody = *req.PassthroughBody + } + if req.CustomMessage != nil { + rule.CustomMessage = req.CustomMessage + } + if req.Description != nil { + rule.Description = req.Description + } + + // 确保切片不为 nil + if rule.ErrorCodes == nil { + rule.ErrorCodes = []int{} + } + if rule.Keywords == nil { + rule.Keywords = []string{} + } + if rule.Platforms == nil { + rule.Platforms = []string{} + } + + updated, err := h.service.Update(c.Request.Context(), rule) + if err != nil { + if _, ok := err.(*model.ValidationError); ok { + response.BadRequest(c, err.Error()) + return + } + response.ErrorFrom(c, err) + return + } + + response.Success(c, updated) +} + +// Delete 删除规则 +// DELETE /api/v1/admin/error-passthrough-rules/:id +func (h *ErrorPassthroughHandler) Delete(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid rule ID") + return + } + + if err := h.service.Delete(c.Request.Context(), id); err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"message": "Rule deleted successfully"}) +} diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 9aa6b72c..beaddbca 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -33,6 +33,7 @@ type GatewayHandler struct { billingCacheService *service.BillingCacheService usageService *service.UsageService apiKeyService *service.APIKeyService + errorPassthroughService *service.ErrorPassthroughService concurrencyHelper *ConcurrencyHelper maxAccountSwitches int maxAccountSwitchesGemini int @@ -48,6 +49,7 @@ func NewGatewayHandler( billingCacheService *service.BillingCacheService, usageService *service.UsageService, apiKeyService *service.APIKeyService, + errorPassthroughService *service.ErrorPassthroughService, cfg *config.Config, ) *GatewayHandler { pingInterval := time.Duration(0) @@ -70,6 +72,7 @@ func NewGatewayHandler( billingCacheService: billingCacheService, usageService: usageService, apiKeyService: apiKeyService, + errorPassthroughService: errorPassthroughService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), maxAccountSwitches: maxAccountSwitches, maxAccountSwitchesGemini: maxAccountSwitchesGemini, @@ -201,7 +204,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { maxAccountSwitches := h.maxAccountSwitchesGemini switchCount := 0 failedAccountIDs := make(map[int64]struct{}) - lastFailoverStatus := 0 + var lastFailoverErr *service.UpstreamFailoverError for { selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制 @@ -210,7 +213,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) return } - h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + if lastFailoverErr != nil { + h.handleFailoverExhausted(c, lastFailoverErr, service.PlatformGemini, streamStarted) + } else { + h.handleFailoverExhaustedSimple(c, 502, streamStarted) + } return } account := selection.Account @@ -301,9 +308,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { failedAccountIDs[account.ID] = struct{}{} - lastFailoverStatus = failoverErr.StatusCode + lastFailoverErr = failoverErr if switchCount >= maxAccountSwitches { - h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted) return } switchCount++ @@ -352,7 +359,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { maxAccountSwitches := h.maxAccountSwitches switchCount := 0 failedAccountIDs := make(map[int64]struct{}) - lastFailoverStatus := 0 + var lastFailoverErr *service.UpstreamFailoverError retryWithFallback := false for { @@ -363,7 +370,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) return } - h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + if lastFailoverErr != nil { + h.handleFailoverExhausted(c, lastFailoverErr, platform, streamStarted) + } else { + h.handleFailoverExhaustedSimple(c, 502, streamStarted) + } return } account := selection.Account @@ -487,9 +498,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { failedAccountIDs[account.ID] = struct{}{} - lastFailoverStatus = failoverErr.StatusCode + lastFailoverErr = failoverErr if switchCount >= maxAccountSwitches { - h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted) return } switchCount++ @@ -755,7 +766,37 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) } -func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) { +func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) { + statusCode := failoverErr.StatusCode + responseBody := failoverErr.ResponseBody + + // 先检查透传规则 + if h.errorPassthroughService != nil && len(responseBody) > 0 { + if rule := h.errorPassthroughService.MatchRule(platform, statusCode, responseBody); rule != nil { + // 确定响应状态码 + respCode := statusCode + if !rule.PassthroughCode && rule.ResponseCode != nil { + respCode = *rule.ResponseCode + } + + // 确定响应消息 + msg := service.ExtractUpstreamErrorMessage(responseBody) + if !rule.PassthroughBody && rule.CustomMessage != nil { + msg = *rule.CustomMessage + } + + h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted) + return + } + } + + // 使用默认的错误映射 + status, errType, errMsg := h.mapUpstreamError(statusCode) + h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) +} + +// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况 +func (h *GatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) { status, errType, errMsg := h.mapUpstreamError(statusCode) h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) } diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 787e3760..be634c0c 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -253,7 +253,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { maxAccountSwitches := h.maxAccountSwitchesGemini switchCount := 0 failedAccountIDs := make(map[int64]struct{}) - lastFailoverStatus := 0 + var lastFailoverErr *service.UpstreamFailoverError for { selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制 @@ -262,7 +262,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) return } - handleGeminiFailoverExhausted(c, lastFailoverStatus) + h.handleGeminiFailoverExhausted(c, lastFailoverErr) return } account := selection.Account @@ -353,11 +353,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { if errors.As(err, &failoverErr) { failedAccountIDs[account.ID] = struct{}{} if switchCount >= maxAccountSwitches { - lastFailoverStatus = failoverErr.StatusCode - handleGeminiFailoverExhausted(c, lastFailoverStatus) + lastFailoverErr = failoverErr + h.handleGeminiFailoverExhausted(c, lastFailoverErr) return } - lastFailoverStatus = failoverErr.StatusCode + lastFailoverErr = failoverErr switchCount++ log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) continue @@ -414,7 +414,36 @@ func parseGeminiModelAction(rest string) (model string, action string, err error return "", "", &pathParseError{"invalid model action path"} } -func handleGeminiFailoverExhausted(c *gin.Context, statusCode int) { +func (h *GatewayHandler) handleGeminiFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError) { + if failoverErr == nil { + googleError(c, http.StatusBadGateway, "Upstream request failed") + return + } + + statusCode := failoverErr.StatusCode + responseBody := failoverErr.ResponseBody + + // 先检查透传规则 + if h.errorPassthroughService != nil && len(responseBody) > 0 { + if rule := h.errorPassthroughService.MatchRule(service.PlatformGemini, statusCode, responseBody); rule != nil { + // 确定响应状态码 + respCode := statusCode + if !rule.PassthroughCode && rule.ResponseCode != nil { + respCode = *rule.ResponseCode + } + + // 确定响应消息 + msg := service.ExtractUpstreamErrorMessage(responseBody) + if !rule.PassthroughBody && rule.CustomMessage != nil { + msg = *rule.CustomMessage + } + + googleError(c, respCode, msg) + return + } + } + + // 使用默认的错误映射 status, message := mapGeminiUpstreamError(statusCode) googleError(c, status, message) } diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index b8f7d417..b2b12c0d 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -24,6 +24,7 @@ type AdminHandlers struct { Subscription *admin.SubscriptionHandler Usage *admin.UsageHandler UserAttribute *admin.UserAttributeHandler + ErrorPassthrough *admin.ErrorPassthroughHandler } // Handlers contains all HTTP handlers diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index a84679ae..1dcb163b 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -22,11 +22,12 @@ import ( // OpenAIGatewayHandler handles OpenAI API gateway requests type OpenAIGatewayHandler struct { - gatewayService *service.OpenAIGatewayService - billingCacheService *service.BillingCacheService - apiKeyService *service.APIKeyService - concurrencyHelper *ConcurrencyHelper - maxAccountSwitches int + gatewayService *service.OpenAIGatewayService + billingCacheService *service.BillingCacheService + apiKeyService *service.APIKeyService + errorPassthroughService *service.ErrorPassthroughService + concurrencyHelper *ConcurrencyHelper + maxAccountSwitches int } // NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler @@ -35,6 +36,7 @@ func NewOpenAIGatewayHandler( concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService, apiKeyService *service.APIKeyService, + errorPassthroughService *service.ErrorPassthroughService, cfg *config.Config, ) *OpenAIGatewayHandler { pingInterval := time.Duration(0) @@ -46,11 +48,12 @@ func NewOpenAIGatewayHandler( } } return &OpenAIGatewayHandler{ - gatewayService: gatewayService, - billingCacheService: billingCacheService, - apiKeyService: apiKeyService, - concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), - maxAccountSwitches: maxAccountSwitches, + gatewayService: gatewayService, + billingCacheService: billingCacheService, + apiKeyService: apiKeyService, + errorPassthroughService: errorPassthroughService, + concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), + maxAccountSwitches: maxAccountSwitches, } } @@ -201,7 +204,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { maxAccountSwitches := h.maxAccountSwitches switchCount := 0 failedAccountIDs := make(map[int64]struct{}) - lastFailoverStatus := 0 + var lastFailoverErr *service.UpstreamFailoverError for { // Select account supporting the requested model @@ -213,7 +216,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) return } - h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + if lastFailoverErr != nil { + h.handleFailoverExhausted(c, lastFailoverErr, streamStarted) + } else { + h.handleFailoverExhaustedSimple(c, 502, streamStarted) + } return } account := selection.Account @@ -278,12 +285,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { failedAccountIDs[account.ID] = struct{}{} + lastFailoverErr = failoverErr if switchCount >= maxAccountSwitches { - lastFailoverStatus = failoverErr.StatusCode - h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + h.handleFailoverExhausted(c, failoverErr, streamStarted) return } - lastFailoverStatus = failoverErr.StatusCode switchCount++ log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) continue @@ -324,7 +330,37 @@ func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) } -func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) { +func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, streamStarted bool) { + statusCode := failoverErr.StatusCode + responseBody := failoverErr.ResponseBody + + // 先检查透传规则 + if h.errorPassthroughService != nil && len(responseBody) > 0 { + if rule := h.errorPassthroughService.MatchRule("openai", statusCode, responseBody); rule != nil { + // 确定响应状态码 + respCode := statusCode + if !rule.PassthroughCode && rule.ResponseCode != nil { + respCode = *rule.ResponseCode + } + + // 确定响应消息 + msg := service.ExtractUpstreamErrorMessage(responseBody) + if !rule.PassthroughBody && rule.CustomMessage != nil { + msg = *rule.CustomMessage + } + + h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted) + return + } + } + + // 使用默认的错误映射 + status, errType, errMsg := h.mapUpstreamError(statusCode) + h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) +} + +// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况 +func (h *OpenAIGatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) { status, errType, errMsg := h.mapUpstreamError(statusCode) h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) } diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index 48a3794b..7b62149c 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -27,6 +27,7 @@ func ProvideAdminHandlers( subscriptionHandler *admin.SubscriptionHandler, usageHandler *admin.UsageHandler, userAttributeHandler *admin.UserAttributeHandler, + errorPassthroughHandler *admin.ErrorPassthroughHandler, ) *AdminHandlers { return &AdminHandlers{ Dashboard: dashboardHandler, @@ -47,6 +48,7 @@ func ProvideAdminHandlers( Subscription: subscriptionHandler, Usage: usageHandler, UserAttribute: userAttributeHandler, + ErrorPassthrough: errorPassthroughHandler, } } @@ -125,6 +127,7 @@ var ProviderSet = wire.NewSet( admin.NewSubscriptionHandler, admin.NewUsageHandler, admin.NewUserAttributeHandler, + admin.NewErrorPassthroughHandler, // AdminHandlers and Handlers constructors ProvideAdminHandlers, diff --git a/backend/internal/model/error_passthrough_rule.go b/backend/internal/model/error_passthrough_rule.go new file mode 100644 index 00000000..d4fc16e3 --- /dev/null +++ b/backend/internal/model/error_passthrough_rule.go @@ -0,0 +1,74 @@ +// Package model 定义服务层使用的数据模型。 +package model + +import "time" + +// ErrorPassthroughRule 全局错误透传规则 +// 用于控制上游错误如何返回给客户端 +type ErrorPassthroughRule struct { + ID int64 `json:"id"` + Name string `json:"name"` // 规则名称 + Enabled bool `json:"enabled"` // 是否启用 + Priority int `json:"priority"` // 优先级(数字越小优先级越高) + ErrorCodes []int `json:"error_codes"` // 匹配的错误码列表(OR关系) + Keywords []string `json:"keywords"` // 匹配的关键词列表(OR关系) + MatchMode string `json:"match_mode"` // "any"(任一条件) 或 "all"(所有条件) + Platforms []string `json:"platforms"` // 适用平台列表 + PassthroughCode bool `json:"passthrough_code"` // 是否透传原始状态码 + ResponseCode *int `json:"response_code"` // 自定义状态码(passthrough_code=false 时使用) + PassthroughBody bool `json:"passthrough_body"` // 是否透传原始错误信息 + CustomMessage *string `json:"custom_message"` // 自定义错误信息(passthrough_body=false 时使用) + Description *string `json:"description"` // 规则描述 + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// MatchModeAny 表示任一条件匹配即可 +const MatchModeAny = "any" + +// MatchModeAll 表示所有条件都必须匹配 +const MatchModeAll = "all" + +// 支持的平台常量 +const ( + PlatformAnthropic = "anthropic" + PlatformOpenAI = "openai" + PlatformGemini = "gemini" + PlatformAntigravity = "antigravity" +) + +// AllPlatforms 返回所有支持的平台列表 +func AllPlatforms() []string { + return []string{PlatformAnthropic, PlatformOpenAI, PlatformGemini, PlatformAntigravity} +} + +// Validate 验证规则配置的有效性 +func (r *ErrorPassthroughRule) Validate() error { + if r.Name == "" { + return &ValidationError{Field: "name", Message: "name is required"} + } + if r.MatchMode != MatchModeAny && r.MatchMode != MatchModeAll { + return &ValidationError{Field: "match_mode", Message: "match_mode must be 'any' or 'all'"} + } + // 至少需要配置一个匹配条件(错误码或关键词) + if len(r.ErrorCodes) == 0 && len(r.Keywords) == 0 { + return &ValidationError{Field: "conditions", Message: "at least one error_code or keyword is required"} + } + if !r.PassthroughCode && (r.ResponseCode == nil || *r.ResponseCode <= 0) { + return &ValidationError{Field: "response_code", Message: "response_code is required when passthrough_code is false"} + } + if !r.PassthroughBody && (r.CustomMessage == nil || *r.CustomMessage == "") { + return &ValidationError{Field: "custom_message", Message: "custom_message is required when passthrough_body is false"} + } + return nil +} + +// ValidationError 表示验证错误 +type ValidationError struct { + Field string + Message string +} + +func (e *ValidationError) Error() string { + return e.Field + ": " + e.Message +} diff --git a/backend/internal/repository/error_passthrough_cache.go b/backend/internal/repository/error_passthrough_cache.go new file mode 100644 index 00000000..5584ffc8 --- /dev/null +++ b/backend/internal/repository/error_passthrough_cache.go @@ -0,0 +1,128 @@ +package repository + +import ( + "context" + "encoding/json" + "log" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/redis/go-redis/v9" +) + +const ( + errorPassthroughCacheKey = "error_passthrough_rules" + errorPassthroughPubSubKey = "error_passthrough_rules_updated" + errorPassthroughCacheTTL = 24 * time.Hour +) + +type errorPassthroughCache struct { + rdb *redis.Client + localCache []*model.ErrorPassthroughRule + localMu sync.RWMutex +} + +// NewErrorPassthroughCache 创建错误透传规则缓存 +func NewErrorPassthroughCache(rdb *redis.Client) service.ErrorPassthroughCache { + return &errorPassthroughCache{ + rdb: rdb, + } +} + +// Get 从缓存获取规则列表 +func (c *errorPassthroughCache) Get(ctx context.Context) ([]*model.ErrorPassthroughRule, bool) { + // 先检查本地缓存 + c.localMu.RLock() + if c.localCache != nil { + rules := c.localCache + c.localMu.RUnlock() + return rules, true + } + c.localMu.RUnlock() + + // 从 Redis 获取 + data, err := c.rdb.Get(ctx, errorPassthroughCacheKey).Bytes() + if err != nil { + if err != redis.Nil { + log.Printf("[ErrorPassthroughCache] Failed to get from Redis: %v", err) + } + return nil, false + } + + var rules []*model.ErrorPassthroughRule + if err := json.Unmarshal(data, &rules); err != nil { + log.Printf("[ErrorPassthroughCache] Failed to unmarshal rules: %v", err) + return nil, false + } + + // 更新本地缓存 + c.localMu.Lock() + c.localCache = rules + c.localMu.Unlock() + + return rules, true +} + +// Set 设置缓存 +func (c *errorPassthroughCache) Set(ctx context.Context, rules []*model.ErrorPassthroughRule) error { + data, err := json.Marshal(rules) + if err != nil { + return err + } + + if err := c.rdb.Set(ctx, errorPassthroughCacheKey, data, errorPassthroughCacheTTL).Err(); err != nil { + return err + } + + // 更新本地缓存 + c.localMu.Lock() + c.localCache = rules + c.localMu.Unlock() + + return nil +} + +// Invalidate 使缓存失效 +func (c *errorPassthroughCache) Invalidate(ctx context.Context) error { + // 清除本地缓存 + c.localMu.Lock() + c.localCache = nil + c.localMu.Unlock() + + // 清除 Redis 缓存 + return c.rdb.Del(ctx, errorPassthroughCacheKey).Err() +} + +// NotifyUpdate 通知其他实例刷新缓存 +func (c *errorPassthroughCache) NotifyUpdate(ctx context.Context) error { + return c.rdb.Publish(ctx, errorPassthroughPubSubKey, "refresh").Err() +} + +// SubscribeUpdates 订阅缓存更新通知 +func (c *errorPassthroughCache) SubscribeUpdates(ctx context.Context, handler func()) { + go func() { + sub := c.rdb.Subscribe(ctx, errorPassthroughPubSubKey) + defer func() { _ = sub.Close() }() + + ch := sub.Channel() + for { + select { + case <-ctx.Done(): + return + case msg := <-ch: + if msg == nil { + return + } + // 清除本地缓存,下次访问时会从 Redis 或数据库重新加载 + c.localMu.Lock() + c.localCache = nil + c.localMu.Unlock() + + // 调用处理函数 + handler() + } + } + }() +} diff --git a/backend/internal/repository/error_passthrough_repo.go b/backend/internal/repository/error_passthrough_repo.go new file mode 100644 index 00000000..a58ab60f --- /dev/null +++ b/backend/internal/repository/error_passthrough_repo.go @@ -0,0 +1,178 @@ +package repository + +import ( + "context" + + "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +type errorPassthroughRepository struct { + client *ent.Client +} + +// NewErrorPassthroughRepository 创建错误透传规则仓库 +func NewErrorPassthroughRepository(client *ent.Client) service.ErrorPassthroughRepository { + return &errorPassthroughRepository{client: client} +} + +// List 获取所有规则 +func (r *errorPassthroughRepository) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) { + rules, err := r.client.ErrorPassthroughRule.Query(). + Order(ent.Asc(errorpassthroughrule.FieldPriority)). + All(ctx) + if err != nil { + return nil, err + } + + result := make([]*model.ErrorPassthroughRule, len(rules)) + for i, rule := range rules { + result[i] = r.toModel(rule) + } + return result, nil +} + +// GetByID 根据 ID 获取规则 +func (r *errorPassthroughRepository) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) { + rule, err := r.client.ErrorPassthroughRule.Get(ctx, id) + if err != nil { + if ent.IsNotFound(err) { + return nil, nil + } + return nil, err + } + return r.toModel(rule), nil +} + +// Create 创建规则 +func (r *errorPassthroughRepository) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) { + builder := r.client.ErrorPassthroughRule.Create(). + SetName(rule.Name). + SetEnabled(rule.Enabled). + SetPriority(rule.Priority). + SetMatchMode(rule.MatchMode). + SetPassthroughCode(rule.PassthroughCode). + SetPassthroughBody(rule.PassthroughBody) + + if len(rule.ErrorCodes) > 0 { + builder.SetErrorCodes(rule.ErrorCodes) + } + if len(rule.Keywords) > 0 { + builder.SetKeywords(rule.Keywords) + } + if len(rule.Platforms) > 0 { + builder.SetPlatforms(rule.Platforms) + } + if rule.ResponseCode != nil { + builder.SetResponseCode(*rule.ResponseCode) + } + if rule.CustomMessage != nil { + builder.SetCustomMessage(*rule.CustomMessage) + } + if rule.Description != nil { + builder.SetDescription(*rule.Description) + } + + created, err := builder.Save(ctx) + if err != nil { + return nil, err + } + return r.toModel(created), nil +} + +// Update 更新规则 +func (r *errorPassthroughRepository) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) { + builder := r.client.ErrorPassthroughRule.UpdateOneID(rule.ID). + SetName(rule.Name). + SetEnabled(rule.Enabled). + SetPriority(rule.Priority). + SetMatchMode(rule.MatchMode). + SetPassthroughCode(rule.PassthroughCode). + SetPassthroughBody(rule.PassthroughBody) + + // 处理可选字段 + if len(rule.ErrorCodes) > 0 { + builder.SetErrorCodes(rule.ErrorCodes) + } else { + builder.ClearErrorCodes() + } + if len(rule.Keywords) > 0 { + builder.SetKeywords(rule.Keywords) + } else { + builder.ClearKeywords() + } + if len(rule.Platforms) > 0 { + builder.SetPlatforms(rule.Platforms) + } else { + builder.ClearPlatforms() + } + if rule.ResponseCode != nil { + builder.SetResponseCode(*rule.ResponseCode) + } else { + builder.ClearResponseCode() + } + if rule.CustomMessage != nil { + builder.SetCustomMessage(*rule.CustomMessage) + } else { + builder.ClearCustomMessage() + } + if rule.Description != nil { + builder.SetDescription(*rule.Description) + } else { + builder.ClearDescription() + } + + updated, err := builder.Save(ctx) + if err != nil { + return nil, err + } + return r.toModel(updated), nil +} + +// Delete 删除规则 +func (r *errorPassthroughRepository) Delete(ctx context.Context, id int64) error { + return r.client.ErrorPassthroughRule.DeleteOneID(id).Exec(ctx) +} + +// toModel 将 Ent 实体转换为服务模型 +func (r *errorPassthroughRepository) toModel(e *ent.ErrorPassthroughRule) *model.ErrorPassthroughRule { + rule := &model.ErrorPassthroughRule{ + ID: int64(e.ID), + Name: e.Name, + Enabled: e.Enabled, + Priority: e.Priority, + ErrorCodes: e.ErrorCodes, + Keywords: e.Keywords, + MatchMode: e.MatchMode, + Platforms: e.Platforms, + PassthroughCode: e.PassthroughCode, + PassthroughBody: e.PassthroughBody, + CreatedAt: e.CreatedAt, + UpdatedAt: e.UpdatedAt, + } + + if e.ResponseCode != nil { + rule.ResponseCode = e.ResponseCode + } + if e.CustomMessage != nil { + rule.CustomMessage = e.CustomMessage + } + if e.Description != nil { + rule.Description = e.Description + } + + // 确保切片不为 nil + if rule.ErrorCodes == nil { + rule.ErrorCodes = []int{} + } + if rule.Keywords == nil { + rule.Keywords = []string{} + } + if rule.Platforms == nil { + rule.Platforms = []string{} + } + + return rule +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 5437de35..3aed9d9c 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -67,6 +67,7 @@ var ProviderSet = wire.NewSet( NewUserAttributeDefinitionRepository, NewUserAttributeValueRepository, NewUserGroupRateRepository, + NewErrorPassthroughRepository, // Cache implementations NewGatewayCache, @@ -87,6 +88,7 @@ var ProviderSet = wire.NewSet( NewProxyLatencyCache, NewTotpCache, NewRefreshTokenCache, + NewErrorPassthroughCache, // Encryptors NewAESEncryptor, diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index ca9d627e..a1c27b00 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -67,6 +67,9 @@ func RegisterAdminRoutes( // 用户属性管理 registerUserAttributeRoutes(admin, h) + + // 错误透传规则管理 + registerErrorPassthroughRoutes(admin, h) } } @@ -387,3 +390,14 @@ func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) { attrs.DELETE("/:id", h.Admin.UserAttribute.DeleteDefinition) } } + +func registerErrorPassthroughRoutes(admin *gin.RouterGroup, h *handler.Handlers) { + rules := admin.Group("/error-passthrough-rules") + { + rules.GET("", h.Admin.ErrorPassthrough.List) + rules.GET("/:id", h.Admin.ErrorPassthrough.GetByID) + rules.POST("", h.Admin.ErrorPassthrough.Create) + rules.PUT("/:id", h.Admin.ErrorPassthrough.Update) + rules.DELETE("/:id", h.Admin.ErrorPassthrough.Delete) + } +} diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index cf7e35fc..4ca32829 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -1106,7 +1106,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, Message: upstreamMsg, Detail: upstreamDetail, }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody) @@ -1779,6 +1779,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co // 处理错误响应 if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + contentType := resp.Header.Get("Content-Type") // 尽早关闭原始响应体,释放连接;后续逻辑仍可能需要读取 body,因此用内存副本重新包装。 _ = resp.Body.Close() resp.Body = io.NopCloser(bytes.NewReader(respBody)) @@ -1849,10 +1850,8 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co Message: upstreamMsg, Detail: upstreamDetail, }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: unwrappedForOps} } - - contentType := resp.Header.Get("Content-Type") if contentType == "" { contentType = "application/json" } diff --git a/backend/internal/service/error_passthrough_service.go b/backend/internal/service/error_passthrough_service.go new file mode 100644 index 00000000..99dc70e3 --- /dev/null +++ b/backend/internal/service/error_passthrough_service.go @@ -0,0 +1,300 @@ +package service + +import ( + "context" + "log" + "sort" + "strings" + "sync" + + "github.com/Wei-Shaw/sub2api/internal/model" +) + +// ErrorPassthroughRepository 定义错误透传规则的数据访问接口 +type ErrorPassthroughRepository interface { + // List 获取所有规则 + List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) + // GetByID 根据 ID 获取规则 + GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) + // Create 创建规则 + Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) + // Update 更新规则 + Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) + // Delete 删除规则 + Delete(ctx context.Context, id int64) error +} + +// ErrorPassthroughCache 定义错误透传规则的缓存接口 +type ErrorPassthroughCache interface { + // Get 从缓存获取规则列表 + Get(ctx context.Context) ([]*model.ErrorPassthroughRule, bool) + // Set 设置缓存 + Set(ctx context.Context, rules []*model.ErrorPassthroughRule) error + // Invalidate 使缓存失效 + Invalidate(ctx context.Context) error + // NotifyUpdate 通知其他实例刷新缓存 + NotifyUpdate(ctx context.Context) error + // SubscribeUpdates 订阅缓存更新通知 + SubscribeUpdates(ctx context.Context, handler func()) +} + +// ErrorPassthroughService 错误透传规则服务 +type ErrorPassthroughService struct { + repo ErrorPassthroughRepository + cache ErrorPassthroughCache + + // 本地内存缓存,用于快速匹配 + localCache []*model.ErrorPassthroughRule + localCacheMu sync.RWMutex +} + +// NewErrorPassthroughService 创建错误透传规则服务 +func NewErrorPassthroughService( + repo ErrorPassthroughRepository, + cache ErrorPassthroughCache, +) *ErrorPassthroughService { + svc := &ErrorPassthroughService{ + repo: repo, + cache: cache, + } + + // 启动时加载规则到本地缓存 + ctx := context.Background() + if err := svc.refreshLocalCache(ctx); err != nil { + log.Printf("[ErrorPassthroughService] Failed to load rules on startup: %v", err) + } + + // 订阅缓存更新通知 + if cache != nil { + cache.SubscribeUpdates(ctx, func() { + if err := svc.refreshLocalCache(context.Background()); err != nil { + log.Printf("[ErrorPassthroughService] Failed to refresh cache on notification: %v", err) + } + }) + } + + return svc +} + +// List 获取所有规则 +func (s *ErrorPassthroughService) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) { + return s.repo.List(ctx) +} + +// GetByID 根据 ID 获取规则 +func (s *ErrorPassthroughService) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) { + return s.repo.GetByID(ctx, id) +} + +// Create 创建规则 +func (s *ErrorPassthroughService) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) { + if err := rule.Validate(); err != nil { + return nil, err + } + + created, err := s.repo.Create(ctx, rule) + if err != nil { + return nil, err + } + + // 刷新缓存 + s.invalidateAndNotify(ctx) + + return created, nil +} + +// Update 更新规则 +func (s *ErrorPassthroughService) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) { + if err := rule.Validate(); err != nil { + return nil, err + } + + updated, err := s.repo.Update(ctx, rule) + if err != nil { + return nil, err + } + + // 刷新缓存 + s.invalidateAndNotify(ctx) + + return updated, nil +} + +// Delete 删除规则 +func (s *ErrorPassthroughService) Delete(ctx context.Context, id int64) error { + if err := s.repo.Delete(ctx, id); err != nil { + return err + } + + // 刷新缓存 + s.invalidateAndNotify(ctx) + + return nil +} + +// MatchRule 匹配透传规则 +// 返回第一个匹配的规则,如果没有匹配则返回 nil +func (s *ErrorPassthroughService) MatchRule(platform string, statusCode int, body []byte) *model.ErrorPassthroughRule { + rules := s.getCachedRules() + if len(rules) == 0 { + return nil + } + + bodyStr := strings.ToLower(string(body)) + + for _, rule := range rules { + if !rule.Enabled { + continue + } + if !s.platformMatches(rule, platform) { + continue + } + if s.ruleMatches(rule, statusCode, bodyStr) { + return rule + } + } + + return nil +} + +// getCachedRules 获取缓存的规则列表(按优先级排序) +func (s *ErrorPassthroughService) getCachedRules() []*model.ErrorPassthroughRule { + s.localCacheMu.RLock() + rules := s.localCache + s.localCacheMu.RUnlock() + + if rules != nil { + return rules + } + + // 如果本地缓存为空,尝试刷新 + ctx := context.Background() + if err := s.refreshLocalCache(ctx); err != nil { + log.Printf("[ErrorPassthroughService] Failed to refresh cache: %v", err) + return nil + } + + s.localCacheMu.RLock() + defer s.localCacheMu.RUnlock() + return s.localCache +} + +// refreshLocalCache 刷新本地缓存 +func (s *ErrorPassthroughService) refreshLocalCache(ctx context.Context) error { + // 先尝试从 Redis 缓存获取 + if s.cache != nil { + if rules, ok := s.cache.Get(ctx); ok { + s.setLocalCache(rules) + return nil + } + } + + // 从数据库加载(repo.List 已按 priority 排序) + rules, err := s.repo.List(ctx) + if err != nil { + return err + } + + // 更新 Redis 缓存 + if s.cache != nil { + if err := s.cache.Set(ctx, rules); err != nil { + log.Printf("[ErrorPassthroughService] Failed to set cache: %v", err) + } + } + + // 更新本地缓存(setLocalCache 内部会确保排序) + s.setLocalCache(rules) + + return nil +} + +// setLocalCache 设置本地缓存 +func (s *ErrorPassthroughService) setLocalCache(rules []*model.ErrorPassthroughRule) { + // 按优先级排序 + sorted := make([]*model.ErrorPassthroughRule, len(rules)) + copy(sorted, rules) + sort.Slice(sorted, func(i, j int) bool { + return sorted[i].Priority < sorted[j].Priority + }) + + s.localCacheMu.Lock() + s.localCache = sorted + s.localCacheMu.Unlock() +} + +// invalidateAndNotify 使缓存失效并通知其他实例 +func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) { + // 刷新本地缓存 + if err := s.refreshLocalCache(ctx); err != nil { + log.Printf("[ErrorPassthroughService] Failed to refresh local cache: %v", err) + } + + // 通知其他实例 + if s.cache != nil { + if err := s.cache.NotifyUpdate(ctx); err != nil { + log.Printf("[ErrorPassthroughService] Failed to notify cache update: %v", err) + } + } +} + +// platformMatches 检查平台是否匹配 +func (s *ErrorPassthroughService) platformMatches(rule *model.ErrorPassthroughRule, platform string) bool { + // 如果没有配置平台限制,则匹配所有平台 + if len(rule.Platforms) == 0 { + return true + } + + platform = strings.ToLower(platform) + for _, p := range rule.Platforms { + if strings.ToLower(p) == platform { + return true + } + } + + return false +} + +// ruleMatches 检查规则是否匹配 +func (s *ErrorPassthroughService) ruleMatches(rule *model.ErrorPassthroughRule, statusCode int, bodyLower string) bool { + hasErrorCodes := len(rule.ErrorCodes) > 0 + hasKeywords := len(rule.Keywords) > 0 + + // 如果没有配置任何条件,不匹配 + if !hasErrorCodes && !hasKeywords { + return false + } + + codeMatch := !hasErrorCodes || s.containsInt(rule.ErrorCodes, statusCode) + keywordMatch := !hasKeywords || s.containsAnyKeyword(bodyLower, rule.Keywords) + + if rule.MatchMode == model.MatchModeAll { + // "all" 模式:所有配置的条件都必须满足 + return codeMatch && keywordMatch + } + + // "any" 模式:任一条件满足即可 + if hasErrorCodes && hasKeywords { + return codeMatch || keywordMatch + } + return codeMatch && keywordMatch +} + +// containsInt 检查切片是否包含指定整数 +func (s *ErrorPassthroughService) containsInt(slice []int, val int) bool { + for _, v := range slice { + if v == val { + return true + } + } + return false +} + +// containsAnyKeyword 检查字符串是否包含任一关键词(不区分大小写) +func (s *ErrorPassthroughService) containsAnyKeyword(bodyLower string, keywords []string) bool { + for _, kw := range keywords { + if strings.Contains(bodyLower, strings.ToLower(kw)) { + return true + } + } + return false +} diff --git a/backend/internal/service/error_passthrough_service_test.go b/backend/internal/service/error_passthrough_service_test.go new file mode 100644 index 00000000..205b4ec4 --- /dev/null +++ b/backend/internal/service/error_passthrough_service_test.go @@ -0,0 +1,755 @@ +//go:build unit + +package service + +import ( + "context" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockErrorPassthroughRepo 用于测试的 mock repository +type mockErrorPassthroughRepo struct { + rules []*model.ErrorPassthroughRule +} + +func (m *mockErrorPassthroughRepo) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) { + return m.rules, nil +} + +func (m *mockErrorPassthroughRepo) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) { + for _, r := range m.rules { + if r.ID == id { + return r, nil + } + } + return nil, nil +} + +func (m *mockErrorPassthroughRepo) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) { + rule.ID = int64(len(m.rules) + 1) + m.rules = append(m.rules, rule) + return rule, nil +} + +func (m *mockErrorPassthroughRepo) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) { + for i, r := range m.rules { + if r.ID == rule.ID { + m.rules[i] = rule + return rule, nil + } + } + return rule, nil +} + +func (m *mockErrorPassthroughRepo) Delete(ctx context.Context, id int64) error { + for i, r := range m.rules { + if r.ID == id { + m.rules = append(m.rules[:i], m.rules[i+1:]...) + return nil + } + } + return nil +} + +// newTestService 创建测试用的服务实例 +func newTestService(rules []*model.ErrorPassthroughRule) *ErrorPassthroughService { + repo := &mockErrorPassthroughRepo{rules: rules} + svc := &ErrorPassthroughService{ + repo: repo, + cache: nil, // 不使用缓存 + } + // 直接设置本地缓存,避免调用 refreshLocalCache + svc.setLocalCache(rules) + return svc +} + +// ============================================================================= +// 测试 ruleMatches 核心匹配逻辑 +// ============================================================================= + +func TestRuleMatches_NoConditions(t *testing.T) { + // 没有配置任何条件时,不应该匹配 + svc := newTestService(nil) + rule := &model.ErrorPassthroughRule{ + Enabled: true, + ErrorCodes: []int{}, + Keywords: []string{}, + MatchMode: model.MatchModeAny, + } + + assert.False(t, svc.ruleMatches(rule, 422, "some error message"), + "没有配置条件时不应该匹配") +} + +func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) { + svc := newTestService(nil) + rule := &model.ErrorPassthroughRule{ + Enabled: true, + ErrorCodes: []int{422, 400}, + Keywords: []string{}, + MatchMode: model.MatchModeAny, + } + + tests := []struct { + name string + statusCode int + body string + expected bool + }{ + {"状态码匹配 422", 422, "any message", true}, + {"状态码匹配 400", 400, "any message", true}, + {"状态码不匹配 500", 500, "any message", false}, + {"状态码不匹配 429", 429, "any message", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := svc.ruleMatches(rule, tt.statusCode, tt.body) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) { + svc := newTestService(nil) + rule := &model.ErrorPassthroughRule{ + Enabled: true, + ErrorCodes: []int{}, + Keywords: []string{"context limit", "model not supported"}, + MatchMode: model.MatchModeAny, + } + + tests := []struct { + name string + statusCode int + body string + expected bool + }{ + {"关键词匹配 context limit", 500, "error: context limit reached", true}, + {"关键词匹配 model not supported", 400, "the model not supported here", true}, + {"关键词不匹配", 422, "some other error", false}, + // 注意:ruleMatches 接收的 body 参数应该是已经转换为小写的 + // 实际使用时,MatchRule 会先将 body 转换为小写再传给 ruleMatches + {"关键词大小写 - 输入已小写", 500, "context limit exceeded", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // 模拟 MatchRule 的行为:先转换为小写 + bodyLower := strings.ToLower(tt.body) + result := svc.ruleMatches(rule, tt.statusCode, bodyLower) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestRuleMatches_BothConditions_AnyMode(t *testing.T) { + // any 模式:错误码 OR 关键词 + svc := newTestService(nil) + rule := &model.ErrorPassthroughRule{ + Enabled: true, + ErrorCodes: []int{422, 400}, + Keywords: []string{"context limit"}, + MatchMode: model.MatchModeAny, + } + + tests := []struct { + name string + statusCode int + body string + expected bool + reason string + }{ + { + name: "状态码和关键词都匹配", + statusCode: 422, + body: "context limit reached", + expected: true, + reason: "both match", + }, + { + name: "只有状态码匹配", + statusCode: 422, + body: "some other error", + expected: true, + reason: "code matches, keyword doesn't - OR mode should match", + }, + { + name: "只有关键词匹配", + statusCode: 500, + body: "context limit exceeded", + expected: true, + reason: "keyword matches, code doesn't - OR mode should match", + }, + { + name: "都不匹配", + statusCode: 500, + body: "some other error", + expected: false, + reason: "neither matches", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := svc.ruleMatches(rule, tt.statusCode, tt.body) + assert.Equal(t, tt.expected, result, tt.reason) + }) + } +} + +func TestRuleMatches_BothConditions_AllMode(t *testing.T) { + // all 模式:错误码 AND 关键词 + svc := newTestService(nil) + rule := &model.ErrorPassthroughRule{ + Enabled: true, + ErrorCodes: []int{422, 400}, + Keywords: []string{"context limit"}, + MatchMode: model.MatchModeAll, + } + + tests := []struct { + name string + statusCode int + body string + expected bool + reason string + }{ + { + name: "状态码和关键词都匹配", + statusCode: 422, + body: "context limit reached", + expected: true, + reason: "both match - AND mode should match", + }, + { + name: "只有状态码匹配", + statusCode: 422, + body: "some other error", + expected: false, + reason: "code matches but keyword doesn't - AND mode should NOT match", + }, + { + name: "只有关键词匹配", + statusCode: 500, + body: "context limit exceeded", + expected: false, + reason: "keyword matches but code doesn't - AND mode should NOT match", + }, + { + name: "都不匹配", + statusCode: 500, + body: "some other error", + expected: false, + reason: "neither matches", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := svc.ruleMatches(rule, tt.statusCode, tt.body) + assert.Equal(t, tt.expected, result, tt.reason) + }) + } +} + +// ============================================================================= +// 测试 platformMatches 平台匹配逻辑 +// ============================================================================= + +func TestPlatformMatches(t *testing.T) { + svc := newTestService(nil) + + tests := []struct { + name string + rulePlatforms []string + requestPlatform string + expected bool + }{ + { + name: "空平台列表匹配所有", + rulePlatforms: []string{}, + requestPlatform: "anthropic", + expected: true, + }, + { + name: "nil平台列表匹配所有", + rulePlatforms: nil, + requestPlatform: "openai", + expected: true, + }, + { + name: "精确匹配 anthropic", + rulePlatforms: []string{"anthropic", "openai"}, + requestPlatform: "anthropic", + expected: true, + }, + { + name: "精确匹配 openai", + rulePlatforms: []string{"anthropic", "openai"}, + requestPlatform: "openai", + expected: true, + }, + { + name: "不匹配 gemini", + rulePlatforms: []string{"anthropic", "openai"}, + requestPlatform: "gemini", + expected: false, + }, + { + name: "大小写不敏感", + rulePlatforms: []string{"Anthropic", "OpenAI"}, + requestPlatform: "anthropic", + expected: true, + }, + { + name: "匹配 antigravity", + rulePlatforms: []string{"antigravity"}, + requestPlatform: "antigravity", + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rule := &model.ErrorPassthroughRule{ + Platforms: tt.rulePlatforms, + } + result := svc.platformMatches(rule, tt.requestPlatform) + assert.Equal(t, tt.expected, result) + }) + } +} + +// ============================================================================= +// 测试 MatchRule 完整匹配流程 +// ============================================================================= + +func TestMatchRule_Priority(t *testing.T) { + // 测试规则按优先级排序,优先级小的先匹配 + rules := []*model.ErrorPassthroughRule{ + { + ID: 1, + Name: "Low Priority", + Enabled: true, + Priority: 10, + ErrorCodes: []int{422}, + MatchMode: model.MatchModeAny, + }, + { + ID: 2, + Name: "High Priority", + Enabled: true, + Priority: 1, + ErrorCodes: []int{422}, + MatchMode: model.MatchModeAny, + }, + } + + svc := newTestService(rules) + matched := svc.MatchRule("anthropic", 422, []byte("error")) + + require.NotNil(t, matched) + assert.Equal(t, int64(2), matched.ID, "应该匹配优先级更高(数值更小)的规则") + assert.Equal(t, "High Priority", matched.Name) +} + +func TestMatchRule_DisabledRule(t *testing.T) { + rules := []*model.ErrorPassthroughRule{ + { + ID: 1, + Name: "Disabled Rule", + Enabled: false, + Priority: 1, + ErrorCodes: []int{422}, + MatchMode: model.MatchModeAny, + }, + { + ID: 2, + Name: "Enabled Rule", + Enabled: true, + Priority: 10, + ErrorCodes: []int{422}, + MatchMode: model.MatchModeAny, + }, + } + + svc := newTestService(rules) + matched := svc.MatchRule("anthropic", 422, []byte("error")) + + require.NotNil(t, matched) + assert.Equal(t, int64(2), matched.ID, "应该跳过禁用的规则") +} + +func TestMatchRule_PlatformFilter(t *testing.T) { + rules := []*model.ErrorPassthroughRule{ + { + ID: 1, + Name: "Anthropic Only", + Enabled: true, + Priority: 1, + ErrorCodes: []int{422}, + Platforms: []string{"anthropic"}, + MatchMode: model.MatchModeAny, + }, + { + ID: 2, + Name: "OpenAI Only", + Enabled: true, + Priority: 2, + ErrorCodes: []int{422}, + Platforms: []string{"openai"}, + MatchMode: model.MatchModeAny, + }, + { + ID: 3, + Name: "All Platforms", + Enabled: true, + Priority: 3, + ErrorCodes: []int{422}, + Platforms: []string{}, + MatchMode: model.MatchModeAny, + }, + } + + svc := newTestService(rules) + + t.Run("Anthropic 请求匹配 Anthropic 规则", func(t *testing.T) { + matched := svc.MatchRule("anthropic", 422, []byte("error")) + require.NotNil(t, matched) + assert.Equal(t, int64(1), matched.ID) + }) + + t.Run("OpenAI 请求匹配 OpenAI 规则", func(t *testing.T) { + matched := svc.MatchRule("openai", 422, []byte("error")) + require.NotNil(t, matched) + assert.Equal(t, int64(2), matched.ID) + }) + + t.Run("Gemini 请求匹配全平台规则", func(t *testing.T) { + matched := svc.MatchRule("gemini", 422, []byte("error")) + require.NotNil(t, matched) + assert.Equal(t, int64(3), matched.ID) + }) + + t.Run("Antigravity 请求匹配全平台规则", func(t *testing.T) { + matched := svc.MatchRule("antigravity", 422, []byte("error")) + require.NotNil(t, matched) + assert.Equal(t, int64(3), matched.ID) + }) +} + +func TestMatchRule_NoMatch(t *testing.T) { + rules := []*model.ErrorPassthroughRule{ + { + ID: 1, + Name: "Rule for 422", + Enabled: true, + Priority: 1, + ErrorCodes: []int{422}, + MatchMode: model.MatchModeAny, + }, + } + + svc := newTestService(rules) + matched := svc.MatchRule("anthropic", 500, []byte("error")) + + assert.Nil(t, matched, "不匹配任何规则时应返回 nil") +} + +func TestMatchRule_EmptyRules(t *testing.T) { + svc := newTestService([]*model.ErrorPassthroughRule{}) + matched := svc.MatchRule("anthropic", 422, []byte("error")) + + assert.Nil(t, matched, "没有规则时应返回 nil") +} + +func TestMatchRule_CaseInsensitiveKeyword(t *testing.T) { + rules := []*model.ErrorPassthroughRule{ + { + ID: 1, + Name: "Context Limit", + Enabled: true, + Priority: 1, + Keywords: []string{"Context Limit"}, + MatchMode: model.MatchModeAny, + }, + } + + svc := newTestService(rules) + + tests := []struct { + name string + body string + expected bool + }{ + {"完全匹配", "Context Limit reached", true}, + {"小写匹配", "context limit reached", true}, + {"大写匹配", "CONTEXT LIMIT REACHED", true}, + {"混合大小写", "ConTeXt LiMiT error", true}, + {"不匹配", "some other error", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + matched := svc.MatchRule("anthropic", 500, []byte(tt.body)) + if tt.expected { + assert.NotNil(t, matched) + } else { + assert.Nil(t, matched) + } + }) + } +} + +// ============================================================================= +// 测试真实场景 +// ============================================================================= + +func TestMatchRule_RealWorldScenario_ContextLimitPassthrough(t *testing.T) { + // 场景:上游返回 422 + "context limit has been reached",需要透传给客户端 + rules := []*model.ErrorPassthroughRule{ + { + ID: 1, + Name: "Context Limit Passthrough", + Enabled: true, + Priority: 1, + ErrorCodes: []int{422}, + Keywords: []string{"context limit"}, + MatchMode: model.MatchModeAll, // 必须同时满足 + Platforms: []string{"anthropic", "antigravity"}, + PassthroughCode: true, + PassthroughBody: true, + }, + } + + svc := newTestService(rules) + + // 测试 Anthropic 平台 + t.Run("Anthropic 422 with context limit", func(t *testing.T) { + body := []byte(`{"type":"error","error":{"type":"invalid_request","message":"The context limit has been reached"}}`) + matched := svc.MatchRule("anthropic", 422, body) + require.NotNil(t, matched) + assert.True(t, matched.PassthroughCode) + assert.True(t, matched.PassthroughBody) + }) + + // 测试 Antigravity 平台 + t.Run("Antigravity 422 with context limit", func(t *testing.T) { + body := []byte(`{"error":"context limit exceeded"}`) + matched := svc.MatchRule("antigravity", 422, body) + require.NotNil(t, matched) + }) + + // 测试 OpenAI 平台(不在规则的平台列表中) + t.Run("OpenAI should not match", func(t *testing.T) { + body := []byte(`{"error":"context limit exceeded"}`) + matched := svc.MatchRule("openai", 422, body) + assert.Nil(t, matched, "OpenAI 不在规则的平台列表中") + }) + + // 测试状态码不匹配 + t.Run("Wrong status code", func(t *testing.T) { + body := []byte(`{"error":"context limit exceeded"}`) + matched := svc.MatchRule("anthropic", 400, body) + assert.Nil(t, matched, "状态码不匹配") + }) + + // 测试关键词不匹配 + t.Run("Wrong keyword", func(t *testing.T) { + body := []byte(`{"error":"rate limit exceeded"}`) + matched := svc.MatchRule("anthropic", 422, body) + assert.Nil(t, matched, "关键词不匹配") + }) +} + +func TestMatchRule_RealWorldScenario_CustomErrorMessage(t *testing.T) { + // 场景:某些错误需要返回自定义消息,隐藏上游详细信息 + customMsg := "Service temporarily unavailable, please try again later" + responseCode := 503 + rules := []*model.ErrorPassthroughRule{ + { + ID: 1, + Name: "Hide Internal Errors", + Enabled: true, + Priority: 1, + ErrorCodes: []int{500, 502, 503}, + MatchMode: model.MatchModeAny, + PassthroughCode: false, + ResponseCode: &responseCode, + PassthroughBody: false, + CustomMessage: &customMsg, + }, + } + + svc := newTestService(rules) + + matched := svc.MatchRule("anthropic", 500, []byte("internal server error")) + require.NotNil(t, matched) + assert.False(t, matched.PassthroughCode) + assert.Equal(t, 503, *matched.ResponseCode) + assert.False(t, matched.PassthroughBody) + assert.Equal(t, customMsg, *matched.CustomMessage) +} + +// ============================================================================= +// 测试 model.Validate +// ============================================================================= + +func TestErrorPassthroughRule_Validate(t *testing.T) { + tests := []struct { + name string + rule *model.ErrorPassthroughRule + expectError bool + errorField string + }{ + { + name: "有效规则 - 透传模式(含错误码)", + rule: &model.ErrorPassthroughRule{ + Name: "Valid Rule", + MatchMode: model.MatchModeAny, + ErrorCodes: []int{422}, + PassthroughCode: true, + PassthroughBody: true, + }, + expectError: false, + }, + { + name: "有效规则 - 透传模式(含关键词)", + rule: &model.ErrorPassthroughRule{ + Name: "Valid Rule", + MatchMode: model.MatchModeAny, + Keywords: []string{"context limit"}, + PassthroughCode: true, + PassthroughBody: true, + }, + expectError: false, + }, + { + name: "有效规则 - 自定义响应", + rule: &model.ErrorPassthroughRule{ + Name: "Valid Rule", + MatchMode: model.MatchModeAll, + ErrorCodes: []int{500}, + Keywords: []string{"internal error"}, + PassthroughCode: false, + ResponseCode: testIntPtr(503), + PassthroughBody: false, + CustomMessage: testStrPtr("Custom error"), + }, + expectError: false, + }, + { + name: "缺少名称", + rule: &model.ErrorPassthroughRule{ + Name: "", + MatchMode: model.MatchModeAny, + ErrorCodes: []int{422}, + PassthroughCode: true, + PassthroughBody: true, + }, + expectError: true, + errorField: "name", + }, + { + name: "无效的匹配模式", + rule: &model.ErrorPassthroughRule{ + Name: "Invalid Mode", + MatchMode: "invalid", + ErrorCodes: []int{422}, + PassthroughCode: true, + PassthroughBody: true, + }, + expectError: true, + errorField: "match_mode", + }, + { + name: "缺少匹配条件(错误码和关键词都为空)", + rule: &model.ErrorPassthroughRule{ + Name: "No Conditions", + MatchMode: model.MatchModeAny, + ErrorCodes: []int{}, + Keywords: []string{}, + PassthroughCode: true, + PassthroughBody: true, + }, + expectError: true, + errorField: "conditions", + }, + { + name: "缺少匹配条件(nil切片)", + rule: &model.ErrorPassthroughRule{ + Name: "Nil Conditions", + MatchMode: model.MatchModeAny, + ErrorCodes: nil, + Keywords: nil, + PassthroughCode: true, + PassthroughBody: true, + }, + expectError: true, + errorField: "conditions", + }, + { + name: "自定义状态码但未提供值", + rule: &model.ErrorPassthroughRule{ + Name: "Missing Code", + MatchMode: model.MatchModeAny, + ErrorCodes: []int{422}, + PassthroughCode: false, + ResponseCode: nil, + PassthroughBody: true, + }, + expectError: true, + errorField: "response_code", + }, + { + name: "自定义消息但未提供值", + rule: &model.ErrorPassthroughRule{ + Name: "Missing Message", + MatchMode: model.MatchModeAny, + ErrorCodes: []int{422}, + PassthroughCode: true, + PassthroughBody: false, + CustomMessage: nil, + }, + expectError: true, + errorField: "custom_message", + }, + { + name: "自定义消息为空字符串", + rule: &model.ErrorPassthroughRule{ + Name: "Empty Message", + MatchMode: model.MatchModeAny, + ErrorCodes: []int{422}, + PassthroughCode: true, + PassthroughBody: false, + CustomMessage: testStrPtr(""), + }, + expectError: true, + errorField: "custom_message", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.rule.Validate() + if tt.expectError { + require.Error(t, err) + validationErr, ok := err.(*model.ValidationError) + require.True(t, ok, "应该返回 ValidationError") + assert.Equal(t, tt.errorField, validationErr.Field) + } else { + assert.NoError(t, err) + } + }) + } +} + +// Helper functions +func testIntPtr(i int) *int { return &i } +func testStrPtr(s string) *string { return &s } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 9036955a..9aecce22 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -370,7 +370,8 @@ type ForwardResult struct { // UpstreamFailoverError indicates an upstream error that should trigger account failover. type UpstreamFailoverError struct { - StatusCode int + StatusCode int + ResponseBody []byte // 上游响应体,用于错误透传规则匹配 } func (e *UpstreamFailoverError) Error() string { @@ -3284,7 +3285,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A return "" }(), }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } return s.handleRetryExhaustedError(ctx, resp, c, account) } @@ -3314,10 +3315,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A return "" }(), }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } - - // 处理错误响应(不可重试的错误) if resp.StatusCode >= 400 { // 可选:对部分 400 触发 failover(默认关闭以保持语义) if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 { @@ -3361,7 +3360,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A log.Printf("Account %d: 400 error, attempting failover", account.ID) } s.handleFailoverSideEffects(ctx, resp, account) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } } return s.handleErrorResponse(ctx, resp, c, account) @@ -3758,6 +3757,12 @@ func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool { return false } +// ExtractUpstreamErrorMessage 从上游响应体中提取错误消息 +// 支持 Claude 风格的错误格式:{"type":"error","error":{"type":"...","message":"..."}} +func ExtractUpstreamErrorMessage(body []byte) string { + return extractUpstreamErrorMessage(body) +} + func extractUpstreamErrorMessage(body []byte) string { // Claude 风格:{"type":"error","error":{"type":"...","message":"..."}} if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" { @@ -3825,7 +3830,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) } if shouldDisable { - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body} } // 记录上游错误响应体摘要便于排障(可选:由配置控制;不回显到客户端) diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index bd322991..eecb88f6 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -864,7 +864,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex Message: upstreamMsg, Detail: upstreamDetail, }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) { upstreamReqID := resp.Header.Get(requestIDHeader) @@ -891,7 +891,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex Message: upstreamMsg, Detail: upstreamDetail, }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } upstreamReqID := resp.Header.Get(requestIDHeader) if upstreamReqID == "" { @@ -1301,7 +1301,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. Message: upstreamMsg, Detail: upstreamDetail, }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) { evBody := unwrapIfNeeded(isOAuth, respBody) @@ -1325,7 +1325,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. Message: upstreamMsg, Detail: upstreamDetail, }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: evBody} } respBody = unwrapIfNeeded(isOAuth, respBody) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 4658c694..564ffa4d 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -940,7 +940,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco }) s.handleFailoverSideEffects(ctx, resp, account) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } return s.handleErrorResponse(ctx, resp, c, account) } @@ -1131,7 +1131,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht Detail: upstreamDetail, }) if shouldDisable { - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body} } // Return appropriate error response diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 4b721bb6..05371022 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -274,4 +274,5 @@ var ProviderSet = wire.NewSet( NewUserAttributeService, NewUsageCache, NewTotpService, + NewErrorPassthroughService, ) diff --git a/backend/migrations/048_add_error_passthrough_rules.sql b/backend/migrations/048_add_error_passthrough_rules.sql new file mode 100644 index 00000000..bf2a9117 --- /dev/null +++ b/backend/migrations/048_add_error_passthrough_rules.sql @@ -0,0 +1,24 @@ +-- Error Passthrough Rules table +-- Allows administrators to configure how upstream errors are passed through to clients + +CREATE TABLE IF NOT EXISTS error_passthrough_rules ( + id BIGSERIAL PRIMARY KEY, + name VARCHAR(100) NOT NULL, + enabled BOOLEAN NOT NULL DEFAULT true, + priority INTEGER NOT NULL DEFAULT 0, + error_codes JSONB DEFAULT '[]', + keywords JSONB DEFAULT '[]', + match_mode VARCHAR(10) NOT NULL DEFAULT 'any', + platforms JSONB DEFAULT '[]', + passthrough_code BOOLEAN NOT NULL DEFAULT true, + response_code INTEGER, + passthrough_body BOOLEAN NOT NULL DEFAULT true, + custom_message TEXT, + description TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +-- Indexes for efficient queries +CREATE INDEX IF NOT EXISTS idx_error_passthrough_rules_enabled ON error_passthrough_rules (enabled); +CREATE INDEX IF NOT EXISTS idx_error_passthrough_rules_priority ON error_passthrough_rules (priority); diff --git a/frontend/src/api/admin/errorPassthrough.ts b/frontend/src/api/admin/errorPassthrough.ts new file mode 100644 index 00000000..4c545ad5 --- /dev/null +++ b/frontend/src/api/admin/errorPassthrough.ts @@ -0,0 +1,134 @@ +/** + * Admin Error Passthrough Rules API endpoints + * Handles error passthrough rule management for administrators + */ + +import { apiClient } from '../client' + +/** + * Error passthrough rule interface + */ +export interface ErrorPassthroughRule { + id: number + name: string + enabled: boolean + priority: number + error_codes: number[] + keywords: string[] + match_mode: 'any' | 'all' + platforms: string[] + passthrough_code: boolean + response_code: number | null + passthrough_body: boolean + custom_message: string | null + description: string | null + created_at: string + updated_at: string +} + +/** + * Create rule request + */ +export interface CreateRuleRequest { + name: string + enabled?: boolean + priority?: number + error_codes?: number[] + keywords?: string[] + match_mode?: 'any' | 'all' + platforms?: string[] + passthrough_code?: boolean + response_code?: number | null + passthrough_body?: boolean + custom_message?: string | null + description?: string | null +} + +/** + * Update rule request + */ +export interface UpdateRuleRequest { + name?: string + enabled?: boolean + priority?: number + error_codes?: number[] + keywords?: string[] + match_mode?: 'any' | 'all' + platforms?: string[] + passthrough_code?: boolean + response_code?: number | null + passthrough_body?: boolean + custom_message?: string | null + description?: string | null +} + +/** + * List all error passthrough rules + * @returns List of all rules sorted by priority + */ +export async function list(): Promise { + const { data } = await apiClient.get('/admin/error-passthrough-rules') + return data +} + +/** + * Get rule by ID + * @param id - Rule ID + * @returns Rule details + */ +export async function getById(id: number): Promise { + const { data } = await apiClient.get(`/admin/error-passthrough-rules/${id}`) + return data +} + +/** + * Create new rule + * @param ruleData - Rule data + * @returns Created rule + */ +export async function create(ruleData: CreateRuleRequest): Promise { + const { data } = await apiClient.post('/admin/error-passthrough-rules', ruleData) + return data +} + +/** + * Update rule + * @param id - Rule ID + * @param updates - Fields to update + * @returns Updated rule + */ +export async function update(id: number, updates: UpdateRuleRequest): Promise { + const { data } = await apiClient.put(`/admin/error-passthrough-rules/${id}`, updates) + return data +} + +/** + * Delete rule + * @param id - Rule ID + * @returns Success confirmation + */ +export async function deleteRule(id: number): Promise<{ message: string }> { + const { data } = await apiClient.delete<{ message: string }>(`/admin/error-passthrough-rules/${id}`) + return data +} + +/** + * Toggle rule enabled status + * @param id - Rule ID + * @param enabled - New enabled status + * @returns Updated rule + */ +export async function toggleEnabled(id: number, enabled: boolean): Promise { + return update(id, { enabled }) +} + +export const errorPassthroughAPI = { + list, + getById, + create, + update, + delete: deleteRule, + toggleEnabled +} + +export default errorPassthroughAPI diff --git a/frontend/src/api/admin/index.ts b/frontend/src/api/admin/index.ts index 9a8a4195..ffb9b179 100644 --- a/frontend/src/api/admin/index.ts +++ b/frontend/src/api/admin/index.ts @@ -19,6 +19,7 @@ import geminiAPI from './gemini' import antigravityAPI from './antigravity' import userAttributesAPI from './userAttributes' import opsAPI from './ops' +import errorPassthroughAPI from './errorPassthrough' /** * Unified admin API object for convenient access @@ -39,7 +40,8 @@ export const adminAPI = { gemini: geminiAPI, antigravity: antigravityAPI, userAttributes: userAttributesAPI, - ops: opsAPI + ops: opsAPI, + errorPassthrough: errorPassthroughAPI } export { @@ -58,10 +60,12 @@ export { geminiAPI, antigravityAPI, userAttributesAPI, - opsAPI + opsAPI, + errorPassthroughAPI } export default adminAPI // Re-export types used by components export type { BalanceHistoryItem } from './users' +export type { ErrorPassthroughRule, CreateRuleRequest, UpdateRuleRequest } from './errorPassthrough' diff --git a/frontend/src/components/admin/ErrorPassthroughRulesModal.vue b/frontend/src/components/admin/ErrorPassthroughRulesModal.vue new file mode 100644 index 00000000..b93319c5 --- /dev/null +++ b/frontend/src/components/admin/ErrorPassthroughRulesModal.vue @@ -0,0 +1,623 @@ +