From 979114db45fe817a2741d63f971c2287533b0bc0 Mon Sep 17 00:00:00 2001 From: IanShaw027 <131567472+IanShaw027@users.noreply.github.com> Date: Thu, 5 Feb 2026 13:57:02 +0800 Subject: [PATCH 01/16] =?UTF-8?q?fix(gemini):=20=E4=BF=AE=E5=A4=8D?= =?UTF-8?q?=E5=B7=B2=E6=B3=A8=E5=86=8C=E7=94=A8=E6=88=B7=20OAuth=20?= =?UTF-8?q?=E6=8E=88=E6=9D=83=E6=97=B6=E9=94=99=E8=AF=AF=E8=B0=83=E7=94=A8?= =?UTF-8?q?=20onboardUser=20=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 问题:Google One Ultra 等已注册用户在 OAuth 授权时,如果 LoadCodeAssist 返回了 currentTier/paidTier 但没有返回 cloudaicompanionProject,之前的 逻辑会继续调用 onboardUser,导致 INVALID_ARGUMENT 错误。 修复:对齐 Gemini CLI 的处理逻辑: - 当检测到用户已注册(有 currentTier/paidTier)时,不再调用 onboardUser - 先尝试从 Cloud Resource Manager 获取可用项目 - 如果仍无法获取,返回友好的错误提示,引导用户手动填写 Project ID 这个修复解决了 Google One 订阅用户无法正常授权的问题。 --- .../internal/service/gemini_oauth_service.go | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/backend/internal/service/gemini_oauth_service.go b/backend/internal/service/gemini_oauth_service.go index bc84baeb..fd2932e6 100644 --- a/backend/internal/service/gemini_oauth_service.go +++ b/backend/internal/service/gemini_oauth_service.go @@ -944,6 +944,32 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr return strings.TrimSpace(loadResp.CloudAICompanionProject), tierID, nil } + // 关键逻辑:对齐 Gemini CLI 对“已注册用户”的处理方式。 + // 当 LoadCodeAssist 返回了 currentTier / paidTier(表示账号已注册)但没有返回 cloudaicompanionProject 时: + // - 不要再调用 onboardUser(通常不会再分配 project_id,且可能触发 INVALID_ARGUMENT) + // - 先尝试从 Cloud Resource Manager 获取可用项目;仍失败则提示用户手动填写 project_id + if loadResp != nil { + registeredTierID := strings.TrimSpace(loadResp.GetTier()) + if registeredTierID != "" { + // 已注册但未返回 cloudaicompanionProject,这在 Google One 用户中较常见:需要用户自行提供 project_id。 + log.Printf("[GeminiOAuth] User has tier (%s) but no cloudaicompanionProject, trying Cloud Resource Manager...", registeredTierID) + + // Try to get project from Cloud Resource Manager + fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL) + if fbErr == nil && strings.TrimSpace(fallback) != "" { + log.Printf("[GeminiOAuth] Found project from Cloud Resource Manager: %s", fallback) + return strings.TrimSpace(fallback), tierID, nil + } + + // No project found - user must provide project_id manually + log.Printf("[GeminiOAuth] No project found from Cloud Resource Manager, user must provide project_id manually") + return "", tierID, fmt.Errorf("user is registered (tier: %s) but no project_id available. Please provide Project ID manually in the authorization form, or create a project at https://console.cloud.google.com", registeredTierID) + } + } + + // 未检测到 currentTier/paidTier,视为新用户,继续调用 onboardUser + log.Printf("[GeminiOAuth] No currentTier/paidTier found, proceeding with onboardUser (tierID: %s)", tierID) + req := &geminicli.OnboardUserRequest{ TierID: tierID, Metadata: geminicli.LoadCodeAssistMetadata{ From 2b192f7dcab70187999c0e04743ff33e024e37f7 Mon Sep 17 00:00:00 2001 From: shaw Date: Thu, 5 Feb 2026 16:00:34 +0800 Subject: [PATCH 02/16] =?UTF-8?q?feat:=20=E6=94=AF=E6=8C=81=E7=94=A8?= =?UTF-8?q?=E6=88=B7=E4=B8=93=E5=B1=9E=E5=88=86=E7=BB=84=E5=80=8D=E7=8E=87?= =?UTF-8?q?=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/cmd/server/wire_gen.go | 7 +- backend/go.sum | 8 + .../internal/handler/admin/user_handler.go | 4 + backend/internal/handler/api_key_handler.go | 18 + backend/internal/handler/dto/mappers.go | 5 +- backend/internal/handler/dto/types.go | 3 + .../repository/user_group_rate_repo.go | 113 ++++++ backend/internal/repository/wire.go | 1 + backend/internal/server/api_contract_test.go | 4 +- .../middleware/api_key_auth_google_test.go | 2 + .../server/middleware/api_key_auth_test.go | 8 +- backend/internal/server/routes/user.go | 1 + backend/internal/service/admin_service.go | 41 ++- backend/internal/service/api_key_service.go | 46 ++- .../service/api_key_service_cache_test.go | 18 +- backend/internal/service/gateway_service.go | 21 +- backend/internal/service/user.go | 26 +- backend/internal/service/user_group_rate.go | 25 ++ .../047_add_user_group_rate_multipliers.sql | 19 + frontend/src/api/groups.ts | 12 +- .../admin/user/UserAllowedGroupsModal.vue | 337 ++++++++++++++++-- frontend/src/components/common/GroupBadge.vue | 29 +- .../src/components/common/GroupOptionItem.vue | 5 +- frontend/src/i18n/locales/en.ts | 10 + frontend/src/i18n/locales/zh.ts | 10 + frontend/src/types/index.ts | 5 + frontend/src/views/user/KeysView.vue | 16 + 27 files changed, 705 insertions(+), 89 deletions(-) create mode 100644 backend/internal/repository/user_group_rate_repo.go create mode 100644 backend/internal/service/user_group_rate.go create mode 100644 backend/migrations/047_add_user_group_rate_multipliers.sql diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 47b1e8ac..3ca86f91 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -59,8 +59,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, configConfig) apiKeyRepository := repository.NewAPIKeyRepository(client) groupRepository := repository.NewGroupRepository(client, db) + userGroupRateRepository := repository.NewUserGroupRateRepository(db) apiKeyCache := repository.NewAPIKeyCache(redisClient) - apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, apiKeyCache, configConfig) + apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) promoService := service.NewPromoService(promoCodeRepository, userRepository, billingCacheService, client, apiKeyAuthCacheInvalidator) authService := service.NewAuthService(userRepository, redeemCodeRepository, refreshTokenCache, configConfig, settingService, emailService, turnstileService, emailQueueService, promoService) @@ -100,7 +101,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { proxyRepository := repository.NewProxyRepository(client, db) proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) - adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator) + adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator) adminUserHandler := admin.NewUserHandler(adminService) groupHandler := admin.NewGroupHandler(adminService) claudeOAuthClient := repository.NewClaudeOAuthClient() @@ -153,7 +154,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { identityService := service.NewIdentityService(identityCache) deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService) - gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache) + gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache) openAITokenProvider := service.NewOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService) openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, userRepository, userSubscriptionRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) diff --git a/backend/go.sum b/backend/go.sum index 171995c7..3000eb38 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -170,6 +170,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= +github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= @@ -203,6 +205,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= +github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= @@ -230,6 +234,8 @@ github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkr github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= @@ -252,6 +258,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= +github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index ac76689d..1c772e7d 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -45,6 +45,9 @@ type UpdateUserRequest struct { Concurrency *int `json:"concurrency"` Status string `json:"status" binding:"omitempty,oneof=active disabled"` AllowedGroups *[]int64 `json:"allowed_groups"` + // GroupRates 用户专属分组倍率配置 + // map[groupID]*rate,nil 表示删除该分组的专属倍率 + GroupRates map[int64]*float64 `json:"group_rates"` } // UpdateBalanceRequest represents balance update request @@ -183,6 +186,7 @@ func (h *UserHandler) Update(c *gin.Context) { Concurrency: req.Concurrency, Status: req.Status, AllowedGroups: req.AllowedGroups, + GroupRates: req.GroupRates, }) if err != nil { response.ErrorFrom(c, err) diff --git a/backend/internal/handler/api_key_handler.go b/backend/internal/handler/api_key_handler.go index 9717194b..f1a18ad2 100644 --- a/backend/internal/handler/api_key_handler.go +++ b/backend/internal/handler/api_key_handler.go @@ -243,3 +243,21 @@ func (h *APIKeyHandler) GetAvailableGroups(c *gin.Context) { } response.Success(c, out) } + +// GetUserGroupRates 获取当前用户的专属分组倍率配置 +// GET /api/v1/groups/rates +func (h *APIKeyHandler) GetUserGroupRates(c *gin.Context) { + subject, ok := middleware2.GetAuthSubjectFromContext(c) + if !ok { + response.Unauthorized(c, "User not authenticated") + return + } + + rates, err := h.apiKeyService.GetUserGroupRates(c.Request.Context(), subject.UserID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, rates) +} diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 4f8d1eeb..da0e9fc6 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -58,8 +58,9 @@ func UserFromServiceAdmin(u *service.User) *AdminUser { return nil } return &AdminUser{ - User: *base, - Notes: u.Notes, + User: *base, + Notes: u.Notes, + GroupRates: u.GroupRates, } } diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 8e6faf02..71bb1ed4 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -29,6 +29,9 @@ type AdminUser struct { User Notes string `json:"notes"` + // GroupRates 用户专属分组倍率配置 + // map[groupID]rateMultiplier + GroupRates map[int64]float64 `json:"group_rates,omitempty"` } type APIKey struct { diff --git a/backend/internal/repository/user_group_rate_repo.go b/backend/internal/repository/user_group_rate_repo.go new file mode 100644 index 00000000..eb65403b --- /dev/null +++ b/backend/internal/repository/user_group_rate_repo.go @@ -0,0 +1,113 @@ +package repository + +import ( + "context" + "database/sql" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +type userGroupRateRepository struct { + sql sqlExecutor +} + +// NewUserGroupRateRepository 创建用户专属分组倍率仓储 +func NewUserGroupRateRepository(sqlDB *sql.DB) service.UserGroupRateRepository { + return &userGroupRateRepository{sql: sqlDB} +} + +// GetByUserID 获取用户的所有专属分组倍率 +func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) { + query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1` + rows, err := r.sql.QueryContext(ctx, query, userID) + if err != nil { + return nil, err + } + defer func() { _ = rows.Close() }() + + result := make(map[int64]float64) + for rows.Next() { + var groupID int64 + var rate float64 + if err := rows.Scan(&groupID, &rate); err != nil { + return nil, err + } + result[groupID] = rate + } + if err := rows.Err(); err != nil { + return nil, err + } + return result, nil +} + +// GetByUserAndGroup 获取用户在特定分组的专属倍率 +func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) { + query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2` + var rate float64 + err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rate) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, err + } + return &rate, nil +} + +// SyncUserGroupRates 同步用户的分组专属倍率 +func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error { + if len(rates) == 0 { + // 如果传入空 map,删除该用户的所有专属倍率 + _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID) + return err + } + + // 分离需要删除和需要 upsert 的记录 + var toDelete []int64 + toUpsert := make(map[int64]float64) + for groupID, rate := range rates { + if rate == nil { + toDelete = append(toDelete, groupID) + } else { + toUpsert[groupID] = *rate + } + } + + // 删除指定的记录 + for _, groupID := range toDelete { + _, err := r.sql.ExecContext(ctx, + `DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2`, + userID, groupID) + if err != nil { + return err + } + } + + // Upsert 记录 + now := time.Now() + for groupID, rate := range toUpsert { + _, err := r.sql.ExecContext(ctx, ` + INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at) + VALUES ($1, $2, $3, $4, $4) + ON CONFLICT (user_id, group_id) DO UPDATE SET rate_multiplier = $3, updated_at = $4 + `, userID, groupID, rate, now) + if err != nil { + return err + } + } + + return nil +} + +// DeleteByGroupID 删除指定分组的所有用户专属倍率 +func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error { + _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID) + return err +} + +// DeleteByUserID 删除指定用户的所有专属倍率 +func (r *userGroupRateRepository) DeleteByUserID(ctx context.Context, userID int64) error { + _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID) + return err +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 857ce3e8..5437de35 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -66,6 +66,7 @@ var ProviderSet = wire.NewSet( NewUserSubscriptionRepository, NewUserAttributeDefinitionRepository, NewUserAttributeValueRepository, + NewUserGroupRateRepository, // Cache implementations NewGatewayCache, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index e197b776..f5f8cda7 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -593,7 +593,7 @@ func newContractDeps(t *testing.T) *contractDeps { } userService := service.NewUserService(userRepo, nil) - apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, apiKeyCache, cfg) + apiKeyService := service.NewAPIKeyService(apiKeyRepo, userRepo, groupRepo, userSubRepo, nil, apiKeyCache, cfg) usageRepo := newStubUsageLogRepo() usageService := service.NewUsageService(usageRepo, userRepo, nil, nil) @@ -607,7 +607,7 @@ func newContractDeps(t *testing.T) *contractDeps { settingRepo := newStubSettingRepo() settingService := service.NewSettingService(settingRepo, cfg) - adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil) + adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) diff --git a/backend/internal/server/middleware/api_key_auth_google_test.go b/backend/internal/server/middleware/api_key_auth_google_test.go index c14582bd..38b93cb2 100644 --- a/backend/internal/server/middleware/api_key_auth_google_test.go +++ b/backend/internal/server/middleware/api_key_auth_google_test.go @@ -93,6 +93,7 @@ func newTestAPIKeyService(repo service.APIKeyRepository) *service.APIKeyService nil, // userRepo (unused in GetByKey) nil, // groupRepo nil, // userSubRepo + nil, // userGroupRateRepo nil, // cache &config.Config{}, ) @@ -187,6 +188,7 @@ func TestApiKeyAuthWithSubscriptionGoogleSetsGroupContext(t *testing.T) { nil, nil, nil, + nil, &config.Config{RunMode: config.RunModeSimple}, ) diff --git a/backend/internal/server/middleware/api_key_auth_test.go b/backend/internal/server/middleware/api_key_auth_test.go index a03f6168..9d514818 100644 --- a/backend/internal/server/middleware/api_key_auth_test.go +++ b/backend/internal/server/middleware/api_key_auth_test.go @@ -59,7 +59,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { t.Run("simple_mode_bypasses_quota_check", func(t *testing.T) { cfg := &config.Config{RunMode: config.RunModeSimple} - apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) subscriptionService := service.NewSubscriptionService(nil, &stubUserSubscriptionRepo{}, nil) router := newAuthTestRouter(apiKeyService, subscriptionService, cfg) @@ -73,7 +73,7 @@ func TestSimpleModeBypassesQuotaCheck(t *testing.T) { t.Run("standard_mode_enforces_quota_check", func(t *testing.T) { cfg := &config.Config{RunMode: config.RunModeStandard} - apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) now := time.Now() sub := &service.UserSubscription{ @@ -150,7 +150,7 @@ func TestAPIKeyAuthSetsGroupContext(t *testing.T) { } cfg := &config.Config{RunMode: config.RunModeSimple} - apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) router := gin.New() router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg))) router.GET("/t", func(c *gin.Context) { @@ -208,7 +208,7 @@ func TestAPIKeyAuthOverwritesInvalidContextGroup(t *testing.T) { } cfg := &config.Config{RunMode: config.RunModeSimple} - apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, cfg) + apiKeyService := service.NewAPIKeyService(apiKeyRepo, nil, nil, nil, nil, nil, cfg) router := gin.New() router.Use(gin.HandlerFunc(NewAPIKeyAuthMiddleware(apiKeyService, nil, cfg))) diff --git a/backend/internal/server/routes/user.go b/backend/internal/server/routes/user.go index 5581e1e1..d0ed2489 100644 --- a/backend/internal/server/routes/user.go +++ b/backend/internal/server/routes/user.go @@ -49,6 +49,7 @@ func RegisterUserRoutes( groups := authenticated.Group("/groups") { groups.GET("/available", h.APIKey.GetAvailableGroups) + groups.GET("/rates", h.APIKey.GetUserGroupRates) } // 使用记录 diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index c512f235..f215f82e 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -93,6 +93,9 @@ type UpdateUserInput struct { Concurrency *int // 使用指针区分"未提供"和"设置为0" Status string AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组" + // GroupRates 用户专属分组倍率配置 + // map[groupID]*rate,nil 表示删除该分组的专属倍率 + GroupRates map[int64]*float64 } type CreateGroupInput struct { @@ -293,6 +296,7 @@ type adminServiceImpl struct { proxyRepo ProxyRepository apiKeyRepo APIKeyRepository redeemCodeRepo RedeemCodeRepository + userGroupRateRepo UserGroupRateRepository billingCacheService *BillingCacheService proxyProber ProxyExitInfoProber proxyLatencyCache ProxyLatencyCache @@ -307,6 +311,7 @@ func NewAdminService( proxyRepo ProxyRepository, apiKeyRepo APIKeyRepository, redeemCodeRepo RedeemCodeRepository, + userGroupRateRepo UserGroupRateRepository, billingCacheService *BillingCacheService, proxyProber ProxyExitInfoProber, proxyLatencyCache ProxyLatencyCache, @@ -319,6 +324,7 @@ func NewAdminService( proxyRepo: proxyRepo, apiKeyRepo: apiKeyRepo, redeemCodeRepo: redeemCodeRepo, + userGroupRateRepo: userGroupRateRepo, billingCacheService: billingCacheService, proxyProber: proxyProber, proxyLatencyCache: proxyLatencyCache, @@ -333,11 +339,35 @@ func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, fi if err != nil { return nil, 0, err } + // 批量加载用户专属分组倍率 + if s.userGroupRateRepo != nil && len(users) > 0 { + for i := range users { + rates, err := s.userGroupRateRepo.GetByUserID(ctx, users[i].ID) + if err != nil { + log.Printf("failed to load user group rates: user_id=%d err=%v", users[i].ID, err) + continue + } + users[i].GroupRates = rates + } + } return users, result.Total, nil } func (s *adminServiceImpl) GetUser(ctx context.Context, id int64) (*User, error) { - return s.userRepo.GetByID(ctx, id) + user, err := s.userRepo.GetByID(ctx, id) + if err != nil { + return nil, err + } + // 加载用户专属分组倍率 + if s.userGroupRateRepo != nil { + rates, err := s.userGroupRateRepo.GetByUserID(ctx, id) + if err != nil { + log.Printf("failed to load user group rates: user_id=%d err=%v", id, err) + } else { + user.GroupRates = rates + } + } + return user, nil } func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInput) (*User, error) { @@ -406,6 +436,14 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda if err := s.userRepo.Update(ctx, user); err != nil { return nil, err } + + // 同步用户专属分组倍率 + if input.GroupRates != nil && s.userGroupRateRepo != nil { + if err := s.userGroupRateRepo.SyncUserGroupRates(ctx, user.ID, input.GroupRates); err != nil { + log.Printf("failed to sync user group rates: user_id=%d err=%v", user.ID, err) + } + } + if s.authCacheInvalidator != nil { if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole { s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID) @@ -941,6 +979,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error { if err != nil { return err } + // 注意:user_group_rate_multipliers 表通过外键 ON DELETE CASCADE 自动清理 // 事务成功后,异步失效受影响用户的订阅缓存 if len(affectedUserIDs) > 0 && s.billingCacheService != nil { diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index b27682f3..cb1dd60a 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -115,15 +115,16 @@ type UpdateAPIKeyRequest struct { // APIKeyService API Key服务 type APIKeyService struct { - apiKeyRepo APIKeyRepository - userRepo UserRepository - groupRepo GroupRepository - userSubRepo UserSubscriptionRepository - cache APIKeyCache - cfg *config.Config - authCacheL1 *ristretto.Cache - authCfg apiKeyAuthCacheConfig - authGroup singleflight.Group + apiKeyRepo APIKeyRepository + userRepo UserRepository + groupRepo GroupRepository + userSubRepo UserSubscriptionRepository + userGroupRateRepo UserGroupRateRepository + cache APIKeyCache + cfg *config.Config + authCacheL1 *ristretto.Cache + authCfg apiKeyAuthCacheConfig + authGroup singleflight.Group } // NewAPIKeyService 创建API Key服务实例 @@ -132,16 +133,18 @@ func NewAPIKeyService( userRepo UserRepository, groupRepo GroupRepository, userSubRepo UserSubscriptionRepository, + userGroupRateRepo UserGroupRateRepository, cache APIKeyCache, cfg *config.Config, ) *APIKeyService { svc := &APIKeyService{ - apiKeyRepo: apiKeyRepo, - userRepo: userRepo, - groupRepo: groupRepo, - userSubRepo: userSubRepo, - cache: cache, - cfg: cfg, + apiKeyRepo: apiKeyRepo, + userRepo: userRepo, + groupRepo: groupRepo, + userSubRepo: userSubRepo, + userGroupRateRepo: userGroupRateRepo, + cache: cache, + cfg: cfg, } svc.initAuthCache(cfg) return svc @@ -627,6 +630,19 @@ func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword return keys, nil } +// GetUserGroupRates 获取用户的专属分组倍率配置 +// 返回 map[groupID]rateMultiplier +func (s *APIKeyService) GetUserGroupRates(ctx context.Context, userID int64) (map[int64]float64, error) { + if s.userGroupRateRepo == nil { + return nil, nil + } + rates, err := s.userGroupRateRepo.GetByUserID(ctx, userID) + if err != nil { + return nil, fmt.Errorf("get user group rates: %w", err) + } + return rates, nil +} + // CheckAPIKeyQuotaAndExpiry checks if the API key is valid for use (not expired, quota not exhausted) // Returns nil if valid, error if invalid func (s *APIKeyService) CheckAPIKeyQuotaAndExpiry(apiKey *APIKey) error { diff --git a/backend/internal/service/api_key_service_cache_test.go b/backend/internal/service/api_key_service_cache_test.go index 1099b1d2..14ecbf39 100644 --- a/backend/internal/service/api_key_service_cache_test.go +++ b/backend/internal/service/api_key_service_cache_test.go @@ -167,7 +167,7 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) { NegativeTTLSeconds: 30, }, } - svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg) groupID := int64(9) cacheEntry := &APIKeyAuthCacheEntry{ @@ -223,7 +223,7 @@ func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) { NegativeTTLSeconds: 30, }, } - svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg) cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { return &APIKeyAuthCacheEntry{NotFound: true}, nil } @@ -256,7 +256,7 @@ func TestAPIKeyService_GetByKey_CacheMissStoresL2(t *testing.T) { NegativeTTLSeconds: 30, }, } - svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg) cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { return nil, redis.Nil } @@ -293,7 +293,7 @@ func TestAPIKeyService_GetByKey_UsesL1Cache(t *testing.T) { L1TTLSeconds: 60, }, } - svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg) require.NotNil(t, svc.authCacheL1) _, err := svc.GetByKey(context.Background(), "k-l1") @@ -320,7 +320,7 @@ func TestAPIKeyService_InvalidateAuthCacheByUserID(t *testing.T) { NegativeTTLSeconds: 30, }, } - svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg) svc.InvalidateAuthCacheByUserID(context.Background(), 7) require.Len(t, cache.deleteAuthKeys, 2) @@ -338,7 +338,7 @@ func TestAPIKeyService_InvalidateAuthCacheByGroupID(t *testing.T) { L2TTLSeconds: 60, }, } - svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg) svc.InvalidateAuthCacheByGroupID(context.Background(), 9) require.Len(t, cache.deleteAuthKeys, 2) @@ -356,7 +356,7 @@ func TestAPIKeyService_InvalidateAuthCacheByKey(t *testing.T) { L2TTLSeconds: 60, }, } - svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg) svc.InvalidateAuthCacheByKey(context.Background(), "k1") require.Len(t, cache.deleteAuthKeys, 1) @@ -375,7 +375,7 @@ func TestAPIKeyService_GetByKey_CachesNegativeOnRepoMiss(t *testing.T) { NegativeTTLSeconds: 30, }, } - svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg) cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) { return nil, redis.Nil } @@ -411,7 +411,7 @@ func TestAPIKeyService_GetByKey_SingleflightCollapses(t *testing.T) { Singleflight: true, }, } - svc := NewAPIKeyService(repo, nil, nil, nil, cache, cfg) + svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg) start := make(chan struct{}) wg := sync.WaitGroup{} diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 8c88c0a9..9036955a 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -384,6 +384,7 @@ type GatewayService struct { usageLogRepo UsageLogRepository userRepo UserRepository userSubRepo UserSubscriptionRepository + userGroupRateRepo UserGroupRateRepository cache GatewayCache cfg *config.Config schedulerSnapshot *SchedulerSnapshotService @@ -405,6 +406,7 @@ func NewGatewayService( usageLogRepo UsageLogRepository, userRepo UserRepository, userSubRepo UserSubscriptionRepository, + userGroupRateRepo UserGroupRateRepository, cache GatewayCache, cfg *config.Config, schedulerSnapshot *SchedulerSnapshotService, @@ -424,6 +426,7 @@ func NewGatewayService( usageLogRepo: usageLogRepo, userRepo: userRepo, userSubRepo: userSubRepo, + userGroupRateRepo: userGroupRateRepo, cache: cache, cfg: cfg, schedulerSnapshot: schedulerSnapshot, @@ -4609,10 +4612,17 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu account := input.Account subscription := input.Subscription - // 获取费率倍数 + // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) multiplier := s.cfg.Default.RateMultiplier if apiKey.GroupID != nil && apiKey.Group != nil { multiplier = apiKey.Group.RateMultiplier + + // 检查用户专属倍率 + if s.userGroupRateRepo != nil { + if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil { + multiplier = *userRate + } + } } var cost *CostBreakdown @@ -4773,10 +4783,17 @@ func (s *GatewayService) RecordUsageWithLongContext(ctx context.Context, input * account := input.Account subscription := input.Subscription - // 获取费率倍数 + // 获取费率倍数(优先级:用户专属 > 分组默认 > 系统默认) multiplier := s.cfg.Default.RateMultiplier if apiKey.GroupID != nil && apiKey.Group != nil { multiplier = apiKey.Group.RateMultiplier + + // 检查用户专属倍率 + if s.userGroupRateRepo != nil { + if userRate, err := s.userGroupRateRepo.GetByUserAndGroup(ctx, user.ID, *apiKey.GroupID); err == nil && userRate != nil { + multiplier = *userRate + } + } } var cost *CostBreakdown diff --git a/backend/internal/service/user.go b/backend/internal/service/user.go index 0f589eb3..e56d83bf 100644 --- a/backend/internal/service/user.go +++ b/backend/internal/service/user.go @@ -21,6 +21,10 @@ type User struct { CreatedAt time.Time UpdatedAt time.Time + // GroupRates 用户专属分组倍率配置 + // map[groupID]rateMultiplier + GroupRates map[int64]float64 + // TOTP 双因素认证字段 TotpSecretEncrypted *string // AES-256-GCM 加密的 TOTP 密钥 TotpEnabled bool // 是否启用 TOTP @@ -40,18 +44,20 @@ func (u *User) IsActive() bool { // CanBindGroup checks whether a user can bind to a given group. // For standard groups: -// - If AllowedGroups is non-empty, only allow binding to IDs in that list. -// - If AllowedGroups is empty (nil or length 0), allow binding to any non-exclusive group. +// - Public groups (non-exclusive): all users can bind +// - Exclusive groups: only users with the group in AllowedGroups can bind func (u *User) CanBindGroup(groupID int64, isExclusive bool) bool { - if len(u.AllowedGroups) > 0 { - for _, id := range u.AllowedGroups { - if id == groupID { - return true - } - } - return false + // 公开分组(非专属):所有用户都可以绑定 + if !isExclusive { + return true } - return !isExclusive + // 专属分组:需要在 AllowedGroups 中 + for _, id := range u.AllowedGroups { + if id == groupID { + return true + } + } + return false } func (u *User) SetPassword(password string) error { diff --git a/backend/internal/service/user_group_rate.go b/backend/internal/service/user_group_rate.go new file mode 100644 index 00000000..9eb5f067 --- /dev/null +++ b/backend/internal/service/user_group_rate.go @@ -0,0 +1,25 @@ +package service + +import "context" + +// UserGroupRateRepository 用户专属分组倍率仓储接口 +// 允许管理员为特定用户设置分组的专属计费倍率,覆盖分组默认倍率 +type UserGroupRateRepository interface { + // GetByUserID 获取用户的所有专属分组倍率 + // 返回 map[groupID]rateMultiplier + GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) + + // GetByUserAndGroup 获取用户在特定分组的专属倍率 + // 如果未设置专属倍率,返回 nil + GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) + + // SyncUserGroupRates 同步用户的分组专属倍率 + // rates: map[groupID]*rateMultiplier,nil 表示删除该分组的专属倍率 + SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error + + // DeleteByGroupID 删除指定分组的所有用户专属倍率(分组删除时调用) + DeleteByGroupID(ctx context.Context, groupID int64) error + + // DeleteByUserID 删除指定用户的所有专属倍率(用户删除时调用) + DeleteByUserID(ctx context.Context, userID int64) error +} diff --git a/backend/migrations/047_add_user_group_rate_multipliers.sql b/backend/migrations/047_add_user_group_rate_multipliers.sql new file mode 100644 index 00000000..a37d5bcd --- /dev/null +++ b/backend/migrations/047_add_user_group_rate_multipliers.sql @@ -0,0 +1,19 @@ +-- 用户专属分组倍率表 +-- 允许管理员为特定用户设置分组的专属计费倍率,覆盖分组默认倍率 +CREATE TABLE IF NOT EXISTS user_group_rate_multipliers ( + user_id BIGINT NOT NULL REFERENCES users(id) ON DELETE CASCADE, + group_id BIGINT NOT NULL REFERENCES groups(id) ON DELETE CASCADE, + rate_multiplier DECIMAL(10,4) NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + PRIMARY KEY (user_id, group_id) +); + +-- 按 group_id 查询索引(删除分组时清理关联记录) +CREATE INDEX IF NOT EXISTS idx_user_group_rate_multipliers_group_id + ON user_group_rate_multipliers(group_id); + +COMMENT ON TABLE user_group_rate_multipliers IS '用户专属分组倍率配置'; +COMMENT ON COLUMN user_group_rate_multipliers.user_id IS '用户ID'; +COMMENT ON COLUMN user_group_rate_multipliers.group_id IS '分组ID'; +COMMENT ON COLUMN user_group_rate_multipliers.rate_multiplier IS '专属计费倍率(覆盖分组默认倍率)'; diff --git a/frontend/src/api/groups.ts b/frontend/src/api/groups.ts index 0f366d51..0963a7a6 100644 --- a/frontend/src/api/groups.ts +++ b/frontend/src/api/groups.ts @@ -18,8 +18,18 @@ export async function getAvailable(): Promise { return data } +/** + * Get current user's custom group rate multipliers + * @returns Map of group_id to custom rate_multiplier + */ +export async function getUserGroupRates(): Promise> { + const { data } = await apiClient.get | null>('/groups/rates') + return data || {} +} + export const userGroupsAPI = { - getAvailable + getAvailable, + getUserGroupRates } export default userGroupsAPI diff --git a/frontend/src/components/admin/user/UserAllowedGroupsModal.vue b/frontend/src/components/admin/user/UserAllowedGroupsModal.vue index 825d2be5..bccc22c7 100644 --- a/frontend/src/components/admin/user/UserAllowedGroupsModal.vue +++ b/frontend/src/components/admin/user/UserAllowedGroupsModal.vue @@ -1,59 +1,328 @@ + + diff --git a/frontend/src/components/common/GroupBadge.vue b/frontend/src/components/common/GroupBadge.vue index 239d0452..83f4b8aa 100644 --- a/frontend/src/components/common/GroupBadge.vue +++ b/frontend/src/components/common/GroupBadge.vue @@ -11,7 +11,14 @@ {{ name }} - {{ labelText }} + + @@ -27,6 +34,7 @@ interface Props { platform?: GroupPlatform subscriptionType?: SubscriptionType rateMultiplier?: number + userRateMultiplier?: number | null // 用户专属倍率 showRate?: boolean daysRemaining?: number | null // 剩余天数(订阅类型时使用) } @@ -34,20 +42,31 @@ interface Props { const props = withDefaults(defineProps(), { subscriptionType: 'standard', showRate: true, - daysRemaining: null + daysRemaining: null, + userRateMultiplier: null }) const { t } = useI18n() const isSubscription = computed(() => props.subscriptionType === 'subscription') +// 是否有专属倍率(且与默认倍率不同) +const hasCustomRate = computed(() => { + return ( + props.userRateMultiplier !== null && + props.userRateMultiplier !== undefined && + props.rateMultiplier !== undefined && + props.userRateMultiplier !== props.rateMultiplier + ) +}) + // 是否显示右侧标签 const showLabel = computed(() => { if (!props.showRate) return false // 订阅类型:显示天数或"订阅" if (isSubscription.value) return true - // 标准类型:显示倍率 - return props.rateMultiplier !== undefined + // 标准类型:显示倍率(包括专属倍率) + return props.rateMultiplier !== undefined || hasCustomRate.value }) // Label text @@ -71,7 +90,7 @@ const labelClass = computed(() => { const base = 'px-1.5 py-0.5 rounded text-[10px] font-semibold' if (!isSubscription.value) { - // Standard: subtle background + // Standard: subtle background (不再为专属倍率使用不同的背景色) return `${base} bg-black/10 dark:bg-white/10` } diff --git a/frontend/src/components/common/GroupOptionItem.vue b/frontend/src/components/common/GroupOptionItem.vue index 3283c330..44750350 100644 --- a/frontend/src/components/common/GroupOptionItem.vue +++ b/frontend/src/components/common/GroupOptionItem.vue @@ -9,6 +9,7 @@ :platform="platform" :subscription-type="subscriptionType" :rate-multiplier="rateMultiplier" + :user-rate-multiplier="userRateMultiplier" /> (), { subscriptionType: 'standard', selected: false, - showCheckmark: true + showCheckmark: true, + userRateMultiplier: null }) diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index fb255c1a..a4571b10 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -849,6 +849,16 @@ export default { allowedGroupsUpdated: 'Allowed groups updated successfully', failedToLoadGroups: 'Failed to load groups', failedToUpdateAllowedGroups: 'Failed to update allowed groups', + // User Group Configuration + groupConfig: 'User Group Configuration', + groupConfigHint: 'Configure custom rate multipliers for user {email} (overrides group defaults)', + exclusiveGroups: 'Exclusive Groups', + publicGroups: 'Public Groups (Default Available)', + defaultRate: 'Default Rate', + customRate: 'Custom Rate', + useDefaultRate: 'Use Default', + customRatePlaceholder: 'Leave empty for default', + groupConfigUpdated: 'Group configuration updated successfully', deposit: 'Deposit', withdraw: 'Withdraw', depositAmount: 'Deposit Amount', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index e964aae2..8c6b1d91 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -910,6 +910,16 @@ export default { allowedGroupsUpdated: '允许分组更新成功', failedToLoadGroups: '加载分组列表失败', failedToUpdateAllowedGroups: '更新允许分组失败', + // 用户分组配置 + groupConfig: '用户分组配置', + groupConfigHint: '为用户 {email} 配置专属分组倍率(覆盖分组默认倍率)', + exclusiveGroups: '专属分组', + publicGroups: '公开分组(默认可用)', + defaultRate: '默认倍率', + customRate: '专属倍率', + useDefaultRate: '使用默认', + customRatePlaceholder: '留空使用默认', + groupConfigUpdated: '分组配置更新成功', deposit: '充值', withdraw: '退款', depositAmount: '充值金额', diff --git a/frontend/src/types/index.ts b/frontend/src/types/index.ts index eb53de44..a87ae4ca 100644 --- a/frontend/src/types/index.ts +++ b/frontend/src/types/index.ts @@ -41,6 +41,8 @@ export interface User { export interface AdminUser extends User { // 管理员备注(普通用户接口不返回) notes: string + // 用户专属分组倍率配置 (group_id -> rate_multiplier) + group_rates?: Record } export interface LoginRequest { @@ -966,6 +968,9 @@ export interface UpdateUserRequest { concurrency?: number status?: 'active' | 'disabled' allowed_groups?: number[] | null + // 用户专属分组倍率配置 (group_id -> rate_multiplier | null) + // null 表示删除该分组的专属倍率 + group_rates?: Record } export interface ChangePasswordRequest { diff --git a/frontend/src/views/user/KeysView.vue b/frontend/src/views/user/KeysView.vue index 51b015fa..80a64f2e 100644 --- a/frontend/src/views/user/KeysView.vue +++ b/frontend/src/views/user/KeysView.vue @@ -73,6 +73,7 @@ :platform="row.group.platform" :subscription-type="row.group.subscription_type" :rate-multiplier="row.group.rate_multiplier" + :user-rate-multiplier="userGroupRates[row.group.id]" /> {{ t('keys.noGroup') @@ -272,6 +273,7 @@ :platform="(option as unknown as GroupOption).platform" :subscription-type="(option as unknown as GroupOption).subscriptionType" :rate-multiplier="(option as unknown as GroupOption).rate" + :user-rate-multiplier="(option as unknown as GroupOption).userRate" /> {{ t('keys.selectGroup') }} @@ -281,6 +283,7 @@ :platform="(option as unknown as GroupOption).platform" :subscription-type="(option as unknown as GroupOption).subscriptionType" :rate-multiplier="(option as unknown as GroupOption).rate" + :user-rate-multiplier="(option as unknown as GroupOption).userRate" :description="(option as unknown as GroupOption).description" :selected="selected" /> @@ -667,6 +670,7 @@ :platform="option.platform" :subscription-type="option.subscriptionType" :rate-multiplier="option.rate" + :user-rate-multiplier="option.userRate" :description="option.description" :selected=" selectedKeyForGroup?.group_id === option.value || @@ -718,6 +722,7 @@ interface GroupOption { label: string description: string | null rate: number + userRate: number | null subscriptionType: SubscriptionType platform: GroupPlatform } @@ -742,6 +747,7 @@ const groups = ref([]) const loading = ref(false) const submitting = ref(false) const usageStats = ref>({}) +const userGroupRates = ref>({}) const pagination = ref({ page: 1, @@ -825,6 +831,7 @@ const groupOptions = computed(() => label: group.name, description: group.description, rate: group.rate_multiplier, + userRate: userGroupRates.value[group.id] ?? null, subscriptionType: group.subscription_type, platform: group.platform })) @@ -899,6 +906,14 @@ const loadGroups = async () => { } } +const loadUserGroupRates = async () => { + try { + userGroupRates.value = await userGroupsAPI.getUserGroupRates() + } catch (error) { + console.error('Failed to load user group rates:', error) + } +} + const loadPublicSettings = async () => { try { publicSettings.value = await authAPI.getPublicSettings() @@ -1268,6 +1283,7 @@ const closeCcsClientSelect = () => { onMounted(() => { loadApiKeys() loadGroups() + loadUserGroupRates() loadPublicSettings() document.addEventListener('click', closeGroupSelector) }) From 1d8b686446cc5374ddb1b192116a9afd8395b197 Mon Sep 17 00:00:00 2001 From: shaw Date: Thu, 5 Feb 2026 16:17:11 +0800 Subject: [PATCH 03/16] =?UTF-8?q?chore:=20=E7=A7=BB=E9=99=A4=E6=97=A0?= =?UTF-8?q?=E5=85=B3=E7=9A=84md=E6=96=87=E6=A1=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- PR_DESCRIPTION.md | 164 ---------------------------------------------- 1 file changed, 164 deletions(-) delete mode 100644 PR_DESCRIPTION.md diff --git a/PR_DESCRIPTION.md b/PR_DESCRIPTION.md deleted file mode 100644 index b240f45c..00000000 --- a/PR_DESCRIPTION.md +++ /dev/null @@ -1,164 +0,0 @@ -## 概述 - -全面增强运维监控系统(Ops)的错误日志管理和告警静默功能,优化前端 UI 组件代码质量和用户体验。本次更新重构了核心服务层和数据访问层,提升系统可维护性和运维效率。 - -## 主要改动 - -### 1. 错误日志查询优化 - -**功能特性:** -- 新增 GetErrorLogByID 接口,支持按 ID 精确查询错误详情 -- 优化错误日志过滤逻辑,支持多维度筛选(平台、阶段、来源、所有者等) -- 改进查询参数处理,简化代码结构 -- 增强错误分类和标准化处理 -- 支持错误解决状态追踪(resolved 字段) - -**技术实现:** -- `ops_handler.go` - 新增单条错误日志查询接口 -- `ops_repo.go` - 优化数据查询和过滤条件构建 -- `ops_models.go` - 扩展错误日志数据模型 -- 前端 API 接口同步更新 - -### 2. 告警静默功能 - -**功能特性:** -- 支持按规则、平台、分组、区域等维度静默告警 -- 可设置静默时长和原因说明 -- 静默记录可追溯,记录创建人和创建时间 -- 自动过期机制,避免永久静默 - -**技术实现:** -- `037_ops_alert_silences.sql` - 新增告警静默表 -- `ops_alerts.go` - 告警静默逻辑实现 -- `ops_alerts_handler.go` - 告警静默 API 接口 -- `OpsAlertEventsCard.vue` - 前端告警静默操作界面 - -**数据库结构:** - -| 字段 | 类型 | 说明 | -|------|------|------| -| rule_id | BIGINT | 告警规则 ID | -| platform | VARCHAR(64) | 平台标识 | -| group_id | BIGINT | 分组 ID(可选) | -| region | VARCHAR(64) | 区域(可选) | -| until | TIMESTAMPTZ | 静默截止时间 | -| reason | TEXT | 静默原因 | -| created_by | BIGINT | 创建人 ID | - -### 3. 错误分类标准化 - -**功能特性:** -- 统一错误阶段分类(request|auth|routing|upstream|network|internal) -- 规范错误归属分类(client|provider|platform) -- 标准化错误来源分类(client_request|upstream_http|gateway) -- 自动迁移历史数据到新分类体系 - -**技术实现:** -- `038_ops_errors_resolution_retry_results_and_standardize_classification.sql` - 分类标准化迁移 -- 自动映射历史遗留分类到新标准 -- 自动解决已恢复的上游错误(客户端状态码 < 400) - -### 4. Gateway 服务集成 - -**功能特性:** -- 完善各 Gateway 服务的 Ops 集成 -- 统一错误日志记录接口 -- 增强上游错误追踪能力 - -**涉及服务:** -- `antigravity_gateway_service.go` - Antigravity 网关集成 -- `gateway_service.go` - 通用网关集成 -- `gemini_messages_compat_service.go` - Gemini 兼容层集成 -- `openai_gateway_service.go` - OpenAI 网关集成 - -### 5. 前端 UI 优化 - -**代码重构:** -- 大幅简化错误详情模态框代码(从 828 行优化到 450 行) -- 优化错误日志表格组件,提升可读性 -- 清理未使用的 i18n 翻译,减少冗余 -- 统一组件代码风格和格式 -- 优化骨架屏组件,更好匹配实际看板布局 - -**布局改进:** -- 修复模态框内容溢出和滚动问题 -- 优化表格布局,使用 flex 布局确保正确显示 -- 改进看板头部布局和交互 -- 提升响应式体验 -- 骨架屏支持全屏模式适配 - -**交互优化:** -- 优化告警事件卡片功能和展示 -- 改进错误详情展示逻辑 -- 增强请求详情模态框 -- 完善运行时设置卡片 -- 改进加载动画效果 - -### 6. 国际化完善 - -**文案补充:** -- 补充错误日志相关的英文翻译 -- 添加告警静默功能的中英文文案 -- 完善提示文本和错误信息 -- 统一术语翻译标准 - -## 文件变更 - -**后端(26 个文件):** -- `backend/internal/handler/admin/ops_alerts_handler.go` - 告警接口增强 -- `backend/internal/handler/admin/ops_handler.go` - 错误日志接口优化 -- `backend/internal/handler/ops_error_logger.go` - 错误记录器增强 -- `backend/internal/repository/ops_repo.go` - 数据访问层重构 -- `backend/internal/repository/ops_repo_alerts.go` - 告警数据访问增强 -- `backend/internal/service/ops_*.go` - 核心服务层重构(10 个文件) -- `backend/internal/service/*_gateway_service.go` - Gateway 集成(4 个文件) -- `backend/internal/server/routes/admin.go` - 路由配置更新 -- `backend/migrations/*.sql` - 数据库迁移(2 个文件) -- 测试文件更新(5 个文件) - -**前端(13 个文件):** -- `frontend/src/views/admin/ops/OpsDashboard.vue` - 看板主页优化 -- `frontend/src/views/admin/ops/components/*.vue` - 组件重构(10 个文件) -- `frontend/src/api/admin/ops.ts` - API 接口扩展 -- `frontend/src/i18n/locales/*.ts` - 国际化文本(2 个文件) - -## 代码统计 - -- 44 个文件修改 -- 3733 行新增 -- 995 行删除 -- 净增加 2738 行 - -## 核心改进 - -**可维护性提升:** -- 重构核心服务层,职责更清晰 -- 简化前端组件代码,降低复杂度 -- 统一代码风格和命名规范 -- 清理冗余代码和未使用的翻译 -- 标准化错误分类体系 - -**功能完善:** -- 告警静默功能,减少告警噪音 -- 错误日志查询优化,提升运维效率 -- Gateway 服务集成完善,统一监控能力 -- 错误解决状态追踪,便于问题管理 - -**用户体验优化:** -- 修复多个 UI 布局问题 -- 优化交互流程 -- 完善国际化支持 -- 提升响应式体验 -- 改进加载状态展示 - -## 测试验证 - -- ✅ 错误日志查询和过滤功能 -- ✅ 告警静默创建和自动过期 -- ✅ 错误分类标准化迁移 -- ✅ Gateway 服务错误日志记录 -- ✅ 前端组件布局和交互 -- ✅ 骨架屏全屏模式适配 -- ✅ 国际化文本完整性 -- ✅ API 接口功能正确性 -- ✅ 数据库迁移执行成功 From d2527e36eb8a6b57bbf1f9c1629da206bc067900 Mon Sep 17 00:00:00 2001 From: ianshaw Date: Thu, 5 Feb 2026 20:13:06 +0800 Subject: [PATCH 04/16] =?UTF-8?q?feat(gemini):=20=E5=A2=9E=E5=BC=BA=20API?= =?UTF-8?q?=20=E6=8E=88=E6=9D=83=E9=94=99=E8=AF=AF=E5=A4=84=E7=90=86?= =?UTF-8?q?=EF=BC=8C=E8=87=AA=E5=8A=A8=E6=8F=90=E5=8F=96=E5=B9=B6=E6=98=BE?= =?UTF-8?q?=E7=A4=BA=E6=BF=80=E6=B4=BB=20URL?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 当 Gemini for Google Cloud API 未启用时(SERVICE_DISABLED 错误), 系统现在会: - 自动检测 403 PERMISSION_DENIED 错误 - 从错误响应中提取 API 激活 URL - 向用户显示清晰的错误消息和可点击的激活链接 - 提供操作指引(启用后等待几分钟) 新增文件: - internal/pkg/googleapi/error.go: Google API 错误解析器 - internal/pkg/googleapi/error_test.go: 完整的测试覆盖 - GEMINI_API_ERROR_HANDLING.md: 实现文档 修改文件: - internal/repository/geminicli_codeassist_client.go: 在 LoadCodeAssist 和 OnboardUser 中增强错误处理 这大大改善了用户体验,用户不再需要手动从错误日志中查找激活 URL。 --- backend/internal/pkg/googleapi/error.go | 109 +++++++++++++ backend/internal/pkg/googleapi/error_test.go | 143 ++++++++++++++++++ .../repository/geminicli_codeassist_client.go | 35 ++++- 3 files changed, 281 insertions(+), 6 deletions(-) create mode 100644 backend/internal/pkg/googleapi/error.go create mode 100644 backend/internal/pkg/googleapi/error_test.go diff --git a/backend/internal/pkg/googleapi/error.go b/backend/internal/pkg/googleapi/error.go new file mode 100644 index 00000000..b6374e02 --- /dev/null +++ b/backend/internal/pkg/googleapi/error.go @@ -0,0 +1,109 @@ +// Package googleapi provides helpers for Google-style API responses. +package googleapi + +import ( + "encoding/json" + "fmt" + "strings" +) + +// ErrorResponse represents a Google API error response +type ErrorResponse struct { + Error ErrorDetail `json:"error"` +} + +// ErrorDetail contains the error details from Google API +type ErrorDetail struct { + Code int `json:"code"` + Message string `json:"message"` + Status string `json:"status"` + Details []json.RawMessage `json:"details,omitempty"` +} + +// ErrorDetailInfo contains additional error information +type ErrorDetailInfo struct { + Type string `json:"@type"` + Reason string `json:"reason,omitempty"` + Domain string `json:"domain,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +// ErrorHelp contains help links +type ErrorHelp struct { + Type string `json:"@type"` + Links []HelpLink `json:"links,omitempty"` +} + +// HelpLink represents a help link +type HelpLink struct { + Description string `json:"description"` + URL string `json:"url"` +} + +// ParseError parses a Google API error response and extracts key information +func ParseError(body string) (*ErrorResponse, error) { + var errResp ErrorResponse + if err := json.Unmarshal([]byte(body), &errResp); err != nil { + return nil, fmt.Errorf("failed to parse error response: %w", err) + } + return &errResp, nil +} + +// ExtractActivationURL extracts the API activation URL from error details +func ExtractActivationURL(body string) string { + var errResp ErrorResponse + if err := json.Unmarshal([]byte(body), &errResp); err != nil { + return "" + } + + // Check error details for activation URL + for _, detailRaw := range errResp.Error.Details { + // Parse as ErrorDetailInfo + var info ErrorDetailInfo + if err := json.Unmarshal(detailRaw, &info); err == nil { + if info.Metadata != nil { + if activationURL, ok := info.Metadata["activationUrl"]; ok && activationURL != "" { + return activationURL + } + } + } + + // Parse as ErrorHelp + var help ErrorHelp + if err := json.Unmarshal(detailRaw, &help); err == nil { + for _, link := range help.Links { + if strings.Contains(link.Description, "activation") || + strings.Contains(link.Description, "API activation") || + strings.Contains(link.URL, "/apis/api/") { + return link.URL + } + } + } + } + + return "" +} + +// IsServiceDisabledError checks if the error is a SERVICE_DISABLED error +func IsServiceDisabledError(body string) bool { + var errResp ErrorResponse + if err := json.Unmarshal([]byte(body), &errResp); err != nil { + return false + } + + // Check if it's a 403 PERMISSION_DENIED with SERVICE_DISABLED reason + if errResp.Error.Code != 403 || errResp.Error.Status != "PERMISSION_DENIED" { + return false + } + + for _, detailRaw := range errResp.Error.Details { + var info ErrorDetailInfo + if err := json.Unmarshal(detailRaw, &info); err == nil { + if info.Reason == "SERVICE_DISABLED" { + return true + } + } + } + + return false +} diff --git a/backend/internal/pkg/googleapi/error_test.go b/backend/internal/pkg/googleapi/error_test.go new file mode 100644 index 00000000..992dcf85 --- /dev/null +++ b/backend/internal/pkg/googleapi/error_test.go @@ -0,0 +1,143 @@ +package googleapi + +import ( + "testing" +) + +func TestExtractActivationURL(t *testing.T) { + // Test case from the user's error message + errorBody := `{ + "error": { + "code": 403, + "message": "Gemini for Google Cloud API has not been used in project project-6eca5881-ab73-4736-843 before or it is disabled. Enable it by visiting https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843 then retry. If you enabled this API recently, wait a few minutes for the action to propagate to our systems and retry.", + "status": "PERMISSION_DENIED", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "SERVICE_DISABLED", + "domain": "googleapis.com", + "metadata": { + "service": "cloudaicompanion.googleapis.com", + "activationUrl": "https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843", + "consumer": "projects/project-6eca5881-ab73-4736-843", + "serviceTitle": "Gemini for Google Cloud API", + "containerInfo": "project-6eca5881-ab73-4736-843" + } + }, + { + "@type": "type.googleapis.com/google.rpc.LocalizedMessage", + "locale": "en-US", + "message": "Gemini for Google Cloud API has not been used in project project-6eca5881-ab73-4736-843 before or it is disabled. Enable it by visiting https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843 then retry. If you enabled this API recently, wait a few minutes for the action to propagate to our systems and retry." + }, + { + "@type": "type.googleapis.com/google.rpc.Help", + "links": [ + { + "description": "Google developers console API activation", + "url": "https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843" + } + ] + } + ] + } + }` + + activationURL := ExtractActivationURL(errorBody) + expectedURL := "https://console.developers.google.com/apis/api/cloudaicompanion.googleapis.com/overview?project=project-6eca5881-ab73-4736-843" + + if activationURL != expectedURL { + t.Errorf("Expected activation URL %s, got %s", expectedURL, activationURL) + } +} + +func TestIsServiceDisabledError(t *testing.T) { + tests := []struct { + name string + body string + expected bool + }{ + { + name: "SERVICE_DISABLED error", + body: `{ + "error": { + "code": 403, + "status": "PERMISSION_DENIED", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "SERVICE_DISABLED" + } + ] + } + }`, + expected: true, + }, + { + name: "Other 403 error", + body: `{ + "error": { + "code": 403, + "status": "PERMISSION_DENIED", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "OTHER_REASON" + } + ] + } + }`, + expected: false, + }, + { + name: "404 error", + body: `{ + "error": { + "code": 404, + "status": "NOT_FOUND" + } + }`, + expected: false, + }, + { + name: "Invalid JSON", + body: `invalid json`, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := IsServiceDisabledError(tt.body) + if result != tt.expected { + t.Errorf("Expected %v, got %v", tt.expected, result) + } + }) + } +} + +func TestParseError(t *testing.T) { + errorBody := `{ + "error": { + "code": 403, + "message": "API not enabled", + "status": "PERMISSION_DENIED" + } + }` + + errResp, err := ParseError(errorBody) + if err != nil { + t.Fatalf("Failed to parse error: %v", err) + } + + if errResp.Error.Code != 403 { + t.Errorf("Expected code 403, got %d", errResp.Error.Code) + } + + if errResp.Error.Status != "PERMISSION_DENIED" { + t.Errorf("Expected status PERMISSION_DENIED, got %s", errResp.Error.Status) + } + + if errResp.Error.Message != "API not enabled" { + t.Errorf("Expected message 'API not enabled', got %s", errResp.Error.Message) + } +} diff --git a/backend/internal/repository/geminicli_codeassist_client.go b/backend/internal/repository/geminicli_codeassist_client.go index d7f54e85..b63be1ad 100644 --- a/backend/internal/repository/geminicli_codeassist_client.go +++ b/backend/internal/repository/geminicli_codeassist_client.go @@ -6,6 +6,7 @@ import ( "time" "github.com/Wei-Shaw/sub2api/internal/pkg/geminicli" + "github.com/Wei-Shaw/sub2api/internal/pkg/googleapi" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/imroc/req/v3" @@ -38,9 +39,20 @@ func (c *geminiCliCodeAssistClient) LoadCodeAssist(ctx context.Context, accessTo return nil, fmt.Errorf("request failed: %w", err) } if !resp.IsSuccessState() { - body := geminicli.SanitizeBodyForLogs(resp.String()) - fmt.Printf("[CodeAssist] LoadCodeAssist failed: status %d, body: %s\n", resp.StatusCode, body) - return nil, fmt.Errorf("loadCodeAssist failed: status %d, body: %s", resp.StatusCode, body) + body := resp.String() + sanitizedBody := geminicli.SanitizeBodyForLogs(body) + fmt.Printf("[CodeAssist] LoadCodeAssist failed: status %d, body: %s\n", resp.StatusCode, sanitizedBody) + + // Check if this is a SERVICE_DISABLED error and extract activation URL + if googleapi.IsServiceDisabledError(body) { + activationURL := googleapi.ExtractActivationURL(body) + if activationURL != "" { + return nil, fmt.Errorf("Gemini for Google Cloud API is not enabled for this project. Please enable it by visiting: %s\n\nAfter enabling the API, wait a few minutes for the changes to propagate, then try again", activationURL) + } + return nil, fmt.Errorf("Gemini for Google Cloud API is not enabled for this project. Please enable it in the Google Cloud Console at: https://console.cloud.google.com/apis/library/cloudaicompanion.googleapis.com") + } + + return nil, fmt.Errorf("loadCodeAssist failed: status %d, body: %s", resp.StatusCode, sanitizedBody) } fmt.Printf("[CodeAssist] LoadCodeAssist success: status %d, response: %+v\n", resp.StatusCode, out) return &out, nil @@ -67,9 +79,20 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken return nil, fmt.Errorf("request failed: %w", err) } if !resp.IsSuccessState() { - body := geminicli.SanitizeBodyForLogs(resp.String()) - fmt.Printf("[CodeAssist] OnboardUser failed: status %d, body: %s\n", resp.StatusCode, body) - return nil, fmt.Errorf("onboardUser failed: status %d, body: %s", resp.StatusCode, body) + body := resp.String() + sanitizedBody := geminicli.SanitizeBodyForLogs(body) + fmt.Printf("[CodeAssist] OnboardUser failed: status %d, body: %s\n", resp.StatusCode, sanitizedBody) + + // Check if this is a SERVICE_DISABLED error and extract activation URL + if googleapi.IsServiceDisabledError(body) { + activationURL := googleapi.ExtractActivationURL(body) + if activationURL != "" { + return nil, fmt.Errorf("Gemini for Google Cloud API is not enabled for this project. Please enable it by visiting: %s\n\nAfter enabling the API, wait a few minutes for the changes to propagate, then try again", activationURL) + } + return nil, fmt.Errorf("Gemini for Google Cloud API is not enabled for this project. Please enable it in the Google Cloud Console at: https://console.cloud.google.com/apis/library/cloudaicompanion.googleapis.com") + } + + return nil, fmt.Errorf("onboardUser failed: status %d, body: %s", resp.StatusCode, sanitizedBody) } fmt.Printf("[CodeAssist] OnboardUser success: status %d, response: %+v\n", resp.StatusCode, out) return &out, nil From 7b46bbb6286afd1f07d9b9fe39e6051ce0e5f0e3 Mon Sep 17 00:00:00 2001 From: ianshaw Date: Thu, 5 Feb 2026 20:47:15 +0800 Subject: [PATCH 05/16] =?UTF-8?q?fix(lint):=20=E4=BF=AE=E5=A4=8D=E9=94=99?= =?UTF-8?q?=E8=AF=AF=E6=B6=88=E6=81=AF=E5=A4=A7=E5=86=99=E9=97=AE=E9=A2=98?= =?UTF-8?q?=E4=BB=A5=E7=AC=A6=E5=90=88=20Go=20=E6=83=AF=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../internal/repository/geminicli_codeassist_client.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/backend/internal/repository/geminicli_codeassist_client.go b/backend/internal/repository/geminicli_codeassist_client.go index b63be1ad..4f63280d 100644 --- a/backend/internal/repository/geminicli_codeassist_client.go +++ b/backend/internal/repository/geminicli_codeassist_client.go @@ -47,9 +47,9 @@ func (c *geminiCliCodeAssistClient) LoadCodeAssist(ctx context.Context, accessTo if googleapi.IsServiceDisabledError(body) { activationURL := googleapi.ExtractActivationURL(body) if activationURL != "" { - return nil, fmt.Errorf("Gemini for Google Cloud API is not enabled for this project. Please enable it by visiting: %s\n\nAfter enabling the API, wait a few minutes for the changes to propagate, then try again", activationURL) + return nil, fmt.Errorf("gemini API not enabled for this project, please enable it by visiting: %s\n\nAfter enabling the API, wait a few minutes for the changes to propagate, then try again", activationURL) } - return nil, fmt.Errorf("Gemini for Google Cloud API is not enabled for this project. Please enable it in the Google Cloud Console at: https://console.cloud.google.com/apis/library/cloudaicompanion.googleapis.com") + return nil, fmt.Errorf("gemini API not enabled for this project, please enable it in the Google Cloud Console at: https://console.cloud.google.com/apis/library/cloudaicompanion.googleapis.com") } return nil, fmt.Errorf("loadCodeAssist failed: status %d, body: %s", resp.StatusCode, sanitizedBody) @@ -87,9 +87,9 @@ func (c *geminiCliCodeAssistClient) OnboardUser(ctx context.Context, accessToken if googleapi.IsServiceDisabledError(body) { activationURL := googleapi.ExtractActivationURL(body) if activationURL != "" { - return nil, fmt.Errorf("Gemini for Google Cloud API is not enabled for this project. Please enable it by visiting: %s\n\nAfter enabling the API, wait a few minutes for the changes to propagate, then try again", activationURL) + return nil, fmt.Errorf("gemini API not enabled for this project, please enable it by visiting: %s\n\nAfter enabling the API, wait a few minutes for the changes to propagate, then try again", activationURL) } - return nil, fmt.Errorf("Gemini for Google Cloud API is not enabled for this project. Please enable it in the Google Cloud Console at: https://console.cloud.google.com/apis/library/cloudaicompanion.googleapis.com") + return nil, fmt.Errorf("gemini API not enabled for this project, please enable it in the Google Cloud Console at: https://console.cloud.google.com/apis/library/cloudaicompanion.googleapis.com") } return nil, fmt.Errorf("onboardUser failed: status %d, body: %s", resp.StatusCode, sanitizedBody) From 39e05a2dad412b161d9feec39cd7e9bbb17e3213 Mon Sep 17 00:00:00 2001 From: shaw Date: Thu, 5 Feb 2026 21:52:54 +0800 Subject: [PATCH 06/16] =?UTF-8?q?feat:=20=E6=96=B0=E5=A2=9E=E5=85=A8?= =?UTF-8?q?=E5=B1=80=E9=94=99=E8=AF=AF=E9=80=8F=E4=BC=A0=E8=A7=84=E5=88=99?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 支持管理员配置上游错误如何返回给客户端: - 新增 ErrorPassthroughRule 数据模型和 Ent Schema - 实现规则的 CRUD API(/admin/error-passthrough-rules) - 支持按错误码、关键词匹配,支持 any/all 匹配模式 - 支持按平台过滤(anthropic/openai/gemini/antigravity) - 支持透传或自定义响应状态码和错误消息 - 实现两级缓存(Redis + 本地内存)和多实例同步 - 集成到 gateway_handler 的错误处理流程 - 新增前端管理界面组件 - 新增单元测试覆盖核心匹配逻辑 优化: - 移除 refreshLocalCache 中的冗余排序(数据库已排序) - 后端 Validate() 增加匹配条件非空校验 --- backend/cmd/server/wire_gen.go | 10 +- backend/ent/client.go | 171 +- backend/ent/ent.go | 2 + backend/ent/errorpassthroughrule.go | 269 ++++ .../errorpassthroughrule.go | 161 ++ backend/ent/errorpassthroughrule/where.go | 635 ++++++++ backend/ent/errorpassthroughrule_create.go | 1382 +++++++++++++++++ backend/ent/errorpassthroughrule_delete.go | 88 ++ backend/ent/errorpassthroughrule_query.go | 564 +++++++ backend/ent/errorpassthroughrule_update.go | 823 ++++++++++ backend/ent/hook/hook.go | 12 + backend/ent/intercept/intercept.go | 30 + backend/ent/migrate/schema.go | 40 + backend/ent/mutation.go | 1268 +++++++++++++++ backend/ent/predicate/predicate.go | 3 + backend/ent/runtime/runtime.go | 56 + backend/ent/schema/error_passthrough_rule.go | 121 ++ backend/ent/tx.go | 3 + .../admin/error_passthrough_handler.go | 273 ++++ backend/internal/handler/gateway_handler.go | 59 +- .../internal/handler/gemini_v1beta_handler.go | 41 +- backend/internal/handler/handler.go | 1 + .../handler/openai_gateway_handler.go | 68 +- backend/internal/handler/wire.go | 3 + .../internal/model/error_passthrough_rule.go | 74 + .../repository/error_passthrough_cache.go | 128 ++ .../repository/error_passthrough_repo.go | 178 +++ backend/internal/repository/wire.go | 2 + backend/internal/server/routes/admin.go | 14 + .../service/antigravity_gateway_service.go | 7 +- .../service/error_passthrough_service.go | 300 ++++ .../service/error_passthrough_service_test.go | 755 +++++++++ backend/internal/service/gateway_service.go | 19 +- .../service/gemini_messages_compat_service.go | 8 +- .../service/openai_gateway_service.go | 4 +- backend/internal/service/wire.go | 1 + .../048_add_error_passthrough_rules.sql | 24 + frontend/src/api/admin/errorPassthrough.ts | 134 ++ frontend/src/api/admin/index.ts | 8 +- .../admin/ErrorPassthroughRulesModal.vue | 623 ++++++++ frontend/src/i18n/locales/en.ts | 74 + frontend/src/i18n/locales/zh.ts | 74 + frontend/src/views/admin/AccountsView.vue | 13 + 43 files changed, 8456 insertions(+), 67 deletions(-) create mode 100644 backend/ent/errorpassthroughrule.go create mode 100644 backend/ent/errorpassthroughrule/errorpassthroughrule.go create mode 100644 backend/ent/errorpassthroughrule/where.go create mode 100644 backend/ent/errorpassthroughrule_create.go create mode 100644 backend/ent/errorpassthroughrule_delete.go create mode 100644 backend/ent/errorpassthroughrule_query.go create mode 100644 backend/ent/errorpassthroughrule_update.go create mode 100644 backend/ent/schema/error_passthrough_rule.go create mode 100644 backend/internal/handler/admin/error_passthrough_handler.go create mode 100644 backend/internal/model/error_passthrough_rule.go create mode 100644 backend/internal/repository/error_passthrough_cache.go create mode 100644 backend/internal/repository/error_passthrough_repo.go create mode 100644 backend/internal/service/error_passthrough_service.go create mode 100644 backend/internal/service/error_passthrough_service_test.go create mode 100644 backend/migrations/048_add_error_passthrough_rules.sql create mode 100644 frontend/src/api/admin/errorPassthrough.ts create mode 100644 frontend/src/components/admin/ErrorPassthroughRulesModal.vue diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 3ca86f91..8184bc1c 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -174,9 +174,13 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { userAttributeValueRepository := repository.NewUserAttributeValueRepository(client) userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository) userAttributeHandler := admin.NewUserAttributeHandler(userAttributeService) - adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler) - gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, configConfig) - openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, configConfig) + errorPassthroughRepository := repository.NewErrorPassthroughRepository(client) + errorPassthroughCache := repository.NewErrorPassthroughCache(redisClient) + errorPassthroughService := service.NewErrorPassthroughService(errorPassthroughRepository, errorPassthroughCache) + errorPassthroughHandler := admin.NewErrorPassthroughHandler(errorPassthroughService) + adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler) + gatewayHandler := handler.NewGatewayHandler(gatewayService, geminiMessagesCompatService, antigravityGatewayService, userService, concurrencyService, billingCacheService, usageService, apiKeyService, errorPassthroughService, configConfig) + openAIGatewayHandler := handler.NewOpenAIGatewayHandler(openAIGatewayService, concurrencyService, billingCacheService, apiKeyService, errorPassthroughService, configConfig) handlerSettingHandler := handler.ProvideSettingHandler(settingService, buildInfo) totpHandler := handler.NewTotpHandler(totpService) handlers := handler.ProvideHandlers(authHandler, userHandler, apiKeyHandler, usageHandler, redeemHandler, subscriptionHandler, announcementHandler, adminHandlers, gatewayHandler, openAIGatewayHandler, handlerSettingHandler, totpHandler) diff --git a/backend/ent/client.go b/backend/ent/client.go index a17721da..a791c081 100644 --- a/backend/ent/client.go +++ b/backend/ent/client.go @@ -20,6 +20,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" @@ -52,6 +53,8 @@ type Client struct { Announcement *AnnouncementClient // AnnouncementRead is the client for interacting with the AnnouncementRead builders. AnnouncementRead *AnnouncementReadClient + // ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders. + ErrorPassthroughRule *ErrorPassthroughRuleClient // Group is the client for interacting with the Group builders. Group *GroupClient // PromoCode is the client for interacting with the PromoCode builders. @@ -94,6 +97,7 @@ func (c *Client) init() { c.AccountGroup = NewAccountGroupClient(c.config) c.Announcement = NewAnnouncementClient(c.config) c.AnnouncementRead = NewAnnouncementReadClient(c.config) + c.ErrorPassthroughRule = NewErrorPassthroughRuleClient(c.config) c.Group = NewGroupClient(c.config) c.PromoCode = NewPromoCodeClient(c.config) c.PromoCodeUsage = NewPromoCodeUsageClient(c.config) @@ -204,6 +208,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) { AccountGroup: NewAccountGroupClient(cfg), Announcement: NewAnnouncementClient(cfg), AnnouncementRead: NewAnnouncementReadClient(cfg), + ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), Group: NewGroupClient(cfg), PromoCode: NewPromoCodeClient(cfg), PromoCodeUsage: NewPromoCodeUsageClient(cfg), @@ -241,6 +246,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) AccountGroup: NewAccountGroupClient(cfg), Announcement: NewAnnouncementClient(cfg), AnnouncementRead: NewAnnouncementReadClient(cfg), + ErrorPassthroughRule: NewErrorPassthroughRuleClient(cfg), Group: NewGroupClient(cfg), PromoCode: NewPromoCodeClient(cfg), PromoCodeUsage: NewPromoCodeUsageClient(cfg), @@ -284,9 +290,10 @@ func (c *Client) Close() error { func (c *Client) Use(hooks ...Hook) { for _, n := range []interface{ Use(...Hook) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, - c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.Setting, - c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup, - c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, + c.ErrorPassthroughRule, c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy, + c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User, + c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, + c.UserSubscription, } { n.Use(hooks...) } @@ -297,9 +304,10 @@ func (c *Client) Use(hooks ...Hook) { func (c *Client) Intercept(interceptors ...Interceptor) { for _, n := range []interface{ Intercept(...Interceptor) }{ c.APIKey, c.Account, c.AccountGroup, c.Announcement, c.AnnouncementRead, - c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy, c.RedeemCode, c.Setting, - c.UsageCleanupTask, c.UsageLog, c.User, c.UserAllowedGroup, - c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, + c.ErrorPassthroughRule, c.Group, c.PromoCode, c.PromoCodeUsage, c.Proxy, + c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User, + c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, + c.UserSubscription, } { n.Intercept(interceptors...) } @@ -318,6 +326,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { return c.Announcement.mutate(ctx, m) case *AnnouncementReadMutation: return c.AnnouncementRead.mutate(ctx, m) + case *ErrorPassthroughRuleMutation: + return c.ErrorPassthroughRule.mutate(ctx, m) case *GroupMutation: return c.Group.mutate(ctx, m) case *PromoCodeMutation: @@ -1161,6 +1171,139 @@ func (c *AnnouncementReadClient) mutate(ctx context.Context, m *AnnouncementRead } } +// ErrorPassthroughRuleClient is a client for the ErrorPassthroughRule schema. +type ErrorPassthroughRuleClient struct { + config +} + +// NewErrorPassthroughRuleClient returns a client for the ErrorPassthroughRule from the given config. +func NewErrorPassthroughRuleClient(c config) *ErrorPassthroughRuleClient { + return &ErrorPassthroughRuleClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `errorpassthroughrule.Hooks(f(g(h())))`. +func (c *ErrorPassthroughRuleClient) Use(hooks ...Hook) { + c.hooks.ErrorPassthroughRule = append(c.hooks.ErrorPassthroughRule, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `errorpassthroughrule.Intercept(f(g(h())))`. +func (c *ErrorPassthroughRuleClient) Intercept(interceptors ...Interceptor) { + c.inters.ErrorPassthroughRule = append(c.inters.ErrorPassthroughRule, interceptors...) +} + +// Create returns a builder for creating a ErrorPassthroughRule entity. +func (c *ErrorPassthroughRuleClient) Create() *ErrorPassthroughRuleCreate { + mutation := newErrorPassthroughRuleMutation(c.config, OpCreate) + return &ErrorPassthroughRuleCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of ErrorPassthroughRule entities. +func (c *ErrorPassthroughRuleClient) CreateBulk(builders ...*ErrorPassthroughRuleCreate) *ErrorPassthroughRuleCreateBulk { + return &ErrorPassthroughRuleCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *ErrorPassthroughRuleClient) MapCreateBulk(slice any, setFunc func(*ErrorPassthroughRuleCreate, int)) *ErrorPassthroughRuleCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &ErrorPassthroughRuleCreateBulk{err: fmt.Errorf("calling to ErrorPassthroughRuleClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*ErrorPassthroughRuleCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &ErrorPassthroughRuleCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for ErrorPassthroughRule. +func (c *ErrorPassthroughRuleClient) Update() *ErrorPassthroughRuleUpdate { + mutation := newErrorPassthroughRuleMutation(c.config, OpUpdate) + return &ErrorPassthroughRuleUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *ErrorPassthroughRuleClient) UpdateOne(_m *ErrorPassthroughRule) *ErrorPassthroughRuleUpdateOne { + mutation := newErrorPassthroughRuleMutation(c.config, OpUpdateOne, withErrorPassthroughRule(_m)) + return &ErrorPassthroughRuleUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *ErrorPassthroughRuleClient) UpdateOneID(id int64) *ErrorPassthroughRuleUpdateOne { + mutation := newErrorPassthroughRuleMutation(c.config, OpUpdateOne, withErrorPassthroughRuleID(id)) + return &ErrorPassthroughRuleUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for ErrorPassthroughRule. +func (c *ErrorPassthroughRuleClient) Delete() *ErrorPassthroughRuleDelete { + mutation := newErrorPassthroughRuleMutation(c.config, OpDelete) + return &ErrorPassthroughRuleDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *ErrorPassthroughRuleClient) DeleteOne(_m *ErrorPassthroughRule) *ErrorPassthroughRuleDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *ErrorPassthroughRuleClient) DeleteOneID(id int64) *ErrorPassthroughRuleDeleteOne { + builder := c.Delete().Where(errorpassthroughrule.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &ErrorPassthroughRuleDeleteOne{builder} +} + +// Query returns a query builder for ErrorPassthroughRule. +func (c *ErrorPassthroughRuleClient) Query() *ErrorPassthroughRuleQuery { + return &ErrorPassthroughRuleQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeErrorPassthroughRule}, + inters: c.Interceptors(), + } +} + +// Get returns a ErrorPassthroughRule entity by its id. +func (c *ErrorPassthroughRuleClient) Get(ctx context.Context, id int64) (*ErrorPassthroughRule, error) { + return c.Query().Where(errorpassthroughrule.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *ErrorPassthroughRuleClient) GetX(ctx context.Context, id int64) *ErrorPassthroughRule { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// Hooks returns the client hooks. +func (c *ErrorPassthroughRuleClient) Hooks() []Hook { + return c.hooks.ErrorPassthroughRule +} + +// Interceptors returns the client interceptors. +func (c *ErrorPassthroughRuleClient) Interceptors() []Interceptor { + return c.inters.ErrorPassthroughRule +} + +func (c *ErrorPassthroughRuleClient) mutate(ctx context.Context, m *ErrorPassthroughRuleMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&ErrorPassthroughRuleCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&ErrorPassthroughRuleUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&ErrorPassthroughRuleUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&ErrorPassthroughRuleDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown ErrorPassthroughRule mutation op: %q", m.Op()) + } +} + // GroupClient is a client for the Group schema. type GroupClient struct { config @@ -3462,16 +3605,16 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription // hooks and interceptors per client, for fast access. type ( hooks struct { - APIKey, Account, AccountGroup, Announcement, AnnouncementRead, Group, PromoCode, - PromoCodeUsage, Proxy, RedeemCode, Setting, UsageCleanupTask, UsageLog, User, - UserAllowedGroup, UserAttributeDefinition, UserAttributeValue, - UserSubscription []ent.Hook + APIKey, Account, AccountGroup, Announcement, AnnouncementRead, + ErrorPassthroughRule, Group, PromoCode, PromoCodeUsage, Proxy, RedeemCode, + Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup, + UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook } inters struct { - APIKey, Account, AccountGroup, Announcement, AnnouncementRead, Group, PromoCode, - PromoCodeUsage, Proxy, RedeemCode, Setting, UsageCleanupTask, UsageLog, User, - UserAllowedGroup, UserAttributeDefinition, UserAttributeValue, - UserSubscription []ent.Interceptor + APIKey, Account, AccountGroup, Announcement, AnnouncementRead, + ErrorPassthroughRule, Group, PromoCode, PromoCodeUsage, Proxy, RedeemCode, + Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup, + UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor } ) diff --git a/backend/ent/ent.go b/backend/ent/ent.go index 05e30ba7..5767a167 100644 --- a/backend/ent/ent.go +++ b/backend/ent/ent.go @@ -17,6 +17,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" @@ -95,6 +96,7 @@ func checkColumn(t, c string) error { accountgroup.Table: accountgroup.ValidColumn, announcement.Table: announcement.ValidColumn, announcementread.Table: announcementread.ValidColumn, + errorpassthroughrule.Table: errorpassthroughrule.ValidColumn, group.Table: group.ValidColumn, promocode.Table: promocode.ValidColumn, promocodeusage.Table: promocodeusage.ValidColumn, diff --git a/backend/ent/errorpassthroughrule.go b/backend/ent/errorpassthroughrule.go new file mode 100644 index 00000000..1932f626 --- /dev/null +++ b/backend/ent/errorpassthroughrule.go @@ -0,0 +1,269 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" +) + +// ErrorPassthroughRule is the model entity for the ErrorPassthroughRule schema. +type ErrorPassthroughRule struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // Name holds the value of the "name" field. + Name string `json:"name,omitempty"` + // Enabled holds the value of the "enabled" field. + Enabled bool `json:"enabled,omitempty"` + // Priority holds the value of the "priority" field. + Priority int `json:"priority,omitempty"` + // ErrorCodes holds the value of the "error_codes" field. + ErrorCodes []int `json:"error_codes,omitempty"` + // Keywords holds the value of the "keywords" field. + Keywords []string `json:"keywords,omitempty"` + // MatchMode holds the value of the "match_mode" field. + MatchMode string `json:"match_mode,omitempty"` + // Platforms holds the value of the "platforms" field. + Platforms []string `json:"platforms,omitempty"` + // PassthroughCode holds the value of the "passthrough_code" field. + PassthroughCode bool `json:"passthrough_code,omitempty"` + // ResponseCode holds the value of the "response_code" field. + ResponseCode *int `json:"response_code,omitempty"` + // PassthroughBody holds the value of the "passthrough_body" field. + PassthroughBody bool `json:"passthrough_body,omitempty"` + // CustomMessage holds the value of the "custom_message" field. + CustomMessage *string `json:"custom_message,omitempty"` + // Description holds the value of the "description" field. + Description *string `json:"description,omitempty"` + selectValues sql.SelectValues +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*ErrorPassthroughRule) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case errorpassthroughrule.FieldErrorCodes, errorpassthroughrule.FieldKeywords, errorpassthroughrule.FieldPlatforms: + values[i] = new([]byte) + case errorpassthroughrule.FieldEnabled, errorpassthroughrule.FieldPassthroughCode, errorpassthroughrule.FieldPassthroughBody: + values[i] = new(sql.NullBool) + case errorpassthroughrule.FieldID, errorpassthroughrule.FieldPriority, errorpassthroughrule.FieldResponseCode: + values[i] = new(sql.NullInt64) + case errorpassthroughrule.FieldName, errorpassthroughrule.FieldMatchMode, errorpassthroughrule.FieldCustomMessage, errorpassthroughrule.FieldDescription: + values[i] = new(sql.NullString) + case errorpassthroughrule.FieldCreatedAt, errorpassthroughrule.FieldUpdatedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the ErrorPassthroughRule fields. +func (_m *ErrorPassthroughRule) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case errorpassthroughrule.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case errorpassthroughrule.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case errorpassthroughrule.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case errorpassthroughrule.FieldName: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field name", values[i]) + } else if value.Valid { + _m.Name = value.String + } + case errorpassthroughrule.FieldEnabled: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field enabled", values[i]) + } else if value.Valid { + _m.Enabled = value.Bool + } + case errorpassthroughrule.FieldPriority: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field priority", values[i]) + } else if value.Valid { + _m.Priority = int(value.Int64) + } + case errorpassthroughrule.FieldErrorCodes: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field error_codes", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.ErrorCodes); err != nil { + return fmt.Errorf("unmarshal field error_codes: %w", err) + } + } + case errorpassthroughrule.FieldKeywords: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field keywords", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.Keywords); err != nil { + return fmt.Errorf("unmarshal field keywords: %w", err) + } + } + case errorpassthroughrule.FieldMatchMode: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field match_mode", values[i]) + } else if value.Valid { + _m.MatchMode = value.String + } + case errorpassthroughrule.FieldPlatforms: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field platforms", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.Platforms); err != nil { + return fmt.Errorf("unmarshal field platforms: %w", err) + } + } + case errorpassthroughrule.FieldPassthroughCode: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field passthrough_code", values[i]) + } else if value.Valid { + _m.PassthroughCode = value.Bool + } + case errorpassthroughrule.FieldResponseCode: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field response_code", values[i]) + } else if value.Valid { + _m.ResponseCode = new(int) + *_m.ResponseCode = int(value.Int64) + } + case errorpassthroughrule.FieldPassthroughBody: + if value, ok := values[i].(*sql.NullBool); !ok { + return fmt.Errorf("unexpected type %T for field passthrough_body", values[i]) + } else if value.Valid { + _m.PassthroughBody = value.Bool + } + case errorpassthroughrule.FieldCustomMessage: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field custom_message", values[i]) + } else if value.Valid { + _m.CustomMessage = new(string) + *_m.CustomMessage = value.String + } + case errorpassthroughrule.FieldDescription: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field description", values[i]) + } else if value.Valid { + _m.Description = new(string) + *_m.Description = value.String + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the ErrorPassthroughRule. +// This includes values selected through modifiers, order, etc. +func (_m *ErrorPassthroughRule) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// Update returns a builder for updating this ErrorPassthroughRule. +// Note that you need to call ErrorPassthroughRule.Unwrap() before calling this method if this ErrorPassthroughRule +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *ErrorPassthroughRule) Update() *ErrorPassthroughRuleUpdateOne { + return NewErrorPassthroughRuleClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the ErrorPassthroughRule entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *ErrorPassthroughRule) Unwrap() *ErrorPassthroughRule { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: ErrorPassthroughRule is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *ErrorPassthroughRule) String() string { + var builder strings.Builder + builder.WriteString("ErrorPassthroughRule(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("name=") + builder.WriteString(_m.Name) + builder.WriteString(", ") + builder.WriteString("enabled=") + builder.WriteString(fmt.Sprintf("%v", _m.Enabled)) + builder.WriteString(", ") + builder.WriteString("priority=") + builder.WriteString(fmt.Sprintf("%v", _m.Priority)) + builder.WriteString(", ") + builder.WriteString("error_codes=") + builder.WriteString(fmt.Sprintf("%v", _m.ErrorCodes)) + builder.WriteString(", ") + builder.WriteString("keywords=") + builder.WriteString(fmt.Sprintf("%v", _m.Keywords)) + builder.WriteString(", ") + builder.WriteString("match_mode=") + builder.WriteString(_m.MatchMode) + builder.WriteString(", ") + builder.WriteString("platforms=") + builder.WriteString(fmt.Sprintf("%v", _m.Platforms)) + builder.WriteString(", ") + builder.WriteString("passthrough_code=") + builder.WriteString(fmt.Sprintf("%v", _m.PassthroughCode)) + builder.WriteString(", ") + if v := _m.ResponseCode; v != nil { + builder.WriteString("response_code=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + builder.WriteString("passthrough_body=") + builder.WriteString(fmt.Sprintf("%v", _m.PassthroughBody)) + builder.WriteString(", ") + if v := _m.CustomMessage; v != nil { + builder.WriteString("custom_message=") + builder.WriteString(*v) + } + builder.WriteString(", ") + if v := _m.Description; v != nil { + builder.WriteString("description=") + builder.WriteString(*v) + } + builder.WriteByte(')') + return builder.String() +} + +// ErrorPassthroughRules is a parsable slice of ErrorPassthroughRule. +type ErrorPassthroughRules []*ErrorPassthroughRule diff --git a/backend/ent/errorpassthroughrule/errorpassthroughrule.go b/backend/ent/errorpassthroughrule/errorpassthroughrule.go new file mode 100644 index 00000000..d7be4f03 --- /dev/null +++ b/backend/ent/errorpassthroughrule/errorpassthroughrule.go @@ -0,0 +1,161 @@ +// Code generated by ent, DO NOT EDIT. + +package errorpassthroughrule + +import ( + "time" + + "entgo.io/ent/dialect/sql" +) + +const ( + // Label holds the string label denoting the errorpassthroughrule type in the database. + Label = "error_passthrough_rule" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldName holds the string denoting the name field in the database. + FieldName = "name" + // FieldEnabled holds the string denoting the enabled field in the database. + FieldEnabled = "enabled" + // FieldPriority holds the string denoting the priority field in the database. + FieldPriority = "priority" + // FieldErrorCodes holds the string denoting the error_codes field in the database. + FieldErrorCodes = "error_codes" + // FieldKeywords holds the string denoting the keywords field in the database. + FieldKeywords = "keywords" + // FieldMatchMode holds the string denoting the match_mode field in the database. + FieldMatchMode = "match_mode" + // FieldPlatforms holds the string denoting the platforms field in the database. + FieldPlatforms = "platforms" + // FieldPassthroughCode holds the string denoting the passthrough_code field in the database. + FieldPassthroughCode = "passthrough_code" + // FieldResponseCode holds the string denoting the response_code field in the database. + FieldResponseCode = "response_code" + // FieldPassthroughBody holds the string denoting the passthrough_body field in the database. + FieldPassthroughBody = "passthrough_body" + // FieldCustomMessage holds the string denoting the custom_message field in the database. + FieldCustomMessage = "custom_message" + // FieldDescription holds the string denoting the description field in the database. + FieldDescription = "description" + // Table holds the table name of the errorpassthroughrule in the database. + Table = "error_passthrough_rules" +) + +// Columns holds all SQL columns for errorpassthroughrule fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldName, + FieldEnabled, + FieldPriority, + FieldErrorCodes, + FieldKeywords, + FieldMatchMode, + FieldPlatforms, + FieldPassthroughCode, + FieldResponseCode, + FieldPassthroughBody, + FieldCustomMessage, + FieldDescription, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // NameValidator is a validator for the "name" field. It is called by the builders before save. + NameValidator func(string) error + // DefaultEnabled holds the default value on creation for the "enabled" field. + DefaultEnabled bool + // DefaultPriority holds the default value on creation for the "priority" field. + DefaultPriority int + // DefaultMatchMode holds the default value on creation for the "match_mode" field. + DefaultMatchMode string + // MatchModeValidator is a validator for the "match_mode" field. It is called by the builders before save. + MatchModeValidator func(string) error + // DefaultPassthroughCode holds the default value on creation for the "passthrough_code" field. + DefaultPassthroughCode bool + // DefaultPassthroughBody holds the default value on creation for the "passthrough_body" field. + DefaultPassthroughBody bool +) + +// OrderOption defines the ordering options for the ErrorPassthroughRule queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByName orders the results by the name field. +func ByName(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldName, opts...).ToFunc() +} + +// ByEnabled orders the results by the enabled field. +func ByEnabled(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldEnabled, opts...).ToFunc() +} + +// ByPriority orders the results by the priority field. +func ByPriority(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPriority, opts...).ToFunc() +} + +// ByMatchMode orders the results by the match_mode field. +func ByMatchMode(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldMatchMode, opts...).ToFunc() +} + +// ByPassthroughCode orders the results by the passthrough_code field. +func ByPassthroughCode(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPassthroughCode, opts...).ToFunc() +} + +// ByResponseCode orders the results by the response_code field. +func ByResponseCode(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldResponseCode, opts...).ToFunc() +} + +// ByPassthroughBody orders the results by the passthrough_body field. +func ByPassthroughBody(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldPassthroughBody, opts...).ToFunc() +} + +// ByCustomMessage orders the results by the custom_message field. +func ByCustomMessage(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCustomMessage, opts...).ToFunc() +} + +// ByDescription orders the results by the description field. +func ByDescription(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDescription, opts...).ToFunc() +} diff --git a/backend/ent/errorpassthroughrule/where.go b/backend/ent/errorpassthroughrule/where.go new file mode 100644 index 00000000..56839d52 --- /dev/null +++ b/backend/ent/errorpassthroughrule/where.go @@ -0,0 +1,635 @@ +// Code generated by ent, DO NOT EDIT. + +package errorpassthroughrule + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// Name applies equality check predicate on the "name" field. It's identical to NameEQ. +func Name(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldName, v)) +} + +// Enabled applies equality check predicate on the "enabled" field. It's identical to EnabledEQ. +func Enabled(v bool) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldEnabled, v)) +} + +// Priority applies equality check predicate on the "priority" field. It's identical to PriorityEQ. +func Priority(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldPriority, v)) +} + +// MatchMode applies equality check predicate on the "match_mode" field. It's identical to MatchModeEQ. +func MatchMode(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldMatchMode, v)) +} + +// PassthroughCode applies equality check predicate on the "passthrough_code" field. It's identical to PassthroughCodeEQ. +func PassthroughCode(v bool) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldPassthroughCode, v)) +} + +// ResponseCode applies equality check predicate on the "response_code" field. It's identical to ResponseCodeEQ. +func ResponseCode(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldResponseCode, v)) +} + +// PassthroughBody applies equality check predicate on the "passthrough_body" field. It's identical to PassthroughBodyEQ. +func PassthroughBody(v bool) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldPassthroughBody, v)) +} + +// CustomMessage applies equality check predicate on the "custom_message" field. It's identical to CustomMessageEQ. +func CustomMessage(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldCustomMessage, v)) +} + +// Description applies equality check predicate on the "description" field. It's identical to DescriptionEQ. +func Description(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldDescription, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// NameEQ applies the EQ predicate on the "name" field. +func NameEQ(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldName, v)) +} + +// NameNEQ applies the NEQ predicate on the "name" field. +func NameNEQ(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldName, v)) +} + +// NameIn applies the In predicate on the "name" field. +func NameIn(vs ...string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIn(FieldName, vs...)) +} + +// NameNotIn applies the NotIn predicate on the "name" field. +func NameNotIn(vs ...string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldName, vs...)) +} + +// NameGT applies the GT predicate on the "name" field. +func NameGT(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGT(FieldName, v)) +} + +// NameGTE applies the GTE predicate on the "name" field. +func NameGTE(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldName, v)) +} + +// NameLT applies the LT predicate on the "name" field. +func NameLT(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLT(FieldName, v)) +} + +// NameLTE applies the LTE predicate on the "name" field. +func NameLTE(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldName, v)) +} + +// NameContains applies the Contains predicate on the "name" field. +func NameContains(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldContains(FieldName, v)) +} + +// NameHasPrefix applies the HasPrefix predicate on the "name" field. +func NameHasPrefix(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldHasPrefix(FieldName, v)) +} + +// NameHasSuffix applies the HasSuffix predicate on the "name" field. +func NameHasSuffix(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldHasSuffix(FieldName, v)) +} + +// NameEqualFold applies the EqualFold predicate on the "name" field. +func NameEqualFold(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEqualFold(FieldName, v)) +} + +// NameContainsFold applies the ContainsFold predicate on the "name" field. +func NameContainsFold(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldContainsFold(FieldName, v)) +} + +// EnabledEQ applies the EQ predicate on the "enabled" field. +func EnabledEQ(v bool) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldEnabled, v)) +} + +// EnabledNEQ applies the NEQ predicate on the "enabled" field. +func EnabledNEQ(v bool) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldEnabled, v)) +} + +// PriorityEQ applies the EQ predicate on the "priority" field. +func PriorityEQ(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldPriority, v)) +} + +// PriorityNEQ applies the NEQ predicate on the "priority" field. +func PriorityNEQ(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldPriority, v)) +} + +// PriorityIn applies the In predicate on the "priority" field. +func PriorityIn(vs ...int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIn(FieldPriority, vs...)) +} + +// PriorityNotIn applies the NotIn predicate on the "priority" field. +func PriorityNotIn(vs ...int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldPriority, vs...)) +} + +// PriorityGT applies the GT predicate on the "priority" field. +func PriorityGT(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGT(FieldPriority, v)) +} + +// PriorityGTE applies the GTE predicate on the "priority" field. +func PriorityGTE(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldPriority, v)) +} + +// PriorityLT applies the LT predicate on the "priority" field. +func PriorityLT(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLT(FieldPriority, v)) +} + +// PriorityLTE applies the LTE predicate on the "priority" field. +func PriorityLTE(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldPriority, v)) +} + +// ErrorCodesIsNil applies the IsNil predicate on the "error_codes" field. +func ErrorCodesIsNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIsNull(FieldErrorCodes)) +} + +// ErrorCodesNotNil applies the NotNil predicate on the "error_codes" field. +func ErrorCodesNotNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotNull(FieldErrorCodes)) +} + +// KeywordsIsNil applies the IsNil predicate on the "keywords" field. +func KeywordsIsNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIsNull(FieldKeywords)) +} + +// KeywordsNotNil applies the NotNil predicate on the "keywords" field. +func KeywordsNotNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotNull(FieldKeywords)) +} + +// MatchModeEQ applies the EQ predicate on the "match_mode" field. +func MatchModeEQ(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldMatchMode, v)) +} + +// MatchModeNEQ applies the NEQ predicate on the "match_mode" field. +func MatchModeNEQ(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldMatchMode, v)) +} + +// MatchModeIn applies the In predicate on the "match_mode" field. +func MatchModeIn(vs ...string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIn(FieldMatchMode, vs...)) +} + +// MatchModeNotIn applies the NotIn predicate on the "match_mode" field. +func MatchModeNotIn(vs ...string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldMatchMode, vs...)) +} + +// MatchModeGT applies the GT predicate on the "match_mode" field. +func MatchModeGT(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGT(FieldMatchMode, v)) +} + +// MatchModeGTE applies the GTE predicate on the "match_mode" field. +func MatchModeGTE(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldMatchMode, v)) +} + +// MatchModeLT applies the LT predicate on the "match_mode" field. +func MatchModeLT(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLT(FieldMatchMode, v)) +} + +// MatchModeLTE applies the LTE predicate on the "match_mode" field. +func MatchModeLTE(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldMatchMode, v)) +} + +// MatchModeContains applies the Contains predicate on the "match_mode" field. +func MatchModeContains(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldContains(FieldMatchMode, v)) +} + +// MatchModeHasPrefix applies the HasPrefix predicate on the "match_mode" field. +func MatchModeHasPrefix(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldHasPrefix(FieldMatchMode, v)) +} + +// MatchModeHasSuffix applies the HasSuffix predicate on the "match_mode" field. +func MatchModeHasSuffix(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldHasSuffix(FieldMatchMode, v)) +} + +// MatchModeEqualFold applies the EqualFold predicate on the "match_mode" field. +func MatchModeEqualFold(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEqualFold(FieldMatchMode, v)) +} + +// MatchModeContainsFold applies the ContainsFold predicate on the "match_mode" field. +func MatchModeContainsFold(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldContainsFold(FieldMatchMode, v)) +} + +// PlatformsIsNil applies the IsNil predicate on the "platforms" field. +func PlatformsIsNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIsNull(FieldPlatforms)) +} + +// PlatformsNotNil applies the NotNil predicate on the "platforms" field. +func PlatformsNotNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotNull(FieldPlatforms)) +} + +// PassthroughCodeEQ applies the EQ predicate on the "passthrough_code" field. +func PassthroughCodeEQ(v bool) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldPassthroughCode, v)) +} + +// PassthroughCodeNEQ applies the NEQ predicate on the "passthrough_code" field. +func PassthroughCodeNEQ(v bool) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldPassthroughCode, v)) +} + +// ResponseCodeEQ applies the EQ predicate on the "response_code" field. +func ResponseCodeEQ(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldResponseCode, v)) +} + +// ResponseCodeNEQ applies the NEQ predicate on the "response_code" field. +func ResponseCodeNEQ(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldResponseCode, v)) +} + +// ResponseCodeIn applies the In predicate on the "response_code" field. +func ResponseCodeIn(vs ...int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIn(FieldResponseCode, vs...)) +} + +// ResponseCodeNotIn applies the NotIn predicate on the "response_code" field. +func ResponseCodeNotIn(vs ...int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldResponseCode, vs...)) +} + +// ResponseCodeGT applies the GT predicate on the "response_code" field. +func ResponseCodeGT(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGT(FieldResponseCode, v)) +} + +// ResponseCodeGTE applies the GTE predicate on the "response_code" field. +func ResponseCodeGTE(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldResponseCode, v)) +} + +// ResponseCodeLT applies the LT predicate on the "response_code" field. +func ResponseCodeLT(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLT(FieldResponseCode, v)) +} + +// ResponseCodeLTE applies the LTE predicate on the "response_code" field. +func ResponseCodeLTE(v int) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldResponseCode, v)) +} + +// ResponseCodeIsNil applies the IsNil predicate on the "response_code" field. +func ResponseCodeIsNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIsNull(FieldResponseCode)) +} + +// ResponseCodeNotNil applies the NotNil predicate on the "response_code" field. +func ResponseCodeNotNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotNull(FieldResponseCode)) +} + +// PassthroughBodyEQ applies the EQ predicate on the "passthrough_body" field. +func PassthroughBodyEQ(v bool) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldPassthroughBody, v)) +} + +// PassthroughBodyNEQ applies the NEQ predicate on the "passthrough_body" field. +func PassthroughBodyNEQ(v bool) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldPassthroughBody, v)) +} + +// CustomMessageEQ applies the EQ predicate on the "custom_message" field. +func CustomMessageEQ(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldCustomMessage, v)) +} + +// CustomMessageNEQ applies the NEQ predicate on the "custom_message" field. +func CustomMessageNEQ(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldCustomMessage, v)) +} + +// CustomMessageIn applies the In predicate on the "custom_message" field. +func CustomMessageIn(vs ...string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIn(FieldCustomMessage, vs...)) +} + +// CustomMessageNotIn applies the NotIn predicate on the "custom_message" field. +func CustomMessageNotIn(vs ...string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldCustomMessage, vs...)) +} + +// CustomMessageGT applies the GT predicate on the "custom_message" field. +func CustomMessageGT(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGT(FieldCustomMessage, v)) +} + +// CustomMessageGTE applies the GTE predicate on the "custom_message" field. +func CustomMessageGTE(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldCustomMessage, v)) +} + +// CustomMessageLT applies the LT predicate on the "custom_message" field. +func CustomMessageLT(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLT(FieldCustomMessage, v)) +} + +// CustomMessageLTE applies the LTE predicate on the "custom_message" field. +func CustomMessageLTE(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldCustomMessage, v)) +} + +// CustomMessageContains applies the Contains predicate on the "custom_message" field. +func CustomMessageContains(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldContains(FieldCustomMessage, v)) +} + +// CustomMessageHasPrefix applies the HasPrefix predicate on the "custom_message" field. +func CustomMessageHasPrefix(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldHasPrefix(FieldCustomMessage, v)) +} + +// CustomMessageHasSuffix applies the HasSuffix predicate on the "custom_message" field. +func CustomMessageHasSuffix(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldHasSuffix(FieldCustomMessage, v)) +} + +// CustomMessageIsNil applies the IsNil predicate on the "custom_message" field. +func CustomMessageIsNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIsNull(FieldCustomMessage)) +} + +// CustomMessageNotNil applies the NotNil predicate on the "custom_message" field. +func CustomMessageNotNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotNull(FieldCustomMessage)) +} + +// CustomMessageEqualFold applies the EqualFold predicate on the "custom_message" field. +func CustomMessageEqualFold(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEqualFold(FieldCustomMessage, v)) +} + +// CustomMessageContainsFold applies the ContainsFold predicate on the "custom_message" field. +func CustomMessageContainsFold(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldContainsFold(FieldCustomMessage, v)) +} + +// DescriptionEQ applies the EQ predicate on the "description" field. +func DescriptionEQ(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEQ(FieldDescription, v)) +} + +// DescriptionNEQ applies the NEQ predicate on the "description" field. +func DescriptionNEQ(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNEQ(FieldDescription, v)) +} + +// DescriptionIn applies the In predicate on the "description" field. +func DescriptionIn(vs ...string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIn(FieldDescription, vs...)) +} + +// DescriptionNotIn applies the NotIn predicate on the "description" field. +func DescriptionNotIn(vs ...string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotIn(FieldDescription, vs...)) +} + +// DescriptionGT applies the GT predicate on the "description" field. +func DescriptionGT(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGT(FieldDescription, v)) +} + +// DescriptionGTE applies the GTE predicate on the "description" field. +func DescriptionGTE(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldGTE(FieldDescription, v)) +} + +// DescriptionLT applies the LT predicate on the "description" field. +func DescriptionLT(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLT(FieldDescription, v)) +} + +// DescriptionLTE applies the LTE predicate on the "description" field. +func DescriptionLTE(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldLTE(FieldDescription, v)) +} + +// DescriptionContains applies the Contains predicate on the "description" field. +func DescriptionContains(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldContains(FieldDescription, v)) +} + +// DescriptionHasPrefix applies the HasPrefix predicate on the "description" field. +func DescriptionHasPrefix(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldHasPrefix(FieldDescription, v)) +} + +// DescriptionHasSuffix applies the HasSuffix predicate on the "description" field. +func DescriptionHasSuffix(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldHasSuffix(FieldDescription, v)) +} + +// DescriptionIsNil applies the IsNil predicate on the "description" field. +func DescriptionIsNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldIsNull(FieldDescription)) +} + +// DescriptionNotNil applies the NotNil predicate on the "description" field. +func DescriptionNotNil() predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldNotNull(FieldDescription)) +} + +// DescriptionEqualFold applies the EqualFold predicate on the "description" field. +func DescriptionEqualFold(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldEqualFold(FieldDescription, v)) +} + +// DescriptionContainsFold applies the ContainsFold predicate on the "description" field. +func DescriptionContainsFold(v string) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.FieldContainsFold(FieldDescription, v)) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.ErrorPassthroughRule) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.ErrorPassthroughRule) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.ErrorPassthroughRule) predicate.ErrorPassthroughRule { + return predicate.ErrorPassthroughRule(sql.NotPredicates(p)) +} diff --git a/backend/ent/errorpassthroughrule_create.go b/backend/ent/errorpassthroughrule_create.go new file mode 100644 index 00000000..4dc08dce --- /dev/null +++ b/backend/ent/errorpassthroughrule_create.go @@ -0,0 +1,1382 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" +) + +// ErrorPassthroughRuleCreate is the builder for creating a ErrorPassthroughRule entity. +type ErrorPassthroughRuleCreate struct { + config + mutation *ErrorPassthroughRuleMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *ErrorPassthroughRuleCreate) SetCreatedAt(v time.Time) *ErrorPassthroughRuleCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *ErrorPassthroughRuleCreate) SetNillableCreatedAt(v *time.Time) *ErrorPassthroughRuleCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *ErrorPassthroughRuleCreate) SetUpdatedAt(v time.Time) *ErrorPassthroughRuleCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *ErrorPassthroughRuleCreate) SetNillableUpdatedAt(v *time.Time) *ErrorPassthroughRuleCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetName sets the "name" field. +func (_c *ErrorPassthroughRuleCreate) SetName(v string) *ErrorPassthroughRuleCreate { + _c.mutation.SetName(v) + return _c +} + +// SetEnabled sets the "enabled" field. +func (_c *ErrorPassthroughRuleCreate) SetEnabled(v bool) *ErrorPassthroughRuleCreate { + _c.mutation.SetEnabled(v) + return _c +} + +// SetNillableEnabled sets the "enabled" field if the given value is not nil. +func (_c *ErrorPassthroughRuleCreate) SetNillableEnabled(v *bool) *ErrorPassthroughRuleCreate { + if v != nil { + _c.SetEnabled(*v) + } + return _c +} + +// SetPriority sets the "priority" field. +func (_c *ErrorPassthroughRuleCreate) SetPriority(v int) *ErrorPassthroughRuleCreate { + _c.mutation.SetPriority(v) + return _c +} + +// SetNillablePriority sets the "priority" field if the given value is not nil. +func (_c *ErrorPassthroughRuleCreate) SetNillablePriority(v *int) *ErrorPassthroughRuleCreate { + if v != nil { + _c.SetPriority(*v) + } + return _c +} + +// SetErrorCodes sets the "error_codes" field. +func (_c *ErrorPassthroughRuleCreate) SetErrorCodes(v []int) *ErrorPassthroughRuleCreate { + _c.mutation.SetErrorCodes(v) + return _c +} + +// SetKeywords sets the "keywords" field. +func (_c *ErrorPassthroughRuleCreate) SetKeywords(v []string) *ErrorPassthroughRuleCreate { + _c.mutation.SetKeywords(v) + return _c +} + +// SetMatchMode sets the "match_mode" field. +func (_c *ErrorPassthroughRuleCreate) SetMatchMode(v string) *ErrorPassthroughRuleCreate { + _c.mutation.SetMatchMode(v) + return _c +} + +// SetNillableMatchMode sets the "match_mode" field if the given value is not nil. +func (_c *ErrorPassthroughRuleCreate) SetNillableMatchMode(v *string) *ErrorPassthroughRuleCreate { + if v != nil { + _c.SetMatchMode(*v) + } + return _c +} + +// SetPlatforms sets the "platforms" field. +func (_c *ErrorPassthroughRuleCreate) SetPlatforms(v []string) *ErrorPassthroughRuleCreate { + _c.mutation.SetPlatforms(v) + return _c +} + +// SetPassthroughCode sets the "passthrough_code" field. +func (_c *ErrorPassthroughRuleCreate) SetPassthroughCode(v bool) *ErrorPassthroughRuleCreate { + _c.mutation.SetPassthroughCode(v) + return _c +} + +// SetNillablePassthroughCode sets the "passthrough_code" field if the given value is not nil. +func (_c *ErrorPassthroughRuleCreate) SetNillablePassthroughCode(v *bool) *ErrorPassthroughRuleCreate { + if v != nil { + _c.SetPassthroughCode(*v) + } + return _c +} + +// SetResponseCode sets the "response_code" field. +func (_c *ErrorPassthroughRuleCreate) SetResponseCode(v int) *ErrorPassthroughRuleCreate { + _c.mutation.SetResponseCode(v) + return _c +} + +// SetNillableResponseCode sets the "response_code" field if the given value is not nil. +func (_c *ErrorPassthroughRuleCreate) SetNillableResponseCode(v *int) *ErrorPassthroughRuleCreate { + if v != nil { + _c.SetResponseCode(*v) + } + return _c +} + +// SetPassthroughBody sets the "passthrough_body" field. +func (_c *ErrorPassthroughRuleCreate) SetPassthroughBody(v bool) *ErrorPassthroughRuleCreate { + _c.mutation.SetPassthroughBody(v) + return _c +} + +// SetNillablePassthroughBody sets the "passthrough_body" field if the given value is not nil. +func (_c *ErrorPassthroughRuleCreate) SetNillablePassthroughBody(v *bool) *ErrorPassthroughRuleCreate { + if v != nil { + _c.SetPassthroughBody(*v) + } + return _c +} + +// SetCustomMessage sets the "custom_message" field. +func (_c *ErrorPassthroughRuleCreate) SetCustomMessage(v string) *ErrorPassthroughRuleCreate { + _c.mutation.SetCustomMessage(v) + return _c +} + +// SetNillableCustomMessage sets the "custom_message" field if the given value is not nil. +func (_c *ErrorPassthroughRuleCreate) SetNillableCustomMessage(v *string) *ErrorPassthroughRuleCreate { + if v != nil { + _c.SetCustomMessage(*v) + } + return _c +} + +// SetDescription sets the "description" field. +func (_c *ErrorPassthroughRuleCreate) SetDescription(v string) *ErrorPassthroughRuleCreate { + _c.mutation.SetDescription(v) + return _c +} + +// SetNillableDescription sets the "description" field if the given value is not nil. +func (_c *ErrorPassthroughRuleCreate) SetNillableDescription(v *string) *ErrorPassthroughRuleCreate { + if v != nil { + _c.SetDescription(*v) + } + return _c +} + +// Mutation returns the ErrorPassthroughRuleMutation object of the builder. +func (_c *ErrorPassthroughRuleCreate) Mutation() *ErrorPassthroughRuleMutation { + return _c.mutation +} + +// Save creates the ErrorPassthroughRule in the database. +func (_c *ErrorPassthroughRuleCreate) Save(ctx context.Context) (*ErrorPassthroughRule, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *ErrorPassthroughRuleCreate) SaveX(ctx context.Context) *ErrorPassthroughRule { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *ErrorPassthroughRuleCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *ErrorPassthroughRuleCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *ErrorPassthroughRuleCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := errorpassthroughrule.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := errorpassthroughrule.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } + if _, ok := _c.mutation.Enabled(); !ok { + v := errorpassthroughrule.DefaultEnabled + _c.mutation.SetEnabled(v) + } + if _, ok := _c.mutation.Priority(); !ok { + v := errorpassthroughrule.DefaultPriority + _c.mutation.SetPriority(v) + } + if _, ok := _c.mutation.MatchMode(); !ok { + v := errorpassthroughrule.DefaultMatchMode + _c.mutation.SetMatchMode(v) + } + if _, ok := _c.mutation.PassthroughCode(); !ok { + v := errorpassthroughrule.DefaultPassthroughCode + _c.mutation.SetPassthroughCode(v) + } + if _, ok := _c.mutation.PassthroughBody(); !ok { + v := errorpassthroughrule.DefaultPassthroughBody + _c.mutation.SetPassthroughBody(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *ErrorPassthroughRuleCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "ErrorPassthroughRule.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "ErrorPassthroughRule.updated_at"`)} + } + if _, ok := _c.mutation.Name(); !ok { + return &ValidationError{Name: "name", err: errors.New(`ent: missing required field "ErrorPassthroughRule.name"`)} + } + if v, ok := _c.mutation.Name(); ok { + if err := errorpassthroughrule.NameValidator(v); err != nil { + return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ErrorPassthroughRule.name": %w`, err)} + } + } + if _, ok := _c.mutation.Enabled(); !ok { + return &ValidationError{Name: "enabled", err: errors.New(`ent: missing required field "ErrorPassthroughRule.enabled"`)} + } + if _, ok := _c.mutation.Priority(); !ok { + return &ValidationError{Name: "priority", err: errors.New(`ent: missing required field "ErrorPassthroughRule.priority"`)} + } + if _, ok := _c.mutation.MatchMode(); !ok { + return &ValidationError{Name: "match_mode", err: errors.New(`ent: missing required field "ErrorPassthroughRule.match_mode"`)} + } + if v, ok := _c.mutation.MatchMode(); ok { + if err := errorpassthroughrule.MatchModeValidator(v); err != nil { + return &ValidationError{Name: "match_mode", err: fmt.Errorf(`ent: validator failed for field "ErrorPassthroughRule.match_mode": %w`, err)} + } + } + if _, ok := _c.mutation.PassthroughCode(); !ok { + return &ValidationError{Name: "passthrough_code", err: errors.New(`ent: missing required field "ErrorPassthroughRule.passthrough_code"`)} + } + if _, ok := _c.mutation.PassthroughBody(); !ok { + return &ValidationError{Name: "passthrough_body", err: errors.New(`ent: missing required field "ErrorPassthroughRule.passthrough_body"`)} + } + return nil +} + +func (_c *ErrorPassthroughRuleCreate) sqlSave(ctx context.Context) (*ErrorPassthroughRule, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *ErrorPassthroughRuleCreate) createSpec() (*ErrorPassthroughRule, *sqlgraph.CreateSpec) { + var ( + _node = &ErrorPassthroughRule{config: _c.config} + _spec = sqlgraph.NewCreateSpec(errorpassthroughrule.Table, sqlgraph.NewFieldSpec(errorpassthroughrule.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(errorpassthroughrule.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(errorpassthroughrule.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.Name(); ok { + _spec.SetField(errorpassthroughrule.FieldName, field.TypeString, value) + _node.Name = value + } + if value, ok := _c.mutation.Enabled(); ok { + _spec.SetField(errorpassthroughrule.FieldEnabled, field.TypeBool, value) + _node.Enabled = value + } + if value, ok := _c.mutation.Priority(); ok { + _spec.SetField(errorpassthroughrule.FieldPriority, field.TypeInt, value) + _node.Priority = value + } + if value, ok := _c.mutation.ErrorCodes(); ok { + _spec.SetField(errorpassthroughrule.FieldErrorCodes, field.TypeJSON, value) + _node.ErrorCodes = value + } + if value, ok := _c.mutation.Keywords(); ok { + _spec.SetField(errorpassthroughrule.FieldKeywords, field.TypeJSON, value) + _node.Keywords = value + } + if value, ok := _c.mutation.MatchMode(); ok { + _spec.SetField(errorpassthroughrule.FieldMatchMode, field.TypeString, value) + _node.MatchMode = value + } + if value, ok := _c.mutation.Platforms(); ok { + _spec.SetField(errorpassthroughrule.FieldPlatforms, field.TypeJSON, value) + _node.Platforms = value + } + if value, ok := _c.mutation.PassthroughCode(); ok { + _spec.SetField(errorpassthroughrule.FieldPassthroughCode, field.TypeBool, value) + _node.PassthroughCode = value + } + if value, ok := _c.mutation.ResponseCode(); ok { + _spec.SetField(errorpassthroughrule.FieldResponseCode, field.TypeInt, value) + _node.ResponseCode = &value + } + if value, ok := _c.mutation.PassthroughBody(); ok { + _spec.SetField(errorpassthroughrule.FieldPassthroughBody, field.TypeBool, value) + _node.PassthroughBody = value + } + if value, ok := _c.mutation.CustomMessage(); ok { + _spec.SetField(errorpassthroughrule.FieldCustomMessage, field.TypeString, value) + _node.CustomMessage = &value + } + if value, ok := _c.mutation.Description(); ok { + _spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value) + _node.Description = &value + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.ErrorPassthroughRule.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.ErrorPassthroughRuleUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *ErrorPassthroughRuleCreate) OnConflict(opts ...sql.ConflictOption) *ErrorPassthroughRuleUpsertOne { + _c.conflict = opts + return &ErrorPassthroughRuleUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.ErrorPassthroughRule.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *ErrorPassthroughRuleCreate) OnConflictColumns(columns ...string) *ErrorPassthroughRuleUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &ErrorPassthroughRuleUpsertOne{ + create: _c, + } +} + +type ( + // ErrorPassthroughRuleUpsertOne is the builder for "upsert"-ing + // one ErrorPassthroughRule node. + ErrorPassthroughRuleUpsertOne struct { + create *ErrorPassthroughRuleCreate + } + + // ErrorPassthroughRuleUpsert is the "OnConflict" setter. + ErrorPassthroughRuleUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *ErrorPassthroughRuleUpsert) SetUpdatedAt(v time.Time) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdateUpdatedAt() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldUpdatedAt) + return u +} + +// SetName sets the "name" field. +func (u *ErrorPassthroughRuleUpsert) SetName(v string) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldName, v) + return u +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdateName() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldName) + return u +} + +// SetEnabled sets the "enabled" field. +func (u *ErrorPassthroughRuleUpsert) SetEnabled(v bool) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldEnabled, v) + return u +} + +// UpdateEnabled sets the "enabled" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdateEnabled() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldEnabled) + return u +} + +// SetPriority sets the "priority" field. +func (u *ErrorPassthroughRuleUpsert) SetPriority(v int) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldPriority, v) + return u +} + +// UpdatePriority sets the "priority" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdatePriority() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldPriority) + return u +} + +// AddPriority adds v to the "priority" field. +func (u *ErrorPassthroughRuleUpsert) AddPriority(v int) *ErrorPassthroughRuleUpsert { + u.Add(errorpassthroughrule.FieldPriority, v) + return u +} + +// SetErrorCodes sets the "error_codes" field. +func (u *ErrorPassthroughRuleUpsert) SetErrorCodes(v []int) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldErrorCodes, v) + return u +} + +// UpdateErrorCodes sets the "error_codes" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdateErrorCodes() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldErrorCodes) + return u +} + +// ClearErrorCodes clears the value of the "error_codes" field. +func (u *ErrorPassthroughRuleUpsert) ClearErrorCodes() *ErrorPassthroughRuleUpsert { + u.SetNull(errorpassthroughrule.FieldErrorCodes) + return u +} + +// SetKeywords sets the "keywords" field. +func (u *ErrorPassthroughRuleUpsert) SetKeywords(v []string) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldKeywords, v) + return u +} + +// UpdateKeywords sets the "keywords" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdateKeywords() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldKeywords) + return u +} + +// ClearKeywords clears the value of the "keywords" field. +func (u *ErrorPassthroughRuleUpsert) ClearKeywords() *ErrorPassthroughRuleUpsert { + u.SetNull(errorpassthroughrule.FieldKeywords) + return u +} + +// SetMatchMode sets the "match_mode" field. +func (u *ErrorPassthroughRuleUpsert) SetMatchMode(v string) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldMatchMode, v) + return u +} + +// UpdateMatchMode sets the "match_mode" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdateMatchMode() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldMatchMode) + return u +} + +// SetPlatforms sets the "platforms" field. +func (u *ErrorPassthroughRuleUpsert) SetPlatforms(v []string) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldPlatforms, v) + return u +} + +// UpdatePlatforms sets the "platforms" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdatePlatforms() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldPlatforms) + return u +} + +// ClearPlatforms clears the value of the "platforms" field. +func (u *ErrorPassthroughRuleUpsert) ClearPlatforms() *ErrorPassthroughRuleUpsert { + u.SetNull(errorpassthroughrule.FieldPlatforms) + return u +} + +// SetPassthroughCode sets the "passthrough_code" field. +func (u *ErrorPassthroughRuleUpsert) SetPassthroughCode(v bool) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldPassthroughCode, v) + return u +} + +// UpdatePassthroughCode sets the "passthrough_code" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdatePassthroughCode() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldPassthroughCode) + return u +} + +// SetResponseCode sets the "response_code" field. +func (u *ErrorPassthroughRuleUpsert) SetResponseCode(v int) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldResponseCode, v) + return u +} + +// UpdateResponseCode sets the "response_code" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdateResponseCode() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldResponseCode) + return u +} + +// AddResponseCode adds v to the "response_code" field. +func (u *ErrorPassthroughRuleUpsert) AddResponseCode(v int) *ErrorPassthroughRuleUpsert { + u.Add(errorpassthroughrule.FieldResponseCode, v) + return u +} + +// ClearResponseCode clears the value of the "response_code" field. +func (u *ErrorPassthroughRuleUpsert) ClearResponseCode() *ErrorPassthroughRuleUpsert { + u.SetNull(errorpassthroughrule.FieldResponseCode) + return u +} + +// SetPassthroughBody sets the "passthrough_body" field. +func (u *ErrorPassthroughRuleUpsert) SetPassthroughBody(v bool) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldPassthroughBody, v) + return u +} + +// UpdatePassthroughBody sets the "passthrough_body" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdatePassthroughBody() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldPassthroughBody) + return u +} + +// SetCustomMessage sets the "custom_message" field. +func (u *ErrorPassthroughRuleUpsert) SetCustomMessage(v string) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldCustomMessage, v) + return u +} + +// UpdateCustomMessage sets the "custom_message" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdateCustomMessage() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldCustomMessage) + return u +} + +// ClearCustomMessage clears the value of the "custom_message" field. +func (u *ErrorPassthroughRuleUpsert) ClearCustomMessage() *ErrorPassthroughRuleUpsert { + u.SetNull(errorpassthroughrule.FieldCustomMessage) + return u +} + +// SetDescription sets the "description" field. +func (u *ErrorPassthroughRuleUpsert) SetDescription(v string) *ErrorPassthroughRuleUpsert { + u.Set(errorpassthroughrule.FieldDescription, v) + return u +} + +// UpdateDescription sets the "description" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsert) UpdateDescription() *ErrorPassthroughRuleUpsert { + u.SetExcluded(errorpassthroughrule.FieldDescription) + return u +} + +// ClearDescription clears the value of the "description" field. +func (u *ErrorPassthroughRuleUpsert) ClearDescription() *ErrorPassthroughRuleUpsert { + u.SetNull(errorpassthroughrule.FieldDescription) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.ErrorPassthroughRule.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *ErrorPassthroughRuleUpsertOne) UpdateNewValues() *ErrorPassthroughRuleUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(errorpassthroughrule.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.ErrorPassthroughRule.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *ErrorPassthroughRuleUpsertOne) Ignore() *ErrorPassthroughRuleUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *ErrorPassthroughRuleUpsertOne) DoNothing() *ErrorPassthroughRuleUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the ErrorPassthroughRuleCreate.OnConflict +// documentation for more info. +func (u *ErrorPassthroughRuleUpsertOne) Update(set func(*ErrorPassthroughRuleUpsert)) *ErrorPassthroughRuleUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&ErrorPassthroughRuleUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *ErrorPassthroughRuleUpsertOne) SetUpdatedAt(v time.Time) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdateUpdatedAt() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetName sets the "name" field. +func (u *ErrorPassthroughRuleUpsertOne) SetName(v string) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetName(v) + }) +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdateName() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateName() + }) +} + +// SetEnabled sets the "enabled" field. +func (u *ErrorPassthroughRuleUpsertOne) SetEnabled(v bool) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetEnabled(v) + }) +} + +// UpdateEnabled sets the "enabled" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdateEnabled() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateEnabled() + }) +} + +// SetPriority sets the "priority" field. +func (u *ErrorPassthroughRuleUpsertOne) SetPriority(v int) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetPriority(v) + }) +} + +// AddPriority adds v to the "priority" field. +func (u *ErrorPassthroughRuleUpsertOne) AddPriority(v int) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.AddPriority(v) + }) +} + +// UpdatePriority sets the "priority" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdatePriority() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdatePriority() + }) +} + +// SetErrorCodes sets the "error_codes" field. +func (u *ErrorPassthroughRuleUpsertOne) SetErrorCodes(v []int) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetErrorCodes(v) + }) +} + +// UpdateErrorCodes sets the "error_codes" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdateErrorCodes() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateErrorCodes() + }) +} + +// ClearErrorCodes clears the value of the "error_codes" field. +func (u *ErrorPassthroughRuleUpsertOne) ClearErrorCodes() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearErrorCodes() + }) +} + +// SetKeywords sets the "keywords" field. +func (u *ErrorPassthroughRuleUpsertOne) SetKeywords(v []string) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetKeywords(v) + }) +} + +// UpdateKeywords sets the "keywords" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdateKeywords() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateKeywords() + }) +} + +// ClearKeywords clears the value of the "keywords" field. +func (u *ErrorPassthroughRuleUpsertOne) ClearKeywords() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearKeywords() + }) +} + +// SetMatchMode sets the "match_mode" field. +func (u *ErrorPassthroughRuleUpsertOne) SetMatchMode(v string) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetMatchMode(v) + }) +} + +// UpdateMatchMode sets the "match_mode" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdateMatchMode() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateMatchMode() + }) +} + +// SetPlatforms sets the "platforms" field. +func (u *ErrorPassthroughRuleUpsertOne) SetPlatforms(v []string) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetPlatforms(v) + }) +} + +// UpdatePlatforms sets the "platforms" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdatePlatforms() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdatePlatforms() + }) +} + +// ClearPlatforms clears the value of the "platforms" field. +func (u *ErrorPassthroughRuleUpsertOne) ClearPlatforms() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearPlatforms() + }) +} + +// SetPassthroughCode sets the "passthrough_code" field. +func (u *ErrorPassthroughRuleUpsertOne) SetPassthroughCode(v bool) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetPassthroughCode(v) + }) +} + +// UpdatePassthroughCode sets the "passthrough_code" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdatePassthroughCode() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdatePassthroughCode() + }) +} + +// SetResponseCode sets the "response_code" field. +func (u *ErrorPassthroughRuleUpsertOne) SetResponseCode(v int) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetResponseCode(v) + }) +} + +// AddResponseCode adds v to the "response_code" field. +func (u *ErrorPassthroughRuleUpsertOne) AddResponseCode(v int) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.AddResponseCode(v) + }) +} + +// UpdateResponseCode sets the "response_code" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdateResponseCode() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateResponseCode() + }) +} + +// ClearResponseCode clears the value of the "response_code" field. +func (u *ErrorPassthroughRuleUpsertOne) ClearResponseCode() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearResponseCode() + }) +} + +// SetPassthroughBody sets the "passthrough_body" field. +func (u *ErrorPassthroughRuleUpsertOne) SetPassthroughBody(v bool) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetPassthroughBody(v) + }) +} + +// UpdatePassthroughBody sets the "passthrough_body" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdatePassthroughBody() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdatePassthroughBody() + }) +} + +// SetCustomMessage sets the "custom_message" field. +func (u *ErrorPassthroughRuleUpsertOne) SetCustomMessage(v string) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetCustomMessage(v) + }) +} + +// UpdateCustomMessage sets the "custom_message" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdateCustomMessage() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateCustomMessage() + }) +} + +// ClearCustomMessage clears the value of the "custom_message" field. +func (u *ErrorPassthroughRuleUpsertOne) ClearCustomMessage() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearCustomMessage() + }) +} + +// SetDescription sets the "description" field. +func (u *ErrorPassthroughRuleUpsertOne) SetDescription(v string) *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetDescription(v) + }) +} + +// UpdateDescription sets the "description" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertOne) UpdateDescription() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateDescription() + }) +} + +// ClearDescription clears the value of the "description" field. +func (u *ErrorPassthroughRuleUpsertOne) ClearDescription() *ErrorPassthroughRuleUpsertOne { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearDescription() + }) +} + +// Exec executes the query. +func (u *ErrorPassthroughRuleUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for ErrorPassthroughRuleCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *ErrorPassthroughRuleUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *ErrorPassthroughRuleUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *ErrorPassthroughRuleUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// ErrorPassthroughRuleCreateBulk is the builder for creating many ErrorPassthroughRule entities in bulk. +type ErrorPassthroughRuleCreateBulk struct { + config + err error + builders []*ErrorPassthroughRuleCreate + conflict []sql.ConflictOption +} + +// Save creates the ErrorPassthroughRule entities in the database. +func (_c *ErrorPassthroughRuleCreateBulk) Save(ctx context.Context) ([]*ErrorPassthroughRule, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*ErrorPassthroughRule, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*ErrorPassthroughRuleMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *ErrorPassthroughRuleCreateBulk) SaveX(ctx context.Context) []*ErrorPassthroughRule { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *ErrorPassthroughRuleCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *ErrorPassthroughRuleCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.ErrorPassthroughRule.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.ErrorPassthroughRuleUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *ErrorPassthroughRuleCreateBulk) OnConflict(opts ...sql.ConflictOption) *ErrorPassthroughRuleUpsertBulk { + _c.conflict = opts + return &ErrorPassthroughRuleUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.ErrorPassthroughRule.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *ErrorPassthroughRuleCreateBulk) OnConflictColumns(columns ...string) *ErrorPassthroughRuleUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &ErrorPassthroughRuleUpsertBulk{ + create: _c, + } +} + +// ErrorPassthroughRuleUpsertBulk is the builder for "upsert"-ing +// a bulk of ErrorPassthroughRule nodes. +type ErrorPassthroughRuleUpsertBulk struct { + create *ErrorPassthroughRuleCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.ErrorPassthroughRule.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *ErrorPassthroughRuleUpsertBulk) UpdateNewValues() *ErrorPassthroughRuleUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(errorpassthroughrule.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.ErrorPassthroughRule.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *ErrorPassthroughRuleUpsertBulk) Ignore() *ErrorPassthroughRuleUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *ErrorPassthroughRuleUpsertBulk) DoNothing() *ErrorPassthroughRuleUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the ErrorPassthroughRuleCreateBulk.OnConflict +// documentation for more info. +func (u *ErrorPassthroughRuleUpsertBulk) Update(set func(*ErrorPassthroughRuleUpsert)) *ErrorPassthroughRuleUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&ErrorPassthroughRuleUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetUpdatedAt(v time.Time) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdateUpdatedAt() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetName sets the "name" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetName(v string) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetName(v) + }) +} + +// UpdateName sets the "name" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdateName() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateName() + }) +} + +// SetEnabled sets the "enabled" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetEnabled(v bool) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetEnabled(v) + }) +} + +// UpdateEnabled sets the "enabled" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdateEnabled() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateEnabled() + }) +} + +// SetPriority sets the "priority" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetPriority(v int) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetPriority(v) + }) +} + +// AddPriority adds v to the "priority" field. +func (u *ErrorPassthroughRuleUpsertBulk) AddPriority(v int) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.AddPriority(v) + }) +} + +// UpdatePriority sets the "priority" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdatePriority() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdatePriority() + }) +} + +// SetErrorCodes sets the "error_codes" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetErrorCodes(v []int) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetErrorCodes(v) + }) +} + +// UpdateErrorCodes sets the "error_codes" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdateErrorCodes() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateErrorCodes() + }) +} + +// ClearErrorCodes clears the value of the "error_codes" field. +func (u *ErrorPassthroughRuleUpsertBulk) ClearErrorCodes() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearErrorCodes() + }) +} + +// SetKeywords sets the "keywords" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetKeywords(v []string) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetKeywords(v) + }) +} + +// UpdateKeywords sets the "keywords" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdateKeywords() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateKeywords() + }) +} + +// ClearKeywords clears the value of the "keywords" field. +func (u *ErrorPassthroughRuleUpsertBulk) ClearKeywords() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearKeywords() + }) +} + +// SetMatchMode sets the "match_mode" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetMatchMode(v string) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetMatchMode(v) + }) +} + +// UpdateMatchMode sets the "match_mode" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdateMatchMode() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateMatchMode() + }) +} + +// SetPlatforms sets the "platforms" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetPlatforms(v []string) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetPlatforms(v) + }) +} + +// UpdatePlatforms sets the "platforms" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdatePlatforms() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdatePlatforms() + }) +} + +// ClearPlatforms clears the value of the "platforms" field. +func (u *ErrorPassthroughRuleUpsertBulk) ClearPlatforms() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearPlatforms() + }) +} + +// SetPassthroughCode sets the "passthrough_code" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetPassthroughCode(v bool) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetPassthroughCode(v) + }) +} + +// UpdatePassthroughCode sets the "passthrough_code" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdatePassthroughCode() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdatePassthroughCode() + }) +} + +// SetResponseCode sets the "response_code" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetResponseCode(v int) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetResponseCode(v) + }) +} + +// AddResponseCode adds v to the "response_code" field. +func (u *ErrorPassthroughRuleUpsertBulk) AddResponseCode(v int) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.AddResponseCode(v) + }) +} + +// UpdateResponseCode sets the "response_code" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdateResponseCode() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateResponseCode() + }) +} + +// ClearResponseCode clears the value of the "response_code" field. +func (u *ErrorPassthroughRuleUpsertBulk) ClearResponseCode() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearResponseCode() + }) +} + +// SetPassthroughBody sets the "passthrough_body" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetPassthroughBody(v bool) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetPassthroughBody(v) + }) +} + +// UpdatePassthroughBody sets the "passthrough_body" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdatePassthroughBody() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdatePassthroughBody() + }) +} + +// SetCustomMessage sets the "custom_message" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetCustomMessage(v string) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetCustomMessage(v) + }) +} + +// UpdateCustomMessage sets the "custom_message" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdateCustomMessage() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateCustomMessage() + }) +} + +// ClearCustomMessage clears the value of the "custom_message" field. +func (u *ErrorPassthroughRuleUpsertBulk) ClearCustomMessage() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearCustomMessage() + }) +} + +// SetDescription sets the "description" field. +func (u *ErrorPassthroughRuleUpsertBulk) SetDescription(v string) *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.SetDescription(v) + }) +} + +// UpdateDescription sets the "description" field to the value that was provided on create. +func (u *ErrorPassthroughRuleUpsertBulk) UpdateDescription() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.UpdateDescription() + }) +} + +// ClearDescription clears the value of the "description" field. +func (u *ErrorPassthroughRuleUpsertBulk) ClearDescription() *ErrorPassthroughRuleUpsertBulk { + return u.Update(func(s *ErrorPassthroughRuleUpsert) { + s.ClearDescription() + }) +} + +// Exec executes the query. +func (u *ErrorPassthroughRuleUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the ErrorPassthroughRuleCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for ErrorPassthroughRuleCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *ErrorPassthroughRuleUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/errorpassthroughrule_delete.go b/backend/ent/errorpassthroughrule_delete.go new file mode 100644 index 00000000..943c7e2b --- /dev/null +++ b/backend/ent/errorpassthroughrule_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ErrorPassthroughRuleDelete is the builder for deleting a ErrorPassthroughRule entity. +type ErrorPassthroughRuleDelete struct { + config + hooks []Hook + mutation *ErrorPassthroughRuleMutation +} + +// Where appends a list predicates to the ErrorPassthroughRuleDelete builder. +func (_d *ErrorPassthroughRuleDelete) Where(ps ...predicate.ErrorPassthroughRule) *ErrorPassthroughRuleDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *ErrorPassthroughRuleDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *ErrorPassthroughRuleDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *ErrorPassthroughRuleDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(errorpassthroughrule.Table, sqlgraph.NewFieldSpec(errorpassthroughrule.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// ErrorPassthroughRuleDeleteOne is the builder for deleting a single ErrorPassthroughRule entity. +type ErrorPassthroughRuleDeleteOne struct { + _d *ErrorPassthroughRuleDelete +} + +// Where appends a list predicates to the ErrorPassthroughRuleDelete builder. +func (_d *ErrorPassthroughRuleDeleteOne) Where(ps ...predicate.ErrorPassthroughRule) *ErrorPassthroughRuleDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *ErrorPassthroughRuleDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{errorpassthroughrule.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *ErrorPassthroughRuleDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/errorpassthroughrule_query.go b/backend/ent/errorpassthroughrule_query.go new file mode 100644 index 00000000..bfab5bd8 --- /dev/null +++ b/backend/ent/errorpassthroughrule_query.go @@ -0,0 +1,564 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ErrorPassthroughRuleQuery is the builder for querying ErrorPassthroughRule entities. +type ErrorPassthroughRuleQuery struct { + config + ctx *QueryContext + order []errorpassthroughrule.OrderOption + inters []Interceptor + predicates []predicate.ErrorPassthroughRule + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the ErrorPassthroughRuleQuery builder. +func (_q *ErrorPassthroughRuleQuery) Where(ps ...predicate.ErrorPassthroughRule) *ErrorPassthroughRuleQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *ErrorPassthroughRuleQuery) Limit(limit int) *ErrorPassthroughRuleQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *ErrorPassthroughRuleQuery) Offset(offset int) *ErrorPassthroughRuleQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *ErrorPassthroughRuleQuery) Unique(unique bool) *ErrorPassthroughRuleQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *ErrorPassthroughRuleQuery) Order(o ...errorpassthroughrule.OrderOption) *ErrorPassthroughRuleQuery { + _q.order = append(_q.order, o...) + return _q +} + +// First returns the first ErrorPassthroughRule entity from the query. +// Returns a *NotFoundError when no ErrorPassthroughRule was found. +func (_q *ErrorPassthroughRuleQuery) First(ctx context.Context) (*ErrorPassthroughRule, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{errorpassthroughrule.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *ErrorPassthroughRuleQuery) FirstX(ctx context.Context) *ErrorPassthroughRule { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first ErrorPassthroughRule ID from the query. +// Returns a *NotFoundError when no ErrorPassthroughRule ID was found. +func (_q *ErrorPassthroughRuleQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{errorpassthroughrule.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *ErrorPassthroughRuleQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single ErrorPassthroughRule entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one ErrorPassthroughRule entity is found. +// Returns a *NotFoundError when no ErrorPassthroughRule entities are found. +func (_q *ErrorPassthroughRuleQuery) Only(ctx context.Context) (*ErrorPassthroughRule, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{errorpassthroughrule.Label} + default: + return nil, &NotSingularError{errorpassthroughrule.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *ErrorPassthroughRuleQuery) OnlyX(ctx context.Context) *ErrorPassthroughRule { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only ErrorPassthroughRule ID in the query. +// Returns a *NotSingularError when more than one ErrorPassthroughRule ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *ErrorPassthroughRuleQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{errorpassthroughrule.Label} + default: + err = &NotSingularError{errorpassthroughrule.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *ErrorPassthroughRuleQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of ErrorPassthroughRules. +func (_q *ErrorPassthroughRuleQuery) All(ctx context.Context) ([]*ErrorPassthroughRule, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*ErrorPassthroughRule, *ErrorPassthroughRuleQuery]() + return withInterceptors[[]*ErrorPassthroughRule](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *ErrorPassthroughRuleQuery) AllX(ctx context.Context) []*ErrorPassthroughRule { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of ErrorPassthroughRule IDs. +func (_q *ErrorPassthroughRuleQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(errorpassthroughrule.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *ErrorPassthroughRuleQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *ErrorPassthroughRuleQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*ErrorPassthroughRuleQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *ErrorPassthroughRuleQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *ErrorPassthroughRuleQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *ErrorPassthroughRuleQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the ErrorPassthroughRuleQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *ErrorPassthroughRuleQuery) Clone() *ErrorPassthroughRuleQuery { + if _q == nil { + return nil + } + return &ErrorPassthroughRuleQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]errorpassthroughrule.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.ErrorPassthroughRule{}, _q.predicates...), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.ErrorPassthroughRule.Query(). +// GroupBy(errorpassthroughrule.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *ErrorPassthroughRuleQuery) GroupBy(field string, fields ...string) *ErrorPassthroughRuleGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &ErrorPassthroughRuleGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = errorpassthroughrule.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.ErrorPassthroughRule.Query(). +// Select(errorpassthroughrule.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *ErrorPassthroughRuleQuery) Select(fields ...string) *ErrorPassthroughRuleSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &ErrorPassthroughRuleSelect{ErrorPassthroughRuleQuery: _q} + sbuild.label = errorpassthroughrule.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a ErrorPassthroughRuleSelect configured with the given aggregations. +func (_q *ErrorPassthroughRuleQuery) Aggregate(fns ...AggregateFunc) *ErrorPassthroughRuleSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *ErrorPassthroughRuleQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !errorpassthroughrule.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *ErrorPassthroughRuleQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*ErrorPassthroughRule, error) { + var ( + nodes = []*ErrorPassthroughRule{} + _spec = _q.querySpec() + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*ErrorPassthroughRule).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &ErrorPassthroughRule{config: _q.config} + nodes = append(nodes, node) + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + return nodes, nil +} + +func (_q *ErrorPassthroughRuleQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *ErrorPassthroughRuleQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(errorpassthroughrule.Table, errorpassthroughrule.Columns, sqlgraph.NewFieldSpec(errorpassthroughrule.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, errorpassthroughrule.FieldID) + for i := range fields { + if fields[i] != errorpassthroughrule.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *ErrorPassthroughRuleQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(errorpassthroughrule.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = errorpassthroughrule.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *ErrorPassthroughRuleQuery) ForUpdate(opts ...sql.LockOption) *ErrorPassthroughRuleQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *ErrorPassthroughRuleQuery) ForShare(opts ...sql.LockOption) *ErrorPassthroughRuleQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// ErrorPassthroughRuleGroupBy is the group-by builder for ErrorPassthroughRule entities. +type ErrorPassthroughRuleGroupBy struct { + selector + build *ErrorPassthroughRuleQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *ErrorPassthroughRuleGroupBy) Aggregate(fns ...AggregateFunc) *ErrorPassthroughRuleGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *ErrorPassthroughRuleGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*ErrorPassthroughRuleQuery, *ErrorPassthroughRuleGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *ErrorPassthroughRuleGroupBy) sqlScan(ctx context.Context, root *ErrorPassthroughRuleQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// ErrorPassthroughRuleSelect is the builder for selecting fields of ErrorPassthroughRule entities. +type ErrorPassthroughRuleSelect struct { + *ErrorPassthroughRuleQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *ErrorPassthroughRuleSelect) Aggregate(fns ...AggregateFunc) *ErrorPassthroughRuleSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *ErrorPassthroughRuleSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*ErrorPassthroughRuleQuery, *ErrorPassthroughRuleSelect](ctx, _s.ErrorPassthroughRuleQuery, _s, _s.inters, v) +} + +func (_s *ErrorPassthroughRuleSelect) sqlScan(ctx context.Context, root *ErrorPassthroughRuleQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/errorpassthroughrule_update.go b/backend/ent/errorpassthroughrule_update.go new file mode 100644 index 00000000..9d52aa49 --- /dev/null +++ b/backend/ent/errorpassthroughrule_update.go @@ -0,0 +1,823 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/dialect/sql/sqljson" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ErrorPassthroughRuleUpdate is the builder for updating ErrorPassthroughRule entities. +type ErrorPassthroughRuleUpdate struct { + config + hooks []Hook + mutation *ErrorPassthroughRuleMutation +} + +// Where appends a list predicates to the ErrorPassthroughRuleUpdate builder. +func (_u *ErrorPassthroughRuleUpdate) Where(ps ...predicate.ErrorPassthroughRule) *ErrorPassthroughRuleUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *ErrorPassthroughRuleUpdate) SetUpdatedAt(v time.Time) *ErrorPassthroughRuleUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetName sets the "name" field. +func (_u *ErrorPassthroughRuleUpdate) SetName(v string) *ErrorPassthroughRuleUpdate { + _u.mutation.SetName(v) + return _u +} + +// SetNillableName sets the "name" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdate) SetNillableName(v *string) *ErrorPassthroughRuleUpdate { + if v != nil { + _u.SetName(*v) + } + return _u +} + +// SetEnabled sets the "enabled" field. +func (_u *ErrorPassthroughRuleUpdate) SetEnabled(v bool) *ErrorPassthroughRuleUpdate { + _u.mutation.SetEnabled(v) + return _u +} + +// SetNillableEnabled sets the "enabled" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdate) SetNillableEnabled(v *bool) *ErrorPassthroughRuleUpdate { + if v != nil { + _u.SetEnabled(*v) + } + return _u +} + +// SetPriority sets the "priority" field. +func (_u *ErrorPassthroughRuleUpdate) SetPriority(v int) *ErrorPassthroughRuleUpdate { + _u.mutation.ResetPriority() + _u.mutation.SetPriority(v) + return _u +} + +// SetNillablePriority sets the "priority" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdate) SetNillablePriority(v *int) *ErrorPassthroughRuleUpdate { + if v != nil { + _u.SetPriority(*v) + } + return _u +} + +// AddPriority adds value to the "priority" field. +func (_u *ErrorPassthroughRuleUpdate) AddPriority(v int) *ErrorPassthroughRuleUpdate { + _u.mutation.AddPriority(v) + return _u +} + +// SetErrorCodes sets the "error_codes" field. +func (_u *ErrorPassthroughRuleUpdate) SetErrorCodes(v []int) *ErrorPassthroughRuleUpdate { + _u.mutation.SetErrorCodes(v) + return _u +} + +// AppendErrorCodes appends value to the "error_codes" field. +func (_u *ErrorPassthroughRuleUpdate) AppendErrorCodes(v []int) *ErrorPassthroughRuleUpdate { + _u.mutation.AppendErrorCodes(v) + return _u +} + +// ClearErrorCodes clears the value of the "error_codes" field. +func (_u *ErrorPassthroughRuleUpdate) ClearErrorCodes() *ErrorPassthroughRuleUpdate { + _u.mutation.ClearErrorCodes() + return _u +} + +// SetKeywords sets the "keywords" field. +func (_u *ErrorPassthroughRuleUpdate) SetKeywords(v []string) *ErrorPassthroughRuleUpdate { + _u.mutation.SetKeywords(v) + return _u +} + +// AppendKeywords appends value to the "keywords" field. +func (_u *ErrorPassthroughRuleUpdate) AppendKeywords(v []string) *ErrorPassthroughRuleUpdate { + _u.mutation.AppendKeywords(v) + return _u +} + +// ClearKeywords clears the value of the "keywords" field. +func (_u *ErrorPassthroughRuleUpdate) ClearKeywords() *ErrorPassthroughRuleUpdate { + _u.mutation.ClearKeywords() + return _u +} + +// SetMatchMode sets the "match_mode" field. +func (_u *ErrorPassthroughRuleUpdate) SetMatchMode(v string) *ErrorPassthroughRuleUpdate { + _u.mutation.SetMatchMode(v) + return _u +} + +// SetNillableMatchMode sets the "match_mode" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdate) SetNillableMatchMode(v *string) *ErrorPassthroughRuleUpdate { + if v != nil { + _u.SetMatchMode(*v) + } + return _u +} + +// SetPlatforms sets the "platforms" field. +func (_u *ErrorPassthroughRuleUpdate) SetPlatforms(v []string) *ErrorPassthroughRuleUpdate { + _u.mutation.SetPlatforms(v) + return _u +} + +// AppendPlatforms appends value to the "platforms" field. +func (_u *ErrorPassthroughRuleUpdate) AppendPlatforms(v []string) *ErrorPassthroughRuleUpdate { + _u.mutation.AppendPlatforms(v) + return _u +} + +// ClearPlatforms clears the value of the "platforms" field. +func (_u *ErrorPassthroughRuleUpdate) ClearPlatforms() *ErrorPassthroughRuleUpdate { + _u.mutation.ClearPlatforms() + return _u +} + +// SetPassthroughCode sets the "passthrough_code" field. +func (_u *ErrorPassthroughRuleUpdate) SetPassthroughCode(v bool) *ErrorPassthroughRuleUpdate { + _u.mutation.SetPassthroughCode(v) + return _u +} + +// SetNillablePassthroughCode sets the "passthrough_code" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdate) SetNillablePassthroughCode(v *bool) *ErrorPassthroughRuleUpdate { + if v != nil { + _u.SetPassthroughCode(*v) + } + return _u +} + +// SetResponseCode sets the "response_code" field. +func (_u *ErrorPassthroughRuleUpdate) SetResponseCode(v int) *ErrorPassthroughRuleUpdate { + _u.mutation.ResetResponseCode() + _u.mutation.SetResponseCode(v) + return _u +} + +// SetNillableResponseCode sets the "response_code" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdate) SetNillableResponseCode(v *int) *ErrorPassthroughRuleUpdate { + if v != nil { + _u.SetResponseCode(*v) + } + return _u +} + +// AddResponseCode adds value to the "response_code" field. +func (_u *ErrorPassthroughRuleUpdate) AddResponseCode(v int) *ErrorPassthroughRuleUpdate { + _u.mutation.AddResponseCode(v) + return _u +} + +// ClearResponseCode clears the value of the "response_code" field. +func (_u *ErrorPassthroughRuleUpdate) ClearResponseCode() *ErrorPassthroughRuleUpdate { + _u.mutation.ClearResponseCode() + return _u +} + +// SetPassthroughBody sets the "passthrough_body" field. +func (_u *ErrorPassthroughRuleUpdate) SetPassthroughBody(v bool) *ErrorPassthroughRuleUpdate { + _u.mutation.SetPassthroughBody(v) + return _u +} + +// SetNillablePassthroughBody sets the "passthrough_body" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdate) SetNillablePassthroughBody(v *bool) *ErrorPassthroughRuleUpdate { + if v != nil { + _u.SetPassthroughBody(*v) + } + return _u +} + +// SetCustomMessage sets the "custom_message" field. +func (_u *ErrorPassthroughRuleUpdate) SetCustomMessage(v string) *ErrorPassthroughRuleUpdate { + _u.mutation.SetCustomMessage(v) + return _u +} + +// SetNillableCustomMessage sets the "custom_message" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdate) SetNillableCustomMessage(v *string) *ErrorPassthroughRuleUpdate { + if v != nil { + _u.SetCustomMessage(*v) + } + return _u +} + +// ClearCustomMessage clears the value of the "custom_message" field. +func (_u *ErrorPassthroughRuleUpdate) ClearCustomMessage() *ErrorPassthroughRuleUpdate { + _u.mutation.ClearCustomMessage() + return _u +} + +// SetDescription sets the "description" field. +func (_u *ErrorPassthroughRuleUpdate) SetDescription(v string) *ErrorPassthroughRuleUpdate { + _u.mutation.SetDescription(v) + return _u +} + +// SetNillableDescription sets the "description" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdate) SetNillableDescription(v *string) *ErrorPassthroughRuleUpdate { + if v != nil { + _u.SetDescription(*v) + } + return _u +} + +// ClearDescription clears the value of the "description" field. +func (_u *ErrorPassthroughRuleUpdate) ClearDescription() *ErrorPassthroughRuleUpdate { + _u.mutation.ClearDescription() + return _u +} + +// Mutation returns the ErrorPassthroughRuleMutation object of the builder. +func (_u *ErrorPassthroughRuleUpdate) Mutation() *ErrorPassthroughRuleMutation { + return _u.mutation +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *ErrorPassthroughRuleUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *ErrorPassthroughRuleUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *ErrorPassthroughRuleUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *ErrorPassthroughRuleUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *ErrorPassthroughRuleUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := errorpassthroughrule.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *ErrorPassthroughRuleUpdate) check() error { + if v, ok := _u.mutation.Name(); ok { + if err := errorpassthroughrule.NameValidator(v); err != nil { + return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ErrorPassthroughRule.name": %w`, err)} + } + } + if v, ok := _u.mutation.MatchMode(); ok { + if err := errorpassthroughrule.MatchModeValidator(v); err != nil { + return &ValidationError{Name: "match_mode", err: fmt.Errorf(`ent: validator failed for field "ErrorPassthroughRule.match_mode": %w`, err)} + } + } + return nil +} + +func (_u *ErrorPassthroughRuleUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(errorpassthroughrule.Table, errorpassthroughrule.Columns, sqlgraph.NewFieldSpec(errorpassthroughrule.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(errorpassthroughrule.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.Name(); ok { + _spec.SetField(errorpassthroughrule.FieldName, field.TypeString, value) + } + if value, ok := _u.mutation.Enabled(); ok { + _spec.SetField(errorpassthroughrule.FieldEnabled, field.TypeBool, value) + } + if value, ok := _u.mutation.Priority(); ok { + _spec.SetField(errorpassthroughrule.FieldPriority, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedPriority(); ok { + _spec.AddField(errorpassthroughrule.FieldPriority, field.TypeInt, value) + } + if value, ok := _u.mutation.ErrorCodes(); ok { + _spec.SetField(errorpassthroughrule.FieldErrorCodes, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedErrorCodes(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, errorpassthroughrule.FieldErrorCodes, value) + }) + } + if _u.mutation.ErrorCodesCleared() { + _spec.ClearField(errorpassthroughrule.FieldErrorCodes, field.TypeJSON) + } + if value, ok := _u.mutation.Keywords(); ok { + _spec.SetField(errorpassthroughrule.FieldKeywords, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedKeywords(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, errorpassthroughrule.FieldKeywords, value) + }) + } + if _u.mutation.KeywordsCleared() { + _spec.ClearField(errorpassthroughrule.FieldKeywords, field.TypeJSON) + } + if value, ok := _u.mutation.MatchMode(); ok { + _spec.SetField(errorpassthroughrule.FieldMatchMode, field.TypeString, value) + } + if value, ok := _u.mutation.Platforms(); ok { + _spec.SetField(errorpassthroughrule.FieldPlatforms, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedPlatforms(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, errorpassthroughrule.FieldPlatforms, value) + }) + } + if _u.mutation.PlatformsCleared() { + _spec.ClearField(errorpassthroughrule.FieldPlatforms, field.TypeJSON) + } + if value, ok := _u.mutation.PassthroughCode(); ok { + _spec.SetField(errorpassthroughrule.FieldPassthroughCode, field.TypeBool, value) + } + if value, ok := _u.mutation.ResponseCode(); ok { + _spec.SetField(errorpassthroughrule.FieldResponseCode, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedResponseCode(); ok { + _spec.AddField(errorpassthroughrule.FieldResponseCode, field.TypeInt, value) + } + if _u.mutation.ResponseCodeCleared() { + _spec.ClearField(errorpassthroughrule.FieldResponseCode, field.TypeInt) + } + if value, ok := _u.mutation.PassthroughBody(); ok { + _spec.SetField(errorpassthroughrule.FieldPassthroughBody, field.TypeBool, value) + } + if value, ok := _u.mutation.CustomMessage(); ok { + _spec.SetField(errorpassthroughrule.FieldCustomMessage, field.TypeString, value) + } + if _u.mutation.CustomMessageCleared() { + _spec.ClearField(errorpassthroughrule.FieldCustomMessage, field.TypeString) + } + if value, ok := _u.mutation.Description(); ok { + _spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value) + } + if _u.mutation.DescriptionCleared() { + _spec.ClearField(errorpassthroughrule.FieldDescription, field.TypeString) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{errorpassthroughrule.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// ErrorPassthroughRuleUpdateOne is the builder for updating a single ErrorPassthroughRule entity. +type ErrorPassthroughRuleUpdateOne struct { + config + fields []string + hooks []Hook + mutation *ErrorPassthroughRuleMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetUpdatedAt(v time.Time) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetName sets the "name" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetName(v string) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetName(v) + return _u +} + +// SetNillableName sets the "name" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdateOne) SetNillableName(v *string) *ErrorPassthroughRuleUpdateOne { + if v != nil { + _u.SetName(*v) + } + return _u +} + +// SetEnabled sets the "enabled" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetEnabled(v bool) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetEnabled(v) + return _u +} + +// SetNillableEnabled sets the "enabled" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdateOne) SetNillableEnabled(v *bool) *ErrorPassthroughRuleUpdateOne { + if v != nil { + _u.SetEnabled(*v) + } + return _u +} + +// SetPriority sets the "priority" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetPriority(v int) *ErrorPassthroughRuleUpdateOne { + _u.mutation.ResetPriority() + _u.mutation.SetPriority(v) + return _u +} + +// SetNillablePriority sets the "priority" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdateOne) SetNillablePriority(v *int) *ErrorPassthroughRuleUpdateOne { + if v != nil { + _u.SetPriority(*v) + } + return _u +} + +// AddPriority adds value to the "priority" field. +func (_u *ErrorPassthroughRuleUpdateOne) AddPriority(v int) *ErrorPassthroughRuleUpdateOne { + _u.mutation.AddPriority(v) + return _u +} + +// SetErrorCodes sets the "error_codes" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetErrorCodes(v []int) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetErrorCodes(v) + return _u +} + +// AppendErrorCodes appends value to the "error_codes" field. +func (_u *ErrorPassthroughRuleUpdateOne) AppendErrorCodes(v []int) *ErrorPassthroughRuleUpdateOne { + _u.mutation.AppendErrorCodes(v) + return _u +} + +// ClearErrorCodes clears the value of the "error_codes" field. +func (_u *ErrorPassthroughRuleUpdateOne) ClearErrorCodes() *ErrorPassthroughRuleUpdateOne { + _u.mutation.ClearErrorCodes() + return _u +} + +// SetKeywords sets the "keywords" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetKeywords(v []string) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetKeywords(v) + return _u +} + +// AppendKeywords appends value to the "keywords" field. +func (_u *ErrorPassthroughRuleUpdateOne) AppendKeywords(v []string) *ErrorPassthroughRuleUpdateOne { + _u.mutation.AppendKeywords(v) + return _u +} + +// ClearKeywords clears the value of the "keywords" field. +func (_u *ErrorPassthroughRuleUpdateOne) ClearKeywords() *ErrorPassthroughRuleUpdateOne { + _u.mutation.ClearKeywords() + return _u +} + +// SetMatchMode sets the "match_mode" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetMatchMode(v string) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetMatchMode(v) + return _u +} + +// SetNillableMatchMode sets the "match_mode" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdateOne) SetNillableMatchMode(v *string) *ErrorPassthroughRuleUpdateOne { + if v != nil { + _u.SetMatchMode(*v) + } + return _u +} + +// SetPlatforms sets the "platforms" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetPlatforms(v []string) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetPlatforms(v) + return _u +} + +// AppendPlatforms appends value to the "platforms" field. +func (_u *ErrorPassthroughRuleUpdateOne) AppendPlatforms(v []string) *ErrorPassthroughRuleUpdateOne { + _u.mutation.AppendPlatforms(v) + return _u +} + +// ClearPlatforms clears the value of the "platforms" field. +func (_u *ErrorPassthroughRuleUpdateOne) ClearPlatforms() *ErrorPassthroughRuleUpdateOne { + _u.mutation.ClearPlatforms() + return _u +} + +// SetPassthroughCode sets the "passthrough_code" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetPassthroughCode(v bool) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetPassthroughCode(v) + return _u +} + +// SetNillablePassthroughCode sets the "passthrough_code" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdateOne) SetNillablePassthroughCode(v *bool) *ErrorPassthroughRuleUpdateOne { + if v != nil { + _u.SetPassthroughCode(*v) + } + return _u +} + +// SetResponseCode sets the "response_code" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetResponseCode(v int) *ErrorPassthroughRuleUpdateOne { + _u.mutation.ResetResponseCode() + _u.mutation.SetResponseCode(v) + return _u +} + +// SetNillableResponseCode sets the "response_code" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdateOne) SetNillableResponseCode(v *int) *ErrorPassthroughRuleUpdateOne { + if v != nil { + _u.SetResponseCode(*v) + } + return _u +} + +// AddResponseCode adds value to the "response_code" field. +func (_u *ErrorPassthroughRuleUpdateOne) AddResponseCode(v int) *ErrorPassthroughRuleUpdateOne { + _u.mutation.AddResponseCode(v) + return _u +} + +// ClearResponseCode clears the value of the "response_code" field. +func (_u *ErrorPassthroughRuleUpdateOne) ClearResponseCode() *ErrorPassthroughRuleUpdateOne { + _u.mutation.ClearResponseCode() + return _u +} + +// SetPassthroughBody sets the "passthrough_body" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetPassthroughBody(v bool) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetPassthroughBody(v) + return _u +} + +// SetNillablePassthroughBody sets the "passthrough_body" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdateOne) SetNillablePassthroughBody(v *bool) *ErrorPassthroughRuleUpdateOne { + if v != nil { + _u.SetPassthroughBody(*v) + } + return _u +} + +// SetCustomMessage sets the "custom_message" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetCustomMessage(v string) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetCustomMessage(v) + return _u +} + +// SetNillableCustomMessage sets the "custom_message" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdateOne) SetNillableCustomMessage(v *string) *ErrorPassthroughRuleUpdateOne { + if v != nil { + _u.SetCustomMessage(*v) + } + return _u +} + +// ClearCustomMessage clears the value of the "custom_message" field. +func (_u *ErrorPassthroughRuleUpdateOne) ClearCustomMessage() *ErrorPassthroughRuleUpdateOne { + _u.mutation.ClearCustomMessage() + return _u +} + +// SetDescription sets the "description" field. +func (_u *ErrorPassthroughRuleUpdateOne) SetDescription(v string) *ErrorPassthroughRuleUpdateOne { + _u.mutation.SetDescription(v) + return _u +} + +// SetNillableDescription sets the "description" field if the given value is not nil. +func (_u *ErrorPassthroughRuleUpdateOne) SetNillableDescription(v *string) *ErrorPassthroughRuleUpdateOne { + if v != nil { + _u.SetDescription(*v) + } + return _u +} + +// ClearDescription clears the value of the "description" field. +func (_u *ErrorPassthroughRuleUpdateOne) ClearDescription() *ErrorPassthroughRuleUpdateOne { + _u.mutation.ClearDescription() + return _u +} + +// Mutation returns the ErrorPassthroughRuleMutation object of the builder. +func (_u *ErrorPassthroughRuleUpdateOne) Mutation() *ErrorPassthroughRuleMutation { + return _u.mutation +} + +// Where appends a list predicates to the ErrorPassthroughRuleUpdate builder. +func (_u *ErrorPassthroughRuleUpdateOne) Where(ps ...predicate.ErrorPassthroughRule) *ErrorPassthroughRuleUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *ErrorPassthroughRuleUpdateOne) Select(field string, fields ...string) *ErrorPassthroughRuleUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated ErrorPassthroughRule entity. +func (_u *ErrorPassthroughRuleUpdateOne) Save(ctx context.Context) (*ErrorPassthroughRule, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *ErrorPassthroughRuleUpdateOne) SaveX(ctx context.Context) *ErrorPassthroughRule { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *ErrorPassthroughRuleUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *ErrorPassthroughRuleUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *ErrorPassthroughRuleUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := errorpassthroughrule.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *ErrorPassthroughRuleUpdateOne) check() error { + if v, ok := _u.mutation.Name(); ok { + if err := errorpassthroughrule.NameValidator(v); err != nil { + return &ValidationError{Name: "name", err: fmt.Errorf(`ent: validator failed for field "ErrorPassthroughRule.name": %w`, err)} + } + } + if v, ok := _u.mutation.MatchMode(); ok { + if err := errorpassthroughrule.MatchModeValidator(v); err != nil { + return &ValidationError{Name: "match_mode", err: fmt.Errorf(`ent: validator failed for field "ErrorPassthroughRule.match_mode": %w`, err)} + } + } + return nil +} + +func (_u *ErrorPassthroughRuleUpdateOne) sqlSave(ctx context.Context) (_node *ErrorPassthroughRule, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(errorpassthroughrule.Table, errorpassthroughrule.Columns, sqlgraph.NewFieldSpec(errorpassthroughrule.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "ErrorPassthroughRule.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, errorpassthroughrule.FieldID) + for _, f := range fields { + if !errorpassthroughrule.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != errorpassthroughrule.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(errorpassthroughrule.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.Name(); ok { + _spec.SetField(errorpassthroughrule.FieldName, field.TypeString, value) + } + if value, ok := _u.mutation.Enabled(); ok { + _spec.SetField(errorpassthroughrule.FieldEnabled, field.TypeBool, value) + } + if value, ok := _u.mutation.Priority(); ok { + _spec.SetField(errorpassthroughrule.FieldPriority, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedPriority(); ok { + _spec.AddField(errorpassthroughrule.FieldPriority, field.TypeInt, value) + } + if value, ok := _u.mutation.ErrorCodes(); ok { + _spec.SetField(errorpassthroughrule.FieldErrorCodes, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedErrorCodes(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, errorpassthroughrule.FieldErrorCodes, value) + }) + } + if _u.mutation.ErrorCodesCleared() { + _spec.ClearField(errorpassthroughrule.FieldErrorCodes, field.TypeJSON) + } + if value, ok := _u.mutation.Keywords(); ok { + _spec.SetField(errorpassthroughrule.FieldKeywords, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedKeywords(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, errorpassthroughrule.FieldKeywords, value) + }) + } + if _u.mutation.KeywordsCleared() { + _spec.ClearField(errorpassthroughrule.FieldKeywords, field.TypeJSON) + } + if value, ok := _u.mutation.MatchMode(); ok { + _spec.SetField(errorpassthroughrule.FieldMatchMode, field.TypeString, value) + } + if value, ok := _u.mutation.Platforms(); ok { + _spec.SetField(errorpassthroughrule.FieldPlatforms, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedPlatforms(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, errorpassthroughrule.FieldPlatforms, value) + }) + } + if _u.mutation.PlatformsCleared() { + _spec.ClearField(errorpassthroughrule.FieldPlatforms, field.TypeJSON) + } + if value, ok := _u.mutation.PassthroughCode(); ok { + _spec.SetField(errorpassthroughrule.FieldPassthroughCode, field.TypeBool, value) + } + if value, ok := _u.mutation.ResponseCode(); ok { + _spec.SetField(errorpassthroughrule.FieldResponseCode, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedResponseCode(); ok { + _spec.AddField(errorpassthroughrule.FieldResponseCode, field.TypeInt, value) + } + if _u.mutation.ResponseCodeCleared() { + _spec.ClearField(errorpassthroughrule.FieldResponseCode, field.TypeInt) + } + if value, ok := _u.mutation.PassthroughBody(); ok { + _spec.SetField(errorpassthroughrule.FieldPassthroughBody, field.TypeBool, value) + } + if value, ok := _u.mutation.CustomMessage(); ok { + _spec.SetField(errorpassthroughrule.FieldCustomMessage, field.TypeString, value) + } + if _u.mutation.CustomMessageCleared() { + _spec.ClearField(errorpassthroughrule.FieldCustomMessage, field.TypeString) + } + if value, ok := _u.mutation.Description(); ok { + _spec.SetField(errorpassthroughrule.FieldDescription, field.TypeString, value) + } + if _u.mutation.DescriptionCleared() { + _spec.ClearField(errorpassthroughrule.FieldDescription, field.TypeString) + } + _node = &ErrorPassthroughRule{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{errorpassthroughrule.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/ent/hook/hook.go b/backend/ent/hook/hook.go index 1e653c77..1b15685c 100644 --- a/backend/ent/hook/hook.go +++ b/backend/ent/hook/hook.go @@ -69,6 +69,18 @@ func (f AnnouncementReadFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.V return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.AnnouncementReadMutation", m) } +// The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary +// function as ErrorPassthroughRule mutator. +type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f ErrorPassthroughRuleFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.ErrorPassthroughRuleMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.ErrorPassthroughRuleMutation", m) +} + // The GroupFunc type is an adapter to allow the use of ordinary // function as Group mutator. type GroupFunc func(context.Context, *ent.GroupMutation) (ent.Value, error) diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go index a37be48f..8ee42db3 100644 --- a/backend/ent/intercept/intercept.go +++ b/backend/ent/intercept/intercept.go @@ -13,6 +13,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocode" @@ -220,6 +221,33 @@ func (f TraverseAnnouncementRead) Traverse(ctx context.Context, q ent.Query) err return fmt.Errorf("unexpected query type %T. expect *ent.AnnouncementReadQuery", q) } +// The ErrorPassthroughRuleFunc type is an adapter to allow the use of ordinary function as a Querier. +type ErrorPassthroughRuleFunc func(context.Context, *ent.ErrorPassthroughRuleQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f ErrorPassthroughRuleFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.ErrorPassthroughRuleQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.ErrorPassthroughRuleQuery", q) +} + +// The TraverseErrorPassthroughRule type is an adapter to allow the use of ordinary function as Traverser. +type TraverseErrorPassthroughRule func(context.Context, *ent.ErrorPassthroughRuleQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseErrorPassthroughRule) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseErrorPassthroughRule) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.ErrorPassthroughRuleQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.ErrorPassthroughRuleQuery", q) +} + // The GroupFunc type is an adapter to allow the use of ordinary function as a Querier. type GroupFunc func(context.Context, *ent.GroupQuery) (ent.Value, error) @@ -584,6 +612,8 @@ func NewQuery(q ent.Query) (Query, error) { return &query[*ent.AnnouncementQuery, predicate.Announcement, announcement.OrderOption]{typ: ent.TypeAnnouncement, tq: q}, nil case *ent.AnnouncementReadQuery: return &query[*ent.AnnouncementReadQuery, predicate.AnnouncementRead, announcementread.OrderOption]{typ: ent.TypeAnnouncementRead, tq: q}, nil + case *ent.ErrorPassthroughRuleQuery: + return &query[*ent.ErrorPassthroughRuleQuery, predicate.ErrorPassthroughRule, errorpassthroughrule.OrderOption]{typ: ent.TypeErrorPassthroughRule, tq: q}, nil case *ent.GroupQuery: return &query[*ent.GroupQuery, predicate.Group, group.OrderOption]{typ: ent.TypeGroup, tq: q}, nil case *ent.PromoCodeQuery: diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index dc91f6a5..f9e90d73 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -309,6 +309,42 @@ var ( }, }, } + // ErrorPassthroughRulesColumns holds the columns for the "error_passthrough_rules" table. + ErrorPassthroughRulesColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "name", Type: field.TypeString, Size: 100}, + {Name: "enabled", Type: field.TypeBool, Default: true}, + {Name: "priority", Type: field.TypeInt, Default: 0}, + {Name: "error_codes", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "keywords", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "match_mode", Type: field.TypeString, Size: 10, Default: "any"}, + {Name: "platforms", Type: field.TypeJSON, Nullable: true, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "passthrough_code", Type: field.TypeBool, Default: true}, + {Name: "response_code", Type: field.TypeInt, Nullable: true}, + {Name: "passthrough_body", Type: field.TypeBool, Default: true}, + {Name: "custom_message", Type: field.TypeString, Nullable: true, Size: 2147483647}, + {Name: "description", Type: field.TypeString, Nullable: true, Size: 2147483647}, + } + // ErrorPassthroughRulesTable holds the schema information for the "error_passthrough_rules" table. + ErrorPassthroughRulesTable = &schema.Table{ + Name: "error_passthrough_rules", + Columns: ErrorPassthroughRulesColumns, + PrimaryKey: []*schema.Column{ErrorPassthroughRulesColumns[0]}, + Indexes: []*schema.Index{ + { + Name: "errorpassthroughrule_enabled", + Unique: false, + Columns: []*schema.Column{ErrorPassthroughRulesColumns[4]}, + }, + { + Name: "errorpassthroughrule_priority", + Unique: false, + Columns: []*schema.Column{ErrorPassthroughRulesColumns[5]}, + }, + }, + } // GroupsColumns holds the columns for the "groups" table. GroupsColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt64, Increment: true}, @@ -950,6 +986,7 @@ var ( AccountGroupsTable, AnnouncementsTable, AnnouncementReadsTable, + ErrorPassthroughRulesTable, GroupsTable, PromoCodesTable, PromoCodeUsagesTable, @@ -989,6 +1026,9 @@ func init() { AnnouncementReadsTable.Annotation = &entsql.Annotation{ Table: "announcement_reads", } + ErrorPassthroughRulesTable.Annotation = &entsql.Annotation{ + Table: "error_passthrough_rules", + } GroupsTable.Annotation = &entsql.Annotation{ Table: "groups", } diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 77d208e1..5c182dea 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -17,6 +17,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/predicate" "github.com/Wei-Shaw/sub2api/ent/promocode" @@ -48,6 +49,7 @@ const ( TypeAccountGroup = "AccountGroup" TypeAnnouncement = "Announcement" TypeAnnouncementRead = "AnnouncementRead" + TypeErrorPassthroughRule = "ErrorPassthroughRule" TypeGroup = "Group" TypePromoCode = "PromoCode" TypePromoCodeUsage = "PromoCodeUsage" @@ -5750,6 +5752,1272 @@ func (m *AnnouncementReadMutation) ResetEdge(name string) error { return fmt.Errorf("unknown AnnouncementRead edge %s", name) } +// ErrorPassthroughRuleMutation represents an operation that mutates the ErrorPassthroughRule nodes in the graph. +type ErrorPassthroughRuleMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + name *string + enabled *bool + priority *int + addpriority *int + error_codes *[]int + appenderror_codes []int + keywords *[]string + appendkeywords []string + match_mode *string + platforms *[]string + appendplatforms []string + passthrough_code *bool + response_code *int + addresponse_code *int + passthrough_body *bool + custom_message *string + description *string + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*ErrorPassthroughRule, error) + predicates []predicate.ErrorPassthroughRule +} + +var _ ent.Mutation = (*ErrorPassthroughRuleMutation)(nil) + +// errorpassthroughruleOption allows management of the mutation configuration using functional options. +type errorpassthroughruleOption func(*ErrorPassthroughRuleMutation) + +// newErrorPassthroughRuleMutation creates new mutation for the ErrorPassthroughRule entity. +func newErrorPassthroughRuleMutation(c config, op Op, opts ...errorpassthroughruleOption) *ErrorPassthroughRuleMutation { + m := &ErrorPassthroughRuleMutation{ + config: c, + op: op, + typ: TypeErrorPassthroughRule, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withErrorPassthroughRuleID sets the ID field of the mutation. +func withErrorPassthroughRuleID(id int64) errorpassthroughruleOption { + return func(m *ErrorPassthroughRuleMutation) { + var ( + err error + once sync.Once + value *ErrorPassthroughRule + ) + m.oldValue = func(ctx context.Context) (*ErrorPassthroughRule, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().ErrorPassthroughRule.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withErrorPassthroughRule sets the old ErrorPassthroughRule of the mutation. +func withErrorPassthroughRule(node *ErrorPassthroughRule) errorpassthroughruleOption { + return func(m *ErrorPassthroughRuleMutation) { + m.oldValue = func(context.Context) (*ErrorPassthroughRule, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m ErrorPassthroughRuleMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m ErrorPassthroughRuleMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *ErrorPassthroughRuleMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *ErrorPassthroughRuleMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().ErrorPassthroughRule.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *ErrorPassthroughRuleMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *ErrorPassthroughRuleMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *ErrorPassthroughRuleMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *ErrorPassthroughRuleMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *ErrorPassthroughRuleMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *ErrorPassthroughRuleMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetName sets the "name" field. +func (m *ErrorPassthroughRuleMutation) SetName(s string) { + m.name = &s +} + +// Name returns the value of the "name" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Name() (r string, exists bool) { + v := m.name + if v == nil { + return + } + return *v, true +} + +// OldName returns the old "name" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldName(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldName is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldName requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldName: %w", err) + } + return oldValue.Name, nil +} + +// ResetName resets all changes to the "name" field. +func (m *ErrorPassthroughRuleMutation) ResetName() { + m.name = nil +} + +// SetEnabled sets the "enabled" field. +func (m *ErrorPassthroughRuleMutation) SetEnabled(b bool) { + m.enabled = &b +} + +// Enabled returns the value of the "enabled" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Enabled() (r bool, exists bool) { + v := m.enabled + if v == nil { + return + } + return *v, true +} + +// OldEnabled returns the old "enabled" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldEnabled(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldEnabled is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldEnabled requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldEnabled: %w", err) + } + return oldValue.Enabled, nil +} + +// ResetEnabled resets all changes to the "enabled" field. +func (m *ErrorPassthroughRuleMutation) ResetEnabled() { + m.enabled = nil +} + +// SetPriority sets the "priority" field. +func (m *ErrorPassthroughRuleMutation) SetPriority(i int) { + m.priority = &i + m.addpriority = nil +} + +// Priority returns the value of the "priority" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Priority() (r int, exists bool) { + v := m.priority + if v == nil { + return + } + return *v, true +} + +// OldPriority returns the old "priority" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldPriority(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPriority is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPriority requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPriority: %w", err) + } + return oldValue.Priority, nil +} + +// AddPriority adds i to the "priority" field. +func (m *ErrorPassthroughRuleMutation) AddPriority(i int) { + if m.addpriority != nil { + *m.addpriority += i + } else { + m.addpriority = &i + } +} + +// AddedPriority returns the value that was added to the "priority" field in this mutation. +func (m *ErrorPassthroughRuleMutation) AddedPriority() (r int, exists bool) { + v := m.addpriority + if v == nil { + return + } + return *v, true +} + +// ResetPriority resets all changes to the "priority" field. +func (m *ErrorPassthroughRuleMutation) ResetPriority() { + m.priority = nil + m.addpriority = nil +} + +// SetErrorCodes sets the "error_codes" field. +func (m *ErrorPassthroughRuleMutation) SetErrorCodes(i []int) { + m.error_codes = &i + m.appenderror_codes = nil +} + +// ErrorCodes returns the value of the "error_codes" field in the mutation. +func (m *ErrorPassthroughRuleMutation) ErrorCodes() (r []int, exists bool) { + v := m.error_codes + if v == nil { + return + } + return *v, true +} + +// OldErrorCodes returns the old "error_codes" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldErrorCodes(ctx context.Context) (v []int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldErrorCodes is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldErrorCodes requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldErrorCodes: %w", err) + } + return oldValue.ErrorCodes, nil +} + +// AppendErrorCodes adds i to the "error_codes" field. +func (m *ErrorPassthroughRuleMutation) AppendErrorCodes(i []int) { + m.appenderror_codes = append(m.appenderror_codes, i...) +} + +// AppendedErrorCodes returns the list of values that were appended to the "error_codes" field in this mutation. +func (m *ErrorPassthroughRuleMutation) AppendedErrorCodes() ([]int, bool) { + if len(m.appenderror_codes) == 0 { + return nil, false + } + return m.appenderror_codes, true +} + +// ClearErrorCodes clears the value of the "error_codes" field. +func (m *ErrorPassthroughRuleMutation) ClearErrorCodes() { + m.error_codes = nil + m.appenderror_codes = nil + m.clearedFields[errorpassthroughrule.FieldErrorCodes] = struct{}{} +} + +// ErrorCodesCleared returns if the "error_codes" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) ErrorCodesCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldErrorCodes] + return ok +} + +// ResetErrorCodes resets all changes to the "error_codes" field. +func (m *ErrorPassthroughRuleMutation) ResetErrorCodes() { + m.error_codes = nil + m.appenderror_codes = nil + delete(m.clearedFields, errorpassthroughrule.FieldErrorCodes) +} + +// SetKeywords sets the "keywords" field. +func (m *ErrorPassthroughRuleMutation) SetKeywords(s []string) { + m.keywords = &s + m.appendkeywords = nil +} + +// Keywords returns the value of the "keywords" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Keywords() (r []string, exists bool) { + v := m.keywords + if v == nil { + return + } + return *v, true +} + +// OldKeywords returns the old "keywords" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldKeywords(ctx context.Context) (v []string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldKeywords is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldKeywords requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldKeywords: %w", err) + } + return oldValue.Keywords, nil +} + +// AppendKeywords adds s to the "keywords" field. +func (m *ErrorPassthroughRuleMutation) AppendKeywords(s []string) { + m.appendkeywords = append(m.appendkeywords, s...) +} + +// AppendedKeywords returns the list of values that were appended to the "keywords" field in this mutation. +func (m *ErrorPassthroughRuleMutation) AppendedKeywords() ([]string, bool) { + if len(m.appendkeywords) == 0 { + return nil, false + } + return m.appendkeywords, true +} + +// ClearKeywords clears the value of the "keywords" field. +func (m *ErrorPassthroughRuleMutation) ClearKeywords() { + m.keywords = nil + m.appendkeywords = nil + m.clearedFields[errorpassthroughrule.FieldKeywords] = struct{}{} +} + +// KeywordsCleared returns if the "keywords" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) KeywordsCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldKeywords] + return ok +} + +// ResetKeywords resets all changes to the "keywords" field. +func (m *ErrorPassthroughRuleMutation) ResetKeywords() { + m.keywords = nil + m.appendkeywords = nil + delete(m.clearedFields, errorpassthroughrule.FieldKeywords) +} + +// SetMatchMode sets the "match_mode" field. +func (m *ErrorPassthroughRuleMutation) SetMatchMode(s string) { + m.match_mode = &s +} + +// MatchMode returns the value of the "match_mode" field in the mutation. +func (m *ErrorPassthroughRuleMutation) MatchMode() (r string, exists bool) { + v := m.match_mode + if v == nil { + return + } + return *v, true +} + +// OldMatchMode returns the old "match_mode" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldMatchMode(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldMatchMode is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldMatchMode requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldMatchMode: %w", err) + } + return oldValue.MatchMode, nil +} + +// ResetMatchMode resets all changes to the "match_mode" field. +func (m *ErrorPassthroughRuleMutation) ResetMatchMode() { + m.match_mode = nil +} + +// SetPlatforms sets the "platforms" field. +func (m *ErrorPassthroughRuleMutation) SetPlatforms(s []string) { + m.platforms = &s + m.appendplatforms = nil +} + +// Platforms returns the value of the "platforms" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Platforms() (r []string, exists bool) { + v := m.platforms + if v == nil { + return + } + return *v, true +} + +// OldPlatforms returns the old "platforms" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldPlatforms(ctx context.Context) (v []string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPlatforms is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPlatforms requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPlatforms: %w", err) + } + return oldValue.Platforms, nil +} + +// AppendPlatforms adds s to the "platforms" field. +func (m *ErrorPassthroughRuleMutation) AppendPlatforms(s []string) { + m.appendplatforms = append(m.appendplatforms, s...) +} + +// AppendedPlatforms returns the list of values that were appended to the "platforms" field in this mutation. +func (m *ErrorPassthroughRuleMutation) AppendedPlatforms() ([]string, bool) { + if len(m.appendplatforms) == 0 { + return nil, false + } + return m.appendplatforms, true +} + +// ClearPlatforms clears the value of the "platforms" field. +func (m *ErrorPassthroughRuleMutation) ClearPlatforms() { + m.platforms = nil + m.appendplatforms = nil + m.clearedFields[errorpassthroughrule.FieldPlatforms] = struct{}{} +} + +// PlatformsCleared returns if the "platforms" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) PlatformsCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldPlatforms] + return ok +} + +// ResetPlatforms resets all changes to the "platforms" field. +func (m *ErrorPassthroughRuleMutation) ResetPlatforms() { + m.platforms = nil + m.appendplatforms = nil + delete(m.clearedFields, errorpassthroughrule.FieldPlatforms) +} + +// SetPassthroughCode sets the "passthrough_code" field. +func (m *ErrorPassthroughRuleMutation) SetPassthroughCode(b bool) { + m.passthrough_code = &b +} + +// PassthroughCode returns the value of the "passthrough_code" field in the mutation. +func (m *ErrorPassthroughRuleMutation) PassthroughCode() (r bool, exists bool) { + v := m.passthrough_code + if v == nil { + return + } + return *v, true +} + +// OldPassthroughCode returns the old "passthrough_code" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldPassthroughCode(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPassthroughCode is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPassthroughCode requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPassthroughCode: %w", err) + } + return oldValue.PassthroughCode, nil +} + +// ResetPassthroughCode resets all changes to the "passthrough_code" field. +func (m *ErrorPassthroughRuleMutation) ResetPassthroughCode() { + m.passthrough_code = nil +} + +// SetResponseCode sets the "response_code" field. +func (m *ErrorPassthroughRuleMutation) SetResponseCode(i int) { + m.response_code = &i + m.addresponse_code = nil +} + +// ResponseCode returns the value of the "response_code" field in the mutation. +func (m *ErrorPassthroughRuleMutation) ResponseCode() (r int, exists bool) { + v := m.response_code + if v == nil { + return + } + return *v, true +} + +// OldResponseCode returns the old "response_code" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldResponseCode(ctx context.Context) (v *int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldResponseCode is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldResponseCode requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldResponseCode: %w", err) + } + return oldValue.ResponseCode, nil +} + +// AddResponseCode adds i to the "response_code" field. +func (m *ErrorPassthroughRuleMutation) AddResponseCode(i int) { + if m.addresponse_code != nil { + *m.addresponse_code += i + } else { + m.addresponse_code = &i + } +} + +// AddedResponseCode returns the value that was added to the "response_code" field in this mutation. +func (m *ErrorPassthroughRuleMutation) AddedResponseCode() (r int, exists bool) { + v := m.addresponse_code + if v == nil { + return + } + return *v, true +} + +// ClearResponseCode clears the value of the "response_code" field. +func (m *ErrorPassthroughRuleMutation) ClearResponseCode() { + m.response_code = nil + m.addresponse_code = nil + m.clearedFields[errorpassthroughrule.FieldResponseCode] = struct{}{} +} + +// ResponseCodeCleared returns if the "response_code" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) ResponseCodeCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldResponseCode] + return ok +} + +// ResetResponseCode resets all changes to the "response_code" field. +func (m *ErrorPassthroughRuleMutation) ResetResponseCode() { + m.response_code = nil + m.addresponse_code = nil + delete(m.clearedFields, errorpassthroughrule.FieldResponseCode) +} + +// SetPassthroughBody sets the "passthrough_body" field. +func (m *ErrorPassthroughRuleMutation) SetPassthroughBody(b bool) { + m.passthrough_body = &b +} + +// PassthroughBody returns the value of the "passthrough_body" field in the mutation. +func (m *ErrorPassthroughRuleMutation) PassthroughBody() (r bool, exists bool) { + v := m.passthrough_body + if v == nil { + return + } + return *v, true +} + +// OldPassthroughBody returns the old "passthrough_body" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldPassthroughBody(ctx context.Context) (v bool, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldPassthroughBody is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldPassthroughBody requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldPassthroughBody: %w", err) + } + return oldValue.PassthroughBody, nil +} + +// ResetPassthroughBody resets all changes to the "passthrough_body" field. +func (m *ErrorPassthroughRuleMutation) ResetPassthroughBody() { + m.passthrough_body = nil +} + +// SetCustomMessage sets the "custom_message" field. +func (m *ErrorPassthroughRuleMutation) SetCustomMessage(s string) { + m.custom_message = &s +} + +// CustomMessage returns the value of the "custom_message" field in the mutation. +func (m *ErrorPassthroughRuleMutation) CustomMessage() (r string, exists bool) { + v := m.custom_message + if v == nil { + return + } + return *v, true +} + +// OldCustomMessage returns the old "custom_message" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldCustomMessage(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCustomMessage is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCustomMessage requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCustomMessage: %w", err) + } + return oldValue.CustomMessage, nil +} + +// ClearCustomMessage clears the value of the "custom_message" field. +func (m *ErrorPassthroughRuleMutation) ClearCustomMessage() { + m.custom_message = nil + m.clearedFields[errorpassthroughrule.FieldCustomMessage] = struct{}{} +} + +// CustomMessageCleared returns if the "custom_message" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) CustomMessageCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldCustomMessage] + return ok +} + +// ResetCustomMessage resets all changes to the "custom_message" field. +func (m *ErrorPassthroughRuleMutation) ResetCustomMessage() { + m.custom_message = nil + delete(m.clearedFields, errorpassthroughrule.FieldCustomMessage) +} + +// SetDescription sets the "description" field. +func (m *ErrorPassthroughRuleMutation) SetDescription(s string) { + m.description = &s +} + +// Description returns the value of the "description" field in the mutation. +func (m *ErrorPassthroughRuleMutation) Description() (r string, exists bool) { + v := m.description + if v == nil { + return + } + return *v, true +} + +// OldDescription returns the old "description" field's value of the ErrorPassthroughRule entity. +// If the ErrorPassthroughRule object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *ErrorPassthroughRuleMutation) OldDescription(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDescription is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDescription requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDescription: %w", err) + } + return oldValue.Description, nil +} + +// ClearDescription clears the value of the "description" field. +func (m *ErrorPassthroughRuleMutation) ClearDescription() { + m.description = nil + m.clearedFields[errorpassthroughrule.FieldDescription] = struct{}{} +} + +// DescriptionCleared returns if the "description" field was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) DescriptionCleared() bool { + _, ok := m.clearedFields[errorpassthroughrule.FieldDescription] + return ok +} + +// ResetDescription resets all changes to the "description" field. +func (m *ErrorPassthroughRuleMutation) ResetDescription() { + m.description = nil + delete(m.clearedFields, errorpassthroughrule.FieldDescription) +} + +// Where appends a list predicates to the ErrorPassthroughRuleMutation builder. +func (m *ErrorPassthroughRuleMutation) Where(ps ...predicate.ErrorPassthroughRule) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the ErrorPassthroughRuleMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *ErrorPassthroughRuleMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.ErrorPassthroughRule, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *ErrorPassthroughRuleMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *ErrorPassthroughRuleMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (ErrorPassthroughRule). +func (m *ErrorPassthroughRuleMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *ErrorPassthroughRuleMutation) Fields() []string { + fields := make([]string, 0, 14) + if m.created_at != nil { + fields = append(fields, errorpassthroughrule.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, errorpassthroughrule.FieldUpdatedAt) + } + if m.name != nil { + fields = append(fields, errorpassthroughrule.FieldName) + } + if m.enabled != nil { + fields = append(fields, errorpassthroughrule.FieldEnabled) + } + if m.priority != nil { + fields = append(fields, errorpassthroughrule.FieldPriority) + } + if m.error_codes != nil { + fields = append(fields, errorpassthroughrule.FieldErrorCodes) + } + if m.keywords != nil { + fields = append(fields, errorpassthroughrule.FieldKeywords) + } + if m.match_mode != nil { + fields = append(fields, errorpassthroughrule.FieldMatchMode) + } + if m.platforms != nil { + fields = append(fields, errorpassthroughrule.FieldPlatforms) + } + if m.passthrough_code != nil { + fields = append(fields, errorpassthroughrule.FieldPassthroughCode) + } + if m.response_code != nil { + fields = append(fields, errorpassthroughrule.FieldResponseCode) + } + if m.passthrough_body != nil { + fields = append(fields, errorpassthroughrule.FieldPassthroughBody) + } + if m.custom_message != nil { + fields = append(fields, errorpassthroughrule.FieldCustomMessage) + } + if m.description != nil { + fields = append(fields, errorpassthroughrule.FieldDescription) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *ErrorPassthroughRuleMutation) Field(name string) (ent.Value, bool) { + switch name { + case errorpassthroughrule.FieldCreatedAt: + return m.CreatedAt() + case errorpassthroughrule.FieldUpdatedAt: + return m.UpdatedAt() + case errorpassthroughrule.FieldName: + return m.Name() + case errorpassthroughrule.FieldEnabled: + return m.Enabled() + case errorpassthroughrule.FieldPriority: + return m.Priority() + case errorpassthroughrule.FieldErrorCodes: + return m.ErrorCodes() + case errorpassthroughrule.FieldKeywords: + return m.Keywords() + case errorpassthroughrule.FieldMatchMode: + return m.MatchMode() + case errorpassthroughrule.FieldPlatforms: + return m.Platforms() + case errorpassthroughrule.FieldPassthroughCode: + return m.PassthroughCode() + case errorpassthroughrule.FieldResponseCode: + return m.ResponseCode() + case errorpassthroughrule.FieldPassthroughBody: + return m.PassthroughBody() + case errorpassthroughrule.FieldCustomMessage: + return m.CustomMessage() + case errorpassthroughrule.FieldDescription: + return m.Description() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *ErrorPassthroughRuleMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case errorpassthroughrule.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case errorpassthroughrule.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case errorpassthroughrule.FieldName: + return m.OldName(ctx) + case errorpassthroughrule.FieldEnabled: + return m.OldEnabled(ctx) + case errorpassthroughrule.FieldPriority: + return m.OldPriority(ctx) + case errorpassthroughrule.FieldErrorCodes: + return m.OldErrorCodes(ctx) + case errorpassthroughrule.FieldKeywords: + return m.OldKeywords(ctx) + case errorpassthroughrule.FieldMatchMode: + return m.OldMatchMode(ctx) + case errorpassthroughrule.FieldPlatforms: + return m.OldPlatforms(ctx) + case errorpassthroughrule.FieldPassthroughCode: + return m.OldPassthroughCode(ctx) + case errorpassthroughrule.FieldResponseCode: + return m.OldResponseCode(ctx) + case errorpassthroughrule.FieldPassthroughBody: + return m.OldPassthroughBody(ctx) + case errorpassthroughrule.FieldCustomMessage: + return m.OldCustomMessage(ctx) + case errorpassthroughrule.FieldDescription: + return m.OldDescription(ctx) + } + return nil, fmt.Errorf("unknown ErrorPassthroughRule field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *ErrorPassthroughRuleMutation) SetField(name string, value ent.Value) error { + switch name { + case errorpassthroughrule.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case errorpassthroughrule.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case errorpassthroughrule.FieldName: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetName(v) + return nil + case errorpassthroughrule.FieldEnabled: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetEnabled(v) + return nil + case errorpassthroughrule.FieldPriority: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPriority(v) + return nil + case errorpassthroughrule.FieldErrorCodes: + v, ok := value.([]int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetErrorCodes(v) + return nil + case errorpassthroughrule.FieldKeywords: + v, ok := value.([]string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetKeywords(v) + return nil + case errorpassthroughrule.FieldMatchMode: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetMatchMode(v) + return nil + case errorpassthroughrule.FieldPlatforms: + v, ok := value.([]string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPlatforms(v) + return nil + case errorpassthroughrule.FieldPassthroughCode: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPassthroughCode(v) + return nil + case errorpassthroughrule.FieldResponseCode: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetResponseCode(v) + return nil + case errorpassthroughrule.FieldPassthroughBody: + v, ok := value.(bool) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetPassthroughBody(v) + return nil + case errorpassthroughrule.FieldCustomMessage: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCustomMessage(v) + return nil + case errorpassthroughrule.FieldDescription: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDescription(v) + return nil + } + return fmt.Errorf("unknown ErrorPassthroughRule field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *ErrorPassthroughRuleMutation) AddedFields() []string { + var fields []string + if m.addpriority != nil { + fields = append(fields, errorpassthroughrule.FieldPriority) + } + if m.addresponse_code != nil { + fields = append(fields, errorpassthroughrule.FieldResponseCode) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *ErrorPassthroughRuleMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case errorpassthroughrule.FieldPriority: + return m.AddedPriority() + case errorpassthroughrule.FieldResponseCode: + return m.AddedResponseCode() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *ErrorPassthroughRuleMutation) AddField(name string, value ent.Value) error { + switch name { + case errorpassthroughrule.FieldPriority: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddPriority(v) + return nil + case errorpassthroughrule.FieldResponseCode: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddResponseCode(v) + return nil + } + return fmt.Errorf("unknown ErrorPassthroughRule numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *ErrorPassthroughRuleMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(errorpassthroughrule.FieldErrorCodes) { + fields = append(fields, errorpassthroughrule.FieldErrorCodes) + } + if m.FieldCleared(errorpassthroughrule.FieldKeywords) { + fields = append(fields, errorpassthroughrule.FieldKeywords) + } + if m.FieldCleared(errorpassthroughrule.FieldPlatforms) { + fields = append(fields, errorpassthroughrule.FieldPlatforms) + } + if m.FieldCleared(errorpassthroughrule.FieldResponseCode) { + fields = append(fields, errorpassthroughrule.FieldResponseCode) + } + if m.FieldCleared(errorpassthroughrule.FieldCustomMessage) { + fields = append(fields, errorpassthroughrule.FieldCustomMessage) + } + if m.FieldCleared(errorpassthroughrule.FieldDescription) { + fields = append(fields, errorpassthroughrule.FieldDescription) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *ErrorPassthroughRuleMutation) ClearField(name string) error { + switch name { + case errorpassthroughrule.FieldErrorCodes: + m.ClearErrorCodes() + return nil + case errorpassthroughrule.FieldKeywords: + m.ClearKeywords() + return nil + case errorpassthroughrule.FieldPlatforms: + m.ClearPlatforms() + return nil + case errorpassthroughrule.FieldResponseCode: + m.ClearResponseCode() + return nil + case errorpassthroughrule.FieldCustomMessage: + m.ClearCustomMessage() + return nil + case errorpassthroughrule.FieldDescription: + m.ClearDescription() + return nil + } + return fmt.Errorf("unknown ErrorPassthroughRule nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *ErrorPassthroughRuleMutation) ResetField(name string) error { + switch name { + case errorpassthroughrule.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case errorpassthroughrule.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case errorpassthroughrule.FieldName: + m.ResetName() + return nil + case errorpassthroughrule.FieldEnabled: + m.ResetEnabled() + return nil + case errorpassthroughrule.FieldPriority: + m.ResetPriority() + return nil + case errorpassthroughrule.FieldErrorCodes: + m.ResetErrorCodes() + return nil + case errorpassthroughrule.FieldKeywords: + m.ResetKeywords() + return nil + case errorpassthroughrule.FieldMatchMode: + m.ResetMatchMode() + return nil + case errorpassthroughrule.FieldPlatforms: + m.ResetPlatforms() + return nil + case errorpassthroughrule.FieldPassthroughCode: + m.ResetPassthroughCode() + return nil + case errorpassthroughrule.FieldResponseCode: + m.ResetResponseCode() + return nil + case errorpassthroughrule.FieldPassthroughBody: + m.ResetPassthroughBody() + return nil + case errorpassthroughrule.FieldCustomMessage: + m.ResetCustomMessage() + return nil + case errorpassthroughrule.FieldDescription: + m.ResetDescription() + return nil + } + return fmt.Errorf("unknown ErrorPassthroughRule field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *ErrorPassthroughRuleMutation) AddedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *ErrorPassthroughRuleMutation) AddedIDs(name string) []ent.Value { + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *ErrorPassthroughRuleMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *ErrorPassthroughRuleMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *ErrorPassthroughRuleMutation) EdgeCleared(name string) bool { + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *ErrorPassthroughRuleMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown ErrorPassthroughRule unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *ErrorPassthroughRuleMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown ErrorPassthroughRule edge %s", name) +} + // GroupMutation represents an operation that mutates the Group nodes in the graph. type GroupMutation struct { config diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go index 613c5913..c12955ef 100644 --- a/backend/ent/predicate/predicate.go +++ b/backend/ent/predicate/predicate.go @@ -21,6 +21,9 @@ type Announcement func(*sql.Selector) // AnnouncementRead is the predicate function for announcementread builders. type AnnouncementRead func(*sql.Selector) +// ErrorPassthroughRule is the predicate function for errorpassthroughrule builders. +type ErrorPassthroughRule func(*sql.Selector) + // Group is the predicate function for group builders. type Group func(*sql.Selector) diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index f1fea8cc..4b3c1a4f 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -10,6 +10,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/announcement" "github.com/Wei-Shaw/sub2api/ent/announcementread" "github.com/Wei-Shaw/sub2api/ent/apikey" + "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" "github.com/Wei-Shaw/sub2api/ent/group" "github.com/Wei-Shaw/sub2api/ent/promocode" "github.com/Wei-Shaw/sub2api/ent/promocodeusage" @@ -270,6 +271,61 @@ func init() { announcementreadDescCreatedAt := announcementreadFields[3].Descriptor() // announcementread.DefaultCreatedAt holds the default value on creation for the created_at field. announcementread.DefaultCreatedAt = announcementreadDescCreatedAt.Default.(func() time.Time) + errorpassthroughruleMixin := schema.ErrorPassthroughRule{}.Mixin() + errorpassthroughruleMixinFields0 := errorpassthroughruleMixin[0].Fields() + _ = errorpassthroughruleMixinFields0 + errorpassthroughruleFields := schema.ErrorPassthroughRule{}.Fields() + _ = errorpassthroughruleFields + // errorpassthroughruleDescCreatedAt is the schema descriptor for created_at field. + errorpassthroughruleDescCreatedAt := errorpassthroughruleMixinFields0[0].Descriptor() + // errorpassthroughrule.DefaultCreatedAt holds the default value on creation for the created_at field. + errorpassthroughrule.DefaultCreatedAt = errorpassthroughruleDescCreatedAt.Default.(func() time.Time) + // errorpassthroughruleDescUpdatedAt is the schema descriptor for updated_at field. + errorpassthroughruleDescUpdatedAt := errorpassthroughruleMixinFields0[1].Descriptor() + // errorpassthroughrule.DefaultUpdatedAt holds the default value on creation for the updated_at field. + errorpassthroughrule.DefaultUpdatedAt = errorpassthroughruleDescUpdatedAt.Default.(func() time.Time) + // errorpassthroughrule.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + errorpassthroughrule.UpdateDefaultUpdatedAt = errorpassthroughruleDescUpdatedAt.UpdateDefault.(func() time.Time) + // errorpassthroughruleDescName is the schema descriptor for name field. + errorpassthroughruleDescName := errorpassthroughruleFields[0].Descriptor() + // errorpassthroughrule.NameValidator is a validator for the "name" field. It is called by the builders before save. + errorpassthroughrule.NameValidator = func() func(string) error { + validators := errorpassthroughruleDescName.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(name string) error { + for _, fn := range fns { + if err := fn(name); err != nil { + return err + } + } + return nil + } + }() + // errorpassthroughruleDescEnabled is the schema descriptor for enabled field. + errorpassthroughruleDescEnabled := errorpassthroughruleFields[1].Descriptor() + // errorpassthroughrule.DefaultEnabled holds the default value on creation for the enabled field. + errorpassthroughrule.DefaultEnabled = errorpassthroughruleDescEnabled.Default.(bool) + // errorpassthroughruleDescPriority is the schema descriptor for priority field. + errorpassthroughruleDescPriority := errorpassthroughruleFields[2].Descriptor() + // errorpassthroughrule.DefaultPriority holds the default value on creation for the priority field. + errorpassthroughrule.DefaultPriority = errorpassthroughruleDescPriority.Default.(int) + // errorpassthroughruleDescMatchMode is the schema descriptor for match_mode field. + errorpassthroughruleDescMatchMode := errorpassthroughruleFields[5].Descriptor() + // errorpassthroughrule.DefaultMatchMode holds the default value on creation for the match_mode field. + errorpassthroughrule.DefaultMatchMode = errorpassthroughruleDescMatchMode.Default.(string) + // errorpassthroughrule.MatchModeValidator is a validator for the "match_mode" field. It is called by the builders before save. + errorpassthroughrule.MatchModeValidator = errorpassthroughruleDescMatchMode.Validators[0].(func(string) error) + // errorpassthroughruleDescPassthroughCode is the schema descriptor for passthrough_code field. + errorpassthroughruleDescPassthroughCode := errorpassthroughruleFields[7].Descriptor() + // errorpassthroughrule.DefaultPassthroughCode holds the default value on creation for the passthrough_code field. + errorpassthroughrule.DefaultPassthroughCode = errorpassthroughruleDescPassthroughCode.Default.(bool) + // errorpassthroughruleDescPassthroughBody is the schema descriptor for passthrough_body field. + errorpassthroughruleDescPassthroughBody := errorpassthroughruleFields[9].Descriptor() + // errorpassthroughrule.DefaultPassthroughBody holds the default value on creation for the passthrough_body field. + errorpassthroughrule.DefaultPassthroughBody = errorpassthroughruleDescPassthroughBody.Default.(bool) groupMixin := schema.Group{}.Mixin() groupMixinHooks1 := groupMixin[1].Hooks() group.Hooks[0] = groupMixinHooks1[0] diff --git a/backend/ent/schema/error_passthrough_rule.go b/backend/ent/schema/error_passthrough_rule.go new file mode 100644 index 00000000..4a861f38 --- /dev/null +++ b/backend/ent/schema/error_passthrough_rule.go @@ -0,0 +1,121 @@ +// Package schema 定义 Ent ORM 的数据库 schema。 +package schema + +import ( + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +// ErrorPassthroughRule 定义全局错误透传规则的 schema。 +// +// 错误透传规则用于控制上游错误如何返回给客户端: +// - 匹配条件:错误码 + 关键词组合 +// - 响应行为:透传原始信息 或 自定义错误信息 +// - 响应状态码:可指定返回给客户端的状态码 +// - 平台范围:规则适用的平台(Anthropic、OpenAI、Gemini、Antigravity) +type ErrorPassthroughRule struct { + ent.Schema +} + +// Annotations 返回 schema 的注解配置。 +func (ErrorPassthroughRule) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "error_passthrough_rules"}, + } +} + +// Mixin 返回该 schema 使用的混入组件。 +func (ErrorPassthroughRule) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +// Fields 定义错误透传规则实体的所有字段。 +func (ErrorPassthroughRule) Fields() []ent.Field { + return []ent.Field{ + // name: 规则名称,用于在界面中标识规则 + field.String("name"). + MaxLen(100). + NotEmpty(), + + // enabled: 是否启用该规则 + field.Bool("enabled"). + Default(true), + + // priority: 规则优先级,数值越小优先级越高 + // 匹配时按优先级顺序检查,命中第一个匹配的规则 + field.Int("priority"). + Default(0), + + // error_codes: 匹配的错误码列表(OR关系) + // 例如:[422, 400] 表示匹配 422 或 400 错误码 + field.JSON("error_codes", []int{}). + Optional(). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + + // keywords: 匹配的关键词列表(OR关系) + // 例如:["context limit", "model not supported"] + // 关键词匹配不区分大小写 + field.JSON("keywords", []string{}). + Optional(). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + + // match_mode: 匹配模式 + // - "any": 错误码匹配 OR 关键词匹配(任一条件满足即可) + // - "all": 错误码匹配 AND 关键词匹配(所有条件都必须满足) + field.String("match_mode"). + MaxLen(10). + Default("any"), + + // platforms: 适用平台列表 + // 例如:["anthropic", "openai", "gemini", "antigravity"] + // 空列表表示适用于所有平台 + field.JSON("platforms", []string{}). + Optional(). + SchemaType(map[string]string{dialect.Postgres: "jsonb"}), + + // passthrough_code: 是否透传上游原始状态码 + // true: 使用上游返回的状态码 + // false: 使用 response_code 指定的状态码 + field.Bool("passthrough_code"). + Default(true), + + // response_code: 自定义响应状态码 + // 当 passthrough_code=false 时使用此状态码 + field.Int("response_code"). + Optional(). + Nillable(), + + // passthrough_body: 是否透传上游原始错误信息 + // true: 使用上游返回的错误信息 + // false: 使用 custom_message 指定的错误信息 + field.Bool("passthrough_body"). + Default(true), + + // custom_message: 自定义错误信息 + // 当 passthrough_body=false 时使用此错误信息 + field.Text("custom_message"). + Optional(). + Nillable(), + + // description: 规则描述,用于说明规则的用途 + field.Text("description"). + Optional(). + Nillable(), + } +} + +// Indexes 定义数据库索引,优化查询性能。 +func (ErrorPassthroughRule) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("enabled"), // 筛选启用的规则 + index.Fields("priority"), // 按优先级排序 + } +} diff --git a/backend/ent/tx.go b/backend/ent/tx.go index 702bdf90..45d83428 100644 --- a/backend/ent/tx.go +++ b/backend/ent/tx.go @@ -24,6 +24,8 @@ type Tx struct { Announcement *AnnouncementClient // AnnouncementRead is the client for interacting with the AnnouncementRead builders. AnnouncementRead *AnnouncementReadClient + // ErrorPassthroughRule is the client for interacting with the ErrorPassthroughRule builders. + ErrorPassthroughRule *ErrorPassthroughRuleClient // Group is the client for interacting with the Group builders. Group *GroupClient // PromoCode is the client for interacting with the PromoCode builders. @@ -186,6 +188,7 @@ func (tx *Tx) init() { tx.AccountGroup = NewAccountGroupClient(tx.config) tx.Announcement = NewAnnouncementClient(tx.config) tx.AnnouncementRead = NewAnnouncementReadClient(tx.config) + tx.ErrorPassthroughRule = NewErrorPassthroughRuleClient(tx.config) tx.Group = NewGroupClient(tx.config) tx.PromoCode = NewPromoCodeClient(tx.config) tx.PromoCodeUsage = NewPromoCodeUsageClient(tx.config) diff --git a/backend/internal/handler/admin/error_passthrough_handler.go b/backend/internal/handler/admin/error_passthrough_handler.go new file mode 100644 index 00000000..c32db561 --- /dev/null +++ b/backend/internal/handler/admin/error_passthrough_handler.go @@ -0,0 +1,273 @@ +package admin + +import ( + "strconv" + + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" +) + +// ErrorPassthroughHandler 处理错误透传规则的 HTTP 请求 +type ErrorPassthroughHandler struct { + service *service.ErrorPassthroughService +} + +// NewErrorPassthroughHandler 创建错误透传规则处理器 +func NewErrorPassthroughHandler(service *service.ErrorPassthroughService) *ErrorPassthroughHandler { + return &ErrorPassthroughHandler{service: service} +} + +// CreateErrorPassthroughRuleRequest 创建规则请求 +type CreateErrorPassthroughRuleRequest struct { + Name string `json:"name" binding:"required"` + Enabled *bool `json:"enabled"` + Priority int `json:"priority"` + ErrorCodes []int `json:"error_codes"` + Keywords []string `json:"keywords"` + MatchMode string `json:"match_mode"` + Platforms []string `json:"platforms"` + PassthroughCode *bool `json:"passthrough_code"` + ResponseCode *int `json:"response_code"` + PassthroughBody *bool `json:"passthrough_body"` + CustomMessage *string `json:"custom_message"` + Description *string `json:"description"` +} + +// UpdateErrorPassthroughRuleRequest 更新规则请求(部分更新,所有字段可选) +type UpdateErrorPassthroughRuleRequest struct { + Name *string `json:"name"` + Enabled *bool `json:"enabled"` + Priority *int `json:"priority"` + ErrorCodes []int `json:"error_codes"` + Keywords []string `json:"keywords"` + MatchMode *string `json:"match_mode"` + Platforms []string `json:"platforms"` + PassthroughCode *bool `json:"passthrough_code"` + ResponseCode *int `json:"response_code"` + PassthroughBody *bool `json:"passthrough_body"` + CustomMessage *string `json:"custom_message"` + Description *string `json:"description"` +} + +// List 获取所有规则 +// GET /api/v1/admin/error-passthrough-rules +func (h *ErrorPassthroughHandler) List(c *gin.Context) { + rules, err := h.service.List(c.Request.Context()) + if err != nil { + response.ErrorFrom(c, err) + return + } + response.Success(c, rules) +} + +// GetByID 根据 ID 获取规则 +// GET /api/v1/admin/error-passthrough-rules/:id +func (h *ErrorPassthroughHandler) GetByID(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid rule ID") + return + } + + rule, err := h.service.GetByID(c.Request.Context(), id) + if err != nil { + response.ErrorFrom(c, err) + return + } + if rule == nil { + response.NotFound(c, "Rule not found") + return + } + + response.Success(c, rule) +} + +// Create 创建规则 +// POST /api/v1/admin/error-passthrough-rules +func (h *ErrorPassthroughHandler) Create(c *gin.Context) { + var req CreateErrorPassthroughRuleRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + rule := &model.ErrorPassthroughRule{ + Name: req.Name, + Priority: req.Priority, + ErrorCodes: req.ErrorCodes, + Keywords: req.Keywords, + Platforms: req.Platforms, + } + + // 设置默认值 + if req.Enabled != nil { + rule.Enabled = *req.Enabled + } else { + rule.Enabled = true + } + if req.MatchMode != "" { + rule.MatchMode = req.MatchMode + } else { + rule.MatchMode = model.MatchModeAny + } + if req.PassthroughCode != nil { + rule.PassthroughCode = *req.PassthroughCode + } else { + rule.PassthroughCode = true + } + if req.PassthroughBody != nil { + rule.PassthroughBody = *req.PassthroughBody + } else { + rule.PassthroughBody = true + } + rule.ResponseCode = req.ResponseCode + rule.CustomMessage = req.CustomMessage + rule.Description = req.Description + + // 确保切片不为 nil + if rule.ErrorCodes == nil { + rule.ErrorCodes = []int{} + } + if rule.Keywords == nil { + rule.Keywords = []string{} + } + if rule.Platforms == nil { + rule.Platforms = []string{} + } + + created, err := h.service.Create(c.Request.Context(), rule) + if err != nil { + if _, ok := err.(*model.ValidationError); ok { + response.BadRequest(c, err.Error()) + return + } + response.ErrorFrom(c, err) + return + } + + response.Success(c, created) +} + +// Update 更新规则(支持部分更新) +// PUT /api/v1/admin/error-passthrough-rules/:id +func (h *ErrorPassthroughHandler) Update(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid rule ID") + return + } + + var req UpdateErrorPassthroughRuleRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + // 先获取现有规则 + existing, err := h.service.GetByID(c.Request.Context(), id) + if err != nil { + response.ErrorFrom(c, err) + return + } + if existing == nil { + response.NotFound(c, "Rule not found") + return + } + + // 部分更新:只更新请求中提供的字段 + rule := &model.ErrorPassthroughRule{ + ID: id, + Name: existing.Name, + Enabled: existing.Enabled, + Priority: existing.Priority, + ErrorCodes: existing.ErrorCodes, + Keywords: existing.Keywords, + MatchMode: existing.MatchMode, + Platforms: existing.Platforms, + PassthroughCode: existing.PassthroughCode, + ResponseCode: existing.ResponseCode, + PassthroughBody: existing.PassthroughBody, + CustomMessage: existing.CustomMessage, + Description: existing.Description, + } + + // 应用请求中提供的更新 + if req.Name != nil { + rule.Name = *req.Name + } + if req.Enabled != nil { + rule.Enabled = *req.Enabled + } + if req.Priority != nil { + rule.Priority = *req.Priority + } + if req.ErrorCodes != nil { + rule.ErrorCodes = req.ErrorCodes + } + if req.Keywords != nil { + rule.Keywords = req.Keywords + } + if req.MatchMode != nil { + rule.MatchMode = *req.MatchMode + } + if req.Platforms != nil { + rule.Platforms = req.Platforms + } + if req.PassthroughCode != nil { + rule.PassthroughCode = *req.PassthroughCode + } + if req.ResponseCode != nil { + rule.ResponseCode = req.ResponseCode + } + if req.PassthroughBody != nil { + rule.PassthroughBody = *req.PassthroughBody + } + if req.CustomMessage != nil { + rule.CustomMessage = req.CustomMessage + } + if req.Description != nil { + rule.Description = req.Description + } + + // 确保切片不为 nil + if rule.ErrorCodes == nil { + rule.ErrorCodes = []int{} + } + if rule.Keywords == nil { + rule.Keywords = []string{} + } + if rule.Platforms == nil { + rule.Platforms = []string{} + } + + updated, err := h.service.Update(c.Request.Context(), rule) + if err != nil { + if _, ok := err.(*model.ValidationError); ok { + response.BadRequest(c, err.Error()) + return + } + response.ErrorFrom(c, err) + return + } + + response.Success(c, updated) +} + +// Delete 删除规则 +// DELETE /api/v1/admin/error-passthrough-rules/:id +func (h *ErrorPassthroughHandler) Delete(c *gin.Context) { + id, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid rule ID") + return + } + + if err := h.service.Delete(c.Request.Context(), id); err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"message": "Rule deleted successfully"}) +} diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 9aa6b72c..beaddbca 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -33,6 +33,7 @@ type GatewayHandler struct { billingCacheService *service.BillingCacheService usageService *service.UsageService apiKeyService *service.APIKeyService + errorPassthroughService *service.ErrorPassthroughService concurrencyHelper *ConcurrencyHelper maxAccountSwitches int maxAccountSwitchesGemini int @@ -48,6 +49,7 @@ func NewGatewayHandler( billingCacheService *service.BillingCacheService, usageService *service.UsageService, apiKeyService *service.APIKeyService, + errorPassthroughService *service.ErrorPassthroughService, cfg *config.Config, ) *GatewayHandler { pingInterval := time.Duration(0) @@ -70,6 +72,7 @@ func NewGatewayHandler( billingCacheService: billingCacheService, usageService: usageService, apiKeyService: apiKeyService, + errorPassthroughService: errorPassthroughService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), maxAccountSwitches: maxAccountSwitches, maxAccountSwitchesGemini: maxAccountSwitchesGemini, @@ -201,7 +204,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { maxAccountSwitches := h.maxAccountSwitchesGemini switchCount := 0 failedAccountIDs := make(map[int64]struct{}) - lastFailoverStatus := 0 + var lastFailoverErr *service.UpstreamFailoverError for { selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, reqModel, failedAccountIDs, "") // Gemini 不使用会话限制 @@ -210,7 +213,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) return } - h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + if lastFailoverErr != nil { + h.handleFailoverExhausted(c, lastFailoverErr, service.PlatformGemini, streamStarted) + } else { + h.handleFailoverExhaustedSimple(c, 502, streamStarted) + } return } account := selection.Account @@ -301,9 +308,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { failedAccountIDs[account.ID] = struct{}{} - lastFailoverStatus = failoverErr.StatusCode + lastFailoverErr = failoverErr if switchCount >= maxAccountSwitches { - h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + h.handleFailoverExhausted(c, failoverErr, service.PlatformGemini, streamStarted) return } switchCount++ @@ -352,7 +359,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { maxAccountSwitches := h.maxAccountSwitches switchCount := 0 failedAccountIDs := make(map[int64]struct{}) - lastFailoverStatus := 0 + var lastFailoverErr *service.UpstreamFailoverError retryWithFallback := false for { @@ -363,7 +370,11 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) return } - h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + if lastFailoverErr != nil { + h.handleFailoverExhausted(c, lastFailoverErr, platform, streamStarted) + } else { + h.handleFailoverExhaustedSimple(c, 502, streamStarted) + } return } account := selection.Account @@ -487,9 +498,9 @@ func (h *GatewayHandler) Messages(c *gin.Context) { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { failedAccountIDs[account.ID] = struct{}{} - lastFailoverStatus = failoverErr.StatusCode + lastFailoverErr = failoverErr if switchCount >= maxAccountSwitches { - h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + h.handleFailoverExhausted(c, failoverErr, account.Platform, streamStarted) return } switchCount++ @@ -755,7 +766,37 @@ func (h *GatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotT fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) } -func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) { +func (h *GatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, platform string, streamStarted bool) { + statusCode := failoverErr.StatusCode + responseBody := failoverErr.ResponseBody + + // 先检查透传规则 + if h.errorPassthroughService != nil && len(responseBody) > 0 { + if rule := h.errorPassthroughService.MatchRule(platform, statusCode, responseBody); rule != nil { + // 确定响应状态码 + respCode := statusCode + if !rule.PassthroughCode && rule.ResponseCode != nil { + respCode = *rule.ResponseCode + } + + // 确定响应消息 + msg := service.ExtractUpstreamErrorMessage(responseBody) + if !rule.PassthroughBody && rule.CustomMessage != nil { + msg = *rule.CustomMessage + } + + h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted) + return + } + } + + // 使用默认的错误映射 + status, errType, errMsg := h.mapUpstreamError(statusCode) + h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) +} + +// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况 +func (h *GatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) { status, errType, errMsg := h.mapUpstreamError(statusCode) h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) } diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index 787e3760..be634c0c 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -253,7 +253,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { maxAccountSwitches := h.maxAccountSwitchesGemini switchCount := 0 failedAccountIDs := make(map[int64]struct{}) - lastFailoverStatus := 0 + var lastFailoverErr *service.UpstreamFailoverError for { selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionKey, modelName, failedAccountIDs, "") // Gemini 不使用会话限制 @@ -262,7 +262,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { googleError(c, http.StatusServiceUnavailable, "No available Gemini accounts: "+err.Error()) return } - handleGeminiFailoverExhausted(c, lastFailoverStatus) + h.handleGeminiFailoverExhausted(c, lastFailoverErr) return } account := selection.Account @@ -353,11 +353,11 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { if errors.As(err, &failoverErr) { failedAccountIDs[account.ID] = struct{}{} if switchCount >= maxAccountSwitches { - lastFailoverStatus = failoverErr.StatusCode - handleGeminiFailoverExhausted(c, lastFailoverStatus) + lastFailoverErr = failoverErr + h.handleGeminiFailoverExhausted(c, lastFailoverErr) return } - lastFailoverStatus = failoverErr.StatusCode + lastFailoverErr = failoverErr switchCount++ log.Printf("Gemini account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) continue @@ -414,7 +414,36 @@ func parseGeminiModelAction(rest string) (model string, action string, err error return "", "", &pathParseError{"invalid model action path"} } -func handleGeminiFailoverExhausted(c *gin.Context, statusCode int) { +func (h *GatewayHandler) handleGeminiFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError) { + if failoverErr == nil { + googleError(c, http.StatusBadGateway, "Upstream request failed") + return + } + + statusCode := failoverErr.StatusCode + responseBody := failoverErr.ResponseBody + + // 先检查透传规则 + if h.errorPassthroughService != nil && len(responseBody) > 0 { + if rule := h.errorPassthroughService.MatchRule(service.PlatformGemini, statusCode, responseBody); rule != nil { + // 确定响应状态码 + respCode := statusCode + if !rule.PassthroughCode && rule.ResponseCode != nil { + respCode = *rule.ResponseCode + } + + // 确定响应消息 + msg := service.ExtractUpstreamErrorMessage(responseBody) + if !rule.PassthroughBody && rule.CustomMessage != nil { + msg = *rule.CustomMessage + } + + googleError(c, respCode, msg) + return + } + } + + // 使用默认的错误映射 status, message := mapGeminiUpstreamError(statusCode) googleError(c, status, message) } diff --git a/backend/internal/handler/handler.go b/backend/internal/handler/handler.go index b8f7d417..b2b12c0d 100644 --- a/backend/internal/handler/handler.go +++ b/backend/internal/handler/handler.go @@ -24,6 +24,7 @@ type AdminHandlers struct { Subscription *admin.SubscriptionHandler Usage *admin.UsageHandler UserAttribute *admin.UserAttributeHandler + ErrorPassthrough *admin.ErrorPassthroughHandler } // Handlers contains all HTTP handlers diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index a84679ae..1dcb163b 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -22,11 +22,12 @@ import ( // OpenAIGatewayHandler handles OpenAI API gateway requests type OpenAIGatewayHandler struct { - gatewayService *service.OpenAIGatewayService - billingCacheService *service.BillingCacheService - apiKeyService *service.APIKeyService - concurrencyHelper *ConcurrencyHelper - maxAccountSwitches int + gatewayService *service.OpenAIGatewayService + billingCacheService *service.BillingCacheService + apiKeyService *service.APIKeyService + errorPassthroughService *service.ErrorPassthroughService + concurrencyHelper *ConcurrencyHelper + maxAccountSwitches int } // NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler @@ -35,6 +36,7 @@ func NewOpenAIGatewayHandler( concurrencyService *service.ConcurrencyService, billingCacheService *service.BillingCacheService, apiKeyService *service.APIKeyService, + errorPassthroughService *service.ErrorPassthroughService, cfg *config.Config, ) *OpenAIGatewayHandler { pingInterval := time.Duration(0) @@ -46,11 +48,12 @@ func NewOpenAIGatewayHandler( } } return &OpenAIGatewayHandler{ - gatewayService: gatewayService, - billingCacheService: billingCacheService, - apiKeyService: apiKeyService, - concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), - maxAccountSwitches: maxAccountSwitches, + gatewayService: gatewayService, + billingCacheService: billingCacheService, + apiKeyService: apiKeyService, + errorPassthroughService: errorPassthroughService, + concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), + maxAccountSwitches: maxAccountSwitches, } } @@ -201,7 +204,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { maxAccountSwitches := h.maxAccountSwitches switchCount := 0 failedAccountIDs := make(map[int64]struct{}) - lastFailoverStatus := 0 + var lastFailoverErr *service.UpstreamFailoverError for { // Select account supporting the requested model @@ -213,7 +216,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts: "+err.Error(), streamStarted) return } - h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + if lastFailoverErr != nil { + h.handleFailoverExhausted(c, lastFailoverErr, streamStarted) + } else { + h.handleFailoverExhaustedSimple(c, 502, streamStarted) + } return } account := selection.Account @@ -278,12 +285,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { failedAccountIDs[account.ID] = struct{}{} + lastFailoverErr = failoverErr if switchCount >= maxAccountSwitches { - lastFailoverStatus = failoverErr.StatusCode - h.handleFailoverExhausted(c, lastFailoverStatus, streamStarted) + h.handleFailoverExhausted(c, failoverErr, streamStarted) return } - lastFailoverStatus = failoverErr.StatusCode switchCount++ log.Printf("Account %d: upstream error %d, switching account %d/%d", account.ID, failoverErr.StatusCode, switchCount, maxAccountSwitches) continue @@ -324,7 +330,37 @@ func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted) } -func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, statusCode int, streamStarted bool) { +func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, streamStarted bool) { + statusCode := failoverErr.StatusCode + responseBody := failoverErr.ResponseBody + + // 先检查透传规则 + if h.errorPassthroughService != nil && len(responseBody) > 0 { + if rule := h.errorPassthroughService.MatchRule("openai", statusCode, responseBody); rule != nil { + // 确定响应状态码 + respCode := statusCode + if !rule.PassthroughCode && rule.ResponseCode != nil { + respCode = *rule.ResponseCode + } + + // 确定响应消息 + msg := service.ExtractUpstreamErrorMessage(responseBody) + if !rule.PassthroughBody && rule.CustomMessage != nil { + msg = *rule.CustomMessage + } + + h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted) + return + } + } + + // 使用默认的错误映射 + status, errType, errMsg := h.mapUpstreamError(statusCode) + h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) +} + +// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况 +func (h *OpenAIGatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) { status, errType, errMsg := h.mapUpstreamError(statusCode) h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted) } diff --git a/backend/internal/handler/wire.go b/backend/internal/handler/wire.go index 48a3794b..7b62149c 100644 --- a/backend/internal/handler/wire.go +++ b/backend/internal/handler/wire.go @@ -27,6 +27,7 @@ func ProvideAdminHandlers( subscriptionHandler *admin.SubscriptionHandler, usageHandler *admin.UsageHandler, userAttributeHandler *admin.UserAttributeHandler, + errorPassthroughHandler *admin.ErrorPassthroughHandler, ) *AdminHandlers { return &AdminHandlers{ Dashboard: dashboardHandler, @@ -47,6 +48,7 @@ func ProvideAdminHandlers( Subscription: subscriptionHandler, Usage: usageHandler, UserAttribute: userAttributeHandler, + ErrorPassthrough: errorPassthroughHandler, } } @@ -125,6 +127,7 @@ var ProviderSet = wire.NewSet( admin.NewSubscriptionHandler, admin.NewUsageHandler, admin.NewUserAttributeHandler, + admin.NewErrorPassthroughHandler, // AdminHandlers and Handlers constructors ProvideAdminHandlers, diff --git a/backend/internal/model/error_passthrough_rule.go b/backend/internal/model/error_passthrough_rule.go new file mode 100644 index 00000000..d4fc16e3 --- /dev/null +++ b/backend/internal/model/error_passthrough_rule.go @@ -0,0 +1,74 @@ +// Package model 定义服务层使用的数据模型。 +package model + +import "time" + +// ErrorPassthroughRule 全局错误透传规则 +// 用于控制上游错误如何返回给客户端 +type ErrorPassthroughRule struct { + ID int64 `json:"id"` + Name string `json:"name"` // 规则名称 + Enabled bool `json:"enabled"` // 是否启用 + Priority int `json:"priority"` // 优先级(数字越小优先级越高) + ErrorCodes []int `json:"error_codes"` // 匹配的错误码列表(OR关系) + Keywords []string `json:"keywords"` // 匹配的关键词列表(OR关系) + MatchMode string `json:"match_mode"` // "any"(任一条件) 或 "all"(所有条件) + Platforms []string `json:"platforms"` // 适用平台列表 + PassthroughCode bool `json:"passthrough_code"` // 是否透传原始状态码 + ResponseCode *int `json:"response_code"` // 自定义状态码(passthrough_code=false 时使用) + PassthroughBody bool `json:"passthrough_body"` // 是否透传原始错误信息 + CustomMessage *string `json:"custom_message"` // 自定义错误信息(passthrough_body=false 时使用) + Description *string `json:"description"` // 规则描述 + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// MatchModeAny 表示任一条件匹配即可 +const MatchModeAny = "any" + +// MatchModeAll 表示所有条件都必须匹配 +const MatchModeAll = "all" + +// 支持的平台常量 +const ( + PlatformAnthropic = "anthropic" + PlatformOpenAI = "openai" + PlatformGemini = "gemini" + PlatformAntigravity = "antigravity" +) + +// AllPlatforms 返回所有支持的平台列表 +func AllPlatforms() []string { + return []string{PlatformAnthropic, PlatformOpenAI, PlatformGemini, PlatformAntigravity} +} + +// Validate 验证规则配置的有效性 +func (r *ErrorPassthroughRule) Validate() error { + if r.Name == "" { + return &ValidationError{Field: "name", Message: "name is required"} + } + if r.MatchMode != MatchModeAny && r.MatchMode != MatchModeAll { + return &ValidationError{Field: "match_mode", Message: "match_mode must be 'any' or 'all'"} + } + // 至少需要配置一个匹配条件(错误码或关键词) + if len(r.ErrorCodes) == 0 && len(r.Keywords) == 0 { + return &ValidationError{Field: "conditions", Message: "at least one error_code or keyword is required"} + } + if !r.PassthroughCode && (r.ResponseCode == nil || *r.ResponseCode <= 0) { + return &ValidationError{Field: "response_code", Message: "response_code is required when passthrough_code is false"} + } + if !r.PassthroughBody && (r.CustomMessage == nil || *r.CustomMessage == "") { + return &ValidationError{Field: "custom_message", Message: "custom_message is required when passthrough_body is false"} + } + return nil +} + +// ValidationError 表示验证错误 +type ValidationError struct { + Field string + Message string +} + +func (e *ValidationError) Error() string { + return e.Field + ": " + e.Message +} diff --git a/backend/internal/repository/error_passthrough_cache.go b/backend/internal/repository/error_passthrough_cache.go new file mode 100644 index 00000000..5584ffc8 --- /dev/null +++ b/backend/internal/repository/error_passthrough_cache.go @@ -0,0 +1,128 @@ +package repository + +import ( + "context" + "encoding/json" + "log" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/redis/go-redis/v9" +) + +const ( + errorPassthroughCacheKey = "error_passthrough_rules" + errorPassthroughPubSubKey = "error_passthrough_rules_updated" + errorPassthroughCacheTTL = 24 * time.Hour +) + +type errorPassthroughCache struct { + rdb *redis.Client + localCache []*model.ErrorPassthroughRule + localMu sync.RWMutex +} + +// NewErrorPassthroughCache 创建错误透传规则缓存 +func NewErrorPassthroughCache(rdb *redis.Client) service.ErrorPassthroughCache { + return &errorPassthroughCache{ + rdb: rdb, + } +} + +// Get 从缓存获取规则列表 +func (c *errorPassthroughCache) Get(ctx context.Context) ([]*model.ErrorPassthroughRule, bool) { + // 先检查本地缓存 + c.localMu.RLock() + if c.localCache != nil { + rules := c.localCache + c.localMu.RUnlock() + return rules, true + } + c.localMu.RUnlock() + + // 从 Redis 获取 + data, err := c.rdb.Get(ctx, errorPassthroughCacheKey).Bytes() + if err != nil { + if err != redis.Nil { + log.Printf("[ErrorPassthroughCache] Failed to get from Redis: %v", err) + } + return nil, false + } + + var rules []*model.ErrorPassthroughRule + if err := json.Unmarshal(data, &rules); err != nil { + log.Printf("[ErrorPassthroughCache] Failed to unmarshal rules: %v", err) + return nil, false + } + + // 更新本地缓存 + c.localMu.Lock() + c.localCache = rules + c.localMu.Unlock() + + return rules, true +} + +// Set 设置缓存 +func (c *errorPassthroughCache) Set(ctx context.Context, rules []*model.ErrorPassthroughRule) error { + data, err := json.Marshal(rules) + if err != nil { + return err + } + + if err := c.rdb.Set(ctx, errorPassthroughCacheKey, data, errorPassthroughCacheTTL).Err(); err != nil { + return err + } + + // 更新本地缓存 + c.localMu.Lock() + c.localCache = rules + c.localMu.Unlock() + + return nil +} + +// Invalidate 使缓存失效 +func (c *errorPassthroughCache) Invalidate(ctx context.Context) error { + // 清除本地缓存 + c.localMu.Lock() + c.localCache = nil + c.localMu.Unlock() + + // 清除 Redis 缓存 + return c.rdb.Del(ctx, errorPassthroughCacheKey).Err() +} + +// NotifyUpdate 通知其他实例刷新缓存 +func (c *errorPassthroughCache) NotifyUpdate(ctx context.Context) error { + return c.rdb.Publish(ctx, errorPassthroughPubSubKey, "refresh").Err() +} + +// SubscribeUpdates 订阅缓存更新通知 +func (c *errorPassthroughCache) SubscribeUpdates(ctx context.Context, handler func()) { + go func() { + sub := c.rdb.Subscribe(ctx, errorPassthroughPubSubKey) + defer func() { _ = sub.Close() }() + + ch := sub.Channel() + for { + select { + case <-ctx.Done(): + return + case msg := <-ch: + if msg == nil { + return + } + // 清除本地缓存,下次访问时会从 Redis 或数据库重新加载 + c.localMu.Lock() + c.localCache = nil + c.localMu.Unlock() + + // 调用处理函数 + handler() + } + } + }() +} diff --git a/backend/internal/repository/error_passthrough_repo.go b/backend/internal/repository/error_passthrough_repo.go new file mode 100644 index 00000000..a58ab60f --- /dev/null +++ b/backend/internal/repository/error_passthrough_repo.go @@ -0,0 +1,178 @@ +package repository + +import ( + "context" + + "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/errorpassthroughrule" + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +type errorPassthroughRepository struct { + client *ent.Client +} + +// NewErrorPassthroughRepository 创建错误透传规则仓库 +func NewErrorPassthroughRepository(client *ent.Client) service.ErrorPassthroughRepository { + return &errorPassthroughRepository{client: client} +} + +// List 获取所有规则 +func (r *errorPassthroughRepository) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) { + rules, err := r.client.ErrorPassthroughRule.Query(). + Order(ent.Asc(errorpassthroughrule.FieldPriority)). + All(ctx) + if err != nil { + return nil, err + } + + result := make([]*model.ErrorPassthroughRule, len(rules)) + for i, rule := range rules { + result[i] = r.toModel(rule) + } + return result, nil +} + +// GetByID 根据 ID 获取规则 +func (r *errorPassthroughRepository) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) { + rule, err := r.client.ErrorPassthroughRule.Get(ctx, id) + if err != nil { + if ent.IsNotFound(err) { + return nil, nil + } + return nil, err + } + return r.toModel(rule), nil +} + +// Create 创建规则 +func (r *errorPassthroughRepository) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) { + builder := r.client.ErrorPassthroughRule.Create(). + SetName(rule.Name). + SetEnabled(rule.Enabled). + SetPriority(rule.Priority). + SetMatchMode(rule.MatchMode). + SetPassthroughCode(rule.PassthroughCode). + SetPassthroughBody(rule.PassthroughBody) + + if len(rule.ErrorCodes) > 0 { + builder.SetErrorCodes(rule.ErrorCodes) + } + if len(rule.Keywords) > 0 { + builder.SetKeywords(rule.Keywords) + } + if len(rule.Platforms) > 0 { + builder.SetPlatforms(rule.Platforms) + } + if rule.ResponseCode != nil { + builder.SetResponseCode(*rule.ResponseCode) + } + if rule.CustomMessage != nil { + builder.SetCustomMessage(*rule.CustomMessage) + } + if rule.Description != nil { + builder.SetDescription(*rule.Description) + } + + created, err := builder.Save(ctx) + if err != nil { + return nil, err + } + return r.toModel(created), nil +} + +// Update 更新规则 +func (r *errorPassthroughRepository) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) { + builder := r.client.ErrorPassthroughRule.UpdateOneID(rule.ID). + SetName(rule.Name). + SetEnabled(rule.Enabled). + SetPriority(rule.Priority). + SetMatchMode(rule.MatchMode). + SetPassthroughCode(rule.PassthroughCode). + SetPassthroughBody(rule.PassthroughBody) + + // 处理可选字段 + if len(rule.ErrorCodes) > 0 { + builder.SetErrorCodes(rule.ErrorCodes) + } else { + builder.ClearErrorCodes() + } + if len(rule.Keywords) > 0 { + builder.SetKeywords(rule.Keywords) + } else { + builder.ClearKeywords() + } + if len(rule.Platforms) > 0 { + builder.SetPlatforms(rule.Platforms) + } else { + builder.ClearPlatforms() + } + if rule.ResponseCode != nil { + builder.SetResponseCode(*rule.ResponseCode) + } else { + builder.ClearResponseCode() + } + if rule.CustomMessage != nil { + builder.SetCustomMessage(*rule.CustomMessage) + } else { + builder.ClearCustomMessage() + } + if rule.Description != nil { + builder.SetDescription(*rule.Description) + } else { + builder.ClearDescription() + } + + updated, err := builder.Save(ctx) + if err != nil { + return nil, err + } + return r.toModel(updated), nil +} + +// Delete 删除规则 +func (r *errorPassthroughRepository) Delete(ctx context.Context, id int64) error { + return r.client.ErrorPassthroughRule.DeleteOneID(id).Exec(ctx) +} + +// toModel 将 Ent 实体转换为服务模型 +func (r *errorPassthroughRepository) toModel(e *ent.ErrorPassthroughRule) *model.ErrorPassthroughRule { + rule := &model.ErrorPassthroughRule{ + ID: int64(e.ID), + Name: e.Name, + Enabled: e.Enabled, + Priority: e.Priority, + ErrorCodes: e.ErrorCodes, + Keywords: e.Keywords, + MatchMode: e.MatchMode, + Platforms: e.Platforms, + PassthroughCode: e.PassthroughCode, + PassthroughBody: e.PassthroughBody, + CreatedAt: e.CreatedAt, + UpdatedAt: e.UpdatedAt, + } + + if e.ResponseCode != nil { + rule.ResponseCode = e.ResponseCode + } + if e.CustomMessage != nil { + rule.CustomMessage = e.CustomMessage + } + if e.Description != nil { + rule.Description = e.Description + } + + // 确保切片不为 nil + if rule.ErrorCodes == nil { + rule.ErrorCodes = []int{} + } + if rule.Keywords == nil { + rule.Keywords = []string{} + } + if rule.Platforms == nil { + rule.Platforms = []string{} + } + + return rule +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 5437de35..3aed9d9c 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -67,6 +67,7 @@ var ProviderSet = wire.NewSet( NewUserAttributeDefinitionRepository, NewUserAttributeValueRepository, NewUserGroupRateRepository, + NewErrorPassthroughRepository, // Cache implementations NewGatewayCache, @@ -87,6 +88,7 @@ var ProviderSet = wire.NewSet( NewProxyLatencyCache, NewTotpCache, NewRefreshTokenCache, + NewErrorPassthroughCache, // Encryptors NewAESEncryptor, diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index ca9d627e..a1c27b00 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -67,6 +67,9 @@ func RegisterAdminRoutes( // 用户属性管理 registerUserAttributeRoutes(admin, h) + + // 错误透传规则管理 + registerErrorPassthroughRoutes(admin, h) } } @@ -387,3 +390,14 @@ func registerUserAttributeRoutes(admin *gin.RouterGroup, h *handler.Handlers) { attrs.DELETE("/:id", h.Admin.UserAttribute.DeleteDefinition) } } + +func registerErrorPassthroughRoutes(admin *gin.RouterGroup, h *handler.Handlers) { + rules := admin.Group("/error-passthrough-rules") + { + rules.GET("", h.Admin.ErrorPassthrough.List) + rules.GET("/:id", h.Admin.ErrorPassthrough.GetByID) + rules.POST("", h.Admin.ErrorPassthrough.Create) + rules.PUT("/:id", h.Admin.ErrorPassthrough.Update) + rules.DELETE("/:id", h.Admin.ErrorPassthrough.Delete) + } +} diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index cf7e35fc..4ca32829 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -1106,7 +1106,7 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, Message: upstreamMsg, Detail: upstreamDetail, }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody) @@ -1779,6 +1779,7 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co // 处理错误响应 if resp.StatusCode >= 400 { respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + contentType := resp.Header.Get("Content-Type") // 尽早关闭原始响应体,释放连接;后续逻辑仍可能需要读取 body,因此用内存副本重新包装。 _ = resp.Body.Close() resp.Body = io.NopCloser(bytes.NewReader(respBody)) @@ -1849,10 +1850,8 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co Message: upstreamMsg, Detail: upstreamDetail, }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: unwrappedForOps} } - - contentType := resp.Header.Get("Content-Type") if contentType == "" { contentType = "application/json" } diff --git a/backend/internal/service/error_passthrough_service.go b/backend/internal/service/error_passthrough_service.go new file mode 100644 index 00000000..99dc70e3 --- /dev/null +++ b/backend/internal/service/error_passthrough_service.go @@ -0,0 +1,300 @@ +package service + +import ( + "context" + "log" + "sort" + "strings" + "sync" + + "github.com/Wei-Shaw/sub2api/internal/model" +) + +// ErrorPassthroughRepository 定义错误透传规则的数据访问接口 +type ErrorPassthroughRepository interface { + // List 获取所有规则 + List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) + // GetByID 根据 ID 获取规则 + GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) + // Create 创建规则 + Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) + // Update 更新规则 + Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) + // Delete 删除规则 + Delete(ctx context.Context, id int64) error +} + +// ErrorPassthroughCache 定义错误透传规则的缓存接口 +type ErrorPassthroughCache interface { + // Get 从缓存获取规则列表 + Get(ctx context.Context) ([]*model.ErrorPassthroughRule, bool) + // Set 设置缓存 + Set(ctx context.Context, rules []*model.ErrorPassthroughRule) error + // Invalidate 使缓存失效 + Invalidate(ctx context.Context) error + // NotifyUpdate 通知其他实例刷新缓存 + NotifyUpdate(ctx context.Context) error + // SubscribeUpdates 订阅缓存更新通知 + SubscribeUpdates(ctx context.Context, handler func()) +} + +// ErrorPassthroughService 错误透传规则服务 +type ErrorPassthroughService struct { + repo ErrorPassthroughRepository + cache ErrorPassthroughCache + + // 本地内存缓存,用于快速匹配 + localCache []*model.ErrorPassthroughRule + localCacheMu sync.RWMutex +} + +// NewErrorPassthroughService 创建错误透传规则服务 +func NewErrorPassthroughService( + repo ErrorPassthroughRepository, + cache ErrorPassthroughCache, +) *ErrorPassthroughService { + svc := &ErrorPassthroughService{ + repo: repo, + cache: cache, + } + + // 启动时加载规则到本地缓存 + ctx := context.Background() + if err := svc.refreshLocalCache(ctx); err != nil { + log.Printf("[ErrorPassthroughService] Failed to load rules on startup: %v", err) + } + + // 订阅缓存更新通知 + if cache != nil { + cache.SubscribeUpdates(ctx, func() { + if err := svc.refreshLocalCache(context.Background()); err != nil { + log.Printf("[ErrorPassthroughService] Failed to refresh cache on notification: %v", err) + } + }) + } + + return svc +} + +// List 获取所有规则 +func (s *ErrorPassthroughService) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) { + return s.repo.List(ctx) +} + +// GetByID 根据 ID 获取规则 +func (s *ErrorPassthroughService) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) { + return s.repo.GetByID(ctx, id) +} + +// Create 创建规则 +func (s *ErrorPassthroughService) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) { + if err := rule.Validate(); err != nil { + return nil, err + } + + created, err := s.repo.Create(ctx, rule) + if err != nil { + return nil, err + } + + // 刷新缓存 + s.invalidateAndNotify(ctx) + + return created, nil +} + +// Update 更新规则 +func (s *ErrorPassthroughService) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) { + if err := rule.Validate(); err != nil { + return nil, err + } + + updated, err := s.repo.Update(ctx, rule) + if err != nil { + return nil, err + } + + // 刷新缓存 + s.invalidateAndNotify(ctx) + + return updated, nil +} + +// Delete 删除规则 +func (s *ErrorPassthroughService) Delete(ctx context.Context, id int64) error { + if err := s.repo.Delete(ctx, id); err != nil { + return err + } + + // 刷新缓存 + s.invalidateAndNotify(ctx) + + return nil +} + +// MatchRule 匹配透传规则 +// 返回第一个匹配的规则,如果没有匹配则返回 nil +func (s *ErrorPassthroughService) MatchRule(platform string, statusCode int, body []byte) *model.ErrorPassthroughRule { + rules := s.getCachedRules() + if len(rules) == 0 { + return nil + } + + bodyStr := strings.ToLower(string(body)) + + for _, rule := range rules { + if !rule.Enabled { + continue + } + if !s.platformMatches(rule, platform) { + continue + } + if s.ruleMatches(rule, statusCode, bodyStr) { + return rule + } + } + + return nil +} + +// getCachedRules 获取缓存的规则列表(按优先级排序) +func (s *ErrorPassthroughService) getCachedRules() []*model.ErrorPassthroughRule { + s.localCacheMu.RLock() + rules := s.localCache + s.localCacheMu.RUnlock() + + if rules != nil { + return rules + } + + // 如果本地缓存为空,尝试刷新 + ctx := context.Background() + if err := s.refreshLocalCache(ctx); err != nil { + log.Printf("[ErrorPassthroughService] Failed to refresh cache: %v", err) + return nil + } + + s.localCacheMu.RLock() + defer s.localCacheMu.RUnlock() + return s.localCache +} + +// refreshLocalCache 刷新本地缓存 +func (s *ErrorPassthroughService) refreshLocalCache(ctx context.Context) error { + // 先尝试从 Redis 缓存获取 + if s.cache != nil { + if rules, ok := s.cache.Get(ctx); ok { + s.setLocalCache(rules) + return nil + } + } + + // 从数据库加载(repo.List 已按 priority 排序) + rules, err := s.repo.List(ctx) + if err != nil { + return err + } + + // 更新 Redis 缓存 + if s.cache != nil { + if err := s.cache.Set(ctx, rules); err != nil { + log.Printf("[ErrorPassthroughService] Failed to set cache: %v", err) + } + } + + // 更新本地缓存(setLocalCache 内部会确保排序) + s.setLocalCache(rules) + + return nil +} + +// setLocalCache 设置本地缓存 +func (s *ErrorPassthroughService) setLocalCache(rules []*model.ErrorPassthroughRule) { + // 按优先级排序 + sorted := make([]*model.ErrorPassthroughRule, len(rules)) + copy(sorted, rules) + sort.Slice(sorted, func(i, j int) bool { + return sorted[i].Priority < sorted[j].Priority + }) + + s.localCacheMu.Lock() + s.localCache = sorted + s.localCacheMu.Unlock() +} + +// invalidateAndNotify 使缓存失效并通知其他实例 +func (s *ErrorPassthroughService) invalidateAndNotify(ctx context.Context) { + // 刷新本地缓存 + if err := s.refreshLocalCache(ctx); err != nil { + log.Printf("[ErrorPassthroughService] Failed to refresh local cache: %v", err) + } + + // 通知其他实例 + if s.cache != nil { + if err := s.cache.NotifyUpdate(ctx); err != nil { + log.Printf("[ErrorPassthroughService] Failed to notify cache update: %v", err) + } + } +} + +// platformMatches 检查平台是否匹配 +func (s *ErrorPassthroughService) platformMatches(rule *model.ErrorPassthroughRule, platform string) bool { + // 如果没有配置平台限制,则匹配所有平台 + if len(rule.Platforms) == 0 { + return true + } + + platform = strings.ToLower(platform) + for _, p := range rule.Platforms { + if strings.ToLower(p) == platform { + return true + } + } + + return false +} + +// ruleMatches 检查规则是否匹配 +func (s *ErrorPassthroughService) ruleMatches(rule *model.ErrorPassthroughRule, statusCode int, bodyLower string) bool { + hasErrorCodes := len(rule.ErrorCodes) > 0 + hasKeywords := len(rule.Keywords) > 0 + + // 如果没有配置任何条件,不匹配 + if !hasErrorCodes && !hasKeywords { + return false + } + + codeMatch := !hasErrorCodes || s.containsInt(rule.ErrorCodes, statusCode) + keywordMatch := !hasKeywords || s.containsAnyKeyword(bodyLower, rule.Keywords) + + if rule.MatchMode == model.MatchModeAll { + // "all" 模式:所有配置的条件都必须满足 + return codeMatch && keywordMatch + } + + // "any" 模式:任一条件满足即可 + if hasErrorCodes && hasKeywords { + return codeMatch || keywordMatch + } + return codeMatch && keywordMatch +} + +// containsInt 检查切片是否包含指定整数 +func (s *ErrorPassthroughService) containsInt(slice []int, val int) bool { + for _, v := range slice { + if v == val { + return true + } + } + return false +} + +// containsAnyKeyword 检查字符串是否包含任一关键词(不区分大小写) +func (s *ErrorPassthroughService) containsAnyKeyword(bodyLower string, keywords []string) bool { + for _, kw := range keywords { + if strings.Contains(bodyLower, strings.ToLower(kw)) { + return true + } + } + return false +} diff --git a/backend/internal/service/error_passthrough_service_test.go b/backend/internal/service/error_passthrough_service_test.go new file mode 100644 index 00000000..205b4ec4 --- /dev/null +++ b/backend/internal/service/error_passthrough_service_test.go @@ -0,0 +1,755 @@ +//go:build unit + +package service + +import ( + "context" + "strings" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/model" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockErrorPassthroughRepo 用于测试的 mock repository +type mockErrorPassthroughRepo struct { + rules []*model.ErrorPassthroughRule +} + +func (m *mockErrorPassthroughRepo) List(ctx context.Context) ([]*model.ErrorPassthroughRule, error) { + return m.rules, nil +} + +func (m *mockErrorPassthroughRepo) GetByID(ctx context.Context, id int64) (*model.ErrorPassthroughRule, error) { + for _, r := range m.rules { + if r.ID == id { + return r, nil + } + } + return nil, nil +} + +func (m *mockErrorPassthroughRepo) Create(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) { + rule.ID = int64(len(m.rules) + 1) + m.rules = append(m.rules, rule) + return rule, nil +} + +func (m *mockErrorPassthroughRepo) Update(ctx context.Context, rule *model.ErrorPassthroughRule) (*model.ErrorPassthroughRule, error) { + for i, r := range m.rules { + if r.ID == rule.ID { + m.rules[i] = rule + return rule, nil + } + } + return rule, nil +} + +func (m *mockErrorPassthroughRepo) Delete(ctx context.Context, id int64) error { + for i, r := range m.rules { + if r.ID == id { + m.rules = append(m.rules[:i], m.rules[i+1:]...) + return nil + } + } + return nil +} + +// newTestService 创建测试用的服务实例 +func newTestService(rules []*model.ErrorPassthroughRule) *ErrorPassthroughService { + repo := &mockErrorPassthroughRepo{rules: rules} + svc := &ErrorPassthroughService{ + repo: repo, + cache: nil, // 不使用缓存 + } + // 直接设置本地缓存,避免调用 refreshLocalCache + svc.setLocalCache(rules) + return svc +} + +// ============================================================================= +// 测试 ruleMatches 核心匹配逻辑 +// ============================================================================= + +func TestRuleMatches_NoConditions(t *testing.T) { + // 没有配置任何条件时,不应该匹配 + svc := newTestService(nil) + rule := &model.ErrorPassthroughRule{ + Enabled: true, + ErrorCodes: []int{}, + Keywords: []string{}, + MatchMode: model.MatchModeAny, + } + + assert.False(t, svc.ruleMatches(rule, 422, "some error message"), + "没有配置条件时不应该匹配") +} + +func TestRuleMatches_OnlyErrorCodes_AnyMode(t *testing.T) { + svc := newTestService(nil) + rule := &model.ErrorPassthroughRule{ + Enabled: true, + ErrorCodes: []int{422, 400}, + Keywords: []string{}, + MatchMode: model.MatchModeAny, + } + + tests := []struct { + name string + statusCode int + body string + expected bool + }{ + {"状态码匹配 422", 422, "any message", true}, + {"状态码匹配 400", 400, "any message", true}, + {"状态码不匹配 500", 500, "any message", false}, + {"状态码不匹配 429", 429, "any message", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := svc.ruleMatches(rule, tt.statusCode, tt.body) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestRuleMatches_OnlyKeywords_AnyMode(t *testing.T) { + svc := newTestService(nil) + rule := &model.ErrorPassthroughRule{ + Enabled: true, + ErrorCodes: []int{}, + Keywords: []string{"context limit", "model not supported"}, + MatchMode: model.MatchModeAny, + } + + tests := []struct { + name string + statusCode int + body string + expected bool + }{ + {"关键词匹配 context limit", 500, "error: context limit reached", true}, + {"关键词匹配 model not supported", 400, "the model not supported here", true}, + {"关键词不匹配", 422, "some other error", false}, + // 注意:ruleMatches 接收的 body 参数应该是已经转换为小写的 + // 实际使用时,MatchRule 会先将 body 转换为小写再传给 ruleMatches + {"关键词大小写 - 输入已小写", 500, "context limit exceeded", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // 模拟 MatchRule 的行为:先转换为小写 + bodyLower := strings.ToLower(tt.body) + result := svc.ruleMatches(rule, tt.statusCode, bodyLower) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestRuleMatches_BothConditions_AnyMode(t *testing.T) { + // any 模式:错误码 OR 关键词 + svc := newTestService(nil) + rule := &model.ErrorPassthroughRule{ + Enabled: true, + ErrorCodes: []int{422, 400}, + Keywords: []string{"context limit"}, + MatchMode: model.MatchModeAny, + } + + tests := []struct { + name string + statusCode int + body string + expected bool + reason string + }{ + { + name: "状态码和关键词都匹配", + statusCode: 422, + body: "context limit reached", + expected: true, + reason: "both match", + }, + { + name: "只有状态码匹配", + statusCode: 422, + body: "some other error", + expected: true, + reason: "code matches, keyword doesn't - OR mode should match", + }, + { + name: "只有关键词匹配", + statusCode: 500, + body: "context limit exceeded", + expected: true, + reason: "keyword matches, code doesn't - OR mode should match", + }, + { + name: "都不匹配", + statusCode: 500, + body: "some other error", + expected: false, + reason: "neither matches", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := svc.ruleMatches(rule, tt.statusCode, tt.body) + assert.Equal(t, tt.expected, result, tt.reason) + }) + } +} + +func TestRuleMatches_BothConditions_AllMode(t *testing.T) { + // all 模式:错误码 AND 关键词 + svc := newTestService(nil) + rule := &model.ErrorPassthroughRule{ + Enabled: true, + ErrorCodes: []int{422, 400}, + Keywords: []string{"context limit"}, + MatchMode: model.MatchModeAll, + } + + tests := []struct { + name string + statusCode int + body string + expected bool + reason string + }{ + { + name: "状态码和关键词都匹配", + statusCode: 422, + body: "context limit reached", + expected: true, + reason: "both match - AND mode should match", + }, + { + name: "只有状态码匹配", + statusCode: 422, + body: "some other error", + expected: false, + reason: "code matches but keyword doesn't - AND mode should NOT match", + }, + { + name: "只有关键词匹配", + statusCode: 500, + body: "context limit exceeded", + expected: false, + reason: "keyword matches but code doesn't - AND mode should NOT match", + }, + { + name: "都不匹配", + statusCode: 500, + body: "some other error", + expected: false, + reason: "neither matches", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := svc.ruleMatches(rule, tt.statusCode, tt.body) + assert.Equal(t, tt.expected, result, tt.reason) + }) + } +} + +// ============================================================================= +// 测试 platformMatches 平台匹配逻辑 +// ============================================================================= + +func TestPlatformMatches(t *testing.T) { + svc := newTestService(nil) + + tests := []struct { + name string + rulePlatforms []string + requestPlatform string + expected bool + }{ + { + name: "空平台列表匹配所有", + rulePlatforms: []string{}, + requestPlatform: "anthropic", + expected: true, + }, + { + name: "nil平台列表匹配所有", + rulePlatforms: nil, + requestPlatform: "openai", + expected: true, + }, + { + name: "精确匹配 anthropic", + rulePlatforms: []string{"anthropic", "openai"}, + requestPlatform: "anthropic", + expected: true, + }, + { + name: "精确匹配 openai", + rulePlatforms: []string{"anthropic", "openai"}, + requestPlatform: "openai", + expected: true, + }, + { + name: "不匹配 gemini", + rulePlatforms: []string{"anthropic", "openai"}, + requestPlatform: "gemini", + expected: false, + }, + { + name: "大小写不敏感", + rulePlatforms: []string{"Anthropic", "OpenAI"}, + requestPlatform: "anthropic", + expected: true, + }, + { + name: "匹配 antigravity", + rulePlatforms: []string{"antigravity"}, + requestPlatform: "antigravity", + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rule := &model.ErrorPassthroughRule{ + Platforms: tt.rulePlatforms, + } + result := svc.platformMatches(rule, tt.requestPlatform) + assert.Equal(t, tt.expected, result) + }) + } +} + +// ============================================================================= +// 测试 MatchRule 完整匹配流程 +// ============================================================================= + +func TestMatchRule_Priority(t *testing.T) { + // 测试规则按优先级排序,优先级小的先匹配 + rules := []*model.ErrorPassthroughRule{ + { + ID: 1, + Name: "Low Priority", + Enabled: true, + Priority: 10, + ErrorCodes: []int{422}, + MatchMode: model.MatchModeAny, + }, + { + ID: 2, + Name: "High Priority", + Enabled: true, + Priority: 1, + ErrorCodes: []int{422}, + MatchMode: model.MatchModeAny, + }, + } + + svc := newTestService(rules) + matched := svc.MatchRule("anthropic", 422, []byte("error")) + + require.NotNil(t, matched) + assert.Equal(t, int64(2), matched.ID, "应该匹配优先级更高(数值更小)的规则") + assert.Equal(t, "High Priority", matched.Name) +} + +func TestMatchRule_DisabledRule(t *testing.T) { + rules := []*model.ErrorPassthroughRule{ + { + ID: 1, + Name: "Disabled Rule", + Enabled: false, + Priority: 1, + ErrorCodes: []int{422}, + MatchMode: model.MatchModeAny, + }, + { + ID: 2, + Name: "Enabled Rule", + Enabled: true, + Priority: 10, + ErrorCodes: []int{422}, + MatchMode: model.MatchModeAny, + }, + } + + svc := newTestService(rules) + matched := svc.MatchRule("anthropic", 422, []byte("error")) + + require.NotNil(t, matched) + assert.Equal(t, int64(2), matched.ID, "应该跳过禁用的规则") +} + +func TestMatchRule_PlatformFilter(t *testing.T) { + rules := []*model.ErrorPassthroughRule{ + { + ID: 1, + Name: "Anthropic Only", + Enabled: true, + Priority: 1, + ErrorCodes: []int{422}, + Platforms: []string{"anthropic"}, + MatchMode: model.MatchModeAny, + }, + { + ID: 2, + Name: "OpenAI Only", + Enabled: true, + Priority: 2, + ErrorCodes: []int{422}, + Platforms: []string{"openai"}, + MatchMode: model.MatchModeAny, + }, + { + ID: 3, + Name: "All Platforms", + Enabled: true, + Priority: 3, + ErrorCodes: []int{422}, + Platforms: []string{}, + MatchMode: model.MatchModeAny, + }, + } + + svc := newTestService(rules) + + t.Run("Anthropic 请求匹配 Anthropic 规则", func(t *testing.T) { + matched := svc.MatchRule("anthropic", 422, []byte("error")) + require.NotNil(t, matched) + assert.Equal(t, int64(1), matched.ID) + }) + + t.Run("OpenAI 请求匹配 OpenAI 规则", func(t *testing.T) { + matched := svc.MatchRule("openai", 422, []byte("error")) + require.NotNil(t, matched) + assert.Equal(t, int64(2), matched.ID) + }) + + t.Run("Gemini 请求匹配全平台规则", func(t *testing.T) { + matched := svc.MatchRule("gemini", 422, []byte("error")) + require.NotNil(t, matched) + assert.Equal(t, int64(3), matched.ID) + }) + + t.Run("Antigravity 请求匹配全平台规则", func(t *testing.T) { + matched := svc.MatchRule("antigravity", 422, []byte("error")) + require.NotNil(t, matched) + assert.Equal(t, int64(3), matched.ID) + }) +} + +func TestMatchRule_NoMatch(t *testing.T) { + rules := []*model.ErrorPassthroughRule{ + { + ID: 1, + Name: "Rule for 422", + Enabled: true, + Priority: 1, + ErrorCodes: []int{422}, + MatchMode: model.MatchModeAny, + }, + } + + svc := newTestService(rules) + matched := svc.MatchRule("anthropic", 500, []byte("error")) + + assert.Nil(t, matched, "不匹配任何规则时应返回 nil") +} + +func TestMatchRule_EmptyRules(t *testing.T) { + svc := newTestService([]*model.ErrorPassthroughRule{}) + matched := svc.MatchRule("anthropic", 422, []byte("error")) + + assert.Nil(t, matched, "没有规则时应返回 nil") +} + +func TestMatchRule_CaseInsensitiveKeyword(t *testing.T) { + rules := []*model.ErrorPassthroughRule{ + { + ID: 1, + Name: "Context Limit", + Enabled: true, + Priority: 1, + Keywords: []string{"Context Limit"}, + MatchMode: model.MatchModeAny, + }, + } + + svc := newTestService(rules) + + tests := []struct { + name string + body string + expected bool + }{ + {"完全匹配", "Context Limit reached", true}, + {"小写匹配", "context limit reached", true}, + {"大写匹配", "CONTEXT LIMIT REACHED", true}, + {"混合大小写", "ConTeXt LiMiT error", true}, + {"不匹配", "some other error", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + matched := svc.MatchRule("anthropic", 500, []byte(tt.body)) + if tt.expected { + assert.NotNil(t, matched) + } else { + assert.Nil(t, matched) + } + }) + } +} + +// ============================================================================= +// 测试真实场景 +// ============================================================================= + +func TestMatchRule_RealWorldScenario_ContextLimitPassthrough(t *testing.T) { + // 场景:上游返回 422 + "context limit has been reached",需要透传给客户端 + rules := []*model.ErrorPassthroughRule{ + { + ID: 1, + Name: "Context Limit Passthrough", + Enabled: true, + Priority: 1, + ErrorCodes: []int{422}, + Keywords: []string{"context limit"}, + MatchMode: model.MatchModeAll, // 必须同时满足 + Platforms: []string{"anthropic", "antigravity"}, + PassthroughCode: true, + PassthroughBody: true, + }, + } + + svc := newTestService(rules) + + // 测试 Anthropic 平台 + t.Run("Anthropic 422 with context limit", func(t *testing.T) { + body := []byte(`{"type":"error","error":{"type":"invalid_request","message":"The context limit has been reached"}}`) + matched := svc.MatchRule("anthropic", 422, body) + require.NotNil(t, matched) + assert.True(t, matched.PassthroughCode) + assert.True(t, matched.PassthroughBody) + }) + + // 测试 Antigravity 平台 + t.Run("Antigravity 422 with context limit", func(t *testing.T) { + body := []byte(`{"error":"context limit exceeded"}`) + matched := svc.MatchRule("antigravity", 422, body) + require.NotNil(t, matched) + }) + + // 测试 OpenAI 平台(不在规则的平台列表中) + t.Run("OpenAI should not match", func(t *testing.T) { + body := []byte(`{"error":"context limit exceeded"}`) + matched := svc.MatchRule("openai", 422, body) + assert.Nil(t, matched, "OpenAI 不在规则的平台列表中") + }) + + // 测试状态码不匹配 + t.Run("Wrong status code", func(t *testing.T) { + body := []byte(`{"error":"context limit exceeded"}`) + matched := svc.MatchRule("anthropic", 400, body) + assert.Nil(t, matched, "状态码不匹配") + }) + + // 测试关键词不匹配 + t.Run("Wrong keyword", func(t *testing.T) { + body := []byte(`{"error":"rate limit exceeded"}`) + matched := svc.MatchRule("anthropic", 422, body) + assert.Nil(t, matched, "关键词不匹配") + }) +} + +func TestMatchRule_RealWorldScenario_CustomErrorMessage(t *testing.T) { + // 场景:某些错误需要返回自定义消息,隐藏上游详细信息 + customMsg := "Service temporarily unavailable, please try again later" + responseCode := 503 + rules := []*model.ErrorPassthroughRule{ + { + ID: 1, + Name: "Hide Internal Errors", + Enabled: true, + Priority: 1, + ErrorCodes: []int{500, 502, 503}, + MatchMode: model.MatchModeAny, + PassthroughCode: false, + ResponseCode: &responseCode, + PassthroughBody: false, + CustomMessage: &customMsg, + }, + } + + svc := newTestService(rules) + + matched := svc.MatchRule("anthropic", 500, []byte("internal server error")) + require.NotNil(t, matched) + assert.False(t, matched.PassthroughCode) + assert.Equal(t, 503, *matched.ResponseCode) + assert.False(t, matched.PassthroughBody) + assert.Equal(t, customMsg, *matched.CustomMessage) +} + +// ============================================================================= +// 测试 model.Validate +// ============================================================================= + +func TestErrorPassthroughRule_Validate(t *testing.T) { + tests := []struct { + name string + rule *model.ErrorPassthroughRule + expectError bool + errorField string + }{ + { + name: "有效规则 - 透传模式(含错误码)", + rule: &model.ErrorPassthroughRule{ + Name: "Valid Rule", + MatchMode: model.MatchModeAny, + ErrorCodes: []int{422}, + PassthroughCode: true, + PassthroughBody: true, + }, + expectError: false, + }, + { + name: "有效规则 - 透传模式(含关键词)", + rule: &model.ErrorPassthroughRule{ + Name: "Valid Rule", + MatchMode: model.MatchModeAny, + Keywords: []string{"context limit"}, + PassthroughCode: true, + PassthroughBody: true, + }, + expectError: false, + }, + { + name: "有效规则 - 自定义响应", + rule: &model.ErrorPassthroughRule{ + Name: "Valid Rule", + MatchMode: model.MatchModeAll, + ErrorCodes: []int{500}, + Keywords: []string{"internal error"}, + PassthroughCode: false, + ResponseCode: testIntPtr(503), + PassthroughBody: false, + CustomMessage: testStrPtr("Custom error"), + }, + expectError: false, + }, + { + name: "缺少名称", + rule: &model.ErrorPassthroughRule{ + Name: "", + MatchMode: model.MatchModeAny, + ErrorCodes: []int{422}, + PassthroughCode: true, + PassthroughBody: true, + }, + expectError: true, + errorField: "name", + }, + { + name: "无效的匹配模式", + rule: &model.ErrorPassthroughRule{ + Name: "Invalid Mode", + MatchMode: "invalid", + ErrorCodes: []int{422}, + PassthroughCode: true, + PassthroughBody: true, + }, + expectError: true, + errorField: "match_mode", + }, + { + name: "缺少匹配条件(错误码和关键词都为空)", + rule: &model.ErrorPassthroughRule{ + Name: "No Conditions", + MatchMode: model.MatchModeAny, + ErrorCodes: []int{}, + Keywords: []string{}, + PassthroughCode: true, + PassthroughBody: true, + }, + expectError: true, + errorField: "conditions", + }, + { + name: "缺少匹配条件(nil切片)", + rule: &model.ErrorPassthroughRule{ + Name: "Nil Conditions", + MatchMode: model.MatchModeAny, + ErrorCodes: nil, + Keywords: nil, + PassthroughCode: true, + PassthroughBody: true, + }, + expectError: true, + errorField: "conditions", + }, + { + name: "自定义状态码但未提供值", + rule: &model.ErrorPassthroughRule{ + Name: "Missing Code", + MatchMode: model.MatchModeAny, + ErrorCodes: []int{422}, + PassthroughCode: false, + ResponseCode: nil, + PassthroughBody: true, + }, + expectError: true, + errorField: "response_code", + }, + { + name: "自定义消息但未提供值", + rule: &model.ErrorPassthroughRule{ + Name: "Missing Message", + MatchMode: model.MatchModeAny, + ErrorCodes: []int{422}, + PassthroughCode: true, + PassthroughBody: false, + CustomMessage: nil, + }, + expectError: true, + errorField: "custom_message", + }, + { + name: "自定义消息为空字符串", + rule: &model.ErrorPassthroughRule{ + Name: "Empty Message", + MatchMode: model.MatchModeAny, + ErrorCodes: []int{422}, + PassthroughCode: true, + PassthroughBody: false, + CustomMessage: testStrPtr(""), + }, + expectError: true, + errorField: "custom_message", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.rule.Validate() + if tt.expectError { + require.Error(t, err) + validationErr, ok := err.(*model.ValidationError) + require.True(t, ok, "应该返回 ValidationError") + assert.Equal(t, tt.errorField, validationErr.Field) + } else { + assert.NoError(t, err) + } + }) + } +} + +// Helper functions +func testIntPtr(i int) *int { return &i } +func testStrPtr(s string) *string { return &s } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 9036955a..9aecce22 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -370,7 +370,8 @@ type ForwardResult struct { // UpstreamFailoverError indicates an upstream error that should trigger account failover. type UpstreamFailoverError struct { - StatusCode int + StatusCode int + ResponseBody []byte // 上游响应体,用于错误透传规则匹配 } func (e *UpstreamFailoverError) Error() string { @@ -3284,7 +3285,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A return "" }(), }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } return s.handleRetryExhaustedError(ctx, resp, c, account) } @@ -3314,10 +3315,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A return "" }(), }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } - - // 处理错误响应(不可重试的错误) if resp.StatusCode >= 400 { // 可选:对部分 400 触发 failover(默认关闭以保持语义) if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 { @@ -3361,7 +3360,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A log.Printf("Account %d: 400 error, attempting failover", account.ID) } s.handleFailoverSideEffects(ctx, resp, account) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } } return s.handleErrorResponse(ctx, resp, c, account) @@ -3758,6 +3757,12 @@ func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool { return false } +// ExtractUpstreamErrorMessage 从上游响应体中提取错误消息 +// 支持 Claude 风格的错误格式:{"type":"error","error":{"type":"...","message":"..."}} +func ExtractUpstreamErrorMessage(body []byte) string { + return extractUpstreamErrorMessage(body) +} + func extractUpstreamErrorMessage(body []byte) string { // Claude 风格:{"type":"error","error":{"type":"...","message":"..."}} if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" { @@ -3825,7 +3830,7 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body) } if shouldDisable { - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body} } // 记录上游错误响应体摘要便于排障(可选:由配置控制;不回显到客户端) diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index bd322991..eecb88f6 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -864,7 +864,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex Message: upstreamMsg, Detail: upstreamDetail, }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) { upstreamReqID := resp.Header.Get(requestIDHeader) @@ -891,7 +891,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex Message: upstreamMsg, Detail: upstreamDetail, }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } upstreamReqID := resp.Header.Get(requestIDHeader) if upstreamReqID == "" { @@ -1301,7 +1301,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. Message: upstreamMsg, Detail: upstreamDetail, }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) { evBody := unwrapIfNeeded(isOAuth, respBody) @@ -1325,7 +1325,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin. Message: upstreamMsg, Detail: upstreamDetail, }) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: evBody} } respBody = unwrapIfNeeded(isOAuth, respBody) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 4658c694..564ffa4d 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -940,7 +940,7 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco }) s.handleFailoverSideEffects(ctx, resp, account) - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody} } return s.handleErrorResponse(ctx, resp, c, account) } @@ -1131,7 +1131,7 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht Detail: upstreamDetail, }) if shouldDisable { - return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode} + return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: body} } // Return appropriate error response diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 4b721bb6..05371022 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -274,4 +274,5 @@ var ProviderSet = wire.NewSet( NewUserAttributeService, NewUsageCache, NewTotpService, + NewErrorPassthroughService, ) diff --git a/backend/migrations/048_add_error_passthrough_rules.sql b/backend/migrations/048_add_error_passthrough_rules.sql new file mode 100644 index 00000000..bf2a9117 --- /dev/null +++ b/backend/migrations/048_add_error_passthrough_rules.sql @@ -0,0 +1,24 @@ +-- Error Passthrough Rules table +-- Allows administrators to configure how upstream errors are passed through to clients + +CREATE TABLE IF NOT EXISTS error_passthrough_rules ( + id BIGSERIAL PRIMARY KEY, + name VARCHAR(100) NOT NULL, + enabled BOOLEAN NOT NULL DEFAULT true, + priority INTEGER NOT NULL DEFAULT 0, + error_codes JSONB DEFAULT '[]', + keywords JSONB DEFAULT '[]', + match_mode VARCHAR(10) NOT NULL DEFAULT 'any', + platforms JSONB DEFAULT '[]', + passthrough_code BOOLEAN NOT NULL DEFAULT true, + response_code INTEGER, + passthrough_body BOOLEAN NOT NULL DEFAULT true, + custom_message TEXT, + description TEXT, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +-- Indexes for efficient queries +CREATE INDEX IF NOT EXISTS idx_error_passthrough_rules_enabled ON error_passthrough_rules (enabled); +CREATE INDEX IF NOT EXISTS idx_error_passthrough_rules_priority ON error_passthrough_rules (priority); diff --git a/frontend/src/api/admin/errorPassthrough.ts b/frontend/src/api/admin/errorPassthrough.ts new file mode 100644 index 00000000..4c545ad5 --- /dev/null +++ b/frontend/src/api/admin/errorPassthrough.ts @@ -0,0 +1,134 @@ +/** + * Admin Error Passthrough Rules API endpoints + * Handles error passthrough rule management for administrators + */ + +import { apiClient } from '../client' + +/** + * Error passthrough rule interface + */ +export interface ErrorPassthroughRule { + id: number + name: string + enabled: boolean + priority: number + error_codes: number[] + keywords: string[] + match_mode: 'any' | 'all' + platforms: string[] + passthrough_code: boolean + response_code: number | null + passthrough_body: boolean + custom_message: string | null + description: string | null + created_at: string + updated_at: string +} + +/** + * Create rule request + */ +export interface CreateRuleRequest { + name: string + enabled?: boolean + priority?: number + error_codes?: number[] + keywords?: string[] + match_mode?: 'any' | 'all' + platforms?: string[] + passthrough_code?: boolean + response_code?: number | null + passthrough_body?: boolean + custom_message?: string | null + description?: string | null +} + +/** + * Update rule request + */ +export interface UpdateRuleRequest { + name?: string + enabled?: boolean + priority?: number + error_codes?: number[] + keywords?: string[] + match_mode?: 'any' | 'all' + platforms?: string[] + passthrough_code?: boolean + response_code?: number | null + passthrough_body?: boolean + custom_message?: string | null + description?: string | null +} + +/** + * List all error passthrough rules + * @returns List of all rules sorted by priority + */ +export async function list(): Promise { + const { data } = await apiClient.get('/admin/error-passthrough-rules') + return data +} + +/** + * Get rule by ID + * @param id - Rule ID + * @returns Rule details + */ +export async function getById(id: number): Promise { + const { data } = await apiClient.get(`/admin/error-passthrough-rules/${id}`) + return data +} + +/** + * Create new rule + * @param ruleData - Rule data + * @returns Created rule + */ +export async function create(ruleData: CreateRuleRequest): Promise { + const { data } = await apiClient.post('/admin/error-passthrough-rules', ruleData) + return data +} + +/** + * Update rule + * @param id - Rule ID + * @param updates - Fields to update + * @returns Updated rule + */ +export async function update(id: number, updates: UpdateRuleRequest): Promise { + const { data } = await apiClient.put(`/admin/error-passthrough-rules/${id}`, updates) + return data +} + +/** + * Delete rule + * @param id - Rule ID + * @returns Success confirmation + */ +export async function deleteRule(id: number): Promise<{ message: string }> { + const { data } = await apiClient.delete<{ message: string }>(`/admin/error-passthrough-rules/${id}`) + return data +} + +/** + * Toggle rule enabled status + * @param id - Rule ID + * @param enabled - New enabled status + * @returns Updated rule + */ +export async function toggleEnabled(id: number, enabled: boolean): Promise { + return update(id, { enabled }) +} + +export const errorPassthroughAPI = { + list, + getById, + create, + update, + delete: deleteRule, + toggleEnabled +} + +export default errorPassthroughAPI diff --git a/frontend/src/api/admin/index.ts b/frontend/src/api/admin/index.ts index 9a8a4195..ffb9b179 100644 --- a/frontend/src/api/admin/index.ts +++ b/frontend/src/api/admin/index.ts @@ -19,6 +19,7 @@ import geminiAPI from './gemini' import antigravityAPI from './antigravity' import userAttributesAPI from './userAttributes' import opsAPI from './ops' +import errorPassthroughAPI from './errorPassthrough' /** * Unified admin API object for convenient access @@ -39,7 +40,8 @@ export const adminAPI = { gemini: geminiAPI, antigravity: antigravityAPI, userAttributes: userAttributesAPI, - ops: opsAPI + ops: opsAPI, + errorPassthrough: errorPassthroughAPI } export { @@ -58,10 +60,12 @@ export { geminiAPI, antigravityAPI, userAttributesAPI, - opsAPI + opsAPI, + errorPassthroughAPI } export default adminAPI // Re-export types used by components export type { BalanceHistoryItem } from './users' +export type { ErrorPassthroughRule, CreateRuleRequest, UpdateRuleRequest } from './errorPassthrough' diff --git a/frontend/src/components/admin/ErrorPassthroughRulesModal.vue b/frontend/src/components/admin/ErrorPassthroughRulesModal.vue new file mode 100644 index 00000000..b93319c5 --- /dev/null +++ b/frontend/src/components/admin/ErrorPassthroughRulesModal.vue @@ -0,0 +1,623 @@ +