From dc5d42addc6e59c89badd14557fe9cc406681f2c Mon Sep 17 00:00:00 2001 From: james-6-23 <1163476949@qq.com> Date: Thu, 23 Apr 2026 03:33:52 +0800 Subject: [PATCH] =?UTF-8?q?feat(rpm):=20RPM=20=E9=99=90=E6=B5=81=E6=A8=A1?= =?UTF-8?q?=E5=9D=97=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit P0: - rpm_override 嵌入 Auth Cache Snapshot,消除每请求 DB 查询 (snapshot v6→v7) - 429 RPM 响应返回 Retry-After 头(当前分钟剩余秒数) P1: - ClearAll 按钮直连 DELETE API,带 loading 防重复 - 新增 GET /admin/users/:id/rpm-status 管理员 RPM 用量查询端点 优化: - checkRPM 从级联互斥改为并行取最严,user.rpm_limit 作为全局硬上限始终生效 - Override/Group 变更后自动失效 auth cache - fail-open 语义不变,Redis 故障不阻塞业务 --- backend/cmd/server/wire_gen.go | 27 +- backend/cmd/server/wire_gen_test.go | 2 +- backend/ent/group.go | 13 +- backend/ent/group/group.go | 10 + backend/ent/group/where.go | 45 ++ backend/ent/group_create.go | 85 ++++ backend/ent/group_update.go | 54 +++ backend/ent/migrate/schema.go | 4 +- backend/ent/mutation.go | 178 ++++++- backend/ent/runtime/runtime.go | 8 + backend/ent/schema/group.go | 5 + backend/ent/schema/user.go | 4 + backend/ent/user.go | 13 +- backend/ent/user/user.go | 10 + backend/ent/user/where.go | 45 ++ backend/ent/user_create.go | 85 ++++ backend/ent/user_update.go | 54 +++ backend/go.mod | 1 + backend/go.sum | 10 + .../handler/admin/admin_service_stub_test.go | 19 + .../internal/handler/admin/group_handler.go | 51 ++ .../internal/handler/admin/setting_handler.go | 4 + .../internal/handler/admin/user_handler.go | 22 + backend/internal/handler/dto/mappers.go | 2 + backend/internal/handler/dto/settings.go | 1 + backend/internal/handler/dto/types.go | 6 + backend/internal/handler/gateway_handler.go | 34 +- .../gateway_handler_billing_error_test.go | 54 +++ .../gateway_handler_chat_completions.go | 6 +- .../handler/gateway_handler_responses.go | 6 +- ...eway_handler_warmup_intercept_unit_test.go | 2 +- .../internal/handler/gemini_v1beta_handler.go | 6 +- .../handler/openai_chat_completions.go | 6 +- .../handler/openai_gateway_handler.go | 10 +- backend/internal/handler/openai_images.go | 6 +- backend/internal/repository/api_key_repo.go | 4 + backend/internal/repository/group_repo.go | 6 +- .../repository/user_group_rate_repo.go | 225 +++++++-- backend/internal/repository/user_repo.go | 4 +- backend/internal/repository/user_rpm_cache.go | 108 +++++ backend/internal/repository/wire.go | 1 + backend/internal/server/api_contract_test.go | 6 +- backend/internal/server/routes/admin.go | 3 + backend/internal/service/admin_service.go | 160 ++++++- .../service/admin_service_group_rate_test.go | 57 ++- .../service/admin_service_group_test.go | 25 + .../service/admin_service_list_users_test.go | 12 + .../service/admin_service_rpm_status_test.go | 112 +++++ .../admin_service_update_user_rpm_test.go | 69 +++ .../internal/service/api_key_auth_cache.go | 10 + .../service/api_key_auth_cache_impl.go | 20 +- .../service/api_key_service_cache_test.go | 2 +- backend/internal/service/auth_service.go | 17 + .../internal/service/billing_cache_service.go | 106 ++++- .../service/billing_cache_service_rpm_test.go | 253 ++++++++++ ...billing_cache_service_singleflight_test.go | 2 +- .../service/billing_cache_service_test.go | 4 +- backend/internal/service/domain_constants.go | 7 +- backend/internal/service/group.go | 4 + backend/internal/service/setting_service.go | 18 + backend/internal/service/settings_view.go | 1 + backend/internal/service/user.go | 9 + backend/internal/service/user_group_rate.go | 54 ++- backend/internal/service/user_rpm_cache.go | 25 + backend/internal/service/wire.go | 22 +- .../migrations/125_add_group_rpm_limit.sql | 7 + backend/migrations/126_add_user_rpm_limit.sql | 7 + .../127_add_user_group_rpm_override.sql | 16 + frontend/src/api/admin/groups.ts | 64 ++- frontend/src/api/admin/settings.ts | 2 + .../admin/group/GroupRPMOverridesModal.vue | 434 ++++++++++++++++++ .../admin/group/GroupRateMultipliersModal.vue | 51 +- .../components/admin/user/UserCreateModal.vue | 16 +- .../components/admin/user/UserEditModal.vue | 18 +- frontend/src/i18n/locales/en.ts | 24 +- frontend/src/i18n/locales/zh.ts | 22 +- frontend/src/types/index.ts | 2 + frontend/src/views/admin/GroupsView.vue | 54 +++ frontend/src/views/admin/SettingsView.vue | 20 + 79 files changed, 2831 insertions(+), 140 deletions(-) create mode 100644 backend/internal/handler/gateway_handler_billing_error_test.go create mode 100644 backend/internal/repository/user_rpm_cache.go create mode 100644 backend/internal/service/admin_service_rpm_status_test.go create mode 100644 backend/internal/service/admin_service_update_user_rpm_test.go create mode 100644 backend/internal/service/billing_cache_service_rpm_test.go create mode 100644 backend/internal/service/user_rpm_cache.go create mode 100644 backend/migrations/125_add_group_rpm_limit.sql create mode 100644 backend/migrations/126_add_user_rpm_limit.sql create mode 100644 backend/migrations/127_add_user_group_rpm_override.sql create mode 100644 frontend/src/components/admin/group/GroupRPMOverridesModal.vue diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 724f01f2..407f3026 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -61,8 +61,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { billingCache := repository.NewBillingCache(redisClient) userSubscriptionRepository := repository.NewUserSubscriptionRepository(client) apiKeyRepository := repository.NewAPIKeyRepository(client, db) - billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, configConfig) + userRPMCache := repository.NewUserRPMCache(redisClient) userGroupRateRepository := repository.NewUserGroupRateRepository(db) + billingCacheService := service.ProvideBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, userRPMCache, userGroupRateRepository, configConfig) apiKeyCache := repository.NewAPIKeyCache(redisClient) apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) @@ -104,7 +105,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) privacyClientFactory := providePrivacyClientFactory() - adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory) + adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, userRPMCache, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) adminUserHandler := admin.NewUserHandler(adminService, concurrencyService) @@ -137,7 +138,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient) tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache) accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService) - oAuthRefreshAPI := service.NewOAuthRefreshAPI(accountRepository, geminiTokenCache) + oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache) geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI) gatewayCache := repository.NewGatewayCache(redisClient) schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db) @@ -184,6 +185,15 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) + encryptionKey, err := payment.ProvideEncryptionKey(configConfig) + if err != nil { + return nil, err + } + paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey) + registry := payment.ProvideRegistry() + defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey) + paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository) + settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService) opsHandler := admin.NewOpsHandler(opsService) updateCache := repository.NewUpdateCache(redisClient) gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig) @@ -211,16 +221,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { scheduledTestService := service.ProvideScheduledTestService(scheduledTestPlanRepository, scheduledTestResultRepository) scheduledTestHandler := admin.NewScheduledTestHandler(scheduledTestService) channelHandler := admin.NewChannelHandler(channelService, billingService) - registry := payment.ProvideRegistry() - encryptionKey, err := payment.ProvideEncryptionKey(configConfig) - if err != nil { - return nil, err - } - defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey) - paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey) - paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository) - settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService) - paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService) paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService) adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, paymentHandler) usageRecordWorkerPool := service.NewUsageRecordWorkerPool(configConfig) @@ -249,6 +249,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { accountExpiryService := service.ProvideAccountExpiryService(accountRepository) subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig) + paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService) v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService) application := &Application{ Server: httpServer, diff --git a/backend/cmd/server/wire_gen_test.go b/backend/cmd/server/wire_gen_test.go index a6e0551a..cb07862d 100644 --- a/backend/cmd/server/wire_gen_test.go +++ b/backend/cmd/server/wire_gen_test.go @@ -43,7 +43,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) { subscriptionExpirySvc := service.NewSubscriptionExpiryService(nil, time.Second) pricingSvc := service.NewPricingService(cfg, nil) emailQueueSvc := service.NewEmailQueueService(nil, 1) - billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, cfg) + billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg) idempotencyCleanupSvc := service.NewIdempotencyCleanupService(nil, cfg) schedulerSnapshotSvc := service.NewSchedulerSnapshotService(nil, nil, nil, nil, cfg) opsSystemLogSinkSvc := service.NewOpsSystemLogSink(nil) diff --git a/backend/ent/group.go b/backend/ent/group.go index f10b50c3..5d9ae2ed 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -79,6 +79,8 @@ type Group struct { DefaultMappedModel string `json:"default_mapped_model,omitempty"` // OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型 MessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"` + // 分组 RPM 上限,0 表示不限制;设置后接管该分组用户的限流 + RpmLimit int `json:"rpm_limit,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the GroupQuery when eager-loading is set. Edges GroupEdges `json:"edges"` @@ -191,7 +193,7 @@ func (*Group) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullBool) case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k: values[i] = new(sql.NullFloat64) - case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder: + case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder, group.FieldRpmLimit: values[i] = new(sql.NullInt64) case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType, group.FieldDefaultMappedModel: values[i] = new(sql.NullString) @@ -414,6 +416,12 @@ func (_m *Group) assignValues(columns []string, values []any) error { return fmt.Errorf("unmarshal field messages_dispatch_model_config: %w", err) } } + case group.FieldRpmLimit: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field rpm_limit", values[i]) + } else if value.Valid { + _m.RpmLimit = int(value.Int64) + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -599,6 +607,9 @@ func (_m *Group) String() string { builder.WriteString(", ") builder.WriteString("messages_dispatch_model_config=") builder.WriteString(fmt.Sprintf("%v", _m.MessagesDispatchModelConfig)) + builder.WriteString(", ") + builder.WriteString("rpm_limit=") + builder.WriteString(fmt.Sprintf("%v", _m.RpmLimit)) builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index b1371630..24bd9c13 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -76,6 +76,8 @@ const ( FieldDefaultMappedModel = "default_mapped_model" // FieldMessagesDispatchModelConfig holds the string denoting the messages_dispatch_model_config field in the database. FieldMessagesDispatchModelConfig = "messages_dispatch_model_config" + // FieldRpmLimit holds the string denoting the rpm_limit field in the database. + FieldRpmLimit = "rpm_limit" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. EdgeAPIKeys = "api_keys" // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. @@ -181,6 +183,7 @@ var Columns = []string{ FieldRequirePrivacySet, FieldDefaultMappedModel, FieldMessagesDispatchModelConfig, + FieldRpmLimit, } var ( @@ -258,6 +261,8 @@ var ( DefaultMappedModelValidator func(string) error // DefaultMessagesDispatchModelConfig holds the default value on creation for the "messages_dispatch_model_config" field. DefaultMessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig + // DefaultRpmLimit holds the default value on creation for the "rpm_limit" field. + DefaultRpmLimit int ) // OrderOption defines the ordering options for the Group queries. @@ -403,6 +408,11 @@ func ByDefaultMappedModel(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldDefaultMappedModel, opts...).ToFunc() } +// ByRpmLimit orders the results by the rpm_limit field. +func ByRpmLimit(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRpmLimit, opts...).ToFunc() +} + // ByAPIKeysCount orders the results by api_keys count. func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go index cba2ce5f..2814d130 100644 --- a/backend/ent/group/where.go +++ b/backend/ent/group/where.go @@ -190,6 +190,11 @@ func DefaultMappedModel(v string) predicate.Group { return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v)) } +// RpmLimit applies equality check predicate on the "rpm_limit" field. It's identical to RpmLimitEQ. +func RpmLimit(v int) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldRpmLimit, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Group { return predicate.Group(sql.FieldEQ(FieldCreatedAt, v)) @@ -1320,6 +1325,46 @@ func DefaultMappedModelContainsFold(v string) predicate.Group { return predicate.Group(sql.FieldContainsFold(FieldDefaultMappedModel, v)) } +// RpmLimitEQ applies the EQ predicate on the "rpm_limit" field. +func RpmLimitEQ(v int) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldRpmLimit, v)) +} + +// RpmLimitNEQ applies the NEQ predicate on the "rpm_limit" field. +func RpmLimitNEQ(v int) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldRpmLimit, v)) +} + +// RpmLimitIn applies the In predicate on the "rpm_limit" field. +func RpmLimitIn(vs ...int) predicate.Group { + return predicate.Group(sql.FieldIn(FieldRpmLimit, vs...)) +} + +// RpmLimitNotIn applies the NotIn predicate on the "rpm_limit" field. +func RpmLimitNotIn(vs ...int) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldRpmLimit, vs...)) +} + +// RpmLimitGT applies the GT predicate on the "rpm_limit" field. +func RpmLimitGT(v int) predicate.Group { + return predicate.Group(sql.FieldGT(FieldRpmLimit, v)) +} + +// RpmLimitGTE applies the GTE predicate on the "rpm_limit" field. +func RpmLimitGTE(v int) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldRpmLimit, v)) +} + +// RpmLimitLT applies the LT predicate on the "rpm_limit" field. +func RpmLimitLT(v int) predicate.Group { + return predicate.Group(sql.FieldLT(FieldRpmLimit, v)) +} + +// RpmLimitLTE applies the LTE predicate on the "rpm_limit" field. +func RpmLimitLTE(v int) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldRpmLimit, v)) +} + // HasAPIKeys applies the HasEdge predicate on the "api_keys" edge. func HasAPIKeys() predicate.Group { return predicate.Group(func(s *sql.Selector) { diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index f412fa40..20ea0a0f 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -425,6 +425,20 @@ func (_c *GroupCreate) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMe return _c } +// SetRpmLimit sets the "rpm_limit" field. +func (_c *GroupCreate) SetRpmLimit(v int) *GroupCreate { + _c.mutation.SetRpmLimit(v) + return _c +} + +// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil. +func (_c *GroupCreate) SetNillableRpmLimit(v *int) *GroupCreate { + if v != nil { + _c.SetRpmLimit(*v) + } + return _c +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate { _c.mutation.AddAPIKeyIDs(ids...) @@ -630,6 +644,10 @@ func (_c *GroupCreate) defaults() error { v := group.DefaultMessagesDispatchModelConfig _c.mutation.SetMessagesDispatchModelConfig(v) } + if _, ok := _c.mutation.RpmLimit(); !ok { + v := group.DefaultRpmLimit + _c.mutation.SetRpmLimit(v) + } return nil } @@ -717,6 +735,9 @@ func (_c *GroupCreate) check() error { if _, ok := _c.mutation.MessagesDispatchModelConfig(); !ok { return &ValidationError{Name: "messages_dispatch_model_config", err: errors.New(`ent: missing required field "Group.messages_dispatch_model_config"`)} } + if _, ok := _c.mutation.RpmLimit(); !ok { + return &ValidationError{Name: "rpm_limit", err: errors.New(`ent: missing required field "Group.rpm_limit"`)} + } return nil } @@ -864,6 +885,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value) _node.MessagesDispatchModelConfig = value } + if value, ok := _c.mutation.RpmLimit(); ok { + _spec.SetField(group.FieldRpmLimit, field.TypeInt, value) + _node.RpmLimit = value + } if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1500,6 +1525,24 @@ func (u *GroupUpsert) UpdateMessagesDispatchModelConfig() *GroupUpsert { return u } +// SetRpmLimit sets the "rpm_limit" field. +func (u *GroupUpsert) SetRpmLimit(v int) *GroupUpsert { + u.Set(group.FieldRpmLimit, v) + return u +} + +// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create. +func (u *GroupUpsert) UpdateRpmLimit() *GroupUpsert { + u.SetExcluded(group.FieldRpmLimit) + return u +} + +// AddRpmLimit adds v to the "rpm_limit" field. +func (u *GroupUpsert) AddRpmLimit(v int) *GroupUpsert { + u.Add(group.FieldRpmLimit, v) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -2105,6 +2148,27 @@ func (u *GroupUpsertOne) UpdateMessagesDispatchModelConfig() *GroupUpsertOne { }) } +// SetRpmLimit sets the "rpm_limit" field. +func (u *GroupUpsertOne) SetRpmLimit(v int) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetRpmLimit(v) + }) +} + +// AddRpmLimit adds v to the "rpm_limit" field. +func (u *GroupUpsertOne) AddRpmLimit(v int) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddRpmLimit(v) + }) +} + +// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateRpmLimit() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateRpmLimit() + }) +} + // Exec executes the query. func (u *GroupUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -2876,6 +2940,27 @@ func (u *GroupUpsertBulk) UpdateMessagesDispatchModelConfig() *GroupUpsertBulk { }) } +// SetRpmLimit sets the "rpm_limit" field. +func (u *GroupUpsertBulk) SetRpmLimit(v int) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetRpmLimit(v) + }) +} + +// AddRpmLimit adds v to the "rpm_limit" field. +func (u *GroupUpsertBulk) AddRpmLimit(v int) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddRpmLimit(v) + }) +} + +// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateRpmLimit() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateRpmLimit() + }) +} + // Exec executes the query. func (u *GroupUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index 7b6d6193..cc14f897 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -567,6 +567,27 @@ func (_u *GroupUpdate) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMe return _u } +// SetRpmLimit sets the "rpm_limit" field. +func (_u *GroupUpdate) SetRpmLimit(v int) *GroupUpdate { + _u.mutation.ResetRpmLimit() + _u.mutation.SetRpmLimit(v) + return _u +} + +// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableRpmLimit(v *int) *GroupUpdate { + if v != nil { + _u.SetRpmLimit(*v) + } + return _u +} + +// AddRpmLimit adds value to the "rpm_limit" field. +func (_u *GroupUpdate) AddRpmLimit(v int) *GroupUpdate { + _u.mutation.AddRpmLimit(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate { _u.mutation.AddAPIKeyIDs(ids...) @@ -1030,6 +1051,12 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok { _spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value) } + if value, ok := _u.mutation.RpmLimit(); ok { + _spec.SetField(group.FieldRpmLimit, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedRpmLimit(); ok { + _spec.AddField(group.FieldRpmLimit, field.TypeInt, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1875,6 +1902,27 @@ func (_u *GroupUpdateOne) SetNillableMessagesDispatchModelConfig(v *domain.OpenA return _u } +// SetRpmLimit sets the "rpm_limit" field. +func (_u *GroupUpdateOne) SetRpmLimit(v int) *GroupUpdateOne { + _u.mutation.ResetRpmLimit() + _u.mutation.SetRpmLimit(v) + return _u +} + +// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableRpmLimit(v *int) *GroupUpdateOne { + if v != nil { + _u.SetRpmLimit(*v) + } + return _u +} + +// AddRpmLimit adds value to the "rpm_limit" field. +func (_u *GroupUpdateOne) AddRpmLimit(v int) *GroupUpdateOne { + _u.mutation.AddRpmLimit(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne { _u.mutation.AddAPIKeyIDs(ids...) @@ -2368,6 +2416,12 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok { _spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value) } + if value, ok := _u.mutation.RpmLimit(); ok { + _spec.SetField(group.FieldRpmLimit, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedRpmLimit(); ok { + _spec.AddField(group.FieldRpmLimit, field.TypeInt, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 40b326a9..c3e43eed 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -491,6 +491,7 @@ var ( {Name: "require_privacy_set", Type: field.TypeBool, Default: false}, {Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""}, {Name: "messages_dispatch_model_config", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "rpm_limit", Type: field.TypeInt, Default: 0}, } // GroupsTable holds the schema information for the "groups" table. GroupsTable = &schema.Table{ @@ -1276,7 +1277,7 @@ var ( {Name: "totp_secret_encrypted", Type: field.TypeString, Nullable: true, SchemaType: map[string]string{"postgres": "text"}}, {Name: "totp_enabled", Type: field.TypeBool, Default: false}, {Name: "totp_enabled_at", Type: field.TypeTime, Nullable: true}, - {Name: "signup_source", Type: field.TypeString, Size: 20, Default: "email"}, + {Name: "signup_source", Type: field.TypeString, Default: "email"}, {Name: "last_login_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "last_active_at", Type: field.TypeTime, Nullable: true, SchemaType: map[string]string{"postgres": "timestamptz"}}, {Name: "balance_notify_enabled", Type: field.TypeBool, Default: true}, @@ -1284,6 +1285,7 @@ var ( {Name: "balance_notify_threshold", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "balance_notify_extra_emails", Type: field.TypeString, Default: "[]", SchemaType: map[string]string{"postgres": "text"}}, {Name: "total_recharged", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "rpm_limit", Type: field.TypeInt, Default: 0}, } // UsersTable holds the schema information for the "users" table. UsersTable = &schema.Table{ diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index ec4a4070..80c845c3 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -10102,6 +10102,8 @@ type GroupMutation struct { require_privacy_set *bool default_mapped_model *string messages_dispatch_model_config *domain.OpenAIMessagesDispatchModelConfig + rpm_limit *int + addrpm_limit *int clearedFields map[string]struct{} api_keys map[int64]struct{} removedapi_keys map[int64]struct{} @@ -11690,6 +11692,62 @@ func (m *GroupMutation) ResetMessagesDispatchModelConfig() { m.messages_dispatch_model_config = nil } +// SetRpmLimit sets the "rpm_limit" field. +func (m *GroupMutation) SetRpmLimit(i int) { + m.rpm_limit = &i + m.addrpm_limit = nil +} + +// RpmLimit returns the value of the "rpm_limit" field in the mutation. +func (m *GroupMutation) RpmLimit() (r int, exists bool) { + v := m.rpm_limit + if v == nil { + return + } + return *v, true +} + +// OldRpmLimit returns the old "rpm_limit" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldRpmLimit(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRpmLimit is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRpmLimit requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRpmLimit: %w", err) + } + return oldValue.RpmLimit, nil +} + +// AddRpmLimit adds i to the "rpm_limit" field. +func (m *GroupMutation) AddRpmLimit(i int) { + if m.addrpm_limit != nil { + *m.addrpm_limit += i + } else { + m.addrpm_limit = &i + } +} + +// AddedRpmLimit returns the value that was added to the "rpm_limit" field in this mutation. +func (m *GroupMutation) AddedRpmLimit() (r int, exists bool) { + v := m.addrpm_limit + if v == nil { + return + } + return *v, true +} + +// ResetRpmLimit resets all changes to the "rpm_limit" field. +func (m *GroupMutation) ResetRpmLimit() { + m.rpm_limit = nil + m.addrpm_limit = nil +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) { if m.api_keys == nil { @@ -12048,7 +12106,7 @@ func (m *GroupMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 30) + fields := make([]string, 0, 31) if m.created_at != nil { fields = append(fields, group.FieldCreatedAt) } @@ -12139,6 +12197,9 @@ func (m *GroupMutation) Fields() []string { if m.messages_dispatch_model_config != nil { fields = append(fields, group.FieldMessagesDispatchModelConfig) } + if m.rpm_limit != nil { + fields = append(fields, group.FieldRpmLimit) + } return fields } @@ -12207,6 +12268,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.DefaultMappedModel() case group.FieldMessagesDispatchModelConfig: return m.MessagesDispatchModelConfig() + case group.FieldRpmLimit: + return m.RpmLimit() } return nil, false } @@ -12276,6 +12339,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldDefaultMappedModel(ctx) case group.FieldMessagesDispatchModelConfig: return m.OldMessagesDispatchModelConfig(ctx) + case group.FieldRpmLimit: + return m.OldRpmLimit(ctx) } return nil, fmt.Errorf("unknown Group field %s", name) } @@ -12495,6 +12560,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetMessagesDispatchModelConfig(v) return nil + case group.FieldRpmLimit: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRpmLimit(v) + return nil } return fmt.Errorf("unknown Group field %s", name) } @@ -12536,6 +12608,9 @@ func (m *GroupMutation) AddedFields() []string { if m.addsort_order != nil { fields = append(fields, group.FieldSortOrder) } + if m.addrpm_limit != nil { + fields = append(fields, group.FieldRpmLimit) + } return fields } @@ -12566,6 +12641,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) { return m.AddedFallbackGroupIDOnInvalidRequest() case group.FieldSortOrder: return m.AddedSortOrder() + case group.FieldRpmLimit: + return m.AddedRpmLimit() } return nil, false } @@ -12652,6 +12729,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error { } m.AddSortOrder(v) return nil + case group.FieldRpmLimit: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddRpmLimit(v) + return nil } return fmt.Errorf("unknown Group numeric field %s", name) } @@ -12838,6 +12922,9 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldMessagesDispatchModelConfig: m.ResetMessagesDispatchModelConfig() return nil + case group.FieldRpmLimit: + m.ResetRpmLimit() + return nil } return fmt.Errorf("unknown Group field %s", name) } @@ -32681,6 +32768,8 @@ type UserMutation struct { balance_notify_extra_emails *string total_recharged *float64 addtotal_recharged *float64 + rpm_limit *int + addrpm_limit *int clearedFields map[string]struct{} api_keys map[int64]struct{} removedapi_keys map[int64]struct{} @@ -33772,6 +33861,62 @@ func (m *UserMutation) ResetTotalRecharged() { m.addtotal_recharged = nil } +// SetRpmLimit sets the "rpm_limit" field. +func (m *UserMutation) SetRpmLimit(i int) { + m.rpm_limit = &i + m.addrpm_limit = nil +} + +// RpmLimit returns the value of the "rpm_limit" field in the mutation. +func (m *UserMutation) RpmLimit() (r int, exists bool) { + v := m.rpm_limit + if v == nil { + return + } + return *v, true +} + +// OldRpmLimit returns the old "rpm_limit" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldRpmLimit(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRpmLimit is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRpmLimit requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRpmLimit: %w", err) + } + return oldValue.RpmLimit, nil +} + +// AddRpmLimit adds i to the "rpm_limit" field. +func (m *UserMutation) AddRpmLimit(i int) { + if m.addrpm_limit != nil { + *m.addrpm_limit += i + } else { + m.addrpm_limit = &i + } +} + +// AddedRpmLimit returns the value that was added to the "rpm_limit" field in this mutation. +func (m *UserMutation) AddedRpmLimit() (r int, exists bool) { + v := m.addrpm_limit + if v == nil { + return + } + return *v, true +} + +// ResetRpmLimit resets all changes to the "rpm_limit" field. +func (m *UserMutation) ResetRpmLimit() { + m.rpm_limit = nil + m.addrpm_limit = nil +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. func (m *UserMutation) AddAPIKeyIDs(ids ...int64) { if m.api_keys == nil { @@ -34454,7 +34599,7 @@ func (m *UserMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UserMutation) Fields() []string { - fields := make([]string, 0, 22) + fields := make([]string, 0, 23) if m.created_at != nil { fields = append(fields, user.FieldCreatedAt) } @@ -34521,6 +34666,9 @@ func (m *UserMutation) Fields() []string { if m.total_recharged != nil { fields = append(fields, user.FieldTotalRecharged) } + if m.rpm_limit != nil { + fields = append(fields, user.FieldRpmLimit) + } return fields } @@ -34573,6 +34721,8 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) { return m.BalanceNotifyExtraEmails() case user.FieldTotalRecharged: return m.TotalRecharged() + case user.FieldRpmLimit: + return m.RpmLimit() } return nil, false } @@ -34626,6 +34776,8 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er return m.OldBalanceNotifyExtraEmails(ctx) case user.FieldTotalRecharged: return m.OldTotalRecharged(ctx) + case user.FieldRpmLimit: + return m.OldRpmLimit(ctx) } return nil, fmt.Errorf("unknown User field %s", name) } @@ -34789,6 +34941,13 @@ func (m *UserMutation) SetField(name string, value ent.Value) error { } m.SetTotalRecharged(v) return nil + case user.FieldRpmLimit: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRpmLimit(v) + return nil } return fmt.Errorf("unknown User field %s", name) } @@ -34809,6 +34968,9 @@ func (m *UserMutation) AddedFields() []string { if m.addtotal_recharged != nil { fields = append(fields, user.FieldTotalRecharged) } + if m.addrpm_limit != nil { + fields = append(fields, user.FieldRpmLimit) + } return fields } @@ -34825,6 +34987,8 @@ func (m *UserMutation) AddedField(name string) (ent.Value, bool) { return m.AddedBalanceNotifyThreshold() case user.FieldTotalRecharged: return m.AddedTotalRecharged() + case user.FieldRpmLimit: + return m.AddedRpmLimit() } return nil, false } @@ -34862,6 +35026,13 @@ func (m *UserMutation) AddField(name string, value ent.Value) error { } m.AddTotalRecharged(v) return nil + case user.FieldRpmLimit: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddRpmLimit(v) + return nil } return fmt.Errorf("unknown User numeric field %s", name) } @@ -34994,6 +35165,9 @@ func (m *UserMutation) ResetField(name string) error { case user.FieldTotalRecharged: m.ResetTotalRecharged() return nil + case user.FieldRpmLimit: + m.ResetRpmLimit() + return nil } return fmt.Errorf("unknown User field %s", name) } diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index bdb7f7a9..eecb2377 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -595,6 +595,10 @@ func init() { groupDescMessagesDispatchModelConfig := groupFields[26].Descriptor() // group.DefaultMessagesDispatchModelConfig holds the default value on creation for the messages_dispatch_model_config field. group.DefaultMessagesDispatchModelConfig = groupDescMessagesDispatchModelConfig.Default.(domain.OpenAIMessagesDispatchModelConfig) + // groupDescRpmLimit is the schema descriptor for rpm_limit field. + groupDescRpmLimit := groupFields[27].Descriptor() + // group.DefaultRpmLimit holds the default value on creation for the rpm_limit field. + group.DefaultRpmLimit = groupDescRpmLimit.Default.(int) idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin() idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields() _ = idempotencyrecordMixinFields0 @@ -1575,6 +1579,10 @@ func init() { userDescTotalRecharged := userFields[18].Descriptor() // user.DefaultTotalRecharged holds the default value on creation for the total_recharged field. user.DefaultTotalRecharged = userDescTotalRecharged.Default.(float64) + // userDescRpmLimit is the schema descriptor for rpm_limit field. + userDescRpmLimit := userFields[19].Descriptor() + // user.DefaultRpmLimit holds the default value on creation for the rpm_limit field. + user.DefaultRpmLimit = userDescRpmLimit.Default.(int) userallowedgroupFields := schema.UserAllowedGroup{}.Fields() _ = userallowedgroupFields // userallowedgroupDescCreatedAt is the schema descriptor for created_at field. diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index d78a6898..11f38d66 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -145,6 +145,11 @@ func (Group) Fields() []ent.Field { Default(domain.OpenAIMessagesDispatchModelConfig{}). SchemaType(map[string]string{dialect.Postgres: "jsonb"}). Comment("OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型"), + + // 分组级每分钟请求数上限(0 = 不限制)。设置后优先于用户级兜底生效。 + field.Int("rpm_limit"). + Default(0). + Comment("分组 RPM 上限,0 表示不限制;设置后接管该分组用户的限流"), } } diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go index c0f0bdc1..83da5c32 100644 --- a/backend/ent/schema/user.go +++ b/backend/ent/schema/user.go @@ -108,6 +108,10 @@ func (User) Fields() []ent.Field { field.Float("total_recharged"). SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}). Default(0), + + // 用户级每分钟请求数上限(0 = 不限制)。仅当所在分组未设置 rpm_limit 时作为兜底生效。 + field.Int("rpm_limit"). + Default(0), } } diff --git a/backend/ent/user.go b/backend/ent/user.go index 66f33623..06670444 100644 --- a/backend/ent/user.go +++ b/backend/ent/user.go @@ -61,6 +61,8 @@ type User struct { BalanceNotifyExtraEmails string `json:"balance_notify_extra_emails,omitempty"` // TotalRecharged holds the value of the "total_recharged" field. TotalRecharged float64 `json:"total_recharged,omitempty"` + // RpmLimit holds the value of the "rpm_limit" field. + RpmLimit int `json:"rpm_limit,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the UserQuery when eager-loading is set. Edges UserEdges `json:"edges"` @@ -226,7 +228,7 @@ func (*User) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullBool) case user.FieldBalance, user.FieldBalanceNotifyThreshold, user.FieldTotalRecharged: values[i] = new(sql.NullFloat64) - case user.FieldID, user.FieldConcurrency: + case user.FieldID, user.FieldConcurrency, user.FieldRpmLimit: values[i] = new(sql.NullInt64) case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldSignupSource, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails: values[i] = new(sql.NullString) @@ -391,6 +393,12 @@ func (_m *User) assignValues(columns []string, values []any) error { } else if value.Valid { _m.TotalRecharged = value.Float64 } + case user.FieldRpmLimit: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field rpm_limit", values[i]) + } else if value.Valid { + _m.RpmLimit = int(value.Int64) + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -569,6 +577,9 @@ func (_m *User) String() string { builder.WriteString(", ") builder.WriteString("total_recharged=") builder.WriteString(fmt.Sprintf("%v", _m.TotalRecharged)) + builder.WriteString(", ") + builder.WriteString("rpm_limit=") + builder.WriteString(fmt.Sprintf("%v", _m.RpmLimit)) builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go index 567e3b14..e11a8a32 100644 --- a/backend/ent/user/user.go +++ b/backend/ent/user/user.go @@ -59,6 +59,8 @@ const ( FieldBalanceNotifyExtraEmails = "balance_notify_extra_emails" // FieldTotalRecharged holds the string denoting the total_recharged field in the database. FieldTotalRecharged = "total_recharged" + // FieldRpmLimit holds the string denoting the rpm_limit field in the database. + FieldRpmLimit = "rpm_limit" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. EdgeAPIKeys = "api_keys" // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. @@ -203,6 +205,7 @@ var Columns = []string{ FieldBalanceNotifyThreshold, FieldBalanceNotifyExtraEmails, FieldTotalRecharged, + FieldRpmLimit, } var ( @@ -271,6 +274,8 @@ var ( DefaultBalanceNotifyExtraEmails string // DefaultTotalRecharged holds the default value on creation for the "total_recharged" field. DefaultTotalRecharged float64 + // DefaultRpmLimit holds the default value on creation for the "rpm_limit" field. + DefaultRpmLimit int ) // OrderOption defines the ordering options for the User queries. @@ -391,6 +396,11 @@ func ByTotalRecharged(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldTotalRecharged, opts...).ToFunc() } +// ByRpmLimit orders the results by the rpm_limit field. +func ByRpmLimit(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRpmLimit, opts...).ToFunc() +} + // ByAPIKeysCount orders the results by api_keys count. func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go index cbcfcc26..05d3b35b 100644 --- a/backend/ent/user/where.go +++ b/backend/ent/user/where.go @@ -165,6 +165,11 @@ func TotalRecharged(v float64) predicate.User { return predicate.User(sql.FieldEQ(FieldTotalRecharged, v)) } +// RpmLimit applies equality check predicate on the "rpm_limit" field. It's identical to RpmLimitEQ. +func RpmLimit(v int) predicate.User { + return predicate.User(sql.FieldEQ(FieldRpmLimit, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.User { return predicate.User(sql.FieldEQ(FieldCreatedAt, v)) @@ -1295,6 +1300,46 @@ func TotalRechargedLTE(v float64) predicate.User { return predicate.User(sql.FieldLTE(FieldTotalRecharged, v)) } +// RpmLimitEQ applies the EQ predicate on the "rpm_limit" field. +func RpmLimitEQ(v int) predicate.User { + return predicate.User(sql.FieldEQ(FieldRpmLimit, v)) +} + +// RpmLimitNEQ applies the NEQ predicate on the "rpm_limit" field. +func RpmLimitNEQ(v int) predicate.User { + return predicate.User(sql.FieldNEQ(FieldRpmLimit, v)) +} + +// RpmLimitIn applies the In predicate on the "rpm_limit" field. +func RpmLimitIn(vs ...int) predicate.User { + return predicate.User(sql.FieldIn(FieldRpmLimit, vs...)) +} + +// RpmLimitNotIn applies the NotIn predicate on the "rpm_limit" field. +func RpmLimitNotIn(vs ...int) predicate.User { + return predicate.User(sql.FieldNotIn(FieldRpmLimit, vs...)) +} + +// RpmLimitGT applies the GT predicate on the "rpm_limit" field. +func RpmLimitGT(v int) predicate.User { + return predicate.User(sql.FieldGT(FieldRpmLimit, v)) +} + +// RpmLimitGTE applies the GTE predicate on the "rpm_limit" field. +func RpmLimitGTE(v int) predicate.User { + return predicate.User(sql.FieldGTE(FieldRpmLimit, v)) +} + +// RpmLimitLT applies the LT predicate on the "rpm_limit" field. +func RpmLimitLT(v int) predicate.User { + return predicate.User(sql.FieldLT(FieldRpmLimit, v)) +} + +// RpmLimitLTE applies the LTE predicate on the "rpm_limit" field. +func RpmLimitLTE(v int) predicate.User { + return predicate.User(sql.FieldLTE(FieldRpmLimit, v)) +} + // HasAPIKeys applies the HasEdge predicate on the "api_keys" edge. func HasAPIKeys() predicate.User { return predicate.User(func(s *sql.Selector) { diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go index db95e813..b4161128 100644 --- a/backend/ent/user_create.go +++ b/backend/ent/user_create.go @@ -325,6 +325,20 @@ func (_c *UserCreate) SetNillableTotalRecharged(v *float64) *UserCreate { return _c } +// SetRpmLimit sets the "rpm_limit" field. +func (_c *UserCreate) SetRpmLimit(v int) *UserCreate { + _c.mutation.SetRpmLimit(v) + return _c +} + +// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil. +func (_c *UserCreate) SetNillableRpmLimit(v *int) *UserCreate { + if v != nil { + _c.SetRpmLimit(*v) + } + return _c +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate { _c.mutation.AddAPIKeyIDs(ids...) @@ -604,6 +618,10 @@ func (_c *UserCreate) defaults() error { v := user.DefaultTotalRecharged _c.mutation.SetTotalRecharged(v) } + if _, ok := _c.mutation.RpmLimit(); !ok { + v := user.DefaultRpmLimit + _c.mutation.SetRpmLimit(v) + } return nil } @@ -687,6 +705,9 @@ func (_c *UserCreate) check() error { if _, ok := _c.mutation.TotalRecharged(); !ok { return &ValidationError{Name: "total_recharged", err: errors.New(`ent: missing required field "User.total_recharged"`)} } + if _, ok := _c.mutation.RpmLimit(); !ok { + return &ValidationError{Name: "rpm_limit", err: errors.New(`ent: missing required field "User.rpm_limit"`)} + } return nil } @@ -802,6 +823,10 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { _spec.SetField(user.FieldTotalRecharged, field.TypeFloat64, value) _node.TotalRecharged = value } + if value, ok := _c.mutation.RpmLimit(); ok { + _spec.SetField(user.FieldRpmLimit, field.TypeInt, value) + _node.RpmLimit = value + } if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1362,6 +1387,24 @@ func (u *UserUpsert) AddTotalRecharged(v float64) *UserUpsert { return u } +// SetRpmLimit sets the "rpm_limit" field. +func (u *UserUpsert) SetRpmLimit(v int) *UserUpsert { + u.Set(user.FieldRpmLimit, v) + return u +} + +// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create. +func (u *UserUpsert) UpdateRpmLimit() *UserUpsert { + u.SetExcluded(user.FieldRpmLimit) + return u +} + +// AddRpmLimit adds v to the "rpm_limit" field. +func (u *UserUpsert) AddRpmLimit(v int) *UserUpsert { + u.Add(user.FieldRpmLimit, v) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -1771,6 +1814,27 @@ func (u *UserUpsertOne) UpdateTotalRecharged() *UserUpsertOne { }) } +// SetRpmLimit sets the "rpm_limit" field. +func (u *UserUpsertOne) SetRpmLimit(v int) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetRpmLimit(v) + }) +} + +// AddRpmLimit adds v to the "rpm_limit" field. +func (u *UserUpsertOne) AddRpmLimit(v int) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.AddRpmLimit(v) + }) +} + +// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateRpmLimit() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateRpmLimit() + }) +} + // Exec executes the query. func (u *UserUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -2346,6 +2410,27 @@ func (u *UserUpsertBulk) UpdateTotalRecharged() *UserUpsertBulk { }) } +// SetRpmLimit sets the "rpm_limit" field. +func (u *UserUpsertBulk) SetRpmLimit(v int) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetRpmLimit(v) + }) +} + +// AddRpmLimit adds v to the "rpm_limit" field. +func (u *UserUpsertBulk) AddRpmLimit(v int) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.AddRpmLimit(v) + }) +} + +// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateRpmLimit() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateRpmLimit() + }) +} + // Exec executes the query. func (u *UserUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go index 677eeb6b..f1d759ce 100644 --- a/backend/ent/user_update.go +++ b/backend/ent/user_update.go @@ -389,6 +389,27 @@ func (_u *UserUpdate) AddTotalRecharged(v float64) *UserUpdate { return _u } +// SetRpmLimit sets the "rpm_limit" field. +func (_u *UserUpdate) SetRpmLimit(v int) *UserUpdate { + _u.mutation.ResetRpmLimit() + _u.mutation.SetRpmLimit(v) + return _u +} + +// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil. +func (_u *UserUpdate) SetNillableRpmLimit(v *int) *UserUpdate { + if v != nil { + _u.SetRpmLimit(*v) + } + return _u +} + +// AddRpmLimit adds value to the "rpm_limit" field. +func (_u *UserUpdate) AddRpmLimit(v int) *UserUpdate { + _u.mutation.AddRpmLimit(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate { _u.mutation.AddAPIKeyIDs(ids...) @@ -1008,6 +1029,12 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.AddedTotalRecharged(); ok { _spec.AddField(user.FieldTotalRecharged, field.TypeFloat64, value) } + if value, ok := _u.mutation.RpmLimit(); ok { + _spec.SetField(user.FieldRpmLimit, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedRpmLimit(); ok { + _spec.AddField(user.FieldRpmLimit, field.TypeInt, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1930,6 +1957,27 @@ func (_u *UserUpdateOne) AddTotalRecharged(v float64) *UserUpdateOne { return _u } +// SetRpmLimit sets the "rpm_limit" field. +func (_u *UserUpdateOne) SetRpmLimit(v int) *UserUpdateOne { + _u.mutation.ResetRpmLimit() + _u.mutation.SetRpmLimit(v) + return _u +} + +// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableRpmLimit(v *int) *UserUpdateOne { + if v != nil { + _u.SetRpmLimit(*v) + } + return _u +} + +// AddRpmLimit adds value to the "rpm_limit" field. +func (_u *UserUpdateOne) AddRpmLimit(v int) *UserUpdateOne { + _u.mutation.AddRpmLimit(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne { _u.mutation.AddAPIKeyIDs(ids...) @@ -2579,6 +2627,12 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { if value, ok := _u.mutation.AddedTotalRecharged(); ok { _spec.AddField(user.FieldTotalRecharged, field.TypeFloat64, value) } + if value, ok := _u.mutation.RpmLimit(); ok { + _spec.SetField(user.FieldRpmLimit, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedRpmLimit(); ok { + _spec.AddField(user.FieldRpmLimit, field.TypeInt, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, diff --git a/backend/go.mod b/backend/go.mod index 627851bf..982bf91b 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -104,6 +104,7 @@ require ( github.com/goccy/go-json v0.10.2 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/go-querystring v1.1.0 // indirect + github.com/google/subcommands v1.2.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl/v2 v2.18.1 // indirect diff --git a/backend/go.sum b/backend/go.sum index f1c864f5..af6dc81a 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -162,6 +162,8 @@ github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17 github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs= github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= +github.com/google/subcommands v1.2.0 h1:vWQspBTo2nEqTUFita5/KeEWlUL8kQObDFbub/EN9oE= +github.com/google/subcommands v1.2.0/go.mod h1:ZjhPrFU+Olkh9WazFPsl27BQ4UPiG37m3yTrtFlrHVk= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/wire v0.7.0 h1:JxUKI6+CVBgCO2WToKy/nQk0sS+amI9z9EjVmdaocj4= @@ -216,6 +218,8 @@ github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovk github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= +github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mdelapenya/tlscert v0.2.0 h1:7H81W6Z/4weDvZBNOfQte5GpIMo0lGYEeWbkGp5LJHI= @@ -249,6 +253,8 @@ github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= +github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= @@ -278,6 +284,8 @@ github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEv github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= @@ -310,6 +318,8 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= +github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I= +github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index 3a395342..2fe29fa3 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -183,6 +183,17 @@ func (s *stubAdminService) GetUserUsageStats(ctx context.Context, userID int64, return map[string]any{"user_id": userID}, nil } +func (s *stubAdminService) GetUserRPMStatus(ctx context.Context, userID int64) (*service.UserRPMStatus, error) { + user, err := s.GetUser(ctx, userID) + if err != nil { + return nil, err + } + return &service.UserRPMStatus{ + UserRPMUsed: 0, + UserRPMLimit: user.RPMLimit, + }, nil +} + func (s *stubAdminService) BindUserAuthIdentity(ctx context.Context, userID int64, input service.AdminBindAuthIdentityInput) (*service.AdminBoundAuthIdentity, error) { s.boundAuthIdentityFor = userID copied := input @@ -276,6 +287,14 @@ func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int return nil } +func (s *stubAdminService) ClearGroupRPMOverrides(_ context.Context, _ int64) error { + return nil +} + +func (s *stubAdminService) BatchSetGroupRPMOverrides(_ context.Context, _ int64, _ []service.GroupRPMOverrideInput) error { + return nil +} + func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, sortBy, sortOrder string) ([]service.Account, int64, error) { s.lastListAccounts.platform = platform s.lastListAccounts.accountType = accountType diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index cb2bd201..65e5ec78 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -110,6 +110,8 @@ type CreateGroupRequest struct { RequirePrivacySet bool `json:"require_privacy_set"` DefaultMappedModel string `json:"default_mapped_model"` MessagesDispatchModelConfig service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"` + // 分组 RPM 上限(0 = 不限制) + RPMLimit int `json:"rpm_limit"` // 从指定分组复制账号(创建后自动绑定) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` } @@ -145,6 +147,8 @@ type UpdateGroupRequest struct { RequirePrivacySet *bool `json:"require_privacy_set"` DefaultMappedModel *string `json:"default_mapped_model"` MessagesDispatchModelConfig *service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"` + // 分组 RPM 上限(0 = 不限制);nil 表示未提供不改动 + RPMLimit *int `json:"rpm_limit"` // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` } @@ -262,6 +266,7 @@ func (h *GroupHandler) Create(c *gin.Context) { RequirePrivacySet: req.RequirePrivacySet, DefaultMappedModel: req.DefaultMappedModel, MessagesDispatchModelConfig: req.MessagesDispatchModelConfig, + RPMLimit: req.RPMLimit, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) if err != nil { @@ -313,6 +318,7 @@ func (h *GroupHandler) Update(c *gin.Context) { RequirePrivacySet: req.RequirePrivacySet, DefaultMappedModel: req.DefaultMappedModel, MessagesDispatchModelConfig: req.MessagesDispatchModelConfig, + RPMLimit: req.RPMLimit, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) if err != nil { @@ -477,6 +483,51 @@ func (h *GroupHandler) BatchSetGroupRateMultipliers(c *gin.Context) { response.Success(c, gin.H{"message": "Rate multipliers updated successfully"}) } +// BatchSetGroupRPMOverridesRequest represents batch set rpm_override request +type BatchSetGroupRPMOverridesRequest struct { + Entries []service.GroupRPMOverrideInput `json:"entries" binding:"required"` +} + +// BatchSetGroupRPMOverrides handles batch setting rpm_override for users in a group +// PUT /api/v1/admin/groups/:id/rpm-overrides +func (h *GroupHandler) BatchSetGroupRPMOverrides(c *gin.Context) { + groupID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid group ID") + return + } + + var req BatchSetGroupRPMOverridesRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if err := h.adminService.BatchSetGroupRPMOverrides(c.Request.Context(), groupID, req.Entries); err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"message": "RPM overrides updated successfully"}) +} + +// ClearGroupRPMOverrides handles clearing all rpm_override for a group +// DELETE /api/v1/admin/groups/:id/rpm-overrides +func (h *GroupHandler) ClearGroupRPMOverrides(c *gin.Context) { + groupID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid group ID") + return + } + + if err := h.adminService.ClearGroupRPMOverrides(c.Request.Context(), groupID); err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"message": "RPM overrides cleared successfully"}) +} + // UpdateSortOrderRequest represents the request to update group sort orders type UpdateSortOrderRequest struct { Updates []struct { diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index a882d1a1..11e7a652 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -185,6 +185,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), DefaultConcurrency: settings.DefaultConcurrency, DefaultBalance: settings.DefaultBalance, + DefaultUserRPMLimit: settings.DefaultUserRPMLimit, DefaultSubscriptions: defaultSubscriptions, EnableModelFallback: settings.EnableModelFallback, FallbackModelAnthropic: settings.FallbackModelAnthropic, @@ -332,6 +333,7 @@ type UpdateSettingsRequest struct { // 默认配置 DefaultConcurrency int `json:"default_concurrency"` DefaultBalance float64 `json:"default_balance"` + DefaultUserRPMLimit int `json:"default_user_rpm_limit"` DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"` AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"` AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"` @@ -1105,6 +1107,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { CustomEndpoints: customEndpointsJSON, DefaultConcurrency: req.DefaultConcurrency, DefaultBalance: req.DefaultBalance, + DefaultUserRPMLimit: req.DefaultUserRPMLimit, DefaultSubscriptions: defaultSubscriptions, EnableModelFallback: req.EnableModelFallback, FallbackModelAnthropic: req.FallbackModelAnthropic, @@ -1400,6 +1403,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints), DefaultConcurrency: updatedSettings.DefaultConcurrency, DefaultBalance: updatedSettings.DefaultBalance, + DefaultUserRPMLimit: updatedSettings.DefaultUserRPMLimit, DefaultSubscriptions: updatedDefaultSubscriptions, EnableModelFallback: updatedSettings.EnableModelFallback, FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic, diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index b2ed9d18..3d80107f 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -40,6 +40,7 @@ type CreateUserRequest struct { Notes string `json:"notes"` Balance float64 `json:"balance"` Concurrency int `json:"concurrency"` + RPMLimit int `json:"rpm_limit"` AllowedGroups []int64 `json:"allowed_groups"` } @@ -52,6 +53,7 @@ type UpdateUserRequest struct { Notes *string `json:"notes"` Balance *float64 `json:"balance"` Concurrency *int `json:"concurrency"` + RPMLimit *int `json:"rpm_limit"` Status string `json:"status" binding:"omitempty,oneof=active disabled"` AllowedGroups *[]int64 `json:"allowed_groups"` // GroupRates 用户专属分组倍率配置 @@ -243,6 +245,7 @@ func (h *UserHandler) Create(c *gin.Context) { Notes: req.Notes, Balance: req.Balance, Concurrency: req.Concurrency, + RPMLimit: req.RPMLimit, AllowedGroups: req.AllowedGroups, }) if err != nil { @@ -276,6 +279,7 @@ func (h *UserHandler) Update(c *gin.Context) { Notes: req.Notes, Balance: req.Balance, Concurrency: req.Concurrency, + RPMLimit: req.RPMLimit, Status: req.Status, AllowedGroups: req.AllowedGroups, GroupRates: req.GroupRates, @@ -455,3 +459,21 @@ func (h *UserHandler) ReplaceGroup(c *gin.Context) { "migrated_keys": result.MigratedKeys, }) } + +// GetUserRPMStatus 返回指定用户当前分钟的 RPM 用量 +// GET /api/v1/admin/users/:id/rpm-status +func (h *UserHandler) GetUserRPMStatus(c *gin.Context) { + userID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid user ID") + return + } + + status, err := h.adminService.GetUserRPMStatus(c.Request.Context(), userID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, status) +} diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 9780ff79..f7503c2e 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -29,6 +29,7 @@ func UserFromServiceShallow(u *service.User) *User { BalanceNotifyThreshold: u.BalanceNotifyThreshold, BalanceNotifyExtraEmails: NotifyEmailEntriesFromService(u.BalanceNotifyExtraEmails), TotalRecharged: u.TotalRecharged, + RPMLimit: u.RPMLimit, } } @@ -184,6 +185,7 @@ func groupFromServiceBase(g *service.Group) Group { AllowMessagesDispatch: g.AllowMessagesDispatch, RequireOAuthOnly: g.RequireOAuthOnly, RequirePrivacySet: g.RequirePrivacySet, + RPMLimit: g.RPMLimit, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, } diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index fc6a3f9e..a9933e63 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -108,6 +108,7 @@ type SystemSettings struct { DefaultConcurrency int `json:"default_concurrency"` DefaultBalance float64 `json:"default_balance"` + DefaultUserRPMLimit int `json:"default_user_rpm_limit"` DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"` // Model fallback configuration diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index c0bce40b..5cc2f8e4 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -26,6 +26,9 @@ type User struct { BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails"` TotalRecharged float64 `json:"total_recharged"` + // RPMLimit 用户级每分钟请求数上限(0 = 不限制),仅在所用分组未设置 rpm_limit 时作为兜底生效。 + RPMLimit int `json:"rpm_limit"` + APIKeys []APIKey `json:"api_keys,omitempty"` Subscriptions []UserSubscription `json:"subscriptions,omitempty"` } @@ -108,6 +111,9 @@ type Group struct { RequireOAuthOnly bool `json:"require_oauth_only"` RequirePrivacySet bool `json:"require_privacy_set"` + // RPMLimit 分组级每分钟请求数上限(0 = 不限制),设置后覆盖用户级 rpm_limit。 + RPMLimit int `json:"rpm_limit"` + CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index f5eff8c9..94a22935 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -243,7 +243,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 2. 【新增】Wait后二次检查余额/订阅 if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { reqLog.Info("gateway.billing_eligibility_check_failed", zap.Error(err)) - status, code, message := billingErrorDetails(err) + status, code, message, retryAfter := billingErrorDetails(err) + if retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } h.handleStreamingAwareError(c, status, code, message, streamStarted) return } @@ -735,7 +738,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } fallbackAPIKey := cloneAPIKeyWithGroup(apiKey, fallbackGroup) if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil); err != nil { - status, code, message := billingErrorDetails(err) + status, code, message, retryAfter := billingErrorDetails(err) + if retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } h.handleStreamingAwareError(c, status, code, message, streamStarted) return } @@ -1441,7 +1447,10 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { // 校验 billing eligibility(订阅/余额) // 【注意】不计算并发,但需要校验订阅/余额 if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { - status, code, message := billingErrorDetails(err) + status, code, message, retryAfter := billingErrorDetails(err) + if retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } h.errorResponse(c, status, code, message) return } @@ -1684,25 +1693,32 @@ func sendMockInterceptResponse(c *gin.Context, model string, interceptType Inter c.JSON(http.StatusOK, response) } -func billingErrorDetails(err error) (status int, code, message string) { +func billingErrorDetails(err error) (status int, code, message string, retryAfter int) { if errors.Is(err, service.ErrBillingServiceUnavailable) { msg := pkgerrors.Message(err) if msg == "" { msg = "Billing service temporarily unavailable. Please retry later." } - return http.StatusServiceUnavailable, "billing_service_error", msg + return http.StatusServiceUnavailable, "billing_service_error", msg, 0 } if errors.Is(err, service.ErrAPIKeyRateLimit5hExceeded) { msg := pkgerrors.Message(err) - return http.StatusTooManyRequests, "rate_limit_exceeded", msg + return http.StatusTooManyRequests, "rate_limit_exceeded", msg, 0 } if errors.Is(err, service.ErrAPIKeyRateLimit1dExceeded) { msg := pkgerrors.Message(err) - return http.StatusTooManyRequests, "rate_limit_exceeded", msg + return http.StatusTooManyRequests, "rate_limit_exceeded", msg, 0 } if errors.Is(err, service.ErrAPIKeyRateLimit7dExceeded) { msg := pkgerrors.Message(err) - return http.StatusTooManyRequests, "rate_limit_exceeded", msg + return http.StatusTooManyRequests, "rate_limit_exceeded", msg, 0 + } + // 用户/分组 RPM 超限统一映射为 HTTP 429;保留与其它 rate_limit 一致的错误码便于客户端分类。 + // 返回 Retry-After 秒数(当前分钟剩余秒数),让 SDK 自动退避。 + if errors.Is(err, service.ErrGroupRPMExceeded) || errors.Is(err, service.ErrUserRPMExceeded) { + msg := pkgerrors.Message(err) + retrySeconds := 60 - int(time.Now().Unix()%60) + return http.StatusTooManyRequests, "rate_limit_exceeded", msg, retrySeconds } msg := pkgerrors.Message(err) if msg == "" { @@ -1712,7 +1728,7 @@ func billingErrorDetails(err error) (status int, code, message string) { ).Warn("gateway.billing_error_missing_message") msg = "Billing error" } - return http.StatusForbidden, "billing_error", msg + return http.StatusForbidden, "billing_error", msg, 0 } func (h *GatewayHandler) metadataBridgeEnabled() bool { diff --git a/backend/internal/handler/gateway_handler_billing_error_test.go b/backend/internal/handler/gateway_handler_billing_error_test.go new file mode 100644 index 00000000..e8a88802 --- /dev/null +++ b/backend/internal/handler/gateway_handler_billing_error_test.go @@ -0,0 +1,54 @@ +package handler + +import ( + "net/http" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestBillingErrorDetails_MapsGroupRPMExceededToTooManyRequests(t *testing.T) { + status, code, msg, retryAfter := billingErrorDetails(service.ErrGroupRPMExceeded) + require.Equal(t, http.StatusTooManyRequests, status) + require.Equal(t, "rate_limit_exceeded", code) + require.NotEmpty(t, msg) + require.Greater(t, retryAfter, 0, "RPM exceeded should return positive Retry-After") + require.LessOrEqual(t, retryAfter, 60) +} + +func TestBillingErrorDetails_MapsUserRPMExceededToTooManyRequests(t *testing.T) { + status, code, msg, retryAfter := billingErrorDetails(service.ErrUserRPMExceeded) + require.Equal(t, http.StatusTooManyRequests, status) + require.Equal(t, "rate_limit_exceeded", code) + require.NotEmpty(t, msg) + require.Greater(t, retryAfter, 0, "RPM exceeded should return positive Retry-After") + require.LessOrEqual(t, retryAfter, 60) +} + +func TestBillingErrorDetails_APIKeyRateLimitStillMaps(t *testing.T) { + // 回归保护:加 RPM 分支后不应影响已有 APIKey rate limit 的映射。 + for _, err := range []error{ + service.ErrAPIKeyRateLimit5hExceeded, + service.ErrAPIKeyRateLimit1dExceeded, + service.ErrAPIKeyRateLimit7dExceeded, + } { + status, code, _, _ := billingErrorDetails(err) + require.Equal(t, http.StatusTooManyRequests, status, "status for %v", err) + require.Equal(t, "rate_limit_exceeded", code) + } +} + +func TestBillingErrorDetails_BillingServiceUnavailableMapsTo503(t *testing.T) { + status, code, _, retryAfter := billingErrorDetails(service.ErrBillingServiceUnavailable) + require.Equal(t, http.StatusServiceUnavailable, status) + require.Equal(t, "billing_service_error", code) + require.Equal(t, 0, retryAfter, "non-RPM errors should not set Retry-After") +} + +func TestBillingErrorDetails_UnknownErrorFallsBackTo403(t *testing.T) { + status, code, msg, _ := billingErrorDetails(service.ErrInsufficientBalance) + require.Equal(t, http.StatusForbidden, status) + require.Equal(t, "billing_error", code) + require.NotEmpty(t, msg) +} diff --git a/backend/internal/handler/gateway_handler_chat_completions.go b/backend/internal/handler/gateway_handler_chat_completions.go index be267332..4290e54b 100644 --- a/backend/internal/handler/gateway_handler_chat_completions.go +++ b/backend/internal/handler/gateway_handler_chat_completions.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net/http" + "strconv" "time" pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" @@ -136,7 +137,10 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { // 2. Re-check billing if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { reqLog.Info("gateway.cc.billing_check_failed", zap.Error(err)) - status, code, message := billingErrorDetails(err) + status, code, message, retryAfter := billingErrorDetails(err) + if retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } h.chatCompletionsErrorResponse(c, status, code, message) return } diff --git a/backend/internal/handler/gateway_handler_responses.go b/backend/internal/handler/gateway_handler_responses.go index e908eb9e..683cf2b7 100644 --- a/backend/internal/handler/gateway_handler_responses.go +++ b/backend/internal/handler/gateway_handler_responses.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net/http" + "strconv" "time" pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" @@ -141,7 +142,10 @@ func (h *GatewayHandler) Responses(c *gin.Context) { // 2. Re-check billing if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { reqLog.Info("gateway.responses.billing_check_failed", zap.Error(err)) - status, code, message := billingErrorDetails(err) + status, code, message, retryAfter := billingErrorDetails(err) + if retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } h.responsesErrorResponse(c, status, code, message) return } diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go index 1fdc46ba..71030140 100644 --- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go +++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go @@ -173,7 +173,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi // RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。 cfg := &config.Config{RunMode: config.RunModeSimple} - billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, cfg) + billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg) concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{}) concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0) diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index d200c17c..2a34e3f0 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -9,6 +9,7 @@ import ( "errors" "net/http" "regexp" + "strconv" "strings" "github.com/Wei-Shaw/sub2api/internal/domain" @@ -241,7 +242,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 2) billing eligibility check (after wait) if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { reqLog.Info("gemini.billing_eligibility_check_failed", zap.Error(err)) - status, _, message := billingErrorDetails(err) + status, _, message, retryAfter := billingErrorDetails(err) + if retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } googleError(c, status, message) return } diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index 991cbb91..3c4e6251 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net/http" + "strconv" "time" pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" @@ -101,7 +102,10 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err)) - status, code, message := billingErrorDetails(err) + status, code, message, retryAfter := billingErrorDetails(err) + if retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } h.handleStreamingAwareError(c, status, code, message, streamStarted) return } diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 43999a01..1c975573 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -228,7 +228,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { // 2. Re-check billing eligibility after wait if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { reqLog.Info("openai.billing_eligibility_check_failed", zap.Error(err)) - status, code, message := billingErrorDetails(err) + status, code, message, retryAfter := billingErrorDetails(err) + if retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } h.handleStreamingAwareError(c, status, code, message, streamStarted) return } @@ -594,7 +597,10 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { reqLog.Info("openai_messages.billing_eligibility_check_failed", zap.Error(err)) - status, code, message := billingErrorDetails(err) + status, code, message, retryAfter := billingErrorDetails(err) + if retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } h.anthropicStreamingAwareError(c, status, code, message, streamStarted) return } diff --git a/backend/internal/handler/openai_images.go b/backend/internal/handler/openai_images.go index 8dbf8935..403b41ef 100644 --- a/backend/internal/handler/openai_images.go +++ b/backend/internal/handler/openai_images.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net/http" + "strconv" "strings" "time" @@ -108,7 +109,10 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { reqLog.Info("openai.images.billing_eligibility_check_failed", zap.Error(err)) - status, code, message := billingErrorDetails(err) + status, code, message, retryAfter := billingErrorDetails(err) + if retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } h.handleStreamingAwareError(c, status, code, message, streamStarted) return } diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 36d80309..3a527405 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -152,6 +152,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se user.FieldSignupSource, user.FieldLastLoginAt, user.FieldLastActiveAt, + user.FieldRpmLimit, ) }). WithGroup(func(q *dbent.GroupQuery) { @@ -178,6 +179,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se group.FieldAllowMessagesDispatch, group.FieldDefaultMappedModel, group.FieldMessagesDispatchModelConfig, + group.FieldRpmLimit, ) }). Only(ctx) @@ -669,6 +671,7 @@ func userEntityToService(u *dbent.User) *service.User { BalanceNotifyThresholdType: u.BalanceNotifyThresholdType, BalanceNotifyThreshold: u.BalanceNotifyThreshold, TotalRecharged: u.TotalRecharged, + RPMLimit: u.RpmLimit, CreatedAt: u.CreatedAt, UpdatedAt: u.UpdatedAt, } @@ -713,6 +716,7 @@ func groupEntityToService(g *dbent.Group) *service.Group { RequirePrivacySet: g.RequirePrivacySet, DefaultMappedModel: g.DefaultMappedModel, MessagesDispatchModelConfig: g.MessagesDispatchModelConfig, + RPMLimit: g.RpmLimit, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, } diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index c17e3365..5e16475a 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -63,7 +63,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetRequireOauthOnly(groupIn.RequireOAuthOnly). SetRequirePrivacySet(groupIn.RequirePrivacySet). SetDefaultMappedModel(groupIn.DefaultMappedModel). - SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig) + SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig). + SetRpmLimit(groupIn.RPMLimit) // 设置模型路由配置 if groupIn.ModelRouting != nil { @@ -130,7 +131,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er SetRequireOauthOnly(groupIn.RequireOAuthOnly). SetRequirePrivacySet(groupIn.RequirePrivacySet). SetDefaultMappedModel(groupIn.DefaultMappedModel). - SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig) + SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig). + SetRpmLimit(groupIn.RPMLimit) // 显式处理可空字段:nil 需要 clear,非 nil 需要 set。 if groupIn.DailyLimitUSD != nil { diff --git a/backend/internal/repository/user_group_rate_repo.go b/backend/internal/repository/user_group_rate_repo.go index eca5313f..74d25cb0 100644 --- a/backend/internal/repository/user_group_rate_repo.go +++ b/backend/internal/repository/user_group_rate_repo.go @@ -13,14 +13,14 @@ type userGroupRateRepository struct { sql sqlExecutor } -// NewUserGroupRateRepository 创建用户专属分组倍率仓储 +// NewUserGroupRateRepository 创建用户专属分组倍率/RPM 仓储 func NewUserGroupRateRepository(sqlDB *sql.DB) service.UserGroupRateRepository { return &userGroupRateRepository{sql: sqlDB} } -// GetByUserID 获取用户的所有专属分组倍率 +// GetByUserID 获取用户所有专属分组 rate_multiplier(仅返回非 NULL 的条目) func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) { - query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1` + query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND rate_multiplier IS NOT NULL` rows, err := r.sql.QueryContext(ctx, query, userID) if err != nil { return nil, err @@ -42,8 +42,7 @@ func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) return result, nil } -// GetByUserIDs 批量获取多个用户的专属分组倍率。 -// 返回结构:map[userID]map[groupID]rate +// GetByUserIDs 批量获取多个用户的专属分组 rate_multiplier(仅返回非 NULL 的条目) func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) { result := make(map[int64]map[int64]float64, len(userIDs)) if len(userIDs) == 0 { @@ -70,7 +69,7 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in rows, err := r.sql.QueryContext(ctx, ` SELECT user_id, group_id, rate_multiplier FROM user_group_rate_multipliers - WHERE user_id = ANY($1) + WHERE user_id = ANY($1) AND rate_multiplier IS NOT NULL `, pq.Array(uniqueIDs)) if err != nil { return nil, err @@ -95,10 +94,10 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in return result, nil } -// GetByGroupID 获取指定分组下所有用户的专属倍率 +// GetByGroupID 获取指定分组下所有用户的专属配置(rate 与 rpm_override 任一非 NULL 即返回) func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int64) ([]service.UserGroupRateEntry, error) { query := ` - SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier + SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier, ugr.rpm_override FROM user_group_rate_multipliers ugr JOIN users u ON u.id = ugr.user_id AND u.deleted_at IS NULL WHERE ugr.group_id = $1 @@ -113,9 +112,19 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6 var result []service.UserGroupRateEntry for rows.Next() { var entry service.UserGroupRateEntry - if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &entry.RateMultiplier); err != nil { + var rate sql.NullFloat64 + var rpm sql.NullInt32 + if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &rate, &rpm); err != nil { return nil, err } + if rate.Valid { + v := rate.Float64 + entry.RateMultiplier = &v + } + if rpm.Valid { + v := int(rpm.Int32) + entry.RPMOverride = &v + } result = append(result, entry) } if err := rows.Err(); err != nil { @@ -124,10 +133,10 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6 return result, nil } -// GetByUserAndGroup 获取用户在特定分组的专属倍率 +// GetByUserAndGroup 获取用户在特定分组的专属 rate_multiplier(NULL 返回 nil) func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) { query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2` - var rate float64 + var rate sql.NullFloat64 err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rate) if err == sql.ErrNoRows { return nil, nil @@ -135,42 +144,79 @@ func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, if err != nil { return nil, err } - return &rate, nil + if !rate.Valid { + return nil, nil + } + v := rate.Float64 + return &v, nil } -// SyncUserGroupRates 同步用户的分组专属倍率 +// GetRPMOverrideByUserAndGroup 获取用户在特定分组的 rpm_override(NULL 返回 nil) +func (r *userGroupRateRepository) GetRPMOverrideByUserAndGroup(ctx context.Context, userID, groupID int64) (*int, error) { + query := `SELECT rpm_override FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2` + var rpm sql.NullInt32 + err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rpm) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, err + } + if !rpm.Valid { + return nil, nil + } + v := int(rpm.Int32) + return &v, nil +} + +// SyncUserGroupRates 同步用户的分组专属 rate_multiplier。 +// - 传入空 map:清空该用户所有行的 rate_multiplier;若 rpm_override 也为 NULL 则整行删除。 +// - 值为 nil:清空对应行的 rate_multiplier(保留 rpm_override)。 +// - 值非 nil:upsert rate_multiplier(保留已有 rpm_override)。 func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error { if len(rates) == 0 { - // 如果传入空 map,删除该用户的所有专属倍率 - _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID) + if _, err := r.sql.ExecContext(ctx, ` + UPDATE user_group_rate_multipliers + SET rate_multiplier = NULL, updated_at = NOW() + WHERE user_id = $1 + `, userID); err != nil { + return err + } + _, err := r.sql.ExecContext(ctx, + `DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL`, + userID) return err } - // 分离需要删除和需要 upsert 的记录 - var toDelete []int64 + var clearGroupIDs []int64 upsertGroupIDs := make([]int64, 0, len(rates)) upsertRates := make([]float64, 0, len(rates)) for groupID, rate := range rates { if rate == nil { - toDelete = append(toDelete, groupID) + clearGroupIDs = append(clearGroupIDs, groupID) } else { upsertGroupIDs = append(upsertGroupIDs, groupID) upsertRates = append(upsertRates, *rate) } } - // 删除指定的记录 - if len(toDelete) > 0 { + if len(clearGroupIDs) > 0 { + if _, err := r.sql.ExecContext(ctx, ` + UPDATE user_group_rate_multipliers + SET rate_multiplier = NULL, updated_at = NOW() + WHERE user_id = $1 AND group_id = ANY($2) + `, userID, pq.Array(clearGroupIDs)); err != nil { + return err + } if _, err := r.sql.ExecContext(ctx, - `DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2)`, - userID, pq.Array(toDelete)); err != nil { + `DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2) AND rate_multiplier IS NULL AND rpm_override IS NULL`, + userID, pq.Array(clearGroupIDs)); err != nil { return err } } - // Upsert 记录 - now := time.Now() if len(upsertGroupIDs) > 0 { + now := time.Now() _, err := r.sql.ExecContext(ctx, ` INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at) SELECT @@ -193,14 +239,47 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID return nil } -// SyncGroupRateMultipliers 批量同步分组的用户专属倍率(先删后插) +// SyncGroupRateMultipliers 同步分组的 rate_multiplier 部分(不触动 rpm_override)。 +// 语义: +// - 未出现在 entries 中的用户行:rate_multiplier 归 NULL;若 rpm_override 也为 NULL 则整行删除。 +// - 出现的用户行:upsert rate_multiplier。 func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []service.GroupRateMultiplierInput) error { - if _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID); err != nil { + keepUserIDs := make([]int64, 0, len(entries)) + for _, e := range entries { + keepUserIDs = append(keepUserIDs, e.UserID) + } + + // 未在 entries 列表中的行:清空 rate_multiplier。 + if len(keepUserIDs) == 0 { + if _, err := r.sql.ExecContext(ctx, ` + UPDATE user_group_rate_multipliers + SET rate_multiplier = NULL, updated_at = NOW() + WHERE group_id = $1 + `, groupID); err != nil { + return err + } + } else { + if _, err := r.sql.ExecContext(ctx, ` + UPDATE user_group_rate_multipliers + SET rate_multiplier = NULL, updated_at = NOW() + WHERE group_id = $1 AND user_id <> ALL($2) + `, groupID, pq.Array(keepUserIDs)); err != nil { + return err + } + } + + // 清空后若整行 NULL 则删除。 + if _, err := r.sql.ExecContext(ctx, ` + DELETE FROM user_group_rate_multipliers + WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL + `, groupID); err != nil { return err } + if len(entries) == 0 { return nil } + userIDs := make([]int64, len(entries)) rates := make([]float64, len(entries)) for i, e := range entries { @@ -218,13 +297,103 @@ func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context, return err } -// DeleteByGroupID 删除指定分组的所有用户专属倍率 +// SyncGroupRPMOverrides 同步分组的 rpm_override 部分(不触动 rate_multiplier)。 +// 语义: +// - 未出现的用户行:rpm_override 归 NULL;若 rate_multiplier 也为 NULL 则整行删除。 +// - 出现的用户行:若 RPMOverride 为 nil 则清空;非 nil 则 upsert。 +func (r *userGroupRateRepository) SyncGroupRPMOverrides(ctx context.Context, groupID int64, entries []service.GroupRPMOverrideInput) error { + keepUserIDs := make([]int64, 0, len(entries)) + var clearUserIDs []int64 + upsertUserIDs := make([]int64, 0, len(entries)) + upsertValues := make([]int32, 0, len(entries)) + for _, e := range entries { + keepUserIDs = append(keepUserIDs, e.UserID) + if e.RPMOverride == nil { + clearUserIDs = append(clearUserIDs, e.UserID) + } else { + upsertUserIDs = append(upsertUserIDs, e.UserID) + upsertValues = append(upsertValues, int32(*e.RPMOverride)) + } + } + + // 未在 entries 列表中的行:清空 rpm_override。 + if len(keepUserIDs) == 0 { + if _, err := r.sql.ExecContext(ctx, ` + UPDATE user_group_rate_multipliers + SET rpm_override = NULL, updated_at = NOW() + WHERE group_id = $1 + `, groupID); err != nil { + return err + } + } else { + if _, err := r.sql.ExecContext(ctx, ` + UPDATE user_group_rate_multipliers + SET rpm_override = NULL, updated_at = NOW() + WHERE group_id = $1 AND user_id <> ALL($2) + `, groupID, pq.Array(keepUserIDs)); err != nil { + return err + } + } + + // 显式 clear 的行。 + if len(clearUserIDs) > 0 { + if _, err := r.sql.ExecContext(ctx, ` + UPDATE user_group_rate_multipliers + SET rpm_override = NULL, updated_at = NOW() + WHERE group_id = $1 AND user_id = ANY($2) + `, groupID, pq.Array(clearUserIDs)); err != nil { + return err + } + } + + // 清空后若整行 NULL 则删除。 + if _, err := r.sql.ExecContext(ctx, ` + DELETE FROM user_group_rate_multipliers + WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL + `, groupID); err != nil { + return err + } + + if len(upsertUserIDs) > 0 { + now := time.Now() + _, err := r.sql.ExecContext(ctx, ` + INSERT INTO user_group_rate_multipliers (user_id, group_id, rpm_override, created_at, updated_at) + SELECT data.user_id, $1::bigint, data.rpm_override, $2::timestamptz, $2::timestamptz + FROM unnest($3::bigint[], $4::integer[]) AS data(user_id, rpm_override) + ON CONFLICT (user_id, group_id) + DO UPDATE SET rpm_override = EXCLUDED.rpm_override, updated_at = EXCLUDED.updated_at + `, groupID, now, pq.Array(upsertUserIDs), pq.Array(upsertValues)) + if err != nil { + return err + } + } + + return nil +} + +// ClearGroupRPMOverrides 清空指定分组所有行的 rpm_override。 +func (r *userGroupRateRepository) ClearGroupRPMOverrides(ctx context.Context, groupID int64) error { + if _, err := r.sql.ExecContext(ctx, ` + UPDATE user_group_rate_multipliers + SET rpm_override = NULL, updated_at = NOW() + WHERE group_id = $1 + `, groupID); err != nil { + return err + } + _, err := r.sql.ExecContext(ctx, ` + DELETE FROM user_group_rate_multipliers + WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL + `, groupID) + return err +} + +// DeleteByGroupID 删除指定分组的所有用户专属条目 func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error { _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID) return err } -// DeleteByUserID 删除指定用户的所有专属倍率 +// DeleteByUserID 删除指定用户的所有专属条目 func (r *userGroupRateRepository) DeleteByUserID(ctx context.Context, userID int64) error { _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID) return err diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index c5db3dc4..d1f10cbd 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -93,6 +93,7 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)). SetNillableLastLoginAt(userIn.LastLoginAt). SetNillableLastActiveAt(userIn.LastActiveAt). + SetRpmLimit(userIn.RPMLimit). Save(txCtx) if err != nil { return translatePersistenceError(err, nil, service.ErrEmailExists) @@ -219,7 +220,8 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error SetBalanceNotifyThresholdType(userIn.BalanceNotifyThresholdType). SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold). SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)). - SetTotalRecharged(userIn.TotalRecharged) + SetTotalRecharged(userIn.TotalRecharged). + SetRpmLimit(userIn.RPMLimit) if userIn.SignupSource != "" { updateOp = updateOp.SetSignupSource(userIn.SignupSource) } diff --git a/backend/internal/repository/user_rpm_cache.go b/backend/internal/repository/user_rpm_cache.go new file mode 100644 index 00000000..42bf9332 --- /dev/null +++ b/backend/internal/repository/user_rpm_cache.go @@ -0,0 +1,108 @@ +package repository + +import ( + "context" + "fmt" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/redis/go-redis/v9" +) + +// 用户/分组级 RPM 计数器 Redis 实现。 +// +// 设计说明: +// - key 形式:rpm:ug:{uid}:{gid}:{minute}、rpm:u:{uid}:{minute} +// - 时间来源:rdb.Time()(Redis 服务端时间),避免多实例时钟漂移。 +// - 原子操作:TxPipeline (MULTI/EXEC) 执行 INCR+EXPIRE,兼容 Redis Cluster。 +// - TTL:120s,覆盖当前分钟窗口 + 少量冗余。 +// - 返回值语义:超限判断由调用方(billing_cache_service.checkRPM)与 RPMLimit 比较完成。 +const ( + userGroupRPMKeyPrefix = "rpm:ug:" + userRPMKeyPrefix = "rpm:u:" + + userRPMKeyTTL = 120 * time.Second +) + +type userRPMCacheImpl struct { + rdb *redis.Client +} + +// NewUserRPMCache 创建用户/分组级 RPM 计数器。 +func NewUserRPMCache(rdb *redis.Client) service.UserRPMCache { + return &userRPMCacheImpl{rdb: rdb} +} + +// minuteTS 获取当前 Redis 服务端分钟时间戳。 +func (c *userRPMCacheImpl) minuteTS(ctx context.Context) (int64, error) { + t, err := c.rdb.Time(ctx).Result() + if err != nil { + return 0, fmt.Errorf("redis TIME: %w", err) + } + return t.Unix() / 60, nil +} + +// atomicIncr 原子 INCR+EXPIRE。 +func (c *userRPMCacheImpl) atomicIncr(ctx context.Context, key string) (int, error) { + pipe := c.rdb.TxPipeline() + incr := pipe.Incr(ctx, key) + pipe.Expire(ctx, key, userRPMKeyTTL) + if _, err := pipe.Exec(ctx); err != nil { + return 0, fmt.Errorf("user rpm increment: %w", err) + } + return int(incr.Val()), nil +} + +// IncrementUserGroupRPM 递增 (user, group) 分钟计数。 +func (c *userRPMCacheImpl) IncrementUserGroupRPM(ctx context.Context, userID, groupID int64) (int, error) { + minute, err := c.minuteTS(ctx) + if err != nil { + return 0, err + } + key := fmt.Sprintf("%s%d:%d:%d", userGroupRPMKeyPrefix, userID, groupID, minute) + return c.atomicIncr(ctx, key) +} + +// IncrementUserRPM 递增用户分钟计数。 +func (c *userRPMCacheImpl) IncrementUserRPM(ctx context.Context, userID int64) (int, error) { + minute, err := c.minuteTS(ctx) + if err != nil { + return 0, err + } + key := fmt.Sprintf("%s%d:%d", userRPMKeyPrefix, userID, minute) + return c.atomicIncr(ctx, key) +} + +// GetUserGroupRPM 获取 (user, group) 当前分钟已用 RPM(只读)。 +func (c *userRPMCacheImpl) GetUserGroupRPM(ctx context.Context, userID, groupID int64) (int, error) { + minute, err := c.minuteTS(ctx) + if err != nil { + return 0, err + } + key := fmt.Sprintf("%s%d:%d:%d", userGroupRPMKeyPrefix, userID, groupID, minute) + val, err := c.rdb.Get(ctx, key).Int() + if err == redis.Nil { + return 0, nil + } + if err != nil { + return 0, fmt.Errorf("user group rpm get: %w", err) + } + return val, nil +} + +// GetUserRPM 获取用户当前分钟已用 RPM(只读)。 +func (c *userRPMCacheImpl) GetUserRPM(ctx context.Context, userID int64) (int, error) { + minute, err := c.minuteTS(ctx) + if err != nil { + return 0, err + } + key := fmt.Sprintf("%s%d:%d", userRPMKeyPrefix, userID, minute) + val, err := c.rdb.Get(ctx, key).Int() + if err == redis.Nil { + return 0, nil + } + if err != nil { + return 0, fmt.Errorf("user rpm get: %w", err) + } + return val, nil +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index b10175c3..d96ab5f2 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -101,6 +101,7 @@ var ProviderSet = wire.NewSet( ProvideConcurrencyCache, ProvideSessionLimitCache, NewRPMCache, + NewUserRPMCache, NewUserMsgQueueCache, NewDashboardCache, NewEmailCache, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index d2b108f5..856846ae 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -55,6 +55,7 @@ func TestAPIContracts(t *testing.T) { "role": "user", "balance": 12.5, "concurrency": 5, + "rpm_limit": 0, "status": "active", "allowed_groups": null, "created_at": "2025-01-02T03:04:05Z", @@ -333,6 +334,7 @@ func TestAPIContracts(t *testing.T) { "fallback_group_id_on_invalid_request": null, "require_oauth_only": false, "require_privacy_set": false, + "rpm_limit": 0, "created_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z" } @@ -713,6 +715,7 @@ func TestAPIContracts(t *testing.T) { "force_email_on_third_party_signup": false, "default_concurrency": 5, "default_balance": 1.25, + "default_user_rpm_limit": 0, "default_subscriptions": [], "enable_model_fallback": false, "fallback_model_anthropic": "claude-3-5-sonnet-20241022", @@ -889,6 +892,7 @@ func TestAPIContracts(t *testing.T) { "custom_endpoints": [], "default_concurrency": 0, "default_balance": 0, + "default_user_rpm_limit": 0, "default_subscriptions": [], "enable_model_fallback": false, "fallback_model_anthropic": "claude-3-5-sonnet-20241022", @@ -1084,7 +1088,7 @@ func newContractDeps(t *testing.T) *contractDeps { settingRepo := newStubSettingRepo() settingService := service.NewSettingService(settingRepo, cfg) - adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 84c963ec..07618e31 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -221,6 +221,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) { users.GET("/:id/usage", h.Admin.User.GetUserUsage) users.GET("/:id/balance-history", h.Admin.User.GetBalanceHistory) users.POST("/:id/replace-group", h.Admin.User.ReplaceGroup) + users.GET("/:id/rpm-status", h.Admin.User.GetUserRPMStatus) // User attribute values users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes) @@ -244,6 +245,8 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) { groups.GET("/:id/rate-multipliers", h.Admin.Group.GetGroupRateMultipliers) groups.PUT("/:id/rate-multipliers", h.Admin.Group.BatchSetGroupRateMultipliers) groups.DELETE("/:id/rate-multipliers", h.Admin.Group.ClearGroupRateMultipliers) + groups.PUT("/:id/rpm-overrides", h.Admin.Group.BatchSetGroupRPMOverrides) + groups.DELETE("/:id/rpm-overrides", h.Admin.Group.ClearGroupRPMOverrides) groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys) } } diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 4ae66613..434f1f38 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -8,6 +8,7 @@ import ( "io" "log/slog" "net/http" + "sort" "strings" "time" @@ -32,6 +33,7 @@ type AdminService interface { UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) + GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error) // GetUserBalanceHistory returns paginated balance/concurrency change records for a user. // codeType is optional - pass empty string to return all types. // Also returns totalRecharged (sum of all positive balance top-ups). @@ -50,6 +52,8 @@ type AdminService interface { GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error) ClearGroupRateMultipliers(ctx context.Context, groupID int64) error BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error + ClearGroupRPMOverrides(ctx context.Context, groupID int64) error + BatchSetGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error // API Key management (admin) @@ -114,6 +118,7 @@ type CreateUserInput struct { Notes string Balance float64 Concurrency int + RPMLimit int AllowedGroups []int64 } @@ -124,6 +129,7 @@ type UpdateUserInput struct { Notes *string Balance *float64 // 使用指针区分"未提供"和"设置为0" Concurrency *int // 使用指针区分"未提供"和"设置为0" + RPMLimit *int // 使用指针区分"未提供"和"设置为0" Status string AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组" // GroupRates 用户专属分组倍率配置 @@ -199,6 +205,8 @@ type CreateGroupInput struct { RequireOAuthOnly bool RequirePrivacySet bool MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig + // RPMLimit 分组 RPM 上限(0 = 不限制) + RPMLimit int // 从指定分组复制账号(创建分组后在同一事务内绑定) CopyAccountsFromGroupIDs []int64 } @@ -234,6 +242,8 @@ type UpdateGroupInput struct { RequireOAuthOnly *bool RequirePrivacySet *bool MessagesDispatchModelConfig *OpenAIMessagesDispatchModelConfig + // RPMLimit 分组 RPM 上限(0 = 不限制),nil 表示未提供不改动。 + RPMLimit *int // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 } @@ -317,6 +327,22 @@ type ReplaceUserGroupResult struct { MigratedKeys int64 // 迁移的 Key 数量 } +// UserRPMStatus describes a user's current per-minute RPM usage. +type UserRPMStatus struct { + UserRPMUsed int `json:"user_rpm_used"` + UserRPMLimit int `json:"user_rpm_limit"` + PerGroup []UserGroupRPMStatus `json:"per_group"` +} + +// UserGroupRPMStatus describes current per-minute RPM usage for one user/group pair. +type UserGroupRPMStatus struct { + GroupID int64 `json:"group_id"` + GroupName string `json:"group_name"` + Used int `json:"used"` + Limit int `json:"limit"` + Source string `json:"source"` // "group" | "override" +} + // BulkUpdateAccountsResult is the aggregated response for bulk updates. type BulkUpdateAccountsResult struct { Success int `json:"success"` @@ -463,6 +489,8 @@ const ( proxyQualityClientUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36" ) +var ErrRPMStatusUnavailable = infraerrors.New(http.StatusNotImplemented, "RPM_STATUS_UNAVAILABLE", "RPM cache not available") + // adminServiceImpl implements AdminService type adminServiceImpl struct { userRepo UserRepository @@ -472,6 +500,7 @@ type adminServiceImpl struct { apiKeyRepo APIKeyRepository redeemCodeRepo RedeemCodeRepository userGroupRateRepo UserGroupRateRepository + userRPMCache UserRPMCache billingCacheService *BillingCacheService proxyProber ProxyExitInfoProber proxyLatencyCache ProxyLatencyCache @@ -496,6 +525,7 @@ func NewAdminService( apiKeyRepo APIKeyRepository, redeemCodeRepo RedeemCodeRepository, userGroupRateRepo UserGroupRateRepository, + userRPMCache UserRPMCache, billingCacheService *BillingCacheService, proxyProber ProxyExitInfoProber, proxyLatencyCache ProxyLatencyCache, @@ -514,6 +544,7 @@ func NewAdminService( apiKeyRepo: apiKeyRepo, redeemCodeRepo: redeemCodeRepo, userGroupRateRepo: userGroupRateRepo, + userRPMCache: userRPMCache, billingCacheService: billingCacheService, proxyProber: proxyProber, proxyLatencyCache: proxyLatencyCache, @@ -617,6 +648,7 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu Role: RoleUser, // Always create as regular user, never admin Balance: input.Balance, Concurrency: input.Concurrency, + RPMLimit: input.RPMLimit, Status: StatusActive, AllowedGroups: input.AllowedGroups, } @@ -670,6 +702,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda oldConcurrency := user.Concurrency oldStatus := user.Status oldRole := user.Role + oldRPMLimit := user.RPMLimit if input.Email != "" { user.Email = input.Email @@ -695,6 +728,10 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda user.Concurrency = *input.Concurrency } + if input.RPMLimit != nil { + user.RPMLimit = *input.RPMLimit + } + if input.AllowedGroups != nil { user.AllowedGroups = *input.AllowedGroups } @@ -711,7 +748,9 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda } if s.authCacheInvalidator != nil { - if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole { + // RPMLimit 直接参与 billing_cache_service.checkRPM 的三级级联, + // 不失效缓存会让修改在一个 L2 TTL 内失去效果。 + if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole || user.RPMLimit != oldRPMLimit { s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID) } } @@ -833,6 +872,81 @@ func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, pag return keys, result.Total, nil } +func (s *adminServiceImpl) GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error) { + if s.userRPMCache == nil { + return nil, ErrRPMStatusUnavailable + } + + user, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + return nil, err + } + + userRPMUsed, err := s.userRPMCache.GetUserRPM(ctx, userID) + if err != nil { + logger.LegacyPrintf("service.admin", "failed to get user rpm: user_id=%d err=%v", userID, err) + } + + keys, _, err := s.GetUserAPIKeys(ctx, userID, 1, 1000, "", "") + if err != nil { + return nil, err + } + + groupIDSet := make(map[int64]struct{}) + for _, key := range keys { + if key.GroupID != nil && *key.GroupID > 0 { + groupIDSet[*key.GroupID] = struct{}{} + } + } + + groupIDs := make([]int64, 0, len(groupIDSet)) + for groupID := range groupIDSet { + groupIDs = append(groupIDs, groupID) + } + sort.Slice(groupIDs, func(i, j int) bool { return groupIDs[i] < groupIDs[j] }) + + var perGroup []UserGroupRPMStatus + for _, groupID := range groupIDs { + used, getErr := s.userRPMCache.GetUserGroupRPM(ctx, userID, groupID) + if getErr != nil { + logger.LegacyPrintf("service.admin", "failed to get user group rpm: user_id=%d group_id=%d err=%v", userID, groupID, getErr) + } + + entry := UserGroupRPMStatus{ + GroupID: groupID, + Used: used, + } + + if s.groupRepo != nil { + if group, groupErr := s.groupRepo.GetByIDLite(ctx, groupID); groupErr == nil && group != nil { + entry.GroupName = group.Name + entry.Limit = group.RPMLimit + entry.Source = "group" + } else if groupErr != nil { + logger.LegacyPrintf("service.admin", "failed to get group rpm status metadata: group_id=%d err=%v", groupID, groupErr) + } + } + + if s.userGroupRateRepo != nil { + override, overrideErr := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, userID, groupID) + if overrideErr != nil { + logger.LegacyPrintf("service.admin", "failed to get rpm override: user_id=%d group_id=%d err=%v", userID, groupID, overrideErr) + } else if override != nil { + entry.Limit = *override + entry.Source = "override" + } + } + + perGroup = append(perGroup, entry) + } + + return &UserRPMStatus{ + UserRPMUsed: userRPMUsed, + UserRPMLimit: user.RPMLimit, + PerGroup: perGroup, + }, nil +} + func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) { // Return mock data for now return map[string]any{ @@ -1314,6 +1428,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn RequirePrivacySet: input.RequirePrivacySet, DefaultMappedModel: input.DefaultMappedModel, MessagesDispatchModelConfig: normalizeOpenAIMessagesDispatchModelConfig(input.MessagesDispatchModelConfig), + RPMLimit: input.RPMLimit, } sanitizeGroupMessagesDispatchFields(group) if err := s.groupRepo.Create(ctx, group); err != nil { @@ -1548,12 +1663,19 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd if input.MessagesDispatchModelConfig != nil { group.MessagesDispatchModelConfig = normalizeOpenAIMessagesDispatchModelConfig(*input.MessagesDispatchModelConfig) } + if input.RPMLimit != nil { + group.RPMLimit = *input.RPMLimit + } sanitizeGroupMessagesDispatchFields(group) if err := s.groupRepo.Update(ctx, group); err != nil { return nil, err } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id) + } + // 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号) if len(input.CopyAccountsFromGroupIDs) > 0 { // 去重源分组 IDs @@ -1622,9 +1744,6 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd } } - if s.authCacheInvalidator != nil { - s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id) - } return group, nil } @@ -1700,6 +1819,39 @@ func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, gro return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries) } +func (s *adminServiceImpl) ClearGroupRPMOverrides(ctx context.Context, groupID int64) error { + if s.userGroupRateRepo == nil { + return nil + } + if err := s.userGroupRateRepo.ClearGroupRPMOverrides(ctx, groupID); err != nil { + return err + } + // RPM override 已嵌入 auth cache snapshot (v7),变更后必须失效相关缓存。 + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, groupID) + } + return nil +} + +func (s *adminServiceImpl) BatchSetGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error { + if s.userGroupRateRepo == nil { + return nil + } + for _, e := range entries { + if e.RPMOverride != nil && *e.RPMOverride < 0 { + return infraerrors.BadRequest("INVALID_RPM_OVERRIDE", fmt.Sprintf("rpm_override must be >= 0 (user_id=%d)", e.UserID)) + } + } + if err := s.userGroupRateRepo.SyncGroupRPMOverrides(ctx, groupID, entries); err != nil { + return err + } + // RPM override 已嵌入 auth cache snapshot (v7),变更后必须失效相关缓存。 + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, groupID) + } + return nil +} + func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error { return s.groupRepo.UpdateSortOrders(ctx, updates) } diff --git a/backend/internal/service/admin_service_group_rate_test.go b/backend/internal/service/admin_service_group_rate_test.go index 77635247..d2efb644 100644 --- a/backend/internal/service/admin_service_group_rate_test.go +++ b/backend/internal/service/admin_service_group_rate_test.go @@ -5,8 +5,10 @@ package service import ( "context" "errors" + "net/http" "testing" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/stretchr/testify/require" ) @@ -21,6 +23,10 @@ type userGroupRateRepoStubForGroupRate struct { syncedGroupID int64 syncedEntries []GroupRateMultiplierInput syncGroupErr error + + rpmSyncedGroupID int64 + rpmSyncedEntries []GroupRPMOverrideInput + rpmSyncErr error } func (s *userGroupRateRepoStubForGroupRate) GetByUserID(_ context.Context, _ int64) (map[int64]float64, error) { @@ -31,6 +37,10 @@ func (s *userGroupRateRepoStubForGroupRate) GetByUserAndGroup(_ context.Context, panic("unexpected GetByUserAndGroup call") } +func (s *userGroupRateRepoStubForGroupRate) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) { + panic("unexpected GetRPMOverrideByUserAndGroup call") +} + func (s *userGroupRateRepoStubForGroupRate) GetByGroupID(_ context.Context, groupID int64) ([]UserGroupRateEntry, error) { if s.getByGroupIDErr != nil { return nil, s.getByGroupIDErr @@ -48,6 +58,16 @@ func (s *userGroupRateRepoStubForGroupRate) SyncGroupRateMultipliers(_ context.C return s.syncGroupErr } +func (s *userGroupRateRepoStubForGroupRate) SyncGroupRPMOverrides(_ context.Context, groupID int64, entries []GroupRPMOverrideInput) error { + s.rpmSyncedGroupID = groupID + s.rpmSyncedEntries = entries + return s.rpmSyncErr +} + +func (s *userGroupRateRepoStubForGroupRate) ClearGroupRPMOverrides(_ context.Context, _ int64) error { + panic("unexpected ClearGroupRPMOverrides call") +} + func (s *userGroupRateRepoStubForGroupRate) DeleteByGroupID(_ context.Context, groupID int64) error { s.deletedGroupIDs = append(s.deletedGroupIDs, groupID) return s.deleteByGroupErr @@ -62,8 +82,8 @@ func TestAdminService_GetGroupRateMultipliers(t *testing.T) { repo := &userGroupRateRepoStubForGroupRate{ getByGroupIDData: map[int64][]UserGroupRateEntry{ 10: { - {UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: 1.5}, - {UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: 0.8}, + {UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: ptrFloat(1.5)}, + {UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: ptrFloat(0.8)}, }, }, } @@ -74,9 +94,11 @@ func TestAdminService_GetGroupRateMultipliers(t *testing.T) { require.Len(t, entries, 2) require.Equal(t, int64(1), entries[0].UserID) require.Equal(t, "alice", entries[0].UserName) - require.Equal(t, 1.5, entries[0].RateMultiplier) + require.NotNil(t, entries[0].RateMultiplier) + require.Equal(t, 1.5, *entries[0].RateMultiplier) require.Equal(t, int64(2), entries[1].UserID) - require.Equal(t, 0.8, entries[1].RateMultiplier) + require.NotNil(t, entries[1].RateMultiplier) + require.Equal(t, 0.8, *entries[1].RateMultiplier) }) t.Run("returns nil when repo is nil", func(t *testing.T) { @@ -174,3 +196,30 @@ func TestAdminService_BatchSetGroupRateMultipliers(t *testing.T) { require.Contains(t, err.Error(), "sync failed") }) } + +func TestAdminService_BatchSetGroupRPMOverrides(t *testing.T) { + t.Run("syncs entries to repo", func(t *testing.T) { + repo := &userGroupRateRepoStubForGroupRate{} + svc := &adminServiceImpl{userGroupRateRepo: repo} + override := 20 + entries := []GroupRPMOverrideInput{{UserID: 2, RPMOverride: &override}} + + err := svc.BatchSetGroupRPMOverrides(context.Background(), 10, entries) + require.NoError(t, err) + require.Equal(t, int64(10), repo.rpmSyncedGroupID) + require.Equal(t, entries, repo.rpmSyncedEntries) + }) + + t.Run("rejects negative override as bad request", func(t *testing.T) { + repo := &userGroupRateRepoStubForGroupRate{} + svc := &adminServiceImpl{userGroupRateRepo: repo} + negative := -1 + + err := svc.BatchSetGroupRPMOverrides(context.Background(), 10, []GroupRPMOverrideInput{ + {UserID: 2, RPMOverride: &negative}, + }) + require.Error(t, err) + require.Equal(t, http.StatusBadRequest, infraerrors.Code(err)) + require.Zero(t, repo.rpmSyncedGroupID) + }) +} diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go index 41d2c26a..eef02240 100644 --- a/backend/internal/service/admin_service_group_test.go +++ b/backend/internal/service/admin_service_group_test.go @@ -266,6 +266,31 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) { require.Nil(t, repo.updated.ImagePrice4K) } +func TestAdminService_UpdateGroup_InvalidatesAuthCacheOnRPMLimitChange(t *testing.T) { + existingGroup := &Group{ + ID: 1, + Name: "existing-group", + Platform: PlatformAnthropic, + Status: StatusActive, + RPMLimit: 10, + } + repo := &groupRepoStubForAdmin{getByID: existingGroup} + invalidator := &authCacheInvalidatorStub{} + svc := &adminServiceImpl{ + groupRepo: repo, + authCacheInvalidator: invalidator, + } + + rpmLimit := 60 + group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{ + RPMLimit: &rpmLimit, + }) + require.NoError(t, err) + require.NotNil(t, group) + require.Equal(t, 60, repo.updated.RPMLimit) + require.Equal(t, []int64{1}, invalidator.groupIDs, "分组 RPMLimit 写入 auth snapshot,变更后必须失效 API Key 认证缓存") +} + func TestAdminService_CreateGroup_NormalizesMessagesDispatchModelConfig(t *testing.T) { repo := &groupRepoStubForAdmin{} svc := &adminServiceImpl{groupRepo: repo} diff --git a/backend/internal/service/admin_service_list_users_test.go b/backend/internal/service/admin_service_list_users_test.go index 657616c4..ff3f65a8 100644 --- a/backend/internal/service/admin_service_list_users_test.go +++ b/backend/internal/service/admin_service_list_users_test.go @@ -89,6 +89,10 @@ func (s *userGroupRateRepoStubForListUsers) GetByUserAndGroup(_ context.Context, panic("unexpected GetByUserAndGroup call") } +func (s *userGroupRateRepoStubForListUsers) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) { + panic("unexpected GetRPMOverrideByUserAndGroup call") +} + func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context, userID int64, rates map[int64]*float64) error { panic("unexpected SyncUserGroupRates call") } @@ -101,6 +105,14 @@ func (s *userGroupRateRepoStubForListUsers) SyncGroupRateMultipliers(_ context.C panic("unexpected SyncGroupRateMultipliers call") } +func (s *userGroupRateRepoStubForListUsers) SyncGroupRPMOverrides(_ context.Context, _ int64, _ []GroupRPMOverrideInput) error { + panic("unexpected SyncGroupRPMOverrides call") +} + +func (s *userGroupRateRepoStubForListUsers) ClearGroupRPMOverrides(_ context.Context, _ int64) error { + panic("unexpected ClearGroupRPMOverrides call") +} + func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, _ int64) error { panic("unexpected DeleteByGroupID call") } diff --git a/backend/internal/service/admin_service_rpm_status_test.go b/backend/internal/service/admin_service_rpm_status_test.go new file mode 100644 index 00000000..c298f69b --- /dev/null +++ b/backend/internal/service/admin_service_rpm_status_test.go @@ -0,0 +1,112 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +type rpmStatusUserRepoStub struct { + UserRepository + user *User +} + +func (s *rpmStatusUserRepoStub) GetByID(_ context.Context, _ int64) (*User, error) { + return s.user, nil +} + +type rpmStatusAPIKeyRepoStub struct { + APIKeyRepository + keys []APIKey +} + +func (s *rpmStatusAPIKeyRepoStub) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams, _ APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) { + return s.keys, &pagination.PaginationResult{Total: int64(len(s.keys))}, nil +} + +type rpmStatusGroupRepoStub struct { + GroupRepository + groups map[int64]*Group +} + +func (s *rpmStatusGroupRepoStub) GetByIDLite(_ context.Context, id int64) (*Group, error) { + return s.groups[id], nil +} + +type rpmStatusRateRepoStub struct { + UserGroupRateRepository + overrides map[int64]*int +} + +func (s *rpmStatusRateRepoStub) GetRPMOverrideByUserAndGroup(_ context.Context, _, groupID int64) (*int, error) { + return s.overrides[groupID], nil +} + +type rpmStatusCacheStub struct { + UserRPMCache + userUsed int + groupUsed map[int64]int +} + +func (s *rpmStatusCacheStub) IncrementUserGroupRPM(context.Context, int64, int64) (int, error) { + return 0, nil +} + +func (s *rpmStatusCacheStub) IncrementUserRPM(context.Context, int64) (int, error) { + return 0, nil +} + +func (s *rpmStatusCacheStub) GetUserGroupRPM(_ context.Context, _, groupID int64) (int, error) { + return s.groupUsed[groupID], nil +} + +func (s *rpmStatusCacheStub) GetUserRPM(context.Context, int64) (int, error) { + return s.userUsed, nil +} + +func TestAdminService_GetUserRPMStatus_AggregatesUserAndGroupLimits(t *testing.T) { + groupOneID := int64(1) + groupTwoID := int64(2) + override := 7 + svc := &adminServiceImpl{ + userRepo: &rpmStatusUserRepoStub{user: &User{ + ID: 42, + RPMLimit: 20, + }}, + apiKeyRepo: &rpmStatusAPIKeyRepoStub{keys: []APIKey{ + {ID: 100, UserID: 42, GroupID: &groupTwoID}, + {ID: 101, UserID: 42, GroupID: &groupOneID}, + {ID: 102, UserID: 42, GroupID: &groupTwoID}, + {ID: 103, UserID: 42}, + }}, + groupRepo: &rpmStatusGroupRepoStub{groups: map[int64]*Group{ + groupOneID: {ID: groupOneID, Name: "group-one", RPMLimit: 10}, + groupTwoID: {ID: groupTwoID, Name: "group-two", RPMLimit: 60}, + }}, + userGroupRateRepo: &rpmStatusRateRepoStub{overrides: map[int64]*int{ + groupTwoID: &override, + }}, + userRPMCache: &rpmStatusCacheStub{ + userUsed: 5, + groupUsed: map[int64]int{ + groupOneID: 3, + groupTwoID: 4, + }, + }, + } + + status, err := svc.GetUserRPMStatus(context.Background(), 42) + require.NoError(t, err) + require.Equal(t, &UserRPMStatus{ + UserRPMUsed: 5, + UserRPMLimit: 20, + PerGroup: []UserGroupRPMStatus{ + {GroupID: groupOneID, GroupName: "group-one", Used: 3, Limit: 10, Source: "group"}, + {GroupID: groupTwoID, GroupName: "group-two", Used: 4, Limit: 7, Source: "override"}, + }, + }, status) +} diff --git a/backend/internal/service/admin_service_update_user_rpm_test.go b/backend/internal/service/admin_service_update_user_rpm_test.go new file mode 100644 index 00000000..cb4c3986 --- /dev/null +++ b/backend/internal/service/admin_service_update_user_rpm_test.go @@ -0,0 +1,69 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +// rpmUserRepoStub 复用 admin_service_update_balance_test.go 的基础 stub 结构, +// 只在 Update 时把入参克隆一份,便于断言修改后的 RPMLimit。 +type rpmUserRepoStub struct { + *userRepoStub + lastUpdated *User +} + +func (s *rpmUserRepoStub) Update(_ context.Context, user *User) error { + if user == nil { + return nil + } + clone := *user + s.lastUpdated = &clone + if s.userRepoStub != nil { + s.userRepoStub.user = &clone + } + return nil +} + +func TestAdminService_UpdateUser_InvalidatesAuthCacheOnRPMLimitChange(t *testing.T) { + base := &userRepoStub{user: &User{ID: 42, Email: "u@example.com", RPMLimit: 10}} + repo := &rpmUserRepoStub{userRepoStub: base} + invalidator := &authCacheInvalidatorStub{} + svc := &adminServiceImpl{ + userRepo: repo, + redeemCodeRepo: &redeemRepoStub{}, + authCacheInvalidator: invalidator, + } + + newRPM := 60 + updated, err := svc.UpdateUser(context.Background(), 42, &UpdateUserInput{ + RPMLimit: &newRPM, + }) + require.NoError(t, err) + require.NotNil(t, updated) + require.Equal(t, 60, updated.RPMLimit) + require.Equal(t, []int64{42}, invalidator.userIDs, "仅修改 RPMLimit 也应失效 API Key 认证缓存") +} + +func TestAdminService_UpdateUser_NoInvalidateWhenRPMLimitUnchanged(t *testing.T) { + base := &userRepoStub{user: &User{ID: 42, Email: "u@example.com", RPMLimit: 10, Username: "old"}} + repo := &rpmUserRepoStub{userRepoStub: base} + invalidator := &authCacheInvalidatorStub{} + svc := &adminServiceImpl{ + userRepo: repo, + redeemCodeRepo: &redeemRepoStub{}, + authCacheInvalidator: invalidator, + } + + newName := "new" + sameRPM := 10 + _, err := svc.UpdateUser(context.Background(), 42, &UpdateUserInput{ + Username: &newName, + RPMLimit: &sameRPM, + }) + require.NoError(t, err) + require.Empty(t, invalidator.userIDs, "只改 username 不应触发认证缓存失效") +} diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index b1660ea7..1a1c78b8 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -43,6 +43,13 @@ type APIKeyAuthUserSnapshot struct { BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"` BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails,omitempty"` TotalRecharged float64 `json:"total_recharged"` + + // RPMLimit 用户级每分钟请求数上限(0 = 不限制);用于 billing_cache_service.checkRPM 兜底判断。 + RPMLimit int `json:"rpm_limit"` + + // UserGroupRPMOverride 该 API Key 对应的 (user, group) 专属 RPM 覆盖值。 + // nil = 无 override(回退到 group/user 级);0 = 不限流;>0 = 专属上限。 + UserGroupRPMOverride *int `json:"user_group_rpm_override,omitempty"` } // APIKeyAuthGroupSnapshot 分组快照 @@ -76,6 +83,9 @@ type APIKeyAuthGroupSnapshot struct { AllowMessagesDispatch bool `json:"allow_messages_dispatch"` DefaultMappedModel string `json:"default_mapped_model,omitempty"` MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"` + + // RPMLimit 分组级每分钟请求数上限(0 = 不限制);用于 billing_cache_service.checkRPM 级联判断。 + RPMLimit int `json:"rpm_limit"` } // APIKeyAuthCacheEntry 缓存条目,支持负缓存 diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index 2bd9a091..974ea66e 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -14,7 +14,7 @@ import ( "github.com/dgraph-io/ristretto" ) -const apiKeyAuthSnapshotVersion = 5 // v5: added TotalRecharged for percentage threshold +const apiKeyAuthSnapshotVersion = 7 // v7: added UserGroupRPMOverride on user snapshot type apiKeyAuthCacheConfig struct { l1Size int @@ -176,7 +176,7 @@ func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey st return nil, fmt.Errorf("get api key: %w", err) } apiKey.Key = key - snapshot := s.snapshotFromAPIKey(apiKey) + snapshot := s.snapshotFromAPIKey(ctx, apiKey) if snapshot == nil { return nil, fmt.Errorf("get api key: %w", ErrAPIKeyNotFound) } @@ -201,7 +201,7 @@ func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEn return s.snapshotToAPIKey(key, entry.Snapshot), true, nil } -func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { +func (s *APIKeyService) snapshotFromAPIKey(ctx context.Context, apiKey *APIKey) *APIKeyAuthSnapshot { if apiKey == nil || apiKey.User == nil { return nil } @@ -232,8 +232,18 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { BalanceNotifyThreshold: apiKey.User.BalanceNotifyThreshold, BalanceNotifyExtraEmails: apiKey.User.BalanceNotifyExtraEmails, TotalRecharged: apiKey.User.TotalRecharged, + RPMLimit: apiKey.User.RPMLimit, }, } + + // 填充 (user, group) RPM override —— snapshot 构建时查一次 DB,后续请求零 DB 往返。 + if apiKey.GroupID != nil && *apiKey.GroupID > 0 && s.userGroupRateRepo != nil { + override, err := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, apiKey.UserID, *apiKey.GroupID) + if err == nil && override != nil { + snapshot.User.UserGroupRPMOverride = override + } + // 查询失败或无 override 时留 nil,checkRPM 会回退到 DB 查询 + } if apiKey.Group != nil { snapshot.Group = &APIKeyAuthGroupSnapshot{ ID: apiKey.Group.ID, @@ -258,6 +268,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch, DefaultMappedModel: apiKey.Group.DefaultMappedModel, MessagesDispatchModelConfig: apiKey.Group.MessagesDispatchModelConfig, + RPMLimit: apiKey.Group.RPMLimit, } } return snapshot @@ -294,6 +305,8 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho BalanceNotifyThreshold: snapshot.User.BalanceNotifyThreshold, BalanceNotifyExtraEmails: snapshot.User.BalanceNotifyExtraEmails, TotalRecharged: snapshot.User.TotalRecharged, + RPMLimit: snapshot.User.RPMLimit, + UserGroupRPMOverride: snapshot.User.UserGroupRPMOverride, }, } if snapshot.Group != nil { @@ -321,6 +334,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch, DefaultMappedModel: snapshot.Group.DefaultMappedModel, MessagesDispatchModelConfig: snapshot.Group.MessagesDispatchModelConfig, + RPMLimit: snapshot.Group.RPMLimit, } } s.compileAPIKeyIPRules(apiKey) diff --git a/backend/internal/service/api_key_service_cache_test.go b/backend/internal/service/api_key_service_cache_test.go index 3c2f7dbb..8cb1b8c4 100644 --- a/backend/internal/service/api_key_service_cache_test.go +++ b/backend/internal/service/api_key_service_cache_test.go @@ -263,7 +263,7 @@ func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t }, } - snapshot := svc.snapshotFromAPIKey(apiKey) + snapshot := svc.snapshotFromAPIKey(context.Background(), apiKey) roundTrip := svc.snapshotToAPIKey(apiKey.Key, snapshot) require.NotNil(t, roundTrip) diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 3bf9da3d..e45d8d66 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -196,6 +196,12 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw grantPlan := s.resolveSignupGrantPlan(ctx, "email") + // 新用户默认 RPM(0 = 不限制)。注册时写入,后续作为用户级兜底。 + var defaultRPMLimit int + if s.settingService != nil { + defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx) + } + // 创建用户 user := &User{ Email: email, @@ -203,6 +209,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw Role: RoleUser, Balance: grantPlan.Balance, Concurrency: grantPlan.Concurrency, + RPMLimit: defaultRPMLimit, Status: StatusActive, } @@ -481,6 +488,10 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username signupSource := inferLegacySignupSource(email) grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) + var defaultRPMLimit int + if s.settingService != nil { + defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx) + } newUser := &User{ Email: email, @@ -489,6 +500,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username Role: RoleUser, Balance: grantPlan.Balance, Concurrency: grantPlan.Concurrency, + RPMLimit: defaultRPMLimit, Status: StatusActive, SignupSource: signupSource, } @@ -592,6 +604,10 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema signupSource := inferLegacySignupSource(email) grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) + var defaultRPMLimit int + if s.settingService != nil { + defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx) + } newUser := &User{ Email: email, @@ -600,6 +616,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema Role: RoleUser, Balance: grantPlan.Balance, Concurrency: grantPlan.Concurrency, + RPMLimit: defaultRPMLimit, Status: StatusActive, SignupSource: signupSource, } diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go index f2ad0a3d..4e695eb9 100644 --- a/backend/internal/service/billing_cache_service.go +++ b/backend/internal/service/billing_cache_service.go @@ -20,6 +20,9 @@ import ( var ( ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired") ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.") + // RPM 超限错误。gateway_handler 负责映射为 HTTP 429。 + ErrGroupRPMExceeded = infraerrors.TooManyRequests("GROUP_RPM_EXCEEDED", "group requests-per-minute limit exceeded") + ErrUserRPMExceeded = infraerrors.TooManyRequests("USER_RPM_EXCEEDED", "user requests-per-minute limit exceeded") ) // subscriptionCacheData 订阅缓存数据结构(内部使用) @@ -87,6 +90,8 @@ type BillingCacheService struct { userRepo UserRepository subRepo UserSubscriptionRepository apiKeyRateLimitLoader apiKeyRateLimitLoader + userRPMCache UserRPMCache + userGroupRateRepo UserGroupRateRepository cfg *config.Config circuitBreaker *billingCircuitBreaker @@ -104,12 +109,22 @@ type BillingCacheService struct { } // NewBillingCacheService 创建计费缓存服务 -func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, apiKeyRepo APIKeyRepository, cfg *config.Config) *BillingCacheService { +func NewBillingCacheService( + cache BillingCache, + userRepo UserRepository, + subRepo UserSubscriptionRepository, + apiKeyRepo APIKeyRepository, + userRPMCache UserRPMCache, + userGroupRateRepo UserGroupRateRepository, + cfg *config.Config, +) *BillingCacheService { svc := &BillingCacheService{ cache: cache, userRepo: userRepo, subRepo: subRepo, apiKeyRateLimitLoader: apiKeyRepo, + userRPMCache: userRPMCache, + userGroupRateRepo: userGroupRateRepo, cfg: cfg, } svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker) @@ -664,6 +679,95 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user } } + // RPM 限流:级联回落(Override → Group → User),放在最后以避免为注定失败的请求增加计数。 + if err := s.checkRPM(ctx, user, group); err != nil { + return err + } + + return nil +} + +// checkRPM 执行并行 RPM 限流,所有适用的限制同时生效,任一超限即拒绝: +// +// 1. (用户, 分组) rpm_override — 最细粒度:管理员为特定用户在特定分组设定的专属限额。 +// override=0 表示该用户在该分组免检(绿灯),但 user 级全局上限仍然生效。 +// 2. group.rpm_limit — 分组级:该分组的统一 RPM 容量(仅当无 override 时生效)。 +// 3. user.rpm_limit — 用户级全局硬上限:无论 override/group 如何配置,始终生效。 +// +// 与旧版"级联互斥"设计不同,新版确保 user.rpm_limit 作为全局天花板不会被 group 或 override 覆盖。 +// Redis 故障一律 fail-open(打 warning,不阻塞业务)。 +func (s *BillingCacheService) checkRPM(ctx context.Context, user *User, group *Group) error { + if s == nil || s.userRPMCache == nil || user == nil { + return nil + } + + // ── 第一层:分组级检查(override 或 group.rpm_limit) ── + if group != nil { + // 解析 override:优先从 auth cache snapshot,nil 时回退 DB。 + var override *int + if user.UserGroupRPMOverride != nil { + override = user.UserGroupRPMOverride + } else if s.userGroupRateRepo != nil { + dbOverride, err := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, user.ID, group.ID) + if err != nil { + logger.LegacyPrintf( + "service.billing_cache", + "Warning: rpm override lookup failed for user=%d group=%d: %v", + user.ID, group.ID, err, + ) + } else { + override = dbOverride + } + } + + if override != nil { + // override=0 → 该用户在该分组免检(但 user 级仍会在下面检查)。 + if *override > 0 { + count, incErr := s.userRPMCache.IncrementUserGroupRPM(ctx, user.ID, group.ID) + if incErr != nil { + logger.LegacyPrintf( + "service.billing_cache", + "Warning: rpm increment (override) failed for user=%d group=%d: %v", + user.ID, group.ID, incErr, + ) + // fail-open + } else if count > *override { + return ErrGroupRPMExceeded + } + } + // override 命中后跳过 group.rpm_limit(override 替代 group),但不 return——继续检查 user 级。 + } else if group.RPMLimit > 0 { + // 无 override,检查 group.rpm_limit。 + count, err := s.userRPMCache.IncrementUserGroupRPM(ctx, user.ID, group.ID) + if err != nil { + logger.LegacyPrintf( + "service.billing_cache", + "Warning: rpm increment (group) failed for user=%d group=%d: %v", + user.ID, group.ID, err, + ) + // fail-open + } else if count > group.RPMLimit { + return ErrGroupRPMExceeded + } + } + } + + // ── 第二层:用户级全局硬上限(始终生效) ── + if user.RPMLimit > 0 { + count, err := s.userRPMCache.IncrementUserRPM(ctx, user.ID) + if err != nil { + logger.LegacyPrintf( + "service.billing_cache", + "Warning: rpm increment (user) failed for user=%d: %v", + user.ID, err, + ) + return nil // fail-open + } + if count > user.RPMLimit { + return ErrUserRPMExceeded + } + } + return nil } diff --git a/backend/internal/service/billing_cache_service_rpm_test.go b/backend/internal/service/billing_cache_service_rpm_test.go new file mode 100644 index 00000000..de66136f --- /dev/null +++ b/backend/internal/service/billing_cache_service_rpm_test.go @@ -0,0 +1,253 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "sync/atomic" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +// userRPMCacheStub 记录每种计数器被调用的次数,并可注入返回值与错误。 +type userRPMCacheStub struct { + userGroupCalls int32 + userCalls int32 + + userGroupCounts []int // 依次返回的计数值 + userGroupErr error + userCounts []int + userErr error +} + +func (s *userRPMCacheStub) IncrementUserGroupRPM(_ context.Context, _, _ int64) (int, error) { + idx := int(atomic.AddInt32(&s.userGroupCalls, 1)) - 1 + if s.userGroupErr != nil { + return 0, s.userGroupErr + } + if idx < len(s.userGroupCounts) { + return s.userGroupCounts[idx], nil + } + return 1, nil +} + +func (s *userRPMCacheStub) IncrementUserRPM(_ context.Context, _ int64) (int, error) { + idx := int(atomic.AddInt32(&s.userCalls, 1)) - 1 + if s.userErr != nil { + return 0, s.userErr + } + if idx < len(s.userCounts) { + return s.userCounts[idx], nil + } + return 1, nil +} + +func (s *userRPMCacheStub) GetUserGroupRPM(_ context.Context, _, _ int64) (int, error) { + return 0, nil +} + +func (s *userRPMCacheStub) GetUserRPM(_ context.Context, _ int64) (int, error) { + return 0, nil +} + +// rpmOverrideRepoStub 专用于 checkRPM 分支测试,只实现必要方法。 +type rpmOverrideRepoStub struct { + UserGroupRateRepository + + override *int + err error + calls int32 +} + +func (s *rpmOverrideRepoStub) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) { + atomic.AddInt32(&s.calls, 1) + if s.err != nil { + return nil, s.err + } + return s.override, nil +} + +func newBillingServiceForRPM(t *testing.T, cache UserRPMCache, rateRepo UserGroupRateRepository) *BillingCacheService { + t.Helper() + // 用 nil BillingCache 走 "无缓存" 分支,避免 CheckBillingEligibility 副作用。 + // 我们只直接测 checkRPM。 + svc := NewBillingCacheService(nil, nil, nil, nil, cache, rateRepo, &config.Config{}) + t.Cleanup(svc.Stop) + return svc +} + +func TestBillingCacheService_CheckRPM_OverrideTakesPrecedenceOverGroup(t *testing.T) { + override := 2 + // user-group 计数: 1, 2, 3;user 计数: 默认返回 1(远小于 RPMLimit=100,不干扰) + cache := &userRPMCacheStub{userGroupCounts: []int{1, 2, 3}} + repo := &rpmOverrideRepoStub{override: &override} + svc := newBillingServiceForRPM(t, cache, repo) + + user := &User{ID: 1, RPMLimit: 100} // 全局上限设高,不干扰 override 测试 + group := &Group{ID: 10, RPMLimit: 100} + + require.NoError(t, svc.checkRPM(context.Background(), user, group)) + require.NoError(t, svc.checkRPM(context.Background(), user, group)) + require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrGroupRPMExceeded) + + require.EqualValues(t, 3, atomic.LoadInt32(&cache.userGroupCalls), "override 命中分支应走 user-group 计数") + // 并行设计:前 2 次 override 未超→继续检查 user;第 3 次 override 超了→直接 return,不检查 user + require.EqualValues(t, 2, atomic.LoadInt32(&cache.userCalls), "override 超限前 user 计数器应被调用") + require.EqualValues(t, 3, atomic.LoadInt32(&repo.calls)) +} + +func TestBillingCacheService_CheckRPM_UserLimitIsGlobalHardCap(t *testing.T) { + override := 100 // override 很高 + // user-group 计数: 默认返回 1(远小于 override);user 计数: 1, 2, 3 + cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}} + repo := &rpmOverrideRepoStub{override: &override} + svc := newBillingServiceForRPM(t, cache, repo) + + user := &User{ID: 1, RPMLimit: 2} // 全局硬上限=2,应覆盖 override=100 + group := &Group{ID: 10, RPMLimit: 100} + + require.NoError(t, svc.checkRPM(context.Background(), user, group)) + require.NoError(t, svc.checkRPM(context.Background(), user, group)) + require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded, "user 全局硬上限应优先于 override") +} + +func TestBillingCacheService_CheckRPM_OverrideZeroSkipsGroupButUserStillApplies(t *testing.T) { + zero := 0 + // user 计数: 依次返回 1..6 + cache := &userRPMCacheStub{userCounts: []int{1, 2, 3, 4, 5, 6}} + repo := &rpmOverrideRepoStub{override: &zero} + svc := newBillingServiceForRPM(t, cache, repo) + + user := &User{ID: 1, RPMLimit: 5} + group := &Group{ID: 10, RPMLimit: 100} + + // override=0 跳过分组计数,但 user.RPMLimit=5 仍生效 + for i := 0; i < 5; i++ { + require.NoError(t, svc.checkRPM(context.Background(), user, group), "request %d should pass", i+1) + } + require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded, + "override=0 跳过分组但 user 全局上限仍应生效") + require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "override=0 不应触发分组计数器") + require.EqualValues(t, 6, atomic.LoadInt32(&cache.userCalls), "user 计数器应被调用") +} + +func TestBillingCacheService_CheckRPM_OverrideZeroAndUserZeroIsFullyUnlimited(t *testing.T) { + zero := 0 + cache := &userRPMCacheStub{} + repo := &rpmOverrideRepoStub{override: &zero} + svc := newBillingServiceForRPM(t, cache, repo) + + user := &User{ID: 1, RPMLimit: 0} // user 也不限 + group := &Group{ID: 10, RPMLimit: 100} + + for i := 0; i < 50; i++ { + require.NoError(t, svc.checkRPM(context.Background(), user, group)) + } + require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "override=0 不触发分组计数") + require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls), "user.RPMLimit=0 也不触发用户计数") +} + +func TestBillingCacheService_CheckRPM_NilOverrideFallsThroughToGroup(t *testing.T) { + // user-group 计数: 5, 6;user 计数: 默认 1(不干扰) + cache := &userRPMCacheStub{userGroupCounts: []int{5, 6}} + repo := &rpmOverrideRepoStub{override: nil} + svc := newBillingServiceForRPM(t, cache, repo) + + user := &User{ID: 1, RPMLimit: 999} // 全局上限很高,group 先超 + group := &Group{ID: 10, RPMLimit: 5} + + require.NoError(t, svc.checkRPM(context.Background(), user, group)) // ug=5, user=1, 都没超 + require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrGroupRPMExceeded) // ug=6 > 5 + + require.EqualValues(t, 2, atomic.LoadInt32(&cache.userGroupCalls)) + // 并行模式:第 1 次 group 没超 → 继续检查 user;第 2 次 group 超了 → 直接 return,不检查 user + require.EqualValues(t, 1, atomic.LoadInt32(&cache.userCalls), "group 未超时 user 也应检查;group 超时直接返回") +} + +func TestBillingCacheService_CheckRPM_OverrideLookupErrorFallsThroughToGroup(t *testing.T) { + cache := &userRPMCacheStub{userGroupCounts: []int{3}} + repo := &rpmOverrideRepoStub{err: errors.New("db down")} + svc := newBillingServiceForRPM(t, cache, repo) + + user := &User{ID: 1, RPMLimit: 0} + group := &Group{ID: 10, RPMLimit: 10} + + // override 查询失败后应继续尝试 group 分支(不直接拒绝) + require.NoError(t, svc.checkRPM(context.Background(), user, group)) + require.EqualValues(t, 1, atomic.LoadInt32(&cache.userGroupCalls)) + require.EqualValues(t, 1, atomic.LoadInt32(&repo.calls)) +} + +func TestBillingCacheService_CheckRPM_UserLevelFallbackWhenGroupUnlimited(t *testing.T) { + cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}} + repo := &rpmOverrideRepoStub{override: nil} + svc := newBillingServiceForRPM(t, cache, repo) + + user := &User{ID: 1, RPMLimit: 2} + group := &Group{ID: 10, RPMLimit: 0} // 分组未设限 + + require.NoError(t, svc.checkRPM(context.Background(), user, group)) + require.NoError(t, svc.checkRPM(context.Background(), user, group)) + require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded) + + require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "group 未设限时不应 INCR user-group 键") + require.EqualValues(t, 3, atomic.LoadInt32(&cache.userCalls)) +} + +func TestBillingCacheService_CheckRPM_NoLimitsConfiguredIsNoop(t *testing.T) { + cache := &userRPMCacheStub{} + repo := &rpmOverrideRepoStub{override: nil} + svc := newBillingServiceForRPM(t, cache, repo) + + user := &User{ID: 1, RPMLimit: 0} + group := &Group{ID: 10, RPMLimit: 0} + + for i := 0; i < 10; i++ { + require.NoError(t, svc.checkRPM(context.Background(), user, group)) + } + require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls)) + require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls)) +} + +func TestBillingCacheService_CheckRPM_RedisErrorFailOpen(t *testing.T) { + cache := &userRPMCacheStub{userGroupErr: errors.New("redis unavailable")} + repo := &rpmOverrideRepoStub{override: nil} + svc := newBillingServiceForRPM(t, cache, repo) + + user := &User{ID: 1, RPMLimit: 0} + group := &Group{ID: 10, RPMLimit: 5} + + // Redis 故障时应 fail-open,不拒绝请求 + require.NoError(t, svc.checkRPM(context.Background(), user, group)) + require.EqualValues(t, 1, atomic.LoadInt32(&cache.userGroupCalls)) +} + +func TestBillingCacheService_CheckRPM_NoGroupUsesUserOnly(t *testing.T) { + cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}} + repo := &rpmOverrideRepoStub{} + svc := newBillingServiceForRPM(t, cache, repo) + + user := &User{ID: 1, RPMLimit: 2} + + // 无 group(纯用户级限流场景),不应查询 rpm_override。 + require.NoError(t, svc.checkRPM(context.Background(), user, nil)) + require.NoError(t, svc.checkRPM(context.Background(), user, nil)) + require.ErrorIs(t, svc.checkRPM(context.Background(), user, nil), ErrUserRPMExceeded) + + require.EqualValues(t, 0, atomic.LoadInt32(&repo.calls), "无 group 时不应查询 rpm_override") + require.EqualValues(t, 3, atomic.LoadInt32(&cache.userCalls)) +} + +func TestBillingCacheService_CheckRPM_NilUserIsNoop(t *testing.T) { + cache := &userRPMCacheStub{} + repo := &rpmOverrideRepoStub{} + svc := newBillingServiceForRPM(t, cache, repo) + + require.NoError(t, svc.checkRPM(context.Background(), nil, &Group{ID: 1, RPMLimit: 10})) + require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls)) + require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls)) + require.EqualValues(t, 0, atomic.LoadInt32(&repo.calls)) +} diff --git a/backend/internal/service/billing_cache_service_singleflight_test.go b/backend/internal/service/billing_cache_service_singleflight_test.go index 0eaf4570..962becf0 100644 --- a/backend/internal/service/billing_cache_service_singleflight_test.go +++ b/backend/internal/service/billing_cache_service_singleflight_test.go @@ -100,7 +100,7 @@ func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) { delay: 80 * time.Millisecond, balance: 12.34, } - svc := NewBillingCacheService(cache, userRepo, nil, nil, &config.Config{}) + svc := NewBillingCacheService(cache, userRepo, nil, nil, nil, nil, &config.Config{}) t.Cleanup(svc.Stop) const goroutines = 16 diff --git a/backend/internal/service/billing_cache_service_test.go b/backend/internal/service/billing_cache_service_test.go index 7d7045e2..849e24b8 100644 --- a/backend/internal/service/billing_cache_service_test.go +++ b/backend/internal/service/billing_cache_service_test.go @@ -70,7 +70,7 @@ func (b *billingCacheWorkerStub) InvalidateAPIKeyRateLimit(ctx context.Context, func TestBillingCacheServiceQueueHighLoad(t *testing.T) { cache := &billingCacheWorkerStub{} - svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{}) + svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{}) t.Cleanup(svc.Stop) start := time.Now() @@ -92,7 +92,7 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) { func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) { cache := &billingCacheWorkerStub{} - svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{}) + svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{}) svc.Stop() enqueued := svc.enqueueCacheWrite(cacheWriteTask{ diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 3c6888b8..1b7d7a9e 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -170,9 +170,10 @@ const ( SettingKeyCustomEndpoints = "custom_endpoints" // 自定义端点列表(JSON 数组) // 默认配置 - SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 - SettingKeyDefaultBalance = "default_balance" // 新用户默认余额 - SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON) + SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 + SettingKeyDefaultBalance = "default_balance" // 新用户默认余额 + SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON) + SettingKeyDefaultUserRPMLimit = "default_user_rpm_limit" // 新用户默认 RPM 限制(0 = 不限制) // 第三方认证来源默认授予配置 SettingKeyAuthSourceDefaultEmailBalance = "auth_source_default_email_balance" diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index 64434ae1..bb4c5aa1 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -59,6 +59,10 @@ type Group struct { DefaultMappedModel string MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig + // RPMLimit 分组级每分钟请求数上限(0 = 不限制)。 + // 一旦设置即接管该分组用户的限流(覆盖用户级 rpm_limit),可被 user-group rpm_override 进一步覆盖。 + RPMLimit int + CreatedAt time.Time UpdatedAt time.Time diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index f2b644be..df7c86e7 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -1060,6 +1060,7 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting // 默认配置 updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency) updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64) + updates[SettingKeyDefaultUserRPMLimit] = strconv.Itoa(settings.DefaultUserRPMLimit) defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions) if err != nil { return nil, fmt.Errorf("marshal default subscriptions: %w", err) @@ -1422,6 +1423,18 @@ func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 { return s.cfg.Default.UserBalance } +// GetDefaultUserRPMLimit 获取新用户默认 RPM 限制(0 = 不限制)。未配置则返回 0。 +func (s *SettingService) GetDefaultUserRPMLimit(ctx context.Context) int { + value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultUserRPMLimit) + if err != nil || value == "" { + return 0 + } + if v, err := strconv.Atoi(value); err == nil && v >= 0 { + return v + } + return 0 +} + // GetDefaultSubscriptions 获取新用户默认订阅配置列表。 func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultSubscriptionSetting { value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultSubscriptions) @@ -1590,6 +1603,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeyOIDCConnectUserInfoUsernamePath: "", SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), + SettingKeyDefaultUserRPMLimit: "0", SettingKeyDefaultSubscriptions: "[]", SettingKeyAuthSourceDefaultEmailBalance: "0", SettingKeyAuthSourceDefaultEmailConcurrency: "5", @@ -1699,6 +1713,10 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin result.DefaultConcurrency = s.cfg.Default.UserConcurrency } + if rpm, err := strconv.Atoi(settings[SettingKeyDefaultUserRPMLimit]); err == nil && rpm >= 0 { + result.DefaultUserRPMLimit = rpm + } + // 解析浮点数类型 if balance, err := strconv.ParseFloat(settings[SettingKeyDefaultBalance], 64); err == nil { result.DefaultBalance = balance diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index d2ef8fae..a9de5eee 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -106,6 +106,7 @@ type SystemSettings struct { DefaultConcurrency int DefaultBalance float64 + DefaultUserRPMLimit int DefaultSubscriptions []DefaultSubscriptionSetting // Model fallback configuration diff --git a/backend/internal/service/user.go b/backend/internal/service/user.go index 9dc13381..f9833611 100644 --- a/backend/internal/service/user.go +++ b/backend/internal/service/user.go @@ -49,6 +49,15 @@ type User struct { BalanceNotifyExtraEmails []NotifyEmailEntry TotalRecharged float64 + // RPMLimit 用户级每分钟请求数上限(0 = 不限制)。仅在所用分组未设置 rpm_limit + // 且该 (用户, 分组) 无 rpm_override 时作为全局兜底生效,计数键 rpm:u:{userID}:{min}。 + RPMLimit int + + // UserGroupRPMOverride 来自 auth cache snapshot 的 (user, group) RPM 覆盖值。 + // nil = 该 API Key 对应的 (user, group) 无 override;非 nil 时 checkRPM 直接使用, + // 避免每请求查 DB。字段不持久化到数据库。 + UserGroupRPMOverride *int + APIKeys []APIKey Subscriptions []UserSubscription } diff --git a/backend/internal/service/user_group_rate.go b/backend/internal/service/user_group_rate.go index 3d221a25..f069eb7e 100644 --- a/backend/internal/service/user_group_rate.go +++ b/backend/internal/service/user_group_rate.go @@ -2,14 +2,16 @@ package service import "context" -// UserGroupRateEntry 分组下用户专属倍率条目 +// UserGroupRateEntry 分组下用户专属倍率/RPM 条目。 +// RateMultiplier 与 RPMOverride 均为指针以支持"未设置"语义(NULL)。 type UserGroupRateEntry struct { - UserID int64 `json:"user_id"` - UserName string `json:"user_name"` - UserEmail string `json:"user_email"` - UserNotes string `json:"user_notes"` - UserStatus string `json:"user_status"` - RateMultiplier float64 `json:"rate_multiplier"` + UserID int64 `json:"user_id"` + UserName string `json:"user_name"` + UserEmail string `json:"user_email"` + UserNotes string `json:"user_notes"` + UserStatus string `json:"user_status"` + RateMultiplier *float64 `json:"rate_multiplier,omitempty"` + RPMOverride *int `json:"rpm_override,omitempty"` } // GroupRateMultiplierInput 批量设置分组倍率的输入条目 @@ -18,30 +20,44 @@ type GroupRateMultiplierInput struct { RateMultiplier float64 `json:"rate_multiplier"` } -// UserGroupRateRepository 用户专属分组倍率仓储接口 -// 允许管理员为特定用户设置分组的专属计费倍率,覆盖分组默认倍率 +// GroupRPMOverrideInput 批量设置分组 RPM override 的输入条目。 +// RPMOverride 为 *int 以支持清除(nil)语义。 +type GroupRPMOverrideInput struct { + UserID int64 `json:"user_id"` + RPMOverride *int `json:"rpm_override"` +} + +// UserGroupRateRepository 用户专属分组倍率/RPM 仓储接口。 +// 允许管理员为特定用户设置分组的专属计费倍率与 RPM 上限,覆盖分组默认值。 type UserGroupRateRepository interface { - // GetByUserID 获取用户的所有专属分组倍率 - // 返回 map[groupID]rateMultiplier + // GetByUserID 获取用户所有专属分组 rate_multiplier(仅返回非 NULL 的条目) GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) - // GetByUserAndGroup 获取用户在特定分组的专属倍率 - // 如果未设置专属倍率,返回 nil + // GetByUserAndGroup 获取用户在特定分组的专属 rate_multiplier(NULL 返回 nil) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) - // GetByGroupID 获取指定分组下所有用户的专属倍率 + // GetRPMOverrideByUserAndGroup 获取用户在特定分组的 rpm_override(NULL 返回 nil) + GetRPMOverrideByUserAndGroup(ctx context.Context, userID, groupID int64) (*int, error) + + // GetByGroupID 获取指定分组下所有用户的专属配置(rate 与 rpm_override 任一非 NULL 即返回) GetByGroupID(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error) - // SyncUserGroupRates 同步用户的分组专属倍率 - // rates: map[groupID]*rateMultiplier,nil 表示删除该分组的专属倍率 + // SyncUserGroupRates 同步用户的分组专属倍率;nil 表示清空该分组的 rate_multiplier SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error - // SyncGroupRateMultipliers 批量同步分组的用户专属倍率(替换整组数据) + // SyncGroupRateMultipliers 批量同步分组的用户专属倍率(替换整组 rate 部分) SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error - // DeleteByGroupID 删除指定分组的所有用户专属倍率(分组删除时调用) + // SyncGroupRPMOverrides 批量同步分组的用户专属 RPM(替换整组 rpm_override 部分)。 + // 条目中 RPMOverride 为 nil 时清空对应行的 rpm_override;非 nil 时 upsert。 + SyncGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error + + // ClearGroupRPMOverrides 清空指定分组的所有 rpm_override(整组 rpm 部分归 NULL) + ClearGroupRPMOverrides(ctx context.Context, groupID int64) error + + // DeleteByGroupID 删除指定分组的所有用户专属条目(分组删除时调用) DeleteByGroupID(ctx context.Context, groupID int64) error - // DeleteByUserID 删除指定用户的所有专属倍率(用户删除时调用) + // DeleteByUserID 删除指定用户的所有专属条目(用户删除时调用) DeleteByUserID(ctx context.Context, userID int64) error } diff --git a/backend/internal/service/user_rpm_cache.go b/backend/internal/service/user_rpm_cache.go new file mode 100644 index 00000000..b8857311 --- /dev/null +++ b/backend/internal/service/user_rpm_cache.go @@ -0,0 +1,25 @@ +package service + +import "context" + +// UserRPMCache 用户/分组级 RPM 计数器接口。 +// +// 与账号级 RPMCache 的区别: +// - RPMCache —— 按外部 AI provider 账号聚合(key: rpm:{accountID}:{min})。 +// - UserRPMCache —— 按用户或 (用户, 分组) 聚合,杜绝"同一用户创建多个 API Key 绕过 RPM"的路径。 +// key 形如 rpm:ug:{userID}:{groupID}:{min} 或 rpm:u:{userID}:{min}。 +type UserRPMCache interface { + // IncrementUserGroupRPM 原子递增 (user, group) 级分钟计数并返回最新值。 + // 用于分组 rpm_limit 与 user-group rpm_override 两种命中分支。 + IncrementUserGroupRPM(ctx context.Context, userID, groupID int64) (count int, err error) + + // IncrementUserRPM 原子递增用户级分钟计数并返回最新值。 + // 用于用户全局 rpm_limit 兜底分支(分组未设且无 override 时)。 + IncrementUserRPM(ctx context.Context, userID int64) (count int, err error) + + // GetUserGroupRPM 获取 (user, group) 当前分钟已用 RPM(只读,不递增)。 + GetUserGroupRPM(ctx context.Context, userID, groupID int64) (count int, err error) + + // GetUserRPM 获取用户当前分钟已用 RPM(只读,不递增)。 + GetUserRPM(ctx context.Context, userID int64) (count int, err error) +} diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index e6c4c074..6b405724 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -39,6 +39,11 @@ func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService { return NewEmailQueueService(emailService, 3) } +// ProvideOAuthRefreshAPI creates OAuthRefreshAPI with the default lock TTL. +func ProvideOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCache) *OAuthRefreshAPI { + return NewOAuthRefreshAPI(accountRepo, tokenCache) +} + // ProvideTokenRefreshService creates and starts TokenRefreshService func ProvideTokenRefreshService( accountRepo AccountRepository, @@ -383,6 +388,19 @@ func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupReposit return svc } +// ProvideBillingCacheService wires BillingCacheService with its RPM dependencies. +func ProvideBillingCacheService( + cache BillingCache, + userRepo UserRepository, + subRepo UserSubscriptionRepository, + apiKeyRepo APIKeyRepository, + rpmCache UserRPMCache, + rateRepo UserGroupRateRepository, + cfg *config.Config, +) *BillingCacheService { + return NewBillingCacheService(cache, userRepo, subRepo, apiKeyRepo, rpmCache, rateRepo, cfg) +} + // ProviderSet is the Wire provider set for all services var ProviderSet = wire.NewSet( // Core services @@ -399,7 +417,7 @@ var ProviderSet = wire.NewSet( NewDashboardService, ProvidePricingService, NewBillingService, - NewBillingCacheService, + ProvideBillingCacheService, NewAnnouncementService, NewAdminService, NewGatewayService, @@ -411,7 +429,7 @@ var ProviderSet = wire.NewSet( NewCompositeTokenCacheInvalidator, wire.Bind(new(TokenCacheInvalidator), new(*CompositeTokenCacheInvalidator)), NewAntigravityOAuthService, - NewOAuthRefreshAPI, + ProvideOAuthRefreshAPI, ProvideGeminiTokenProvider, NewGeminiMessagesCompatService, ProvideAntigravityTokenProvider, diff --git a/backend/migrations/125_add_group_rpm_limit.sql b/backend/migrations/125_add_group_rpm_limit.sql new file mode 100644 index 00000000..fbde1b20 --- /dev/null +++ b/backend/migrations/125_add_group_rpm_limit.sql @@ -0,0 +1,7 @@ +-- Add per-group Requests-Per-Minute limit. +-- rpm_limit: 分组统一 RPM 上限(0 = 不限制)。 +-- 一旦配置即接管该用户在该分组的限流,覆盖用户级 users.rpm_limit。 +-- 计数键:rpm:ug:{user_id}:{group_id}:{minute}。 +ALTER TABLE groups ADD COLUMN IF NOT EXISTS rpm_limit integer NOT NULL DEFAULT 0; + +COMMENT ON COLUMN groups.rpm_limit IS '分组 RPM 上限;0 表示不限制;设置后接管该分组用户的限流(覆盖用户级 rpm_limit)。'; diff --git a/backend/migrations/126_add_user_rpm_limit.sql b/backend/migrations/126_add_user_rpm_limit.sql new file mode 100644 index 00000000..64a8b977 --- /dev/null +++ b/backend/migrations/126_add_user_rpm_limit.sql @@ -0,0 +1,7 @@ +-- Add per-user Requests-Per-Minute cap. +-- rpm_limit: 用户全局 RPM 兜底(0 = 不限制)。 +-- 仅当所访问分组未设置 rpm_limit 且无 user-group rpm_override 时作为兜底生效。 +-- 计数键:rpm:u:{user_id}:{minute}。 +ALTER TABLE users ADD COLUMN IF NOT EXISTS rpm_limit integer NOT NULL DEFAULT 0; + +COMMENT ON COLUMN users.rpm_limit IS '用户级 RPM 兜底上限;0 表示不限制;仅当分组未设置 rpm_limit 时生效。'; diff --git a/backend/migrations/127_add_user_group_rpm_override.sql b/backend/migrations/127_add_user_group_rpm_override.sql new file mode 100644 index 00000000..1d674258 --- /dev/null +++ b/backend/migrations/127_add_user_group_rpm_override.sql @@ -0,0 +1,16 @@ +-- 在已有的"用户专属分组倍率表"上扩展 rpm_override 列;同时放宽 rate_multiplier 为可空, +-- 使一行记录可以只覆盖 rate、只覆盖 rpm,或同时覆盖两者。 +-- 语义: +-- - rate_multiplier NULL → 该用户在此分组使用 groups.rate_multiplier 默认值 +-- - rate_multiplier 非 NULL → 覆盖分组默认计费倍率 +-- - rpm_override NULL → 该用户在此分组使用 groups.rpm_limit 默认值 +-- - rpm_override 非 NULL → 覆盖分组默认 RPM(0 = 不限制) +-- 用户级 users.rpm_limit 仍独立生效(跨分组总配额)。 +ALTER TABLE user_group_rate_multipliers + ADD COLUMN IF NOT EXISTS rpm_override integer NULL; + +ALTER TABLE user_group_rate_multipliers + ALTER COLUMN rate_multiplier DROP NOT NULL; + +COMMENT ON COLUMN user_group_rate_multipliers.rate_multiplier IS '专属计费倍率;NULL 表示沿用分组默认倍率。'; +COMMENT ON COLUMN user_group_rate_multipliers.rpm_override IS '专属 RPM 上限;NULL 表示沿用分组默认;0 表示该用户在此分组不受 RPM 限制。'; diff --git a/frontend/src/api/admin/groups.ts b/frontend/src/api/admin/groups.ts index 8739d5cb..6b94b799 100644 --- a/frontend/src/api/admin/groups.ts +++ b/frontend/src/api/admin/groups.ts @@ -164,7 +164,8 @@ export interface GroupRateMultiplierEntry { user_email: string user_notes: string user_status: string - rate_multiplier: number + rate_multiplier?: number | null + rpm_override?: number | null } /** @@ -205,9 +206,7 @@ export async function clearGroupRateMultipliers(id: number): Promise<{ message: /** * Batch set rate multipliers for users in a group - * @param id - Group ID - * @param entries - Array of { user_id, rate_multiplier } - * @returns Success confirmation + * Only touches rate_multiplier column; preserves rpm_override on existing rows. */ export async function batchSetGroupRateMultipliers( id: number, @@ -220,6 +219,60 @@ export async function batchSetGroupRateMultipliers( return data } +/** + * RPM override entry for a user in a group + */ +export interface GroupRPMOverrideEntry { + user_id: number + user_name: string + user_email: string + user_notes: string + user_status: string + rpm_override: number +} + +/** + * Get RPM overrides for users in a group (subset of rate-multipliers endpoint). + */ +export async function getGroupRPMOverrides(id: number): Promise { + const { data } = await apiClient.get( + `/admin/groups/${id}/rate-multipliers` + ) + return data + .filter(e => e.rpm_override != null) + .map(e => ({ + user_id: e.user_id, + user_name: e.user_name, + user_email: e.user_email, + user_notes: e.user_notes, + user_status: e.user_status, + rpm_override: e.rpm_override as number + })) +} + +/** + * Batch set RPM overrides for users in a group. + * Only touches rpm_override column; preserves rate_multiplier on existing rows. + */ +export async function batchSetGroupRPMOverrides( + id: number, + entries: Array<{ user_id: number; rpm_override: number }> +): Promise<{ message: string }> { + const { data } = await apiClient.put<{ message: string }>( + `/admin/groups/${id}/rpm-overrides`, + { entries } + ) + return data +} + +/** + * Clear all RPM overrides for a group (preserves rate_multiplier). + */ +export async function clearGroupRPMOverrides(id: number): Promise<{ message: string }> { + const { data } = await apiClient.delete<{ message: string }>(`/admin/groups/${id}/rpm-overrides`) + return data +} + /** * Get usage summary (today + cumulative cost) for all groups * @param timezone - IANA timezone string (e.g. "Asia/Shanghai") @@ -262,6 +315,9 @@ export const groupsAPI = { getGroupRateMultipliers, clearGroupRateMultipliers, batchSetGroupRateMultipliers, + getGroupRPMOverrides, + clearGroupRPMOverrides, + batchSetGroupRPMOverrides, updateSortOrder, getUsageSummary, getCapacitySummary diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index 0403b0f3..c7b9031e 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -309,6 +309,7 @@ export interface SystemSettings { // Default settings default_balance: number; default_concurrency: number; + default_user_rpm_limit: number; default_subscriptions: DefaultSubscriptionSetting[]; auth_source_default_email_balance?: number; auth_source_default_email_concurrency?: number; @@ -482,6 +483,7 @@ export interface UpdateSettingsRequest { totp_enabled?: boolean; // TOTP 双因素认证 default_balance?: number; default_concurrency?: number; + default_user_rpm_limit?: number; default_subscriptions?: DefaultSubscriptionSetting[]; auth_source_default_email_balance?: number; auth_source_default_email_concurrency?: number; diff --git a/frontend/src/components/admin/group/GroupRPMOverridesModal.vue b/frontend/src/components/admin/group/GroupRPMOverridesModal.vue new file mode 100644 index 00000000..a4b4e536 --- /dev/null +++ b/frontend/src/components/admin/group/GroupRPMOverridesModal.vue @@ -0,0 +1,434 @@ + + + + + diff --git a/frontend/src/components/admin/group/GroupRateMultipliersModal.vue b/frontend/src/components/admin/group/GroupRateMultipliersModal.vue index 41b2e63c..d68f3aa5 100644 --- a/frontend/src/components/admin/group/GroupRateMultipliersModal.vue +++ b/frontend/src/components/admin/group/GroupRateMultipliersModal.vue @@ -168,7 +168,8 @@ step="0.001" min="0.001" autocomplete="off" - :value="entry.rate_multiplier" + :value="entry.rate_multiplier ?? ''" + :placeholder="String(props.group?.rate_multiplier ?? 1)" class="hide-spinner w-20 rounded border border-gray-200 bg-white px-2 py-1 text-center text-sm font-medium transition-colors focus:border-primary-500 focus:outline-none focus:ring-1 focus:ring-primary-500/20 dark:border-dark-500 dark:bg-dark-700 dark:focus:border-primary-500" @change="updateLocalRate(entry.user_id, ($event.target as HTMLInputElement).value)" /> @@ -294,19 +295,17 @@ const showFinalRate = computed(() => { }) // 计算最终倍率预览 -const computeFinalRate = (rate: number) => { - if (!batchFactor.value) return rate - return parseFloat((rate * batchFactor.value).toFixed(6)) +const computeFinalRate = (rate: number | null | undefined) => { + const base = rate ?? props.group?.rate_multiplier ?? 1 + if (!batchFactor.value) return base + return parseFloat((base * batchFactor.value).toFixed(6)) } // 检测是否有未保存的修改 const isDirty = computed(() => { if (localEntries.value.length !== serverEntries.value.length) return true - const serverMap = new Map(serverEntries.value.map(e => [e.user_id, e.rate_multiplier])) - return localEntries.value.some(e => { - const serverRate = serverMap.get(e.user_id) - return serverRate === undefined || serverRate !== e.rate_multiplier - }) + const serverMap = new Map(serverEntries.value.map(e => [e.user_id, e.rate_multiplier ?? null])) + return localEntries.value.some(e => serverMap.get(e.user_id) !== (e.rate_multiplier ?? null)) }) const paginatedLocalEntries = computed(() => { @@ -322,7 +321,9 @@ const loadEntries = async () => { if (!props.group) return loading.value = true try { - serverEntries.value = await adminAPI.groups.getGroupRateMultipliers(props.group.id) + const raw = await adminAPI.groups.getGroupRateMultipliers(props.group.id) + // 仅显示已设置 rate_multiplier 的条目;rpm_override 在另一个弹窗管理,保留不动 + serverEntries.value = raw.filter(e => e.rate_multiplier != null) localEntries.value = cloneEntries(serverEntries.value) adjustPage() } catch (error) { @@ -394,7 +395,8 @@ const handleAddLocal = () => { user_email: user.email, user_notes: user.notes || '', user_status: user.status || 'active', - rate_multiplier: newRate.value + rate_multiplier: newRate.value, + rpm_override: null } if (idx >= 0) { localEntries.value[idx] = entry @@ -409,12 +411,15 @@ const handleAddLocal = () => { // 本地修改倍率 const updateLocalRate = (userId: number, value: string) => { + const entry = localEntries.value.find(e => e.user_id === userId) + if (!entry) return + if (value.trim() === '') { + entry.rate_multiplier = null + return + } const num = parseFloat(value) if (isNaN(num)) return - const entry = localEntries.value.find(e => e.user_id === userId) - if (entry) { - entry.rate_multiplier = num - } + entry.rate_multiplier = num } // 本地删除 @@ -427,7 +432,9 @@ const removeLocal = (userId: number) => { const applyBatchFactor = () => { if (!batchFactor.value || batchFactor.value <= 0) return for (const entry of localEntries.value) { - entry.rate_multiplier = parseFloat((entry.rate_multiplier * batchFactor.value).toFixed(6)) + if (entry.rate_multiplier != null) { + entry.rate_multiplier = parseFloat((entry.rate_multiplier * batchFactor.value).toFixed(6)) + } } batchFactor.value = null } @@ -444,15 +451,17 @@ const handleCancel = () => { adjustPage() } -// 保存:一次性提交所有数据 +// 保存:一次性提交所有数据(只提交 rate_multiplier;rpm_override 由独立弹窗管理) const handleSave = async () => { if (!props.group) return saving.value = true try { - const entries = localEntries.value.map(e => ({ - user_id: e.user_id, - rate_multiplier: e.rate_multiplier - })) + const entries = localEntries.value + .filter(e => e.rate_multiplier != null) + .map(e => ({ + user_id: e.user_id, + rate_multiplier: e.rate_multiplier as number + })) await adminAPI.groups.batchSetGroupRateMultipliers(props.group.id, entries) appStore.showSuccess(t('admin.groups.rateSaved')) emit('success') diff --git a/frontend/src/components/admin/user/UserCreateModal.vue b/frontend/src/components/admin/user/UserCreateModal.vue index 0e44d81e..2966a23b 100644 --- a/frontend/src/components/admin/user/UserCreateModal.vue +++ b/frontend/src/components/admin/user/UserCreateModal.vue @@ -35,6 +35,18 @@ +
+ + +

{{ t('admin.users.form.rpmLimitHint') }}

+