diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index b729c575..26ed8524 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -246,10 +246,10 @@ jobs: if [ -n "$DOCKERHUB_USERNAME" ]; then DOCKER_IMAGE="${DOCKERHUB_USERNAME}/sub2api" MESSAGE+="# Docker Hub"$'\n' - MESSAGE+="docker pull ${DOCKER_IMAGE}:${TAG_NAME}"$'\n' + MESSAGE+="docker pull ${DOCKER_IMAGE}:${VERSION}"$'\n' MESSAGE+="# GitHub Container Registry"$'\n' fi - MESSAGE+="docker pull ${GHCR_IMAGE}:${TAG_NAME}"$'\n' + MESSAGE+="docker pull ${GHCR_IMAGE}:${VERSION}"$'\n' MESSAGE+="\`\`\`"$'\n'$'\n' MESSAGE+="🔗 *相关链接:*"$'\n' MESSAGE+="• [GitHub Release](https://github.com/${REPO}/releases/tag/${TAG_NAME})"$'\n' diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 72400828..793d0e31 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.115 +0.1.116 diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index bcfb4e1f..93270e7e 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -61,8 +61,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { billingCache := repository.NewBillingCache(redisClient) userSubscriptionRepository := repository.NewUserSubscriptionRepository(client) apiKeyRepository := repository.NewAPIKeyRepository(client, db) - billingCacheService := service.NewBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, configConfig) + userRPMCache := repository.NewUserRPMCache(redisClient) userGroupRateRepository := repository.NewUserGroupRateRepository(db) + billingCacheService := service.ProvideBillingCacheService(billingCache, userRepository, userSubscriptionRepository, apiKeyRepository, userRPMCache, userGroupRateRepository, configConfig) apiKeyCache := repository.NewAPIKeyCache(redisClient) apiKeyService := service.NewAPIKeyService(apiKeyRepository, userRepository, groupRepository, userSubscriptionRepository, userGroupRateRepository, apiKeyCache, configConfig) apiKeyAuthCacheInvalidator := service.ProvideAPIKeyAuthCacheInvalidator(apiKeyService) @@ -104,7 +105,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { proxyExitInfoProber := repository.NewProxyExitInfoProber(configConfig) proxyLatencyCache := repository.NewProxyLatencyCache(redisClient) privacyClientFactory := providePrivacyClientFactory() - adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory) + adminService := service.NewAdminService(userRepository, groupRepository, accountRepository, proxyRepository, apiKeyRepository, redeemCodeRepository, userGroupRateRepository, userRPMCache, billingCacheService, proxyExitInfoProber, proxyLatencyCache, apiKeyAuthCacheInvalidator, client, settingService, subscriptionService, userSubscriptionRepository, privacyClientFactory) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) adminUserHandler := admin.NewUserHandler(adminService, concurrencyService) @@ -124,9 +125,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository) tempUnschedCache := repository.NewTempUnschedCache(redisClient) timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient) + openAI403CounterCache := repository.NewOpenAI403CounterCache(redisClient) geminiTokenCache := repository.NewGeminiTokenCache(redisClient) compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache) - rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator) + rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, openAI403CounterCache, settingService, compositeTokenCacheInvalidator) httpUpstream := repository.NewHTTPUpstream(configConfig) claudeUsageFetcher := repository.NewClaudeUsageFetcher(httpUpstream) antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository) @@ -136,7 +138,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { tlsFingerprintProfileCache := repository.NewTLSFingerprintProfileCache(redisClient) tlsFingerprintProfileService := service.NewTLSFingerprintProfileService(tlsFingerprintProfileRepository, tlsFingerprintProfileCache) accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache, tlsFingerprintProfileService) - oAuthRefreshAPI := service.NewOAuthRefreshAPI(accountRepository, geminiTokenCache) + oAuthRefreshAPI := service.ProvideOAuthRefreshAPI(accountRepository, geminiTokenCache) geminiTokenProvider := service.ProvideGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService, oAuthRefreshAPI) gatewayCache := repository.NewGatewayCache(redisClient) schedulerOutboxRepository := repository.NewSchedulerOutboxRepository(db) @@ -183,6 +185,15 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) + encryptionKey, err := payment.ProvideEncryptionKey(configConfig) + if err != nil { + return nil, err + } + paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey) + registry := payment.ProvideRegistry() + defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey) + paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository) + settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService) opsHandler := admin.NewOpsHandler(opsService) updateCache := repository.NewUpdateCache(redisClient) gitHubReleaseClient := repository.ProvideGitHubReleaseClient(configConfig) @@ -222,16 +233,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { channelMonitorHandler := admin.NewChannelMonitorHandler(channelMonitorService) channelMonitorUserHandler := handler.NewChannelMonitorUserHandler(channelMonitorService, settingService) channelMonitorRunner := service.ProvideChannelMonitorRunner(channelMonitorService, settingService) - registry := payment.ProvideRegistry() - encryptionKey, err := payment.ProvideEncryptionKey(configConfig) - if err != nil { - return nil, err - } - defaultLoadBalancer := payment.ProvideDefaultLoadBalancer(client, encryptionKey) - paymentConfigService := service.ProvidePaymentConfigService(client, settingRepository, encryptionKey) - paymentService := service.NewPaymentService(client, registry, defaultLoadBalancer, redeemService, subscriptionService, paymentConfigService, userRepository, groupRepository) - settingHandler := admin.NewSettingHandler(settingService, emailService, turnstileService, opsService, paymentConfigService, paymentService) - paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService) paymentHandler := admin.NewPaymentHandler(paymentService, paymentConfigService) availableChannelUserHandler := handler.NewAvailableChannelHandler(channelService, apiKeyService, settingService) adminHandlers := handler.ProvideAdminHandlers(dashboardHandler, adminUserHandler, groupHandler, accountHandler, adminAnnouncementHandler, dataManagementHandler, backupHandler, oAuthHandler, openAIOAuthHandler, geminiOAuthHandler, antigravityOAuthHandler, proxyHandler, adminRedeemHandler, promoHandler, settingHandler, opsHandler, systemHandler, adminSubscriptionHandler, adminUsageHandler, userAttributeHandler, errorPassthroughHandler, tlsFingerprintProfileHandler, adminAPIKeyHandler, scheduledTestHandler, channelHandler, channelMonitorHandler, channelMonitorRequestTemplateHandler, paymentHandler) @@ -261,6 +262,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { accountExpiryService := service.ProvideAccountExpiryService(accountRepository) subscriptionExpiryService := service.ProvideSubscriptionExpiryService(userSubscriptionRepository) scheduledTestRunnerService := service.ProvideScheduledTestRunnerService(scheduledTestPlanRepository, scheduledTestService, accountTestService, rateLimitService, configConfig) + paymentOrderExpiryService := service.ProvidePaymentOrderExpiryService(paymentService) v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, opsSystemLogSink, schedulerSnapshotService, tokenRefreshService, accountExpiryService, subscriptionExpiryService, usageCleanupService, idempotencyCleanupService, pricingService, emailQueueService, billingCacheService, usageRecordWorkerPool, subscriptionService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, openAIGatewayService, scheduledTestRunnerService, backupService, paymentOrderExpiryService, channelMonitorRunner) application := &Application{ Server: httpServer, diff --git a/backend/cmd/server/wire_gen_test.go b/backend/cmd/server/wire_gen_test.go index 39aa85de..5ccd67fb 100644 --- a/backend/cmd/server/wire_gen_test.go +++ b/backend/cmd/server/wire_gen_test.go @@ -43,7 +43,7 @@ func TestProvideCleanup_WithMinimalDependencies_NoPanic(t *testing.T) { subscriptionExpirySvc := service.NewSubscriptionExpiryService(nil, time.Second) pricingSvc := service.NewPricingService(cfg, nil) emailQueueSvc := service.NewEmailQueueService(nil, 1) - billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, cfg) + billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg) idempotencyCleanupSvc := service.NewIdempotencyCleanupService(nil, cfg) schedulerSnapshotSvc := service.NewSchedulerSnapshotService(nil, nil, nil, nil, cfg) opsSystemLogSinkSvc := service.NewOpsSystemLogSink(nil) diff --git a/backend/ent/group.go b/backend/ent/group.go index f10b50c3..5d9ae2ed 100644 --- a/backend/ent/group.go +++ b/backend/ent/group.go @@ -79,6 +79,8 @@ type Group struct { DefaultMappedModel string `json:"default_mapped_model,omitempty"` // OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型 MessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"` + // 分组 RPM 上限,0 表示不限制;设置后接管该分组用户的限流 + RpmLimit int `json:"rpm_limit,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the GroupQuery when eager-loading is set. Edges GroupEdges `json:"edges"` @@ -191,7 +193,7 @@ func (*Group) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullBool) case group.FieldRateMultiplier, group.FieldDailyLimitUsd, group.FieldWeeklyLimitUsd, group.FieldMonthlyLimitUsd, group.FieldImagePrice1k, group.FieldImagePrice2k, group.FieldImagePrice4k: values[i] = new(sql.NullFloat64) - case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder: + case group.FieldID, group.FieldDefaultValidityDays, group.FieldFallbackGroupID, group.FieldFallbackGroupIDOnInvalidRequest, group.FieldSortOrder, group.FieldRpmLimit: values[i] = new(sql.NullInt64) case group.FieldName, group.FieldDescription, group.FieldStatus, group.FieldPlatform, group.FieldSubscriptionType, group.FieldDefaultMappedModel: values[i] = new(sql.NullString) @@ -414,6 +416,12 @@ func (_m *Group) assignValues(columns []string, values []any) error { return fmt.Errorf("unmarshal field messages_dispatch_model_config: %w", err) } } + case group.FieldRpmLimit: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field rpm_limit", values[i]) + } else if value.Valid { + _m.RpmLimit = int(value.Int64) + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -599,6 +607,9 @@ func (_m *Group) String() string { builder.WriteString(", ") builder.WriteString("messages_dispatch_model_config=") builder.WriteString(fmt.Sprintf("%v", _m.MessagesDispatchModelConfig)) + builder.WriteString(", ") + builder.WriteString("rpm_limit=") + builder.WriteString(fmt.Sprintf("%v", _m.RpmLimit)) builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/group/group.go b/backend/ent/group/group.go index b1371630..24bd9c13 100644 --- a/backend/ent/group/group.go +++ b/backend/ent/group/group.go @@ -76,6 +76,8 @@ const ( FieldDefaultMappedModel = "default_mapped_model" // FieldMessagesDispatchModelConfig holds the string denoting the messages_dispatch_model_config field in the database. FieldMessagesDispatchModelConfig = "messages_dispatch_model_config" + // FieldRpmLimit holds the string denoting the rpm_limit field in the database. + FieldRpmLimit = "rpm_limit" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. EdgeAPIKeys = "api_keys" // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. @@ -181,6 +183,7 @@ var Columns = []string{ FieldRequirePrivacySet, FieldDefaultMappedModel, FieldMessagesDispatchModelConfig, + FieldRpmLimit, } var ( @@ -258,6 +261,8 @@ var ( DefaultMappedModelValidator func(string) error // DefaultMessagesDispatchModelConfig holds the default value on creation for the "messages_dispatch_model_config" field. DefaultMessagesDispatchModelConfig domain.OpenAIMessagesDispatchModelConfig + // DefaultRpmLimit holds the default value on creation for the "rpm_limit" field. + DefaultRpmLimit int ) // OrderOption defines the ordering options for the Group queries. @@ -403,6 +408,11 @@ func ByDefaultMappedModel(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldDefaultMappedModel, opts...).ToFunc() } +// ByRpmLimit orders the results by the rpm_limit field. +func ByRpmLimit(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRpmLimit, opts...).ToFunc() +} + // ByAPIKeysCount orders the results by api_keys count. func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/backend/ent/group/where.go b/backend/ent/group/where.go index cba2ce5f..2814d130 100644 --- a/backend/ent/group/where.go +++ b/backend/ent/group/where.go @@ -190,6 +190,11 @@ func DefaultMappedModel(v string) predicate.Group { return predicate.Group(sql.FieldEQ(FieldDefaultMappedModel, v)) } +// RpmLimit applies equality check predicate on the "rpm_limit" field. It's identical to RpmLimitEQ. +func RpmLimit(v int) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldRpmLimit, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.Group { return predicate.Group(sql.FieldEQ(FieldCreatedAt, v)) @@ -1320,6 +1325,46 @@ func DefaultMappedModelContainsFold(v string) predicate.Group { return predicate.Group(sql.FieldContainsFold(FieldDefaultMappedModel, v)) } +// RpmLimitEQ applies the EQ predicate on the "rpm_limit" field. +func RpmLimitEQ(v int) predicate.Group { + return predicate.Group(sql.FieldEQ(FieldRpmLimit, v)) +} + +// RpmLimitNEQ applies the NEQ predicate on the "rpm_limit" field. +func RpmLimitNEQ(v int) predicate.Group { + return predicate.Group(sql.FieldNEQ(FieldRpmLimit, v)) +} + +// RpmLimitIn applies the In predicate on the "rpm_limit" field. +func RpmLimitIn(vs ...int) predicate.Group { + return predicate.Group(sql.FieldIn(FieldRpmLimit, vs...)) +} + +// RpmLimitNotIn applies the NotIn predicate on the "rpm_limit" field. +func RpmLimitNotIn(vs ...int) predicate.Group { + return predicate.Group(sql.FieldNotIn(FieldRpmLimit, vs...)) +} + +// RpmLimitGT applies the GT predicate on the "rpm_limit" field. +func RpmLimitGT(v int) predicate.Group { + return predicate.Group(sql.FieldGT(FieldRpmLimit, v)) +} + +// RpmLimitGTE applies the GTE predicate on the "rpm_limit" field. +func RpmLimitGTE(v int) predicate.Group { + return predicate.Group(sql.FieldGTE(FieldRpmLimit, v)) +} + +// RpmLimitLT applies the LT predicate on the "rpm_limit" field. +func RpmLimitLT(v int) predicate.Group { + return predicate.Group(sql.FieldLT(FieldRpmLimit, v)) +} + +// RpmLimitLTE applies the LTE predicate on the "rpm_limit" field. +func RpmLimitLTE(v int) predicate.Group { + return predicate.Group(sql.FieldLTE(FieldRpmLimit, v)) +} + // HasAPIKeys applies the HasEdge predicate on the "api_keys" edge. func HasAPIKeys() predicate.Group { return predicate.Group(func(s *sql.Selector) { diff --git a/backend/ent/group_create.go b/backend/ent/group_create.go index f412fa40..20ea0a0f 100644 --- a/backend/ent/group_create.go +++ b/backend/ent/group_create.go @@ -425,6 +425,20 @@ func (_c *GroupCreate) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMe return _c } +// SetRpmLimit sets the "rpm_limit" field. +func (_c *GroupCreate) SetRpmLimit(v int) *GroupCreate { + _c.mutation.SetRpmLimit(v) + return _c +} + +// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil. +func (_c *GroupCreate) SetNillableRpmLimit(v *int) *GroupCreate { + if v != nil { + _c.SetRpmLimit(*v) + } + return _c +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_c *GroupCreate) AddAPIKeyIDs(ids ...int64) *GroupCreate { _c.mutation.AddAPIKeyIDs(ids...) @@ -630,6 +644,10 @@ func (_c *GroupCreate) defaults() error { v := group.DefaultMessagesDispatchModelConfig _c.mutation.SetMessagesDispatchModelConfig(v) } + if _, ok := _c.mutation.RpmLimit(); !ok { + v := group.DefaultRpmLimit + _c.mutation.SetRpmLimit(v) + } return nil } @@ -717,6 +735,9 @@ func (_c *GroupCreate) check() error { if _, ok := _c.mutation.MessagesDispatchModelConfig(); !ok { return &ValidationError{Name: "messages_dispatch_model_config", err: errors.New(`ent: missing required field "Group.messages_dispatch_model_config"`)} } + if _, ok := _c.mutation.RpmLimit(); !ok { + return &ValidationError{Name: "rpm_limit", err: errors.New(`ent: missing required field "Group.rpm_limit"`)} + } return nil } @@ -864,6 +885,10 @@ func (_c *GroupCreate) createSpec() (*Group, *sqlgraph.CreateSpec) { _spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value) _node.MessagesDispatchModelConfig = value } + if value, ok := _c.mutation.RpmLimit(); ok { + _spec.SetField(group.FieldRpmLimit, field.TypeInt, value) + _node.RpmLimit = value + } if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1500,6 +1525,24 @@ func (u *GroupUpsert) UpdateMessagesDispatchModelConfig() *GroupUpsert { return u } +// SetRpmLimit sets the "rpm_limit" field. +func (u *GroupUpsert) SetRpmLimit(v int) *GroupUpsert { + u.Set(group.FieldRpmLimit, v) + return u +} + +// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create. +func (u *GroupUpsert) UpdateRpmLimit() *GroupUpsert { + u.SetExcluded(group.FieldRpmLimit) + return u +} + +// AddRpmLimit adds v to the "rpm_limit" field. +func (u *GroupUpsert) AddRpmLimit(v int) *GroupUpsert { + u.Add(group.FieldRpmLimit, v) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -2105,6 +2148,27 @@ func (u *GroupUpsertOne) UpdateMessagesDispatchModelConfig() *GroupUpsertOne { }) } +// SetRpmLimit sets the "rpm_limit" field. +func (u *GroupUpsertOne) SetRpmLimit(v int) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.SetRpmLimit(v) + }) +} + +// AddRpmLimit adds v to the "rpm_limit" field. +func (u *GroupUpsertOne) AddRpmLimit(v int) *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.AddRpmLimit(v) + }) +} + +// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create. +func (u *GroupUpsertOne) UpdateRpmLimit() *GroupUpsertOne { + return u.Update(func(s *GroupUpsert) { + s.UpdateRpmLimit() + }) +} + // Exec executes the query. func (u *GroupUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -2876,6 +2940,27 @@ func (u *GroupUpsertBulk) UpdateMessagesDispatchModelConfig() *GroupUpsertBulk { }) } +// SetRpmLimit sets the "rpm_limit" field. +func (u *GroupUpsertBulk) SetRpmLimit(v int) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.SetRpmLimit(v) + }) +} + +// AddRpmLimit adds v to the "rpm_limit" field. +func (u *GroupUpsertBulk) AddRpmLimit(v int) *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.AddRpmLimit(v) + }) +} + +// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create. +func (u *GroupUpsertBulk) UpdateRpmLimit() *GroupUpsertBulk { + return u.Update(func(s *GroupUpsert) { + s.UpdateRpmLimit() + }) +} + // Exec executes the query. func (u *GroupUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/group_update.go b/backend/ent/group_update.go index 7b6d6193..cc14f897 100644 --- a/backend/ent/group_update.go +++ b/backend/ent/group_update.go @@ -567,6 +567,27 @@ func (_u *GroupUpdate) SetNillableMessagesDispatchModelConfig(v *domain.OpenAIMe return _u } +// SetRpmLimit sets the "rpm_limit" field. +func (_u *GroupUpdate) SetRpmLimit(v int) *GroupUpdate { + _u.mutation.ResetRpmLimit() + _u.mutation.SetRpmLimit(v) + return _u +} + +// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil. +func (_u *GroupUpdate) SetNillableRpmLimit(v *int) *GroupUpdate { + if v != nil { + _u.SetRpmLimit(*v) + } + return _u +} + +// AddRpmLimit adds value to the "rpm_limit" field. +func (_u *GroupUpdate) AddRpmLimit(v int) *GroupUpdate { + _u.mutation.AddRpmLimit(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *GroupUpdate) AddAPIKeyIDs(ids ...int64) *GroupUpdate { _u.mutation.AddAPIKeyIDs(ids...) @@ -1030,6 +1051,12 @@ func (_u *GroupUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok { _spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value) } + if value, ok := _u.mutation.RpmLimit(); ok { + _spec.SetField(group.FieldRpmLimit, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedRpmLimit(); ok { + _spec.AddField(group.FieldRpmLimit, field.TypeInt, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1875,6 +1902,27 @@ func (_u *GroupUpdateOne) SetNillableMessagesDispatchModelConfig(v *domain.OpenA return _u } +// SetRpmLimit sets the "rpm_limit" field. +func (_u *GroupUpdateOne) SetRpmLimit(v int) *GroupUpdateOne { + _u.mutation.ResetRpmLimit() + _u.mutation.SetRpmLimit(v) + return _u +} + +// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil. +func (_u *GroupUpdateOne) SetNillableRpmLimit(v *int) *GroupUpdateOne { + if v != nil { + _u.SetRpmLimit(*v) + } + return _u +} + +// AddRpmLimit adds value to the "rpm_limit" field. +func (_u *GroupUpdateOne) AddRpmLimit(v int) *GroupUpdateOne { + _u.mutation.AddRpmLimit(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *GroupUpdateOne) AddAPIKeyIDs(ids ...int64) *GroupUpdateOne { _u.mutation.AddAPIKeyIDs(ids...) @@ -2368,6 +2416,12 @@ func (_u *GroupUpdateOne) sqlSave(ctx context.Context) (_node *Group, err error) if value, ok := _u.mutation.MessagesDispatchModelConfig(); ok { _spec.SetField(group.FieldMessagesDispatchModelConfig, field.TypeJSON, value) } + if value, ok := _u.mutation.RpmLimit(); ok { + _spec.SetField(group.FieldRpmLimit, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedRpmLimit(); ok { + _spec.AddField(group.FieldRpmLimit, field.TypeInt, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index 38366e95..178ae170 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -654,6 +654,7 @@ var ( {Name: "require_privacy_set", Type: field.TypeBool, Default: false}, {Name: "default_mapped_model", Type: field.TypeString, Size: 100, Default: ""}, {Name: "messages_dispatch_model_config", Type: field.TypeJSON, SchemaType: map[string]string{"postgres": "jsonb"}}, + {Name: "rpm_limit", Type: field.TypeInt, Default: 0}, } // GroupsTable holds the schema information for the "groups" table. GroupsTable = &schema.Table{ @@ -1447,6 +1448,7 @@ var ( {Name: "balance_notify_threshold", Type: field.TypeFloat64, Nullable: true, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, {Name: "balance_notify_extra_emails", Type: field.TypeString, Default: "[]", SchemaType: map[string]string{"postgres": "text"}}, {Name: "total_recharged", Type: field.TypeFloat64, Default: 0, SchemaType: map[string]string{"postgres": "decimal(20,8)"}}, + {Name: "rpm_limit", Type: field.TypeInt, Default: 0}, } // UsersTable holds the schema information for the "users" table. UsersTable = &schema.Table{ diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index 568b3eb5..d616e4ae 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -14787,6 +14787,8 @@ type GroupMutation struct { require_privacy_set *bool default_mapped_model *string messages_dispatch_model_config *domain.OpenAIMessagesDispatchModelConfig + rpm_limit *int + addrpm_limit *int clearedFields map[string]struct{} api_keys map[int64]struct{} removedapi_keys map[int64]struct{} @@ -16375,6 +16377,62 @@ func (m *GroupMutation) ResetMessagesDispatchModelConfig() { m.messages_dispatch_model_config = nil } +// SetRpmLimit sets the "rpm_limit" field. +func (m *GroupMutation) SetRpmLimit(i int) { + m.rpm_limit = &i + m.addrpm_limit = nil +} + +// RpmLimit returns the value of the "rpm_limit" field in the mutation. +func (m *GroupMutation) RpmLimit() (r int, exists bool) { + v := m.rpm_limit + if v == nil { + return + } + return *v, true +} + +// OldRpmLimit returns the old "rpm_limit" field's value of the Group entity. +// If the Group object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *GroupMutation) OldRpmLimit(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRpmLimit is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRpmLimit requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRpmLimit: %w", err) + } + return oldValue.RpmLimit, nil +} + +// AddRpmLimit adds i to the "rpm_limit" field. +func (m *GroupMutation) AddRpmLimit(i int) { + if m.addrpm_limit != nil { + *m.addrpm_limit += i + } else { + m.addrpm_limit = &i + } +} + +// AddedRpmLimit returns the value that was added to the "rpm_limit" field in this mutation. +func (m *GroupMutation) AddedRpmLimit() (r int, exists bool) { + v := m.addrpm_limit + if v == nil { + return + } + return *v, true +} + +// ResetRpmLimit resets all changes to the "rpm_limit" field. +func (m *GroupMutation) ResetRpmLimit() { + m.rpm_limit = nil + m.addrpm_limit = nil +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. func (m *GroupMutation) AddAPIKeyIDs(ids ...int64) { if m.api_keys == nil { @@ -16733,7 +16791,7 @@ func (m *GroupMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *GroupMutation) Fields() []string { - fields := make([]string, 0, 30) + fields := make([]string, 0, 31) if m.created_at != nil { fields = append(fields, group.FieldCreatedAt) } @@ -16824,6 +16882,9 @@ func (m *GroupMutation) Fields() []string { if m.messages_dispatch_model_config != nil { fields = append(fields, group.FieldMessagesDispatchModelConfig) } + if m.rpm_limit != nil { + fields = append(fields, group.FieldRpmLimit) + } return fields } @@ -16892,6 +16953,8 @@ func (m *GroupMutation) Field(name string) (ent.Value, bool) { return m.DefaultMappedModel() case group.FieldMessagesDispatchModelConfig: return m.MessagesDispatchModelConfig() + case group.FieldRpmLimit: + return m.RpmLimit() } return nil, false } @@ -16961,6 +17024,8 @@ func (m *GroupMutation) OldField(ctx context.Context, name string) (ent.Value, e return m.OldDefaultMappedModel(ctx) case group.FieldMessagesDispatchModelConfig: return m.OldMessagesDispatchModelConfig(ctx) + case group.FieldRpmLimit: + return m.OldRpmLimit(ctx) } return nil, fmt.Errorf("unknown Group field %s", name) } @@ -17180,6 +17245,13 @@ func (m *GroupMutation) SetField(name string, value ent.Value) error { } m.SetMessagesDispatchModelConfig(v) return nil + case group.FieldRpmLimit: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRpmLimit(v) + return nil } return fmt.Errorf("unknown Group field %s", name) } @@ -17221,6 +17293,9 @@ func (m *GroupMutation) AddedFields() []string { if m.addsort_order != nil { fields = append(fields, group.FieldSortOrder) } + if m.addrpm_limit != nil { + fields = append(fields, group.FieldRpmLimit) + } return fields } @@ -17251,6 +17326,8 @@ func (m *GroupMutation) AddedField(name string) (ent.Value, bool) { return m.AddedFallbackGroupIDOnInvalidRequest() case group.FieldSortOrder: return m.AddedSortOrder() + case group.FieldRpmLimit: + return m.AddedRpmLimit() } return nil, false } @@ -17337,6 +17414,13 @@ func (m *GroupMutation) AddField(name string, value ent.Value) error { } m.AddSortOrder(v) return nil + case group.FieldRpmLimit: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddRpmLimit(v) + return nil } return fmt.Errorf("unknown Group numeric field %s", name) } @@ -17523,6 +17607,9 @@ func (m *GroupMutation) ResetField(name string) error { case group.FieldMessagesDispatchModelConfig: m.ResetMessagesDispatchModelConfig() return nil + case group.FieldRpmLimit: + m.ResetRpmLimit() + return nil } return fmt.Errorf("unknown Group field %s", name) } @@ -37366,6 +37453,8 @@ type UserMutation struct { balance_notify_extra_emails *string total_recharged *float64 addtotal_recharged *float64 + rpm_limit *int + addrpm_limit *int clearedFields map[string]struct{} api_keys map[int64]struct{} removedapi_keys map[int64]struct{} @@ -38457,6 +38546,62 @@ func (m *UserMutation) ResetTotalRecharged() { m.addtotal_recharged = nil } +// SetRpmLimit sets the "rpm_limit" field. +func (m *UserMutation) SetRpmLimit(i int) { + m.rpm_limit = &i + m.addrpm_limit = nil +} + +// RpmLimit returns the value of the "rpm_limit" field in the mutation. +func (m *UserMutation) RpmLimit() (r int, exists bool) { + v := m.rpm_limit + if v == nil { + return + } + return *v, true +} + +// OldRpmLimit returns the old "rpm_limit" field's value of the User entity. +// If the User object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UserMutation) OldRpmLimit(ctx context.Context) (v int, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldRpmLimit is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldRpmLimit requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldRpmLimit: %w", err) + } + return oldValue.RpmLimit, nil +} + +// AddRpmLimit adds i to the "rpm_limit" field. +func (m *UserMutation) AddRpmLimit(i int) { + if m.addrpm_limit != nil { + *m.addrpm_limit += i + } else { + m.addrpm_limit = &i + } +} + +// AddedRpmLimit returns the value that was added to the "rpm_limit" field in this mutation. +func (m *UserMutation) AddedRpmLimit() (r int, exists bool) { + v := m.addrpm_limit + if v == nil { + return + } + return *v, true +} + +// ResetRpmLimit resets all changes to the "rpm_limit" field. +func (m *UserMutation) ResetRpmLimit() { + m.rpm_limit = nil + m.addrpm_limit = nil +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by ids. func (m *UserMutation) AddAPIKeyIDs(ids ...int64) { if m.api_keys == nil { @@ -39139,7 +39284,7 @@ func (m *UserMutation) Type() string { // order to get all numeric fields that were incremented/decremented, call // AddedFields(). func (m *UserMutation) Fields() []string { - fields := make([]string, 0, 22) + fields := make([]string, 0, 23) if m.created_at != nil { fields = append(fields, user.FieldCreatedAt) } @@ -39206,6 +39351,9 @@ func (m *UserMutation) Fields() []string { if m.total_recharged != nil { fields = append(fields, user.FieldTotalRecharged) } + if m.rpm_limit != nil { + fields = append(fields, user.FieldRpmLimit) + } return fields } @@ -39258,6 +39406,8 @@ func (m *UserMutation) Field(name string) (ent.Value, bool) { return m.BalanceNotifyExtraEmails() case user.FieldTotalRecharged: return m.TotalRecharged() + case user.FieldRpmLimit: + return m.RpmLimit() } return nil, false } @@ -39311,6 +39461,8 @@ func (m *UserMutation) OldField(ctx context.Context, name string) (ent.Value, er return m.OldBalanceNotifyExtraEmails(ctx) case user.FieldTotalRecharged: return m.OldTotalRecharged(ctx) + case user.FieldRpmLimit: + return m.OldRpmLimit(ctx) } return nil, fmt.Errorf("unknown User field %s", name) } @@ -39474,6 +39626,13 @@ func (m *UserMutation) SetField(name string, value ent.Value) error { } m.SetTotalRecharged(v) return nil + case user.FieldRpmLimit: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetRpmLimit(v) + return nil } return fmt.Errorf("unknown User field %s", name) } @@ -39494,6 +39653,9 @@ func (m *UserMutation) AddedFields() []string { if m.addtotal_recharged != nil { fields = append(fields, user.FieldTotalRecharged) } + if m.addrpm_limit != nil { + fields = append(fields, user.FieldRpmLimit) + } return fields } @@ -39510,6 +39672,8 @@ func (m *UserMutation) AddedField(name string) (ent.Value, bool) { return m.AddedBalanceNotifyThreshold() case user.FieldTotalRecharged: return m.AddedTotalRecharged() + case user.FieldRpmLimit: + return m.AddedRpmLimit() } return nil, false } @@ -39547,6 +39711,13 @@ func (m *UserMutation) AddField(name string, value ent.Value) error { } m.AddTotalRecharged(v) return nil + case user.FieldRpmLimit: + v, ok := value.(int) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddRpmLimit(v) + return nil } return fmt.Errorf("unknown User numeric field %s", name) } @@ -39679,6 +39850,9 @@ func (m *UserMutation) ResetField(name string) error { case user.FieldTotalRecharged: m.ResetTotalRecharged() return nil + case user.FieldRpmLimit: + m.ResetRpmLimit() + return nil } return fmt.Errorf("unknown User field %s", name) } diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index aaa939c5..6b344a55 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -845,6 +845,10 @@ func init() { groupDescMessagesDispatchModelConfig := groupFields[26].Descriptor() // group.DefaultMessagesDispatchModelConfig holds the default value on creation for the messages_dispatch_model_config field. group.DefaultMessagesDispatchModelConfig = groupDescMessagesDispatchModelConfig.Default.(domain.OpenAIMessagesDispatchModelConfig) + // groupDescRpmLimit is the schema descriptor for rpm_limit field. + groupDescRpmLimit := groupFields[27].Descriptor() + // group.DefaultRpmLimit holds the default value on creation for the rpm_limit field. + group.DefaultRpmLimit = groupDescRpmLimit.Default.(int) idempotencyrecordMixin := schema.IdempotencyRecord{}.Mixin() idempotencyrecordMixinFields0 := idempotencyrecordMixin[0].Fields() _ = idempotencyrecordMixinFields0 @@ -1825,6 +1829,10 @@ func init() { userDescTotalRecharged := userFields[18].Descriptor() // user.DefaultTotalRecharged holds the default value on creation for the total_recharged field. user.DefaultTotalRecharged = userDescTotalRecharged.Default.(float64) + // userDescRpmLimit is the schema descriptor for rpm_limit field. + userDescRpmLimit := userFields[19].Descriptor() + // user.DefaultRpmLimit holds the default value on creation for the rpm_limit field. + user.DefaultRpmLimit = userDescRpmLimit.Default.(int) userallowedgroupFields := schema.UserAllowedGroup{}.Fields() _ = userallowedgroupFields // userallowedgroupDescCreatedAt is the schema descriptor for created_at field. diff --git a/backend/ent/schema/group.go b/backend/ent/schema/group.go index d78a6898..11f38d66 100644 --- a/backend/ent/schema/group.go +++ b/backend/ent/schema/group.go @@ -145,6 +145,11 @@ func (Group) Fields() []ent.Field { Default(domain.OpenAIMessagesDispatchModelConfig{}). SchemaType(map[string]string{dialect.Postgres: "jsonb"}). Comment("OpenAI Messages 调度模型配置:按 Claude 系列/精确模型映射到目标 GPT 模型"), + + // 分组级每分钟请求数上限(0 = 不限制)。设置后优先于用户级兜底生效。 + field.Int("rpm_limit"). + Default(0). + Comment("分组 RPM 上限,0 表示不限制;设置后接管该分组用户的限流"), } } diff --git a/backend/ent/schema/user.go b/backend/ent/schema/user.go index c0f0bdc1..83da5c32 100644 --- a/backend/ent/schema/user.go +++ b/backend/ent/schema/user.go @@ -108,6 +108,10 @@ func (User) Fields() []ent.Field { field.Float("total_recharged"). SchemaType(map[string]string{dialect.Postgres: "decimal(20,8)"}). Default(0), + + // 用户级每分钟请求数上限(0 = 不限制)。仅当所在分组未设置 rpm_limit 时作为兜底生效。 + field.Int("rpm_limit"). + Default(0), } } diff --git a/backend/ent/user.go b/backend/ent/user.go index 66f33623..06670444 100644 --- a/backend/ent/user.go +++ b/backend/ent/user.go @@ -61,6 +61,8 @@ type User struct { BalanceNotifyExtraEmails string `json:"balance_notify_extra_emails,omitempty"` // TotalRecharged holds the value of the "total_recharged" field. TotalRecharged float64 `json:"total_recharged,omitempty"` + // RpmLimit holds the value of the "rpm_limit" field. + RpmLimit int `json:"rpm_limit,omitempty"` // Edges holds the relations/edges for other nodes in the graph. // The values are being populated by the UserQuery when eager-loading is set. Edges UserEdges `json:"edges"` @@ -226,7 +228,7 @@ func (*User) scanValues(columns []string) ([]any, error) { values[i] = new(sql.NullBool) case user.FieldBalance, user.FieldBalanceNotifyThreshold, user.FieldTotalRecharged: values[i] = new(sql.NullFloat64) - case user.FieldID, user.FieldConcurrency: + case user.FieldID, user.FieldConcurrency, user.FieldRpmLimit: values[i] = new(sql.NullInt64) case user.FieldEmail, user.FieldPasswordHash, user.FieldRole, user.FieldStatus, user.FieldUsername, user.FieldNotes, user.FieldTotpSecretEncrypted, user.FieldSignupSource, user.FieldBalanceNotifyThresholdType, user.FieldBalanceNotifyExtraEmails: values[i] = new(sql.NullString) @@ -391,6 +393,12 @@ func (_m *User) assignValues(columns []string, values []any) error { } else if value.Valid { _m.TotalRecharged = value.Float64 } + case user.FieldRpmLimit: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field rpm_limit", values[i]) + } else if value.Valid { + _m.RpmLimit = int(value.Int64) + } default: _m.selectValues.Set(columns[i], values[i]) } @@ -569,6 +577,9 @@ func (_m *User) String() string { builder.WriteString(", ") builder.WriteString("total_recharged=") builder.WriteString(fmt.Sprintf("%v", _m.TotalRecharged)) + builder.WriteString(", ") + builder.WriteString("rpm_limit=") + builder.WriteString(fmt.Sprintf("%v", _m.RpmLimit)) builder.WriteByte(')') return builder.String() } diff --git a/backend/ent/user/user.go b/backend/ent/user/user.go index 567e3b14..e11a8a32 100644 --- a/backend/ent/user/user.go +++ b/backend/ent/user/user.go @@ -59,6 +59,8 @@ const ( FieldBalanceNotifyExtraEmails = "balance_notify_extra_emails" // FieldTotalRecharged holds the string denoting the total_recharged field in the database. FieldTotalRecharged = "total_recharged" + // FieldRpmLimit holds the string denoting the rpm_limit field in the database. + FieldRpmLimit = "rpm_limit" // EdgeAPIKeys holds the string denoting the api_keys edge name in mutations. EdgeAPIKeys = "api_keys" // EdgeRedeemCodes holds the string denoting the redeem_codes edge name in mutations. @@ -203,6 +205,7 @@ var Columns = []string{ FieldBalanceNotifyThreshold, FieldBalanceNotifyExtraEmails, FieldTotalRecharged, + FieldRpmLimit, } var ( @@ -271,6 +274,8 @@ var ( DefaultBalanceNotifyExtraEmails string // DefaultTotalRecharged holds the default value on creation for the "total_recharged" field. DefaultTotalRecharged float64 + // DefaultRpmLimit holds the default value on creation for the "rpm_limit" field. + DefaultRpmLimit int ) // OrderOption defines the ordering options for the User queries. @@ -391,6 +396,11 @@ func ByTotalRecharged(opts ...sql.OrderTermOption) OrderOption { return sql.OrderByField(FieldTotalRecharged, opts...).ToFunc() } +// ByRpmLimit orders the results by the rpm_limit field. +func ByRpmLimit(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldRpmLimit, opts...).ToFunc() +} + // ByAPIKeysCount orders the results by api_keys count. func ByAPIKeysCount(opts ...sql.OrderTermOption) OrderOption { return func(s *sql.Selector) { diff --git a/backend/ent/user/where.go b/backend/ent/user/where.go index cbcfcc26..05d3b35b 100644 --- a/backend/ent/user/where.go +++ b/backend/ent/user/where.go @@ -165,6 +165,11 @@ func TotalRecharged(v float64) predicate.User { return predicate.User(sql.FieldEQ(FieldTotalRecharged, v)) } +// RpmLimit applies equality check predicate on the "rpm_limit" field. It's identical to RpmLimitEQ. +func RpmLimit(v int) predicate.User { + return predicate.User(sql.FieldEQ(FieldRpmLimit, v)) +} + // CreatedAtEQ applies the EQ predicate on the "created_at" field. func CreatedAtEQ(v time.Time) predicate.User { return predicate.User(sql.FieldEQ(FieldCreatedAt, v)) @@ -1295,6 +1300,46 @@ func TotalRechargedLTE(v float64) predicate.User { return predicate.User(sql.FieldLTE(FieldTotalRecharged, v)) } +// RpmLimitEQ applies the EQ predicate on the "rpm_limit" field. +func RpmLimitEQ(v int) predicate.User { + return predicate.User(sql.FieldEQ(FieldRpmLimit, v)) +} + +// RpmLimitNEQ applies the NEQ predicate on the "rpm_limit" field. +func RpmLimitNEQ(v int) predicate.User { + return predicate.User(sql.FieldNEQ(FieldRpmLimit, v)) +} + +// RpmLimitIn applies the In predicate on the "rpm_limit" field. +func RpmLimitIn(vs ...int) predicate.User { + return predicate.User(sql.FieldIn(FieldRpmLimit, vs...)) +} + +// RpmLimitNotIn applies the NotIn predicate on the "rpm_limit" field. +func RpmLimitNotIn(vs ...int) predicate.User { + return predicate.User(sql.FieldNotIn(FieldRpmLimit, vs...)) +} + +// RpmLimitGT applies the GT predicate on the "rpm_limit" field. +func RpmLimitGT(v int) predicate.User { + return predicate.User(sql.FieldGT(FieldRpmLimit, v)) +} + +// RpmLimitGTE applies the GTE predicate on the "rpm_limit" field. +func RpmLimitGTE(v int) predicate.User { + return predicate.User(sql.FieldGTE(FieldRpmLimit, v)) +} + +// RpmLimitLT applies the LT predicate on the "rpm_limit" field. +func RpmLimitLT(v int) predicate.User { + return predicate.User(sql.FieldLT(FieldRpmLimit, v)) +} + +// RpmLimitLTE applies the LTE predicate on the "rpm_limit" field. +func RpmLimitLTE(v int) predicate.User { + return predicate.User(sql.FieldLTE(FieldRpmLimit, v)) +} + // HasAPIKeys applies the HasEdge predicate on the "api_keys" edge. func HasAPIKeys() predicate.User { return predicate.User(func(s *sql.Selector) { diff --git a/backend/ent/user_create.go b/backend/ent/user_create.go index db95e813..b4161128 100644 --- a/backend/ent/user_create.go +++ b/backend/ent/user_create.go @@ -325,6 +325,20 @@ func (_c *UserCreate) SetNillableTotalRecharged(v *float64) *UserCreate { return _c } +// SetRpmLimit sets the "rpm_limit" field. +func (_c *UserCreate) SetRpmLimit(v int) *UserCreate { + _c.mutation.SetRpmLimit(v) + return _c +} + +// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil. +func (_c *UserCreate) SetNillableRpmLimit(v *int) *UserCreate { + if v != nil { + _c.SetRpmLimit(*v) + } + return _c +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_c *UserCreate) AddAPIKeyIDs(ids ...int64) *UserCreate { _c.mutation.AddAPIKeyIDs(ids...) @@ -604,6 +618,10 @@ func (_c *UserCreate) defaults() error { v := user.DefaultTotalRecharged _c.mutation.SetTotalRecharged(v) } + if _, ok := _c.mutation.RpmLimit(); !ok { + v := user.DefaultRpmLimit + _c.mutation.SetRpmLimit(v) + } return nil } @@ -687,6 +705,9 @@ func (_c *UserCreate) check() error { if _, ok := _c.mutation.TotalRecharged(); !ok { return &ValidationError{Name: "total_recharged", err: errors.New(`ent: missing required field "User.total_recharged"`)} } + if _, ok := _c.mutation.RpmLimit(); !ok { + return &ValidationError{Name: "rpm_limit", err: errors.New(`ent: missing required field "User.rpm_limit"`)} + } return nil } @@ -802,6 +823,10 @@ func (_c *UserCreate) createSpec() (*User, *sqlgraph.CreateSpec) { _spec.SetField(user.FieldTotalRecharged, field.TypeFloat64, value) _node.TotalRecharged = value } + if value, ok := _c.mutation.RpmLimit(); ok { + _spec.SetField(user.FieldRpmLimit, field.TypeInt, value) + _node.RpmLimit = value + } if nodes := _c.mutation.APIKeysIDs(); len(nodes) > 0 { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1362,6 +1387,24 @@ func (u *UserUpsert) AddTotalRecharged(v float64) *UserUpsert { return u } +// SetRpmLimit sets the "rpm_limit" field. +func (u *UserUpsert) SetRpmLimit(v int) *UserUpsert { + u.Set(user.FieldRpmLimit, v) + return u +} + +// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create. +func (u *UserUpsert) UpdateRpmLimit() *UserUpsert { + u.SetExcluded(user.FieldRpmLimit) + return u +} + +// AddRpmLimit adds v to the "rpm_limit" field. +func (u *UserUpsert) AddRpmLimit(v int) *UserUpsert { + u.Add(user.FieldRpmLimit, v) + return u +} + // UpdateNewValues updates the mutable fields using the new values that were set on create. // Using this option is equivalent to using: // @@ -1771,6 +1814,27 @@ func (u *UserUpsertOne) UpdateTotalRecharged() *UserUpsertOne { }) } +// SetRpmLimit sets the "rpm_limit" field. +func (u *UserUpsertOne) SetRpmLimit(v int) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.SetRpmLimit(v) + }) +} + +// AddRpmLimit adds v to the "rpm_limit" field. +func (u *UserUpsertOne) AddRpmLimit(v int) *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.AddRpmLimit(v) + }) +} + +// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create. +func (u *UserUpsertOne) UpdateRpmLimit() *UserUpsertOne { + return u.Update(func(s *UserUpsert) { + s.UpdateRpmLimit() + }) +} + // Exec executes the query. func (u *UserUpsertOne) Exec(ctx context.Context) error { if len(u.create.conflict) == 0 { @@ -2346,6 +2410,27 @@ func (u *UserUpsertBulk) UpdateTotalRecharged() *UserUpsertBulk { }) } +// SetRpmLimit sets the "rpm_limit" field. +func (u *UserUpsertBulk) SetRpmLimit(v int) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.SetRpmLimit(v) + }) +} + +// AddRpmLimit adds v to the "rpm_limit" field. +func (u *UserUpsertBulk) AddRpmLimit(v int) *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.AddRpmLimit(v) + }) +} + +// UpdateRpmLimit sets the "rpm_limit" field to the value that was provided on create. +func (u *UserUpsertBulk) UpdateRpmLimit() *UserUpsertBulk { + return u.Update(func(s *UserUpsert) { + s.UpdateRpmLimit() + }) +} + // Exec executes the query. func (u *UserUpsertBulk) Exec(ctx context.Context) error { if u.create.err != nil { diff --git a/backend/ent/user_update.go b/backend/ent/user_update.go index 677eeb6b..f1d759ce 100644 --- a/backend/ent/user_update.go +++ b/backend/ent/user_update.go @@ -389,6 +389,27 @@ func (_u *UserUpdate) AddTotalRecharged(v float64) *UserUpdate { return _u } +// SetRpmLimit sets the "rpm_limit" field. +func (_u *UserUpdate) SetRpmLimit(v int) *UserUpdate { + _u.mutation.ResetRpmLimit() + _u.mutation.SetRpmLimit(v) + return _u +} + +// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil. +func (_u *UserUpdate) SetNillableRpmLimit(v *int) *UserUpdate { + if v != nil { + _u.SetRpmLimit(*v) + } + return _u +} + +// AddRpmLimit adds value to the "rpm_limit" field. +func (_u *UserUpdate) AddRpmLimit(v int) *UserUpdate { + _u.mutation.AddRpmLimit(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *UserUpdate) AddAPIKeyIDs(ids ...int64) *UserUpdate { _u.mutation.AddAPIKeyIDs(ids...) @@ -1008,6 +1029,12 @@ func (_u *UserUpdate) sqlSave(ctx context.Context) (_node int, err error) { if value, ok := _u.mutation.AddedTotalRecharged(); ok { _spec.AddField(user.FieldTotalRecharged, field.TypeFloat64, value) } + if value, ok := _u.mutation.RpmLimit(); ok { + _spec.SetField(user.FieldRpmLimit, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedRpmLimit(); ok { + _spec.AddField(user.FieldRpmLimit, field.TypeInt, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, @@ -1930,6 +1957,27 @@ func (_u *UserUpdateOne) AddTotalRecharged(v float64) *UserUpdateOne { return _u } +// SetRpmLimit sets the "rpm_limit" field. +func (_u *UserUpdateOne) SetRpmLimit(v int) *UserUpdateOne { + _u.mutation.ResetRpmLimit() + _u.mutation.SetRpmLimit(v) + return _u +} + +// SetNillableRpmLimit sets the "rpm_limit" field if the given value is not nil. +func (_u *UserUpdateOne) SetNillableRpmLimit(v *int) *UserUpdateOne { + if v != nil { + _u.SetRpmLimit(*v) + } + return _u +} + +// AddRpmLimit adds value to the "rpm_limit" field. +func (_u *UserUpdateOne) AddRpmLimit(v int) *UserUpdateOne { + _u.mutation.AddRpmLimit(v) + return _u +} + // AddAPIKeyIDs adds the "api_keys" edge to the APIKey entity by IDs. func (_u *UserUpdateOne) AddAPIKeyIDs(ids ...int64) *UserUpdateOne { _u.mutation.AddAPIKeyIDs(ids...) @@ -2579,6 +2627,12 @@ func (_u *UserUpdateOne) sqlSave(ctx context.Context) (_node *User, err error) { if value, ok := _u.mutation.AddedTotalRecharged(); ok { _spec.AddField(user.FieldTotalRecharged, field.TypeFloat64, value) } + if value, ok := _u.mutation.RpmLimit(); ok { + _spec.SetField(user.FieldRpmLimit, field.TypeInt, value) + } + if value, ok := _u.mutation.AddedRpmLimit(); ok { + _spec.AddField(user.FieldRpmLimit, field.TypeInt, value) + } if _u.mutation.APIKeysCleared() { edge := &sqlgraph.EdgeSpec{ Rel: sqlgraph.O2M, diff --git a/backend/go.mod b/backend/go.mod index 627851bf..982bf91b 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -104,6 +104,7 @@ require ( github.com/goccy/go-json v0.10.2 // indirect github.com/google/go-cmp v0.7.0 // indirect github.com/google/go-querystring v1.1.0 // indirect + github.com/google/subcommands v1.2.0 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 // indirect github.com/hashicorp/hcl v1.0.0 // indirect github.com/hashicorp/hcl/v2 v2.18.1 // indirect diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go index 3a395342..2fe29fa3 100644 --- a/backend/internal/handler/admin/admin_service_stub_test.go +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -183,6 +183,17 @@ func (s *stubAdminService) GetUserUsageStats(ctx context.Context, userID int64, return map[string]any{"user_id": userID}, nil } +func (s *stubAdminService) GetUserRPMStatus(ctx context.Context, userID int64) (*service.UserRPMStatus, error) { + user, err := s.GetUser(ctx, userID) + if err != nil { + return nil, err + } + return &service.UserRPMStatus{ + UserRPMUsed: 0, + UserRPMLimit: user.RPMLimit, + }, nil +} + func (s *stubAdminService) BindUserAuthIdentity(ctx context.Context, userID int64, input service.AdminBindAuthIdentityInput) (*service.AdminBoundAuthIdentity, error) { s.boundAuthIdentityFor = userID copied := input @@ -276,6 +287,14 @@ func (s *stubAdminService) BatchSetGroupRateMultipliers(_ context.Context, _ int return nil } +func (s *stubAdminService) ClearGroupRPMOverrides(_ context.Context, _ int64) error { + return nil +} + +func (s *stubAdminService) BatchSetGroupRPMOverrides(_ context.Context, _ int64, _ []service.GroupRPMOverrideInput) error { + return nil +} + func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string, groupID int64, privacyMode string, sortBy, sortOrder string) ([]service.Account, int64, error) { s.lastListAccounts.platform = platform s.lastListAccounts.accountType = accountType diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index cb2bd201..65e5ec78 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -110,6 +110,8 @@ type CreateGroupRequest struct { RequirePrivacySet bool `json:"require_privacy_set"` DefaultMappedModel string `json:"default_mapped_model"` MessagesDispatchModelConfig service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"` + // 分组 RPM 上限(0 = 不限制) + RPMLimit int `json:"rpm_limit"` // 从指定分组复制账号(创建后自动绑定) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` } @@ -145,6 +147,8 @@ type UpdateGroupRequest struct { RequirePrivacySet *bool `json:"require_privacy_set"` DefaultMappedModel *string `json:"default_mapped_model"` MessagesDispatchModelConfig *service.OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config"` + // 分组 RPM 上限(0 = 不限制);nil 表示未提供不改动 + RPMLimit *int `json:"rpm_limit"` // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 `json:"copy_accounts_from_group_ids"` } @@ -262,6 +266,7 @@ func (h *GroupHandler) Create(c *gin.Context) { RequirePrivacySet: req.RequirePrivacySet, DefaultMappedModel: req.DefaultMappedModel, MessagesDispatchModelConfig: req.MessagesDispatchModelConfig, + RPMLimit: req.RPMLimit, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) if err != nil { @@ -313,6 +318,7 @@ func (h *GroupHandler) Update(c *gin.Context) { RequirePrivacySet: req.RequirePrivacySet, DefaultMappedModel: req.DefaultMappedModel, MessagesDispatchModelConfig: req.MessagesDispatchModelConfig, + RPMLimit: req.RPMLimit, CopyAccountsFromGroupIDs: req.CopyAccountsFromGroupIDs, }) if err != nil { @@ -477,6 +483,51 @@ func (h *GroupHandler) BatchSetGroupRateMultipliers(c *gin.Context) { response.Success(c, gin.H{"message": "Rate multipliers updated successfully"}) } +// BatchSetGroupRPMOverridesRequest represents batch set rpm_override request +type BatchSetGroupRPMOverridesRequest struct { + Entries []service.GroupRPMOverrideInput `json:"entries" binding:"required"` +} + +// BatchSetGroupRPMOverrides handles batch setting rpm_override for users in a group +// PUT /api/v1/admin/groups/:id/rpm-overrides +func (h *GroupHandler) BatchSetGroupRPMOverrides(c *gin.Context) { + groupID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid group ID") + return + } + + var req BatchSetGroupRPMOverridesRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + + if err := h.adminService.BatchSetGroupRPMOverrides(c.Request.Context(), groupID, req.Entries); err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"message": "RPM overrides updated successfully"}) +} + +// ClearGroupRPMOverrides handles clearing all rpm_override for a group +// DELETE /api/v1/admin/groups/:id/rpm-overrides +func (h *GroupHandler) ClearGroupRPMOverrides(c *gin.Context) { + groupID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid group ID") + return + } + + if err := h.adminService.ClearGroupRPMOverrides(c.Request.Context(), groupID); err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, gin.H{"message": "RPM overrides cleared successfully"}) +} + // UpdateSortOrderRequest represents the request to update group sort orders type UpdateSortOrderRequest struct { Updates []struct { diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 09dc8251..4277f0f1 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -185,6 +185,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { CustomEndpoints: dto.ParseCustomEndpoints(settings.CustomEndpoints), DefaultConcurrency: settings.DefaultConcurrency, DefaultBalance: settings.DefaultBalance, + DefaultUserRPMLimit: settings.DefaultUserRPMLimit, DefaultSubscriptions: defaultSubscriptions, EnableModelFallback: settings.EnableModelFallback, FallbackModelAnthropic: settings.FallbackModelAnthropic, @@ -337,6 +338,7 @@ type UpdateSettingsRequest struct { // 默认配置 DefaultConcurrency int `json:"default_concurrency"` DefaultBalance float64 `json:"default_balance"` + DefaultUserRPMLimit int `json:"default_user_rpm_limit"` DefaultSubscriptions []dto.DefaultSubscriptionSetting `json:"default_subscriptions"` AuthSourceDefaultEmailBalance *float64 `json:"auth_source_default_email_balance"` AuthSourceDefaultEmailConcurrency *int `json:"auth_source_default_email_concurrency"` @@ -1117,6 +1119,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { CustomEndpoints: customEndpointsJSON, DefaultConcurrency: req.DefaultConcurrency, DefaultBalance: req.DefaultBalance, + DefaultUserRPMLimit: req.DefaultUserRPMLimit, DefaultSubscriptions: defaultSubscriptions, EnableModelFallback: req.EnableModelFallback, FallbackModelAnthropic: req.FallbackModelAnthropic, @@ -1430,6 +1433,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { CustomEndpoints: dto.ParseCustomEndpoints(updatedSettings.CustomEndpoints), DefaultConcurrency: updatedSettings.DefaultConcurrency, DefaultBalance: updatedSettings.DefaultBalance, + DefaultUserRPMLimit: updatedSettings.DefaultUserRPMLimit, DefaultSubscriptions: updatedDefaultSubscriptions, EnableModelFallback: updatedSettings.EnableModelFallback, FallbackModelAnthropic: updatedSettings.FallbackModelAnthropic, diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index b2ed9d18..3d80107f 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -40,6 +40,7 @@ type CreateUserRequest struct { Notes string `json:"notes"` Balance float64 `json:"balance"` Concurrency int `json:"concurrency"` + RPMLimit int `json:"rpm_limit"` AllowedGroups []int64 `json:"allowed_groups"` } @@ -52,6 +53,7 @@ type UpdateUserRequest struct { Notes *string `json:"notes"` Balance *float64 `json:"balance"` Concurrency *int `json:"concurrency"` + RPMLimit *int `json:"rpm_limit"` Status string `json:"status" binding:"omitempty,oneof=active disabled"` AllowedGroups *[]int64 `json:"allowed_groups"` // GroupRates 用户专属分组倍率配置 @@ -243,6 +245,7 @@ func (h *UserHandler) Create(c *gin.Context) { Notes: req.Notes, Balance: req.Balance, Concurrency: req.Concurrency, + RPMLimit: req.RPMLimit, AllowedGroups: req.AllowedGroups, }) if err != nil { @@ -276,6 +279,7 @@ func (h *UserHandler) Update(c *gin.Context) { Notes: req.Notes, Balance: req.Balance, Concurrency: req.Concurrency, + RPMLimit: req.RPMLimit, Status: req.Status, AllowedGroups: req.AllowedGroups, GroupRates: req.GroupRates, @@ -455,3 +459,21 @@ func (h *UserHandler) ReplaceGroup(c *gin.Context) { "migrated_keys": result.MigratedKeys, }) } + +// GetUserRPMStatus 返回指定用户当前分钟的 RPM 用量 +// GET /api/v1/admin/users/:id/rpm-status +func (h *UserHandler) GetUserRPMStatus(c *gin.Context) { + userID, err := strconv.ParseInt(c.Param("id"), 10, 64) + if err != nil { + response.BadRequest(c, "Invalid user ID") + return + } + + status, err := h.adminService.GetUserRPMStatus(c.Request.Context(), userID) + if err != nil { + response.ErrorFrom(c, err) + return + } + + response.Success(c, status) +} diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index 9780ff79..f7503c2e 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -29,6 +29,7 @@ func UserFromServiceShallow(u *service.User) *User { BalanceNotifyThreshold: u.BalanceNotifyThreshold, BalanceNotifyExtraEmails: NotifyEmailEntriesFromService(u.BalanceNotifyExtraEmails), TotalRecharged: u.TotalRecharged, + RPMLimit: u.RPMLimit, } } @@ -184,6 +185,7 @@ func groupFromServiceBase(g *service.Group) Group { AllowMessagesDispatch: g.AllowMessagesDispatch, RequireOAuthOnly: g.RequireOAuthOnly, RequirePrivacySet: g.RequirePrivacySet, + RPMLimit: g.RPMLimit, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, } diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 193dc940..2affbc46 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -108,6 +108,7 @@ type SystemSettings struct { DefaultConcurrency int `json:"default_concurrency"` DefaultBalance float64 `json:"default_balance"` + DefaultUserRPMLimit int `json:"default_user_rpm_limit"` DefaultSubscriptions []DefaultSubscriptionSetting `json:"default_subscriptions"` // Model fallback configuration diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index c0bce40b..5cc2f8e4 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -26,6 +26,9 @@ type User struct { BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails"` TotalRecharged float64 `json:"total_recharged"` + // RPMLimit 用户级每分钟请求数上限(0 = 不限制),仅在所用分组未设置 rpm_limit 时作为兜底生效。 + RPMLimit int `json:"rpm_limit"` + APIKeys []APIKey `json:"api_keys,omitempty"` Subscriptions []UserSubscription `json:"subscriptions,omitempty"` } @@ -108,6 +111,9 @@ type Group struct { RequireOAuthOnly bool `json:"require_oauth_only"` RequirePrivacySet bool `json:"require_privacy_set"` + // RPMLimit 分组级每分钟请求数上限(0 = 不限制),设置后覆盖用户级 rpm_limit。 + RPMLimit int `json:"rpm_limit"` + CreatedAt time.Time `json:"created_at"` UpdatedAt time.Time `json:"updated_at"` } diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 24ef9c3d..ef532559 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -243,7 +243,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { // 2. 【新增】Wait后二次检查余额/订阅 if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { reqLog.Info("gateway.billing_eligibility_check_failed", zap.Error(err)) - status, code, message := billingErrorDetails(err) + status, code, message, retryAfter := billingErrorDetails(err) + if retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } h.handleStreamingAwareError(c, status, code, message, streamStarted) return } @@ -758,7 +761,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } fallbackAPIKey := cloneAPIKeyWithGroup(apiKey, fallbackGroup) if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), fallbackAPIKey.User, fallbackAPIKey, fallbackGroup, nil); err != nil { - status, code, message := billingErrorDetails(err) + status, code, message, retryAfter := billingErrorDetails(err) + if retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } h.handleStreamingAwareError(c, status, code, message, streamStarted) return } @@ -1464,7 +1470,10 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) { // 校验 billing eligibility(订阅/余额) // 【注意】不计算并发,但需要校验订阅/余额 if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { - status, code, message := billingErrorDetails(err) + status, code, message, retryAfter := billingErrorDetails(err) + if retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } h.errorResponse(c, status, code, message) return } @@ -1707,25 +1716,32 @@ func sendMockInterceptResponse(c *gin.Context, model string, interceptType Inter c.JSON(http.StatusOK, response) } -func billingErrorDetails(err error) (status int, code, message string) { +func billingErrorDetails(err error) (status int, code, message string, retryAfter int) { if errors.Is(err, service.ErrBillingServiceUnavailable) { msg := pkgerrors.Message(err) if msg == "" { msg = "Billing service temporarily unavailable. Please retry later." } - return http.StatusServiceUnavailable, "billing_service_error", msg + return http.StatusServiceUnavailable, "billing_service_error", msg, 0 } if errors.Is(err, service.ErrAPIKeyRateLimit5hExceeded) { msg := pkgerrors.Message(err) - return http.StatusTooManyRequests, "rate_limit_exceeded", msg + return http.StatusTooManyRequests, "rate_limit_exceeded", msg, 0 } if errors.Is(err, service.ErrAPIKeyRateLimit1dExceeded) { msg := pkgerrors.Message(err) - return http.StatusTooManyRequests, "rate_limit_exceeded", msg + return http.StatusTooManyRequests, "rate_limit_exceeded", msg, 0 } if errors.Is(err, service.ErrAPIKeyRateLimit7dExceeded) { msg := pkgerrors.Message(err) - return http.StatusTooManyRequests, "rate_limit_exceeded", msg + return http.StatusTooManyRequests, "rate_limit_exceeded", msg, 0 + } + // 用户/分组 RPM 超限统一映射为 HTTP 429;保留与其它 rate_limit 一致的错误码便于客户端分类。 + // 返回 Retry-After 秒数(当前分钟剩余秒数),让 SDK 自动退避。 + if errors.Is(err, service.ErrGroupRPMExceeded) || errors.Is(err, service.ErrUserRPMExceeded) { + msg := pkgerrors.Message(err) + retrySeconds := 60 - int(time.Now().Unix()%60) + return http.StatusTooManyRequests, "rate_limit_exceeded", msg, retrySeconds } msg := pkgerrors.Message(err) if msg == "" { @@ -1735,7 +1751,7 @@ func billingErrorDetails(err error) (status int, code, message string) { ).Warn("gateway.billing_error_missing_message") msg = "Billing error" } - return http.StatusForbidden, "billing_error", msg + return http.StatusForbidden, "billing_error", msg, 0 } func (h *GatewayHandler) metadataBridgeEnabled() bool { diff --git a/backend/internal/handler/gateway_handler_billing_error_test.go b/backend/internal/handler/gateway_handler_billing_error_test.go new file mode 100644 index 00000000..e8a88802 --- /dev/null +++ b/backend/internal/handler/gateway_handler_billing_error_test.go @@ -0,0 +1,54 @@ +package handler + +import ( + "net/http" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestBillingErrorDetails_MapsGroupRPMExceededToTooManyRequests(t *testing.T) { + status, code, msg, retryAfter := billingErrorDetails(service.ErrGroupRPMExceeded) + require.Equal(t, http.StatusTooManyRequests, status) + require.Equal(t, "rate_limit_exceeded", code) + require.NotEmpty(t, msg) + require.Greater(t, retryAfter, 0, "RPM exceeded should return positive Retry-After") + require.LessOrEqual(t, retryAfter, 60) +} + +func TestBillingErrorDetails_MapsUserRPMExceededToTooManyRequests(t *testing.T) { + status, code, msg, retryAfter := billingErrorDetails(service.ErrUserRPMExceeded) + require.Equal(t, http.StatusTooManyRequests, status) + require.Equal(t, "rate_limit_exceeded", code) + require.NotEmpty(t, msg) + require.Greater(t, retryAfter, 0, "RPM exceeded should return positive Retry-After") + require.LessOrEqual(t, retryAfter, 60) +} + +func TestBillingErrorDetails_APIKeyRateLimitStillMaps(t *testing.T) { + // 回归保护:加 RPM 分支后不应影响已有 APIKey rate limit 的映射。 + for _, err := range []error{ + service.ErrAPIKeyRateLimit5hExceeded, + service.ErrAPIKeyRateLimit1dExceeded, + service.ErrAPIKeyRateLimit7dExceeded, + } { + status, code, _, _ := billingErrorDetails(err) + require.Equal(t, http.StatusTooManyRequests, status, "status for %v", err) + require.Equal(t, "rate_limit_exceeded", code) + } +} + +func TestBillingErrorDetails_BillingServiceUnavailableMapsTo503(t *testing.T) { + status, code, _, retryAfter := billingErrorDetails(service.ErrBillingServiceUnavailable) + require.Equal(t, http.StatusServiceUnavailable, status) + require.Equal(t, "billing_service_error", code) + require.Equal(t, 0, retryAfter, "non-RPM errors should not set Retry-After") +} + +func TestBillingErrorDetails_UnknownErrorFallsBackTo403(t *testing.T) { + status, code, msg, _ := billingErrorDetails(service.ErrInsufficientBalance) + require.Equal(t, http.StatusForbidden, status) + require.Equal(t, "billing_error", code) + require.NotEmpty(t, msg) +} diff --git a/backend/internal/handler/gateway_handler_chat_completions.go b/backend/internal/handler/gateway_handler_chat_completions.go index be267332..4290e54b 100644 --- a/backend/internal/handler/gateway_handler_chat_completions.go +++ b/backend/internal/handler/gateway_handler_chat_completions.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net/http" + "strconv" "time" pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" @@ -136,7 +137,10 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) { // 2. Re-check billing if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { reqLog.Info("gateway.cc.billing_check_failed", zap.Error(err)) - status, code, message := billingErrorDetails(err) + status, code, message, retryAfter := billingErrorDetails(err) + if retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } h.chatCompletionsErrorResponse(c, status, code, message) return } diff --git a/backend/internal/handler/gateway_handler_responses.go b/backend/internal/handler/gateway_handler_responses.go index e908eb9e..683cf2b7 100644 --- a/backend/internal/handler/gateway_handler_responses.go +++ b/backend/internal/handler/gateway_handler_responses.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net/http" + "strconv" "time" pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" @@ -141,7 +142,10 @@ func (h *GatewayHandler) Responses(c *gin.Context) { // 2. Re-check billing if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { reqLog.Info("gateway.responses.billing_check_failed", zap.Error(err)) - status, code, message := billingErrorDetails(err) + status, code, message, retryAfter := billingErrorDetails(err) + if retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } h.responsesErrorResponse(c, status, code, message) return } diff --git a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go index 1fdc46ba..71030140 100644 --- a/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go +++ b/backend/internal/handler/gateway_handler_warmup_intercept_unit_test.go @@ -173,7 +173,7 @@ func newTestGatewayHandler(t *testing.T, group *service.Group, accounts []*servi // RunModeSimple:跳过计费检查,避免引入 repo/cache 依赖。 cfg := &config.Config{RunMode: config.RunModeSimple} - billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, cfg) + billingCacheSvc := service.NewBillingCacheService(nil, nil, nil, nil, nil, nil, cfg) concurrencySvc := service.NewConcurrencyService(&fakeConcurrencyCache{}) concurrencyHelper := NewConcurrencyHelper(concurrencySvc, SSEPingFormatClaude, 0) diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index d200c17c..2a34e3f0 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -9,6 +9,7 @@ import ( "errors" "net/http" "regexp" + "strconv" "strings" "github.com/Wei-Shaw/sub2api/internal/domain" @@ -241,7 +242,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { // 2) billing eligibility check (after wait) if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { reqLog.Info("gemini.billing_eligibility_check_failed", zap.Error(err)) - status, _, message := billingErrorDetails(err) + status, _, message, retryAfter := billingErrorDetails(err) + if retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } googleError(c, status, message) return } diff --git a/backend/internal/handler/openai_chat_completions.go b/backend/internal/handler/openai_chat_completions.go index 991cbb91..3c4e6251 100644 --- a/backend/internal/handler/openai_chat_completions.go +++ b/backend/internal/handler/openai_chat_completions.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net/http" + "strconv" "time" pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil" @@ -101,7 +102,10 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err)) - status, code, message := billingErrorDetails(err) + status, code, message, retryAfter := billingErrorDetails(err) + if retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } h.handleStreamingAwareError(c, status, code, message, streamStarted) return } diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index 43999a01..1c975573 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -228,7 +228,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { // 2. Re-check billing eligibility after wait if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { reqLog.Info("openai.billing_eligibility_check_failed", zap.Error(err)) - status, code, message := billingErrorDetails(err) + status, code, message, retryAfter := billingErrorDetails(err) + if retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } h.handleStreamingAwareError(c, status, code, message, streamStarted) return } @@ -594,7 +597,10 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { reqLog.Info("openai_messages.billing_eligibility_check_failed", zap.Error(err)) - status, code, message := billingErrorDetails(err) + status, code, message, retryAfter := billingErrorDetails(err) + if retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } h.anthropicStreamingAwareError(c, status, code, message, streamStarted) return } diff --git a/backend/internal/handler/openai_images.go b/backend/internal/handler/openai_images.go index 8dbf8935..403b41ef 100644 --- a/backend/internal/handler/openai_images.go +++ b/backend/internal/handler/openai_images.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net/http" + "strconv" "strings" "time" @@ -108,7 +109,10 @@ func (h *OpenAIGatewayHandler) Images(c *gin.Context) { if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil { reqLog.Info("openai.images.billing_eligibility_check_failed", zap.Error(err)) - status, code, message := billingErrorDetails(err) + status, code, message, retryAfter := billingErrorDetails(err) + if retryAfter > 0 { + c.Header("Retry-After", strconv.Itoa(retryAfter)) + } h.handleStreamingAwareError(c, status, code, message, streamStarted) return } diff --git a/backend/internal/repository/api_key_repo.go b/backend/internal/repository/api_key_repo.go index 36d80309..3a527405 100644 --- a/backend/internal/repository/api_key_repo.go +++ b/backend/internal/repository/api_key_repo.go @@ -152,6 +152,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se user.FieldSignupSource, user.FieldLastLoginAt, user.FieldLastActiveAt, + user.FieldRpmLimit, ) }). WithGroup(func(q *dbent.GroupQuery) { @@ -178,6 +179,7 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se group.FieldAllowMessagesDispatch, group.FieldDefaultMappedModel, group.FieldMessagesDispatchModelConfig, + group.FieldRpmLimit, ) }). Only(ctx) @@ -669,6 +671,7 @@ func userEntityToService(u *dbent.User) *service.User { BalanceNotifyThresholdType: u.BalanceNotifyThresholdType, BalanceNotifyThreshold: u.BalanceNotifyThreshold, TotalRecharged: u.TotalRecharged, + RPMLimit: u.RpmLimit, CreatedAt: u.CreatedAt, UpdatedAt: u.UpdatedAt, } @@ -713,6 +716,7 @@ func groupEntityToService(g *dbent.Group) *service.Group { RequirePrivacySet: g.RequirePrivacySet, DefaultMappedModel: g.DefaultMappedModel, MessagesDispatchModelConfig: g.MessagesDispatchModelConfig, + RPMLimit: g.RpmLimit, CreatedAt: g.CreatedAt, UpdatedAt: g.UpdatedAt, } diff --git a/backend/internal/repository/group_repo.go b/backend/internal/repository/group_repo.go index c17e3365..5e16475a 100644 --- a/backend/internal/repository/group_repo.go +++ b/backend/internal/repository/group_repo.go @@ -63,7 +63,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er SetRequireOauthOnly(groupIn.RequireOAuthOnly). SetRequirePrivacySet(groupIn.RequirePrivacySet). SetDefaultMappedModel(groupIn.DefaultMappedModel). - SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig) + SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig). + SetRpmLimit(groupIn.RPMLimit) // 设置模型路由配置 if groupIn.ModelRouting != nil { @@ -130,7 +131,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er SetRequireOauthOnly(groupIn.RequireOAuthOnly). SetRequirePrivacySet(groupIn.RequirePrivacySet). SetDefaultMappedModel(groupIn.DefaultMappedModel). - SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig) + SetMessagesDispatchModelConfig(groupIn.MessagesDispatchModelConfig). + SetRpmLimit(groupIn.RPMLimit) // 显式处理可空字段:nil 需要 clear,非 nil 需要 set。 if groupIn.DailyLimitUSD != nil { diff --git a/backend/internal/repository/openai_403_counter_cache.go b/backend/internal/repository/openai_403_counter_cache.go new file mode 100644 index 00000000..a68d2518 --- /dev/null +++ b/backend/internal/repository/openai_403_counter_cache.go @@ -0,0 +1,51 @@ +package repository + +import ( + "context" + "fmt" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/redis/go-redis/v9" +) + +const openAI403CounterPrefix = "openai_403_count:account:" + +var openAI403CounterIncrScript = redis.NewScript(` + local key = KEYS[1] + local ttl = tonumber(ARGV[1]) + + local count = redis.call('INCR', key) + if count == 1 then + redis.call('EXPIRE', key, ttl) + end + + return count +`) + +type openAI403CounterCache struct { + rdb *redis.Client +} + +func NewOpenAI403CounterCache(rdb *redis.Client) service.OpenAI403CounterCache { + return &openAI403CounterCache{rdb: rdb} +} + +func (c *openAI403CounterCache) IncrementOpenAI403Count(ctx context.Context, accountID int64, windowMinutes int) (int64, error) { + key := fmt.Sprintf("%s%d", openAI403CounterPrefix, accountID) + + ttlSeconds := windowMinutes * 60 + if ttlSeconds < 60 { + ttlSeconds = 60 + } + + result, err := openAI403CounterIncrScript.Run(ctx, c.rdb, []string{key}, ttlSeconds).Int64() + if err != nil { + return 0, fmt.Errorf("increment openai 403 count: %w", err) + } + return result, nil +} + +func (c *openAI403CounterCache) ResetOpenAI403Count(ctx context.Context, accountID int64) error { + key := fmt.Sprintf("%s%d", openAI403CounterPrefix, accountID) + return c.rdb.Del(ctx, key).Err() +} diff --git a/backend/internal/repository/openai_oauth_service.go b/backend/internal/repository/openai_oauth_service.go index dca0b612..acb270a3 100644 --- a/backend/internal/repository/openai_oauth_service.go +++ b/backend/internal/repository/openai_oauth_service.go @@ -2,6 +2,7 @@ package repository import ( "context" + "errors" "net/http" "net/url" "strings" @@ -53,6 +54,9 @@ func (s *openaiOAuthService) ExchangeCode(ctx context.Context, code, codeVerifie Post(s.tokenURL) if err != nil { + if shouldReturnOpenAINoProxyHint(ctx, proxyURL, err) { + return nil, newOpenAINoProxyHintError(err) + } return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err) } @@ -98,6 +102,9 @@ func (s *openaiOAuthService) refreshTokenWithClientID(ctx context.Context, refre Post(s.tokenURL) if err != nil { + if shouldReturnOpenAINoProxyHint(ctx, proxyURL, err) { + return nil, newOpenAINoProxyHintError(err) + } return nil, infraerrors.Newf(http.StatusBadGateway, "OPENAI_OAUTH_REQUEST_FAILED", "request failed: %v", err) } @@ -114,3 +121,21 @@ func createOpenAIReqClient(proxyURL string) (*req.Client, error) { Timeout: 120 * time.Second, }) } + +func shouldReturnOpenAINoProxyHint(ctx context.Context, proxyURL string, err error) bool { + if strings.TrimSpace(proxyURL) != "" || err == nil { + return false + } + if ctx != nil && ctx.Err() != nil { + return false + } + return !errors.Is(err, context.Canceled) +} + +func newOpenAINoProxyHintError(cause error) error { + return infraerrors.New( + http.StatusBadGateway, + "OPENAI_OAUTH_PROXY_REQUIRED", + "OpenAI OAuth request failed: no proxy is configured and this server could not reach OpenAI directly. Select a proxy that can access OpenAI, then retry; if the authorization code has expired, regenerate the authorization URL.", + ).WithCause(cause) +} diff --git a/backend/internal/repository/openai_oauth_service_test.go b/backend/internal/repository/openai_oauth_service_test.go index c1901d71..b43e2b52 100644 --- a/backend/internal/repository/openai_oauth_service_test.go +++ b/backend/internal/repository/openai_oauth_service_test.go @@ -8,6 +8,7 @@ import ( "net/url" "testing" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/Wei-Shaw/sub2api/internal/pkg/openai" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -204,6 +205,17 @@ func (s *OpenAIOAuthServiceSuite) TestRequestError_ClosedServer() { require.ErrorContains(s.T(), err, "request failed") } +func (s *OpenAIOAuthServiceSuite) TestExchangeCode_RequestErrorWithoutProxyReturnsProxyHint() { + s.setupServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + s.srv.Close() + + _, err := s.svc.ExchangeCode(s.ctx, "code", "ver", openai.DefaultRedirectURI, "", "") + + require.Error(s.T(), err) + require.Equal(s.T(), "OPENAI_OAUTH_PROXY_REQUIRED", infraerrors.Reason(err)) + require.Contains(s.T(), infraerrors.Message(err), "no proxy is configured") +} + func (s *OpenAIOAuthServiceSuite) TestContextCancel() { started := make(chan struct{}) block := make(chan struct{}) diff --git a/backend/internal/repository/usage_billing_repo.go b/backend/internal/repository/usage_billing_repo.go index 2b6edad3..62f48b58 100644 --- a/backend/internal/repository/usage_billing_repo.go +++ b/backend/internal/repository/usage_billing_repo.go @@ -290,7 +290,6 @@ func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountI if err != nil { return nil, err } - defer func() { _ = rows.Close() }() var state service.AccountQuotaState if rows.Next() { @@ -299,18 +298,36 @@ func incrementUsageBillingAccountQuota(ctx context.Context, tx *sql.Tx, accountI &state.DailyUsed, &state.DailyLimit, &state.WeeklyUsed, &state.WeeklyLimit, ); err != nil { + _ = rows.Close() return nil, err } } else { if err := rows.Err(); err != nil { + _ = rows.Close() return nil, err } + _ = rows.Close() return nil, service.ErrAccountNotFound } if err := rows.Err(); err != nil { + _ = rows.Close() return nil, err } - if state.TotalLimit > 0 && state.TotalUsed >= state.TotalLimit && (state.TotalUsed-amount) < state.TotalLimit { + // 必须在执行下一条 SQL 前显式关闭 rows:pq 驱动在同一连接上 + // 不允许前一条查询的结果集未耗尽时启动新查询,否则会返回 + // "unexpected Parse response" 错误。 + if err := rows.Close(); err != nil { + return nil, err + } + // 任意维度额度在本次递增中从"未超"跨越到"已超"时,必须刷新调度快照, + // 否则 Redis 中缓存的 Account 仍显示旧的 used 值,后续请求会继续选中本账号, + // 最终观察到 daily_used / weekly_used 大幅超过配置的 limit。 + // 对于日/周额度,即使本次触发了周期重置(pre=0、post=amount), + // 判定式 (post-amount) < limit 同样成立,逻辑与总额度保持一致。 + crossedTotal := state.TotalLimit > 0 && state.TotalUsed >= state.TotalLimit && (state.TotalUsed-amount) < state.TotalLimit + crossedDaily := state.DailyLimit > 0 && state.DailyUsed >= state.DailyLimit && (state.DailyUsed-amount) < state.DailyLimit + crossedWeekly := state.WeeklyLimit > 0 && state.WeeklyUsed >= state.WeeklyLimit && (state.WeeklyUsed-amount) < state.WeeklyLimit + if crossedTotal || crossedDaily || crossedWeekly { if err := enqueueSchedulerOutbox(ctx, tx, service.SchedulerOutboxEventAccountChanged, &accountID, nil, nil); err != nil { logger.LegacyPrintf("repository.usage_billing", "[SchedulerOutbox] enqueue quota exceeded failed: account=%d err=%v", accountID, err) return nil, err diff --git a/backend/internal/repository/usage_billing_repo_integration_test.go b/backend/internal/repository/usage_billing_repo_integration_test.go index eda34cc9..e8d4d327 100644 --- a/backend/internal/repository/usage_billing_repo_integration_test.go +++ b/backend/internal/repository/usage_billing_repo_integration_test.go @@ -199,6 +199,94 @@ func TestUsageBillingRepositoryApply_UpdatesAccountQuota(t *testing.T) { require.InDelta(t, 3.5, quotaUsed, 0.000001) } +func TestUsageBillingRepositoryApply_EnqueuesSchedulerOutboxOnQuotaCrossing(t *testing.T) { + ctx := context.Background() + client := testEntClient(t) + repo := NewUsageBillingRepository(client, integrationDB) + + newFixture := func(t *testing.T, extra map[string]any) (int64, int64) { + t.Helper() + user := mustCreateUser(t, client, &service.User{ + Email: fmt.Sprintf("usage-billing-outbox-user-%d-%s@example.com", time.Now().UnixNano(), uuid.NewString()), + PasswordHash: "hash", + }) + apiKey := mustCreateApiKey(t, client, &service.APIKey{ + UserID: user.ID, + Key: "sk-usage-billing-outbox-" + uuid.NewString(), + Name: "billing-outbox", + }) + account := mustCreateAccount(t, client, &service.Account{ + Name: "usage-billing-outbox-" + uuid.NewString(), + Type: service.AccountTypeAPIKey, + Extra: extra, + }) + return apiKey.ID, account.ID + } + + outboxCountFor := func(t *testing.T, accountID int64) int { + t.Helper() + var count int + require.NoError(t, integrationDB.QueryRowContext(ctx, + "SELECT COUNT(*) FROM scheduler_outbox WHERE event_type = $1 AND account_id = $2", + service.SchedulerOutboxEventAccountChanged, accountID, + ).Scan(&count)) + return count + } + + t.Run("daily_first_crossing_enqueues", func(t *testing.T) { + apiKeyID, accountID := newFixture(t, map[string]any{ + "quota_daily_limit": 10.0, + }) + // 第一次低于日限额:不应入队 outbox + _, err := repo.Apply(ctx, &service.UsageBillingCommand{ + RequestID: uuid.NewString(), + APIKeyID: apiKeyID, + AccountID: accountID, + AccountType: service.AccountTypeAPIKey, + AccountQuotaCost: 4, + }) + require.NoError(t, err) + require.Equal(t, 0, outboxCountFor(t, accountID), "below limit should not enqueue") + + // 第二次跨越日限额:应入队一次 outbox + _, err = repo.Apply(ctx, &service.UsageBillingCommand{ + RequestID: uuid.NewString(), + APIKeyID: apiKeyID, + AccountID: accountID, + AccountType: service.AccountTypeAPIKey, + AccountQuotaCost: 8, + }) + require.NoError(t, err) + require.Equal(t, 1, outboxCountFor(t, accountID), "crossing daily limit should enqueue once") + + // 再次递增(已超):不应重复入队 + _, err = repo.Apply(ctx, &service.UsageBillingCommand{ + RequestID: uuid.NewString(), + APIKeyID: apiKeyID, + AccountID: accountID, + AccountType: service.AccountTypeAPIKey, + AccountQuotaCost: 2, + }) + require.NoError(t, err) + require.Equal(t, 1, outboxCountFor(t, accountID), "subsequent increments beyond limit should not re-enqueue") + }) + + t.Run("weekly_first_crossing_enqueues", func(t *testing.T) { + apiKeyID, accountID := newFixture(t, map[string]any{ + "quota_weekly_limit": 10.0, + }) + _, err := repo.Apply(ctx, &service.UsageBillingCommand{ + RequestID: uuid.NewString(), + APIKeyID: apiKeyID, + AccountID: accountID, + AccountType: service.AccountTypeAPIKey, + AccountQuotaCost: 15, // 单次即跨越 + }) + require.NoError(t, err) + require.Equal(t, 1, outboxCountFor(t, accountID), "single-shot crossing weekly limit should enqueue once") + }) +} + func TestDashboardAggregationRepositoryCleanupUsageBillingDedup_BatchDeletesOldRows(t *testing.T) { ctx := context.Background() repo := newDashboardAggregationRepositoryWithSQL(integrationDB) diff --git a/backend/internal/repository/user_group_rate_repo.go b/backend/internal/repository/user_group_rate_repo.go index eca5313f..74d25cb0 100644 --- a/backend/internal/repository/user_group_rate_repo.go +++ b/backend/internal/repository/user_group_rate_repo.go @@ -13,14 +13,14 @@ type userGroupRateRepository struct { sql sqlExecutor } -// NewUserGroupRateRepository 创建用户专属分组倍率仓储 +// NewUserGroupRateRepository 创建用户专属分组倍率/RPM 仓储 func NewUserGroupRateRepository(sqlDB *sql.DB) service.UserGroupRateRepository { return &userGroupRateRepository{sql: sqlDB} } -// GetByUserID 获取用户的所有专属分组倍率 +// GetByUserID 获取用户所有专属分组 rate_multiplier(仅返回非 NULL 的条目) func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) { - query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1` + query := `SELECT group_id, rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND rate_multiplier IS NOT NULL` rows, err := r.sql.QueryContext(ctx, query, userID) if err != nil { return nil, err @@ -42,8 +42,7 @@ func (r *userGroupRateRepository) GetByUserID(ctx context.Context, userID int64) return result, nil } -// GetByUserIDs 批量获取多个用户的专属分组倍率。 -// 返回结构:map[userID]map[groupID]rate +// GetByUserIDs 批量获取多个用户的专属分组 rate_multiplier(仅返回非 NULL 的条目) func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []int64) (map[int64]map[int64]float64, error) { result := make(map[int64]map[int64]float64, len(userIDs)) if len(userIDs) == 0 { @@ -70,7 +69,7 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in rows, err := r.sql.QueryContext(ctx, ` SELECT user_id, group_id, rate_multiplier FROM user_group_rate_multipliers - WHERE user_id = ANY($1) + WHERE user_id = ANY($1) AND rate_multiplier IS NOT NULL `, pq.Array(uniqueIDs)) if err != nil { return nil, err @@ -95,10 +94,10 @@ func (r *userGroupRateRepository) GetByUserIDs(ctx context.Context, userIDs []in return result, nil } -// GetByGroupID 获取指定分组下所有用户的专属倍率 +// GetByGroupID 获取指定分组下所有用户的专属配置(rate 与 rpm_override 任一非 NULL 即返回) func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int64) ([]service.UserGroupRateEntry, error) { query := ` - SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier + SELECT ugr.user_id, u.username, u.email, COALESCE(u.notes, ''), u.status, ugr.rate_multiplier, ugr.rpm_override FROM user_group_rate_multipliers ugr JOIN users u ON u.id = ugr.user_id AND u.deleted_at IS NULL WHERE ugr.group_id = $1 @@ -113,9 +112,19 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6 var result []service.UserGroupRateEntry for rows.Next() { var entry service.UserGroupRateEntry - if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &entry.RateMultiplier); err != nil { + var rate sql.NullFloat64 + var rpm sql.NullInt32 + if err := rows.Scan(&entry.UserID, &entry.UserName, &entry.UserEmail, &entry.UserNotes, &entry.UserStatus, &rate, &rpm); err != nil { return nil, err } + if rate.Valid { + v := rate.Float64 + entry.RateMultiplier = &v + } + if rpm.Valid { + v := int(rpm.Int32) + entry.RPMOverride = &v + } result = append(result, entry) } if err := rows.Err(); err != nil { @@ -124,10 +133,10 @@ func (r *userGroupRateRepository) GetByGroupID(ctx context.Context, groupID int6 return result, nil } -// GetByUserAndGroup 获取用户在特定分组的专属倍率 +// GetByUserAndGroup 获取用户在特定分组的专属 rate_multiplier(NULL 返回 nil) func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) { query := `SELECT rate_multiplier FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2` - var rate float64 + var rate sql.NullFloat64 err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rate) if err == sql.ErrNoRows { return nil, nil @@ -135,42 +144,79 @@ func (r *userGroupRateRepository) GetByUserAndGroup(ctx context.Context, userID, if err != nil { return nil, err } - return &rate, nil + if !rate.Valid { + return nil, nil + } + v := rate.Float64 + return &v, nil } -// SyncUserGroupRates 同步用户的分组专属倍率 +// GetRPMOverrideByUserAndGroup 获取用户在特定分组的 rpm_override(NULL 返回 nil) +func (r *userGroupRateRepository) GetRPMOverrideByUserAndGroup(ctx context.Context, userID, groupID int64) (*int, error) { + query := `SELECT rpm_override FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = $2` + var rpm sql.NullInt32 + err := scanSingleRow(ctx, r.sql, query, []any{userID, groupID}, &rpm) + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, err + } + if !rpm.Valid { + return nil, nil + } + v := int(rpm.Int32) + return &v, nil +} + +// SyncUserGroupRates 同步用户的分组专属 rate_multiplier。 +// - 传入空 map:清空该用户所有行的 rate_multiplier;若 rpm_override 也为 NULL 则整行删除。 +// - 值为 nil:清空对应行的 rate_multiplier(保留 rpm_override)。 +// - 值非 nil:upsert rate_multiplier(保留已有 rpm_override)。 func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error { if len(rates) == 0 { - // 如果传入空 map,删除该用户的所有专属倍率 - _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID) + if _, err := r.sql.ExecContext(ctx, ` + UPDATE user_group_rate_multipliers + SET rate_multiplier = NULL, updated_at = NOW() + WHERE user_id = $1 + `, userID); err != nil { + return err + } + _, err := r.sql.ExecContext(ctx, + `DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL`, + userID) return err } - // 分离需要删除和需要 upsert 的记录 - var toDelete []int64 + var clearGroupIDs []int64 upsertGroupIDs := make([]int64, 0, len(rates)) upsertRates := make([]float64, 0, len(rates)) for groupID, rate := range rates { if rate == nil { - toDelete = append(toDelete, groupID) + clearGroupIDs = append(clearGroupIDs, groupID) } else { upsertGroupIDs = append(upsertGroupIDs, groupID) upsertRates = append(upsertRates, *rate) } } - // 删除指定的记录 - if len(toDelete) > 0 { + if len(clearGroupIDs) > 0 { + if _, err := r.sql.ExecContext(ctx, ` + UPDATE user_group_rate_multipliers + SET rate_multiplier = NULL, updated_at = NOW() + WHERE user_id = $1 AND group_id = ANY($2) + `, userID, pq.Array(clearGroupIDs)); err != nil { + return err + } if _, err := r.sql.ExecContext(ctx, - `DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2)`, - userID, pq.Array(toDelete)); err != nil { + `DELETE FROM user_group_rate_multipliers WHERE user_id = $1 AND group_id = ANY($2) AND rate_multiplier IS NULL AND rpm_override IS NULL`, + userID, pq.Array(clearGroupIDs)); err != nil { return err } } - // Upsert 记录 - now := time.Now() if len(upsertGroupIDs) > 0 { + now := time.Now() _, err := r.sql.ExecContext(ctx, ` INSERT INTO user_group_rate_multipliers (user_id, group_id, rate_multiplier, created_at, updated_at) SELECT @@ -193,14 +239,47 @@ func (r *userGroupRateRepository) SyncUserGroupRates(ctx context.Context, userID return nil } -// SyncGroupRateMultipliers 批量同步分组的用户专属倍率(先删后插) +// SyncGroupRateMultipliers 同步分组的 rate_multiplier 部分(不触动 rpm_override)。 +// 语义: +// - 未出现在 entries 中的用户行:rate_multiplier 归 NULL;若 rpm_override 也为 NULL 则整行删除。 +// - 出现的用户行:upsert rate_multiplier。 func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []service.GroupRateMultiplierInput) error { - if _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID); err != nil { + keepUserIDs := make([]int64, 0, len(entries)) + for _, e := range entries { + keepUserIDs = append(keepUserIDs, e.UserID) + } + + // 未在 entries 列表中的行:清空 rate_multiplier。 + if len(keepUserIDs) == 0 { + if _, err := r.sql.ExecContext(ctx, ` + UPDATE user_group_rate_multipliers + SET rate_multiplier = NULL, updated_at = NOW() + WHERE group_id = $1 + `, groupID); err != nil { + return err + } + } else { + if _, err := r.sql.ExecContext(ctx, ` + UPDATE user_group_rate_multipliers + SET rate_multiplier = NULL, updated_at = NOW() + WHERE group_id = $1 AND user_id <> ALL($2) + `, groupID, pq.Array(keepUserIDs)); err != nil { + return err + } + } + + // 清空后若整行 NULL 则删除。 + if _, err := r.sql.ExecContext(ctx, ` + DELETE FROM user_group_rate_multipliers + WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL + `, groupID); err != nil { return err } + if len(entries) == 0 { return nil } + userIDs := make([]int64, len(entries)) rates := make([]float64, len(entries)) for i, e := range entries { @@ -218,13 +297,103 @@ func (r *userGroupRateRepository) SyncGroupRateMultipliers(ctx context.Context, return err } -// DeleteByGroupID 删除指定分组的所有用户专属倍率 +// SyncGroupRPMOverrides 同步分组的 rpm_override 部分(不触动 rate_multiplier)。 +// 语义: +// - 未出现的用户行:rpm_override 归 NULL;若 rate_multiplier 也为 NULL 则整行删除。 +// - 出现的用户行:若 RPMOverride 为 nil 则清空;非 nil 则 upsert。 +func (r *userGroupRateRepository) SyncGroupRPMOverrides(ctx context.Context, groupID int64, entries []service.GroupRPMOverrideInput) error { + keepUserIDs := make([]int64, 0, len(entries)) + var clearUserIDs []int64 + upsertUserIDs := make([]int64, 0, len(entries)) + upsertValues := make([]int32, 0, len(entries)) + for _, e := range entries { + keepUserIDs = append(keepUserIDs, e.UserID) + if e.RPMOverride == nil { + clearUserIDs = append(clearUserIDs, e.UserID) + } else { + upsertUserIDs = append(upsertUserIDs, e.UserID) + upsertValues = append(upsertValues, int32(*e.RPMOverride)) + } + } + + // 未在 entries 列表中的行:清空 rpm_override。 + if len(keepUserIDs) == 0 { + if _, err := r.sql.ExecContext(ctx, ` + UPDATE user_group_rate_multipliers + SET rpm_override = NULL, updated_at = NOW() + WHERE group_id = $1 + `, groupID); err != nil { + return err + } + } else { + if _, err := r.sql.ExecContext(ctx, ` + UPDATE user_group_rate_multipliers + SET rpm_override = NULL, updated_at = NOW() + WHERE group_id = $1 AND user_id <> ALL($2) + `, groupID, pq.Array(keepUserIDs)); err != nil { + return err + } + } + + // 显式 clear 的行。 + if len(clearUserIDs) > 0 { + if _, err := r.sql.ExecContext(ctx, ` + UPDATE user_group_rate_multipliers + SET rpm_override = NULL, updated_at = NOW() + WHERE group_id = $1 AND user_id = ANY($2) + `, groupID, pq.Array(clearUserIDs)); err != nil { + return err + } + } + + // 清空后若整行 NULL 则删除。 + if _, err := r.sql.ExecContext(ctx, ` + DELETE FROM user_group_rate_multipliers + WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL + `, groupID); err != nil { + return err + } + + if len(upsertUserIDs) > 0 { + now := time.Now() + _, err := r.sql.ExecContext(ctx, ` + INSERT INTO user_group_rate_multipliers (user_id, group_id, rpm_override, created_at, updated_at) + SELECT data.user_id, $1::bigint, data.rpm_override, $2::timestamptz, $2::timestamptz + FROM unnest($3::bigint[], $4::integer[]) AS data(user_id, rpm_override) + ON CONFLICT (user_id, group_id) + DO UPDATE SET rpm_override = EXCLUDED.rpm_override, updated_at = EXCLUDED.updated_at + `, groupID, now, pq.Array(upsertUserIDs), pq.Array(upsertValues)) + if err != nil { + return err + } + } + + return nil +} + +// ClearGroupRPMOverrides 清空指定分组所有行的 rpm_override。 +func (r *userGroupRateRepository) ClearGroupRPMOverrides(ctx context.Context, groupID int64) error { + if _, err := r.sql.ExecContext(ctx, ` + UPDATE user_group_rate_multipliers + SET rpm_override = NULL, updated_at = NOW() + WHERE group_id = $1 + `, groupID); err != nil { + return err + } + _, err := r.sql.ExecContext(ctx, ` + DELETE FROM user_group_rate_multipliers + WHERE group_id = $1 AND rate_multiplier IS NULL AND rpm_override IS NULL + `, groupID) + return err +} + +// DeleteByGroupID 删除指定分组的所有用户专属条目 func (r *userGroupRateRepository) DeleteByGroupID(ctx context.Context, groupID int64) error { _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE group_id = $1`, groupID) return err } -// DeleteByUserID 删除指定用户的所有专属倍率 +// DeleteByUserID 删除指定用户的所有专属条目 func (r *userGroupRateRepository) DeleteByUserID(ctx context.Context, userID int64) error { _, err := r.sql.ExecContext(ctx, `DELETE FROM user_group_rate_multipliers WHERE user_id = $1`, userID) return err diff --git a/backend/internal/repository/user_repo.go b/backend/internal/repository/user_repo.go index c5db3dc4..d1f10cbd 100644 --- a/backend/internal/repository/user_repo.go +++ b/backend/internal/repository/user_repo.go @@ -93,6 +93,7 @@ func (r *userRepository) Create(ctx context.Context, userIn *service.User) error SetSignupSource(userSignupSourceOrDefault(userIn.SignupSource)). SetNillableLastLoginAt(userIn.LastLoginAt). SetNillableLastActiveAt(userIn.LastActiveAt). + SetRpmLimit(userIn.RPMLimit). Save(txCtx) if err != nil { return translatePersistenceError(err, nil, service.ErrEmailExists) @@ -219,7 +220,8 @@ func (r *userRepository) Update(ctx context.Context, userIn *service.User) error SetBalanceNotifyThresholdType(userIn.BalanceNotifyThresholdType). SetNillableBalanceNotifyThreshold(userIn.BalanceNotifyThreshold). SetBalanceNotifyExtraEmails(marshalExtraEmails(userIn.BalanceNotifyExtraEmails)). - SetTotalRecharged(userIn.TotalRecharged) + SetTotalRecharged(userIn.TotalRecharged). + SetRpmLimit(userIn.RPMLimit) if userIn.SignupSource != "" { updateOp = updateOp.SetSignupSource(userIn.SignupSource) } diff --git a/backend/internal/repository/user_rpm_cache.go b/backend/internal/repository/user_rpm_cache.go new file mode 100644 index 00000000..42bf9332 --- /dev/null +++ b/backend/internal/repository/user_rpm_cache.go @@ -0,0 +1,108 @@ +package repository + +import ( + "context" + "fmt" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/redis/go-redis/v9" +) + +// 用户/分组级 RPM 计数器 Redis 实现。 +// +// 设计说明: +// - key 形式:rpm:ug:{uid}:{gid}:{minute}、rpm:u:{uid}:{minute} +// - 时间来源:rdb.Time()(Redis 服务端时间),避免多实例时钟漂移。 +// - 原子操作:TxPipeline (MULTI/EXEC) 执行 INCR+EXPIRE,兼容 Redis Cluster。 +// - TTL:120s,覆盖当前分钟窗口 + 少量冗余。 +// - 返回值语义:超限判断由调用方(billing_cache_service.checkRPM)与 RPMLimit 比较完成。 +const ( + userGroupRPMKeyPrefix = "rpm:ug:" + userRPMKeyPrefix = "rpm:u:" + + userRPMKeyTTL = 120 * time.Second +) + +type userRPMCacheImpl struct { + rdb *redis.Client +} + +// NewUserRPMCache 创建用户/分组级 RPM 计数器。 +func NewUserRPMCache(rdb *redis.Client) service.UserRPMCache { + return &userRPMCacheImpl{rdb: rdb} +} + +// minuteTS 获取当前 Redis 服务端分钟时间戳。 +func (c *userRPMCacheImpl) minuteTS(ctx context.Context) (int64, error) { + t, err := c.rdb.Time(ctx).Result() + if err != nil { + return 0, fmt.Errorf("redis TIME: %w", err) + } + return t.Unix() / 60, nil +} + +// atomicIncr 原子 INCR+EXPIRE。 +func (c *userRPMCacheImpl) atomicIncr(ctx context.Context, key string) (int, error) { + pipe := c.rdb.TxPipeline() + incr := pipe.Incr(ctx, key) + pipe.Expire(ctx, key, userRPMKeyTTL) + if _, err := pipe.Exec(ctx); err != nil { + return 0, fmt.Errorf("user rpm increment: %w", err) + } + return int(incr.Val()), nil +} + +// IncrementUserGroupRPM 递增 (user, group) 分钟计数。 +func (c *userRPMCacheImpl) IncrementUserGroupRPM(ctx context.Context, userID, groupID int64) (int, error) { + minute, err := c.minuteTS(ctx) + if err != nil { + return 0, err + } + key := fmt.Sprintf("%s%d:%d:%d", userGroupRPMKeyPrefix, userID, groupID, minute) + return c.atomicIncr(ctx, key) +} + +// IncrementUserRPM 递增用户分钟计数。 +func (c *userRPMCacheImpl) IncrementUserRPM(ctx context.Context, userID int64) (int, error) { + minute, err := c.minuteTS(ctx) + if err != nil { + return 0, err + } + key := fmt.Sprintf("%s%d:%d", userRPMKeyPrefix, userID, minute) + return c.atomicIncr(ctx, key) +} + +// GetUserGroupRPM 获取 (user, group) 当前分钟已用 RPM(只读)。 +func (c *userRPMCacheImpl) GetUserGroupRPM(ctx context.Context, userID, groupID int64) (int, error) { + minute, err := c.minuteTS(ctx) + if err != nil { + return 0, err + } + key := fmt.Sprintf("%s%d:%d:%d", userGroupRPMKeyPrefix, userID, groupID, minute) + val, err := c.rdb.Get(ctx, key).Int() + if err == redis.Nil { + return 0, nil + } + if err != nil { + return 0, fmt.Errorf("user group rpm get: %w", err) + } + return val, nil +} + +// GetUserRPM 获取用户当前分钟已用 RPM(只读)。 +func (c *userRPMCacheImpl) GetUserRPM(ctx context.Context, userID int64) (int, error) { + minute, err := c.minuteTS(ctx) + if err != nil { + return 0, err + } + key := fmt.Sprintf("%s%d:%d", userRPMKeyPrefix, userID, minute) + val, err := c.rdb.Get(ctx, key).Int() + if err == redis.Nil { + return 0, nil + } + if err != nil { + return 0, fmt.Errorf("user rpm get: %w", err) + } + return val, nil +} diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index b1d5e36a..6d24d312 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -98,10 +98,12 @@ var ProviderSet = wire.NewSet( NewAPIKeyCache, NewTempUnschedCache, NewTimeoutCounterCache, + NewOpenAI403CounterCache, NewInternal500CounterCache, ProvideConcurrencyCache, ProvideSessionLimitCache, NewRPMCache, + NewUserRPMCache, NewUserMsgQueueCache, NewDashboardCache, NewEmailCache, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 686b835d..e89ef3d9 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -55,6 +55,7 @@ func TestAPIContracts(t *testing.T) { "role": "user", "balance": 12.5, "concurrency": 5, + "rpm_limit": 0, "status": "active", "allowed_groups": null, "created_at": "2025-01-02T03:04:05Z", @@ -333,6 +334,7 @@ func TestAPIContracts(t *testing.T) { "fallback_group_id_on_invalid_request": null, "require_oauth_only": false, "require_privacy_set": false, + "rpm_limit": 0, "created_at": "2025-01-02T03:04:05Z", "updated_at": "2025-01-02T03:04:05Z" } @@ -713,6 +715,7 @@ func TestAPIContracts(t *testing.T) { "force_email_on_third_party_signup": false, "default_concurrency": 5, "default_balance": 1.25, + "default_user_rpm_limit": 0, "default_subscriptions": [], "enable_model_fallback": false, "fallback_model_anthropic": "claude-3-5-sonnet-20241022", @@ -892,6 +895,7 @@ func TestAPIContracts(t *testing.T) { "custom_endpoints": [], "default_concurrency": 0, "default_balance": 0, + "default_user_rpm_limit": 0, "default_subscriptions": [], "enable_model_fallback": false, "fallback_model_anthropic": "claude-3-5-sonnet-20241022", @@ -1090,7 +1094,7 @@ func newContractDeps(t *testing.T) *contractDeps { settingRepo := newStubSettingRepo() settingService := service.NewSettingService(settingRepo, cfg) - adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + adminService := service.NewAdminService(userRepo, groupRepo, &accountRepo, proxyRepo, apiKeyRepo, redeemRepo, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) authHandler := handler.NewAuthHandler(cfg, nil, userService, settingService, nil, redeemService, nil) apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index 4b796d55..70160f7e 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -224,6 +224,7 @@ func registerUserManagementRoutes(admin *gin.RouterGroup, h *handler.Handlers) { users.GET("/:id/usage", h.Admin.User.GetUserUsage) users.GET("/:id/balance-history", h.Admin.User.GetBalanceHistory) users.POST("/:id/replace-group", h.Admin.User.ReplaceGroup) + users.GET("/:id/rpm-status", h.Admin.User.GetUserRPMStatus) // User attribute values users.GET("/:id/attributes", h.Admin.UserAttribute.GetUserAttributes) @@ -247,6 +248,8 @@ func registerGroupRoutes(admin *gin.RouterGroup, h *handler.Handlers) { groups.GET("/:id/rate-multipliers", h.Admin.Group.GetGroupRateMultipliers) groups.PUT("/:id/rate-multipliers", h.Admin.Group.BatchSetGroupRateMultipliers) groups.DELETE("/:id/rate-multipliers", h.Admin.Group.ClearGroupRateMultipliers) + groups.PUT("/:id/rpm-overrides", h.Admin.Group.BatchSetGroupRPMOverrides) + groups.DELETE("/:id/rpm-overrides", h.Admin.Group.ClearGroupRPMOverrides) groups.GET("/:id/api-keys", h.Admin.Group.GetGroupAPIKeys) } } diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 801eac1b..0fb6e18f 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -930,10 +930,8 @@ func (a *Account) SupportsOpenAIImageCapability(capability OpenAIImagesCapabilit return false } switch capability { - case OpenAIImagesCapabilityBasic: + case OpenAIImagesCapabilityBasic, OpenAIImagesCapabilityNative: return a.Type == AccountTypeOAuth || a.Type == AccountTypeAPIKey - case OpenAIImagesCapabilityNative: - return a.Type == AccountTypeAPIKey default: return true } diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 52d53013..e5bc93ca 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -5,7 +5,6 @@ import ( "bytes" "context" "crypto/rand" - "encoding/base64" "encoding/hex" "encoding/json" "errors" @@ -1138,7 +1137,7 @@ func (s *AccountTestService) testOpenAIImageAPIKey(c *gin.Context, ctx context.C return nil } -// testOpenAIImageOAuth tests OpenAI image generation using an OAuth account via ChatGPT backend API. +// testOpenAIImageOAuth tests OpenAI image generation using an OAuth account via Codex /responses API. func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Context, account *Account, modelID, prompt string) error { authToken := account.GetOpenAIAccessToken() if authToken == "" { @@ -1153,69 +1152,46 @@ func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Co c.Writer.Flush() s.sendEvent(c, TestEvent{Type: "test_start", Model: modelID}) - s.sendEvent(c, TestEvent{Type: "content", Text: "Initializing ChatGPT backend...\n"}) + s.sendEvent(c, TestEvent{Type: "content", Text: "Calling Codex /responses image tool...\n"}) - // Build headers (replicating buildOpenAIBackendAPIHeaders logic) - headers := buildOpenAIBackendAPIHeadersForTest(ctx, account, authToken, s.accountRepo) + parsed := &OpenAIImagesRequest{ + Endpoint: openAIImagesGenerationsEndpoint, + Model: strings.TrimSpace(modelID), + Prompt: prompt, + } + applyOpenAIImagesDefaults(parsed) + + responsesBody, err := buildOpenAIImagesResponsesRequest(parsed, parsed.Model) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to build image request: %s", err.Error())) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, chatgptCodexAPIURL, bytes.NewReader(responsesBody)) + if err != nil { + return s.sendErrorAndEnd(c, "Failed to create request") + } + req.Host = "chatgpt.com" + req.Header.Set("Authorization", "Bearer "+authToken) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "text/event-stream") + req.Header.Set("OpenAI-Beta", "responses=experimental") + req.Header.Set("originator", "opencode") + if customUA := strings.TrimSpace(account.GetOpenAIUserAgent()); customUA != "" { + req.Header.Set("User-Agent", customUA) + } else { + req.Header.Set("User-Agent", codexCLIUserAgent) + } + if chatgptAccountID := strings.TrimSpace(account.GetChatGPTAccountID()); chatgptAccountID != "" { + req.Header.Set("chatgpt-account-id", chatgptAccountID) + } proxyURL := "" if account.ProxyID != nil && account.Proxy != nil { proxyURL = account.Proxy.URL() } - - client, err := newOpenAIBackendAPIClient(proxyURL) + resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) if err != nil { - return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to create client: %s", err.Error())) - } - - // Bootstrap - if bootstrapErr := bootstrapOpenAIBackendAPI(ctx, client, headers); bootstrapErr != nil { - log.Printf("OpenAI image test bootstrap warning: %v", bootstrapErr) - } - - // Fetch chat requirements - s.sendEvent(c, TestEvent{Type: "content", Text: "Fetching chat requirements...\n"}) - chatReqs, err := fetchOpenAIChatRequirements(ctx, client, headers) - if err != nil { - return s.sendErrorAndEnd(c, fmt.Sprintf("Chat requirements failed: %s", err.Error())) - } - if chatReqs.Arkose.Required { - return s.sendErrorAndEnd(c, "Unsupported challenge: arkose required") - } - - // Initialize and prepare conversation - s.sendEvent(c, TestEvent{Type: "content", Text: "Preparing image conversation...\n"}) - parentMessageID := uuid.NewString() - proofToken := generateOpenAIProofToken(chatReqs.ProofOfWork.Required, chatReqs.ProofOfWork.Seed, chatReqs.ProofOfWork.Difficulty, headers.Get("User-Agent")) - _ = initializeOpenAIImageConversation(ctx, client, headers) - conduitToken, err := prepareOpenAIImageConversation(ctx, client, headers, prompt, parentMessageID, chatReqs.Token, proofToken) - if err != nil { - return s.sendErrorAndEnd(c, fmt.Sprintf("Conversation prepare failed: %s", err.Error())) - } - - // Build simplified conversation request (no file uploads) - convReq := buildOpenAIImageTestConversationRequest(prompt, parentMessageID) - convHeaders := cloneHTTPHeader(headers) - convHeaders.Set("Accept", "text/event-stream") - convHeaders.Set("Content-Type", "application/json") - convHeaders.Set("openai-sentinel-chat-requirements-token", chatReqs.Token) - if conduitToken != "" { - convHeaders.Set("x-conduit-token", conduitToken) - } - if proofToken != "" { - convHeaders.Set("openai-sentinel-proof-token", proofToken) - } - - s.sendEvent(c, TestEvent{Type: "content", Text: "Generating image...\n"}) - - resp, err := client.R(). - SetContext(ctx). - DisableAutoReadResponse(). - SetHeaders(headerToMap(convHeaders)). - SetBodyJsonMarshal(convReq). - Post(openAIChatGPTConversationURL) - if err != nil { - return s.sendErrorAndEnd(c, fmt.Sprintf("Conversation request failed: %s", err.Error())) + return s.sendErrorAndEnd(c, fmt.Sprintf("Responses API request failed: %s", err.Error())) } defer func() { if resp != nil && resp.Body != nil { @@ -1223,49 +1199,35 @@ func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Co } }() if resp.StatusCode >= 400 { - return s.sendErrorAndEnd(c, fmt.Sprintf("Conversation API returned %d", resp.StatusCode)) + body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + message := strings.TrimSpace(extractUpstreamErrorMessage(body)) + if message == "" { + message = fmt.Sprintf("Responses API returned %d", resp.StatusCode) + } + return s.sendErrorAndEnd(c, message) } - startTime := time.Now() - conversationID, pointerInfos, _, _, err := readOpenAIImageConversationStream(resp, startTime) + body, err := io.ReadAll(resp.Body) if err != nil { - return s.sendErrorAndEnd(c, fmt.Sprintf("Stream read failed: %s", err.Error())) + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to read image response: %s", err.Error())) } - pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, nil) - if conversationID != "" && !hasOpenAIFileServicePointerInfos(pointerInfos) { - s.sendEvent(c, TestEvent{Type: "content", Text: "Waiting for image generation to complete...\n"}) - polledPointers, pollErr := pollOpenAIImageConversation(ctx, client, headers, conversationID) - if pollErr != nil { - return s.sendErrorAndEnd(c, fmt.Sprintf("Poll failed: %s", pollErr.Error())) - } - pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, polledPointers) + results, _, _, _, _, err := collectOpenAIImagesFromResponsesBody(body) + if err != nil { + return s.sendErrorAndEnd(c, fmt.Sprintf("Failed to parse image response: %s", err.Error())) } - pointerInfos = preferOpenAIFileServicePointerInfos(pointerInfos) - if len(pointerInfos) == 0 { - return s.sendErrorAndEnd(c, "No images returned from conversation") + if len(results) == 0 { + return s.sendErrorAndEnd(c, "No images returned from responses API") } - s.sendEvent(c, TestEvent{Type: "content", Text: "Downloading generated image...\n"}) - - // Download and encode each image - for _, pointer := range pointerInfos { - downloadURL, err := fetchOpenAIImageDownloadURL(ctx, client, headers, conversationID, pointer.Pointer) - if err != nil { - return s.sendErrorAndEnd(c, fmt.Sprintf("Download URL fetch failed: %s", err.Error())) - } - data, err := downloadOpenAIImageBytes(ctx, client, headers, downloadURL) - if err != nil { - return s.sendErrorAndEnd(c, fmt.Sprintf("Image download failed: %s", err.Error())) - } - b64 := base64.StdEncoding.EncodeToString(data) - mimeType := http.DetectContentType(data) - if pointer.Prompt != "" { - s.sendEvent(c, TestEvent{Type: "content", Text: pointer.Prompt}) + for _, item := range results { + if item.RevisedPrompt != "" { + s.sendEvent(c, TestEvent{Type: "content", Text: item.RevisedPrompt}) } + mimeType := openAIImageOutputMIMEType(item.OutputFormat) s.sendEvent(c, TestEvent{ Type: "image", - ImageURL: "data:" + mimeType + ";base64," + b64, + ImageURL: "data:" + mimeType + ";base64," + item.Result, MimeType: mimeType, }) } @@ -1274,107 +1236,6 @@ func (s *AccountTestService) testOpenAIImageOAuth(c *gin.Context, ctx context.Co return nil } -// buildOpenAIBackendAPIHeadersForTest builds ChatGPT backend API headers for test purposes. -// Replicates the logic from OpenAIGatewayService.buildOpenAIBackendAPIHeaders without -// requiring the full gateway service dependency. -func buildOpenAIBackendAPIHeadersForTest(ctx context.Context, account *Account, token string, repo AccountRepository) http.Header { - // Ensure device and session IDs exist - deviceID := account.GetOpenAIDeviceID() - sessionID := account.GetOpenAISessionID() - if deviceID == "" || sessionID == "" { - updates := map[string]any{} - if deviceID == "" { - deviceID = uuid.NewString() - updates["openai_device_id"] = deviceID - } - if sessionID == "" { - sessionID = uuid.NewString() - updates["openai_session_id"] = sessionID - } - if account.Extra == nil { - account.Extra = map[string]any{} - } - for key, value := range updates { - account.Extra[key] = value - } - if repo != nil { - updateCtx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - _ = repo.UpdateExtra(updateCtx, account.ID, updates) - } - } - - headers := make(http.Header) - headers.Set("Authorization", "Bearer "+token) - headers.Set("Accept", "application/json") - headers.Set("Origin", "https://chatgpt.com") - headers.Set("Referer", "https://chatgpt.com/") - headers.Set("Sec-Fetch-Dest", "empty") - headers.Set("Sec-Fetch-Mode", "cors") - headers.Set("Sec-Fetch-Site", "same-origin") - headers.Set("User-Agent", openAIImageBackendUserAgent) - if customUA := strings.TrimSpace(account.GetOpenAIUserAgent()); customUA != "" { - headers.Set("User-Agent", customUA) - } - if chatgptAccountID := strings.TrimSpace(account.GetChatGPTAccountID()); chatgptAccountID != "" { - headers.Set("chatgpt-account-id", chatgptAccountID) - } - if deviceID != "" { - headers.Set("oai-device-id", deviceID) - headers.Set("Cookie", "oai-did="+deviceID) - } - if sessionID != "" { - headers.Set("oai-session-id", sessionID) - } - return headers -} - -// buildOpenAIImageTestConversationRequest creates a simplified image generation conversation request. -func buildOpenAIImageTestConversationRequest(prompt, parentMessageID string) map[string]any { - promptText := strings.TrimSpace(prompt) - if promptText == "" { - promptText = "Generate an image." - } - metadata := map[string]any{ - "developer_mode_connector_ids": []any{}, - "selected_github_repos": []any{}, - "selected_all_github_repos": false, - "system_hints": []string{"picture_v2"}, - "serialization_metadata": map[string]any{ - "custom_symbol_offsets": []any{}, - }, - } - message := map[string]any{ - "id": uuid.NewString(), - "author": map[string]any{"role": "user"}, - "content": map[string]any{ - "content_type": "text", - "parts": []any{promptText}, - }, - "metadata": metadata, - "create_time": float64(time.Now().UnixMilli()) / 1000, - } - return map[string]any{ - "action": "next", - "client_prepare_state": "sent", - "parent_message_id": parentMessageID, - "messages": []any{message}, - "model": "auto", - "timezone_offset_min": openAITimezoneOffsetMinutes(), - "timezone": openAITimezoneName(), - "conversation_mode": map[string]any{"kind": "primary_assistant"}, - "system_hints": []string{"picture_v2"}, - "supports_buffering": true, - "supported_encodings": []string{"v1"}, - "client_contextual_info": map[string]any{"app_name": "chatgpt.com"}, - "force_nulligen": false, - "force_paragen": false, - "force_paragen_model_slug": "", - "force_rate_limit": false, - "websocket_request_id": uuid.NewString(), - } -} - func (s *AccountTestService) sendEvent(c *gin.Context, event TestEvent) { eventJSON, _ := json.Marshal(event) if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", eventJSON); err != nil { diff --git a/backend/internal/service/account_test_service_openai_image_test.go b/backend/internal/service/account_test_service_openai_image_test.go new file mode 100644 index 00000000..80a2fc31 --- /dev/null +++ b/backend/internal/service/account_test_service_openai_image_test.go @@ -0,0 +1,50 @@ +package service + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestAccountTestService_OpenAIImageOAuthHandlesOutputItemDoneFallback(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + }, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.output_item.done\",\"item\":{\"id\":\"ig_123\",\"type\":\"image_generation_call\",\"result\":\"aGVsbG8=\",\"revised_prompt\":\"draw a cat\",\"output_format\":\"png\"}}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000006,\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[]}}\n\n" + + "data: [DONE]\n\n", + )), + }, + } + svc := &AccountTestService{httpUpstream: upstream} + account := &Account{ + ID: 53, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "token-123", + }, + } + + err := svc.testOpenAIImageOAuth(c, context.Background(), account, "gpt-image-2", "draw a cat") + require.NoError(t, err) + require.Contains(t, rec.Body.String(), "Calling Codex /responses image tool") + require.Contains(t, rec.Body.String(), "data:image/png;base64,aGVsbG8=") + require.Contains(t, rec.Body.String(), "\"success\":true") +} diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index 4ae66613..434f1f38 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -8,6 +8,7 @@ import ( "io" "log/slog" "net/http" + "sort" "strings" "time" @@ -32,6 +33,7 @@ type AdminService interface { UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int, sortBy, sortOrder string) ([]APIKey, int64, error) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) + GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error) // GetUserBalanceHistory returns paginated balance/concurrency change records for a user. // codeType is optional - pass empty string to return all types. // Also returns totalRecharged (sum of all positive balance top-ups). @@ -50,6 +52,8 @@ type AdminService interface { GetGroupRateMultipliers(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error) ClearGroupRateMultipliers(ctx context.Context, groupID int64) error BatchSetGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error + ClearGroupRPMOverrides(ctx context.Context, groupID int64) error + BatchSetGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error // API Key management (admin) @@ -114,6 +118,7 @@ type CreateUserInput struct { Notes string Balance float64 Concurrency int + RPMLimit int AllowedGroups []int64 } @@ -124,6 +129,7 @@ type UpdateUserInput struct { Notes *string Balance *float64 // 使用指针区分"未提供"和"设置为0" Concurrency *int // 使用指针区分"未提供"和"设置为0" + RPMLimit *int // 使用指针区分"未提供"和"设置为0" Status string AllowedGroups *[]int64 // 使用指针区分"未提供"和"设置为空数组" // GroupRates 用户专属分组倍率配置 @@ -199,6 +205,8 @@ type CreateGroupInput struct { RequireOAuthOnly bool RequirePrivacySet bool MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig + // RPMLimit 分组 RPM 上限(0 = 不限制) + RPMLimit int // 从指定分组复制账号(创建分组后在同一事务内绑定) CopyAccountsFromGroupIDs []int64 } @@ -234,6 +242,8 @@ type UpdateGroupInput struct { RequireOAuthOnly *bool RequirePrivacySet *bool MessagesDispatchModelConfig *OpenAIMessagesDispatchModelConfig + // RPMLimit 分组 RPM 上限(0 = 不限制),nil 表示未提供不改动。 + RPMLimit *int // 从指定分组复制账号(同步操作:先清空当前分组的账号绑定,再绑定源分组的账号) CopyAccountsFromGroupIDs []int64 } @@ -317,6 +327,22 @@ type ReplaceUserGroupResult struct { MigratedKeys int64 // 迁移的 Key 数量 } +// UserRPMStatus describes a user's current per-minute RPM usage. +type UserRPMStatus struct { + UserRPMUsed int `json:"user_rpm_used"` + UserRPMLimit int `json:"user_rpm_limit"` + PerGroup []UserGroupRPMStatus `json:"per_group"` +} + +// UserGroupRPMStatus describes current per-minute RPM usage for one user/group pair. +type UserGroupRPMStatus struct { + GroupID int64 `json:"group_id"` + GroupName string `json:"group_name"` + Used int `json:"used"` + Limit int `json:"limit"` + Source string `json:"source"` // "group" | "override" +} + // BulkUpdateAccountsResult is the aggregated response for bulk updates. type BulkUpdateAccountsResult struct { Success int `json:"success"` @@ -463,6 +489,8 @@ const ( proxyQualityClientUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36" ) +var ErrRPMStatusUnavailable = infraerrors.New(http.StatusNotImplemented, "RPM_STATUS_UNAVAILABLE", "RPM cache not available") + // adminServiceImpl implements AdminService type adminServiceImpl struct { userRepo UserRepository @@ -472,6 +500,7 @@ type adminServiceImpl struct { apiKeyRepo APIKeyRepository redeemCodeRepo RedeemCodeRepository userGroupRateRepo UserGroupRateRepository + userRPMCache UserRPMCache billingCacheService *BillingCacheService proxyProber ProxyExitInfoProber proxyLatencyCache ProxyLatencyCache @@ -496,6 +525,7 @@ func NewAdminService( apiKeyRepo APIKeyRepository, redeemCodeRepo RedeemCodeRepository, userGroupRateRepo UserGroupRateRepository, + userRPMCache UserRPMCache, billingCacheService *BillingCacheService, proxyProber ProxyExitInfoProber, proxyLatencyCache ProxyLatencyCache, @@ -514,6 +544,7 @@ func NewAdminService( apiKeyRepo: apiKeyRepo, redeemCodeRepo: redeemCodeRepo, userGroupRateRepo: userGroupRateRepo, + userRPMCache: userRPMCache, billingCacheService: billingCacheService, proxyProber: proxyProber, proxyLatencyCache: proxyLatencyCache, @@ -617,6 +648,7 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu Role: RoleUser, // Always create as regular user, never admin Balance: input.Balance, Concurrency: input.Concurrency, + RPMLimit: input.RPMLimit, Status: StatusActive, AllowedGroups: input.AllowedGroups, } @@ -670,6 +702,7 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda oldConcurrency := user.Concurrency oldStatus := user.Status oldRole := user.Role + oldRPMLimit := user.RPMLimit if input.Email != "" { user.Email = input.Email @@ -695,6 +728,10 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda user.Concurrency = *input.Concurrency } + if input.RPMLimit != nil { + user.RPMLimit = *input.RPMLimit + } + if input.AllowedGroups != nil { user.AllowedGroups = *input.AllowedGroups } @@ -711,7 +748,9 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda } if s.authCacheInvalidator != nil { - if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole { + // RPMLimit 直接参与 billing_cache_service.checkRPM 的三级级联, + // 不失效缓存会让修改在一个 L2 TTL 内失去效果。 + if user.Concurrency != oldConcurrency || user.Status != oldStatus || user.Role != oldRole || user.RPMLimit != oldRPMLimit { s.authCacheInvalidator.InvalidateAuthCacheByUserID(ctx, user.ID) } } @@ -833,6 +872,81 @@ func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, pag return keys, result.Total, nil } +func (s *adminServiceImpl) GetUserRPMStatus(ctx context.Context, userID int64) (*UserRPMStatus, error) { + if s.userRPMCache == nil { + return nil, ErrRPMStatusUnavailable + } + + user, err := s.userRepo.GetByID(ctx, userID) + if err != nil { + return nil, err + } + + userRPMUsed, err := s.userRPMCache.GetUserRPM(ctx, userID) + if err != nil { + logger.LegacyPrintf("service.admin", "failed to get user rpm: user_id=%d err=%v", userID, err) + } + + keys, _, err := s.GetUserAPIKeys(ctx, userID, 1, 1000, "", "") + if err != nil { + return nil, err + } + + groupIDSet := make(map[int64]struct{}) + for _, key := range keys { + if key.GroupID != nil && *key.GroupID > 0 { + groupIDSet[*key.GroupID] = struct{}{} + } + } + + groupIDs := make([]int64, 0, len(groupIDSet)) + for groupID := range groupIDSet { + groupIDs = append(groupIDs, groupID) + } + sort.Slice(groupIDs, func(i, j int) bool { return groupIDs[i] < groupIDs[j] }) + + var perGroup []UserGroupRPMStatus + for _, groupID := range groupIDs { + used, getErr := s.userRPMCache.GetUserGroupRPM(ctx, userID, groupID) + if getErr != nil { + logger.LegacyPrintf("service.admin", "failed to get user group rpm: user_id=%d group_id=%d err=%v", userID, groupID, getErr) + } + + entry := UserGroupRPMStatus{ + GroupID: groupID, + Used: used, + } + + if s.groupRepo != nil { + if group, groupErr := s.groupRepo.GetByIDLite(ctx, groupID); groupErr == nil && group != nil { + entry.GroupName = group.Name + entry.Limit = group.RPMLimit + entry.Source = "group" + } else if groupErr != nil { + logger.LegacyPrintf("service.admin", "failed to get group rpm status metadata: group_id=%d err=%v", groupID, groupErr) + } + } + + if s.userGroupRateRepo != nil { + override, overrideErr := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, userID, groupID) + if overrideErr != nil { + logger.LegacyPrintf("service.admin", "failed to get rpm override: user_id=%d group_id=%d err=%v", userID, groupID, overrideErr) + } else if override != nil { + entry.Limit = *override + entry.Source = "override" + } + } + + perGroup = append(perGroup, entry) + } + + return &UserRPMStatus{ + UserRPMUsed: userRPMUsed, + UserRPMLimit: user.RPMLimit, + PerGroup: perGroup, + }, nil +} + func (s *adminServiceImpl) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) { // Return mock data for now return map[string]any{ @@ -1314,6 +1428,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn RequirePrivacySet: input.RequirePrivacySet, DefaultMappedModel: input.DefaultMappedModel, MessagesDispatchModelConfig: normalizeOpenAIMessagesDispatchModelConfig(input.MessagesDispatchModelConfig), + RPMLimit: input.RPMLimit, } sanitizeGroupMessagesDispatchFields(group) if err := s.groupRepo.Create(ctx, group); err != nil { @@ -1548,12 +1663,19 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd if input.MessagesDispatchModelConfig != nil { group.MessagesDispatchModelConfig = normalizeOpenAIMessagesDispatchModelConfig(*input.MessagesDispatchModelConfig) } + if input.RPMLimit != nil { + group.RPMLimit = *input.RPMLimit + } sanitizeGroupMessagesDispatchFields(group) if err := s.groupRepo.Update(ctx, group); err != nil { return nil, err } + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id) + } + // 如果指定了复制账号的源分组,同步绑定(替换当前分组的账号) if len(input.CopyAccountsFromGroupIDs) > 0 { // 去重源分组 IDs @@ -1622,9 +1744,6 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd } } - if s.authCacheInvalidator != nil { - s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, id) - } return group, nil } @@ -1700,6 +1819,39 @@ func (s *adminServiceImpl) BatchSetGroupRateMultipliers(ctx context.Context, gro return s.userGroupRateRepo.SyncGroupRateMultipliers(ctx, groupID, entries) } +func (s *adminServiceImpl) ClearGroupRPMOverrides(ctx context.Context, groupID int64) error { + if s.userGroupRateRepo == nil { + return nil + } + if err := s.userGroupRateRepo.ClearGroupRPMOverrides(ctx, groupID); err != nil { + return err + } + // RPM override 已嵌入 auth cache snapshot (v7),变更后必须失效相关缓存。 + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, groupID) + } + return nil +} + +func (s *adminServiceImpl) BatchSetGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error { + if s.userGroupRateRepo == nil { + return nil + } + for _, e := range entries { + if e.RPMOverride != nil && *e.RPMOverride < 0 { + return infraerrors.BadRequest("INVALID_RPM_OVERRIDE", fmt.Sprintf("rpm_override must be >= 0 (user_id=%d)", e.UserID)) + } + } + if err := s.userGroupRateRepo.SyncGroupRPMOverrides(ctx, groupID, entries); err != nil { + return err + } + // RPM override 已嵌入 auth cache snapshot (v7),变更后必须失效相关缓存。 + if s.authCacheInvalidator != nil { + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, groupID) + } + return nil +} + func (s *adminServiceImpl) UpdateGroupSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error { return s.groupRepo.UpdateSortOrders(ctx, updates) } diff --git a/backend/internal/service/admin_service_group_rate_test.go b/backend/internal/service/admin_service_group_rate_test.go index 77635247..d2efb644 100644 --- a/backend/internal/service/admin_service_group_rate_test.go +++ b/backend/internal/service/admin_service_group_rate_test.go @@ -5,8 +5,10 @@ package service import ( "context" "errors" + "net/http" "testing" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" "github.com/stretchr/testify/require" ) @@ -21,6 +23,10 @@ type userGroupRateRepoStubForGroupRate struct { syncedGroupID int64 syncedEntries []GroupRateMultiplierInput syncGroupErr error + + rpmSyncedGroupID int64 + rpmSyncedEntries []GroupRPMOverrideInput + rpmSyncErr error } func (s *userGroupRateRepoStubForGroupRate) GetByUserID(_ context.Context, _ int64) (map[int64]float64, error) { @@ -31,6 +37,10 @@ func (s *userGroupRateRepoStubForGroupRate) GetByUserAndGroup(_ context.Context, panic("unexpected GetByUserAndGroup call") } +func (s *userGroupRateRepoStubForGroupRate) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) { + panic("unexpected GetRPMOverrideByUserAndGroup call") +} + func (s *userGroupRateRepoStubForGroupRate) GetByGroupID(_ context.Context, groupID int64) ([]UserGroupRateEntry, error) { if s.getByGroupIDErr != nil { return nil, s.getByGroupIDErr @@ -48,6 +58,16 @@ func (s *userGroupRateRepoStubForGroupRate) SyncGroupRateMultipliers(_ context.C return s.syncGroupErr } +func (s *userGroupRateRepoStubForGroupRate) SyncGroupRPMOverrides(_ context.Context, groupID int64, entries []GroupRPMOverrideInput) error { + s.rpmSyncedGroupID = groupID + s.rpmSyncedEntries = entries + return s.rpmSyncErr +} + +func (s *userGroupRateRepoStubForGroupRate) ClearGroupRPMOverrides(_ context.Context, _ int64) error { + panic("unexpected ClearGroupRPMOverrides call") +} + func (s *userGroupRateRepoStubForGroupRate) DeleteByGroupID(_ context.Context, groupID int64) error { s.deletedGroupIDs = append(s.deletedGroupIDs, groupID) return s.deleteByGroupErr @@ -62,8 +82,8 @@ func TestAdminService_GetGroupRateMultipliers(t *testing.T) { repo := &userGroupRateRepoStubForGroupRate{ getByGroupIDData: map[int64][]UserGroupRateEntry{ 10: { - {UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: 1.5}, - {UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: 0.8}, + {UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: ptrFloat(1.5)}, + {UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: ptrFloat(0.8)}, }, }, } @@ -74,9 +94,11 @@ func TestAdminService_GetGroupRateMultipliers(t *testing.T) { require.Len(t, entries, 2) require.Equal(t, int64(1), entries[0].UserID) require.Equal(t, "alice", entries[0].UserName) - require.Equal(t, 1.5, entries[0].RateMultiplier) + require.NotNil(t, entries[0].RateMultiplier) + require.Equal(t, 1.5, *entries[0].RateMultiplier) require.Equal(t, int64(2), entries[1].UserID) - require.Equal(t, 0.8, entries[1].RateMultiplier) + require.NotNil(t, entries[1].RateMultiplier) + require.Equal(t, 0.8, *entries[1].RateMultiplier) }) t.Run("returns nil when repo is nil", func(t *testing.T) { @@ -174,3 +196,30 @@ func TestAdminService_BatchSetGroupRateMultipliers(t *testing.T) { require.Contains(t, err.Error(), "sync failed") }) } + +func TestAdminService_BatchSetGroupRPMOverrides(t *testing.T) { + t.Run("syncs entries to repo", func(t *testing.T) { + repo := &userGroupRateRepoStubForGroupRate{} + svc := &adminServiceImpl{userGroupRateRepo: repo} + override := 20 + entries := []GroupRPMOverrideInput{{UserID: 2, RPMOverride: &override}} + + err := svc.BatchSetGroupRPMOverrides(context.Background(), 10, entries) + require.NoError(t, err) + require.Equal(t, int64(10), repo.rpmSyncedGroupID) + require.Equal(t, entries, repo.rpmSyncedEntries) + }) + + t.Run("rejects negative override as bad request", func(t *testing.T) { + repo := &userGroupRateRepoStubForGroupRate{} + svc := &adminServiceImpl{userGroupRateRepo: repo} + negative := -1 + + err := svc.BatchSetGroupRPMOverrides(context.Background(), 10, []GroupRPMOverrideInput{ + {UserID: 2, RPMOverride: &negative}, + }) + require.Error(t, err) + require.Equal(t, http.StatusBadRequest, infraerrors.Code(err)) + require.Zero(t, repo.rpmSyncedGroupID) + }) +} diff --git a/backend/internal/service/admin_service_group_test.go b/backend/internal/service/admin_service_group_test.go index 41d2c26a..eef02240 100644 --- a/backend/internal/service/admin_service_group_test.go +++ b/backend/internal/service/admin_service_group_test.go @@ -266,6 +266,31 @@ func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) { require.Nil(t, repo.updated.ImagePrice4K) } +func TestAdminService_UpdateGroup_InvalidatesAuthCacheOnRPMLimitChange(t *testing.T) { + existingGroup := &Group{ + ID: 1, + Name: "existing-group", + Platform: PlatformAnthropic, + Status: StatusActive, + RPMLimit: 10, + } + repo := &groupRepoStubForAdmin{getByID: existingGroup} + invalidator := &authCacheInvalidatorStub{} + svc := &adminServiceImpl{ + groupRepo: repo, + authCacheInvalidator: invalidator, + } + + rpmLimit := 60 + group, err := svc.UpdateGroup(context.Background(), 1, &UpdateGroupInput{ + RPMLimit: &rpmLimit, + }) + require.NoError(t, err) + require.NotNil(t, group) + require.Equal(t, 60, repo.updated.RPMLimit) + require.Equal(t, []int64{1}, invalidator.groupIDs, "分组 RPMLimit 写入 auth snapshot,变更后必须失效 API Key 认证缓存") +} + func TestAdminService_CreateGroup_NormalizesMessagesDispatchModelConfig(t *testing.T) { repo := &groupRepoStubForAdmin{} svc := &adminServiceImpl{groupRepo: repo} diff --git a/backend/internal/service/admin_service_list_users_test.go b/backend/internal/service/admin_service_list_users_test.go index 657616c4..ff3f65a8 100644 --- a/backend/internal/service/admin_service_list_users_test.go +++ b/backend/internal/service/admin_service_list_users_test.go @@ -89,6 +89,10 @@ func (s *userGroupRateRepoStubForListUsers) GetByUserAndGroup(_ context.Context, panic("unexpected GetByUserAndGroup call") } +func (s *userGroupRateRepoStubForListUsers) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) { + panic("unexpected GetRPMOverrideByUserAndGroup call") +} + func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context, userID int64, rates map[int64]*float64) error { panic("unexpected SyncUserGroupRates call") } @@ -101,6 +105,14 @@ func (s *userGroupRateRepoStubForListUsers) SyncGroupRateMultipliers(_ context.C panic("unexpected SyncGroupRateMultipliers call") } +func (s *userGroupRateRepoStubForListUsers) SyncGroupRPMOverrides(_ context.Context, _ int64, _ []GroupRPMOverrideInput) error { + panic("unexpected SyncGroupRPMOverrides call") +} + +func (s *userGroupRateRepoStubForListUsers) ClearGroupRPMOverrides(_ context.Context, _ int64) error { + panic("unexpected ClearGroupRPMOverrides call") +} + func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, _ int64) error { panic("unexpected DeleteByGroupID call") } diff --git a/backend/internal/service/admin_service_rpm_status_test.go b/backend/internal/service/admin_service_rpm_status_test.go new file mode 100644 index 00000000..c298f69b --- /dev/null +++ b/backend/internal/service/admin_service_rpm_status_test.go @@ -0,0 +1,112 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +type rpmStatusUserRepoStub struct { + UserRepository + user *User +} + +func (s *rpmStatusUserRepoStub) GetByID(_ context.Context, _ int64) (*User, error) { + return s.user, nil +} + +type rpmStatusAPIKeyRepoStub struct { + APIKeyRepository + keys []APIKey +} + +func (s *rpmStatusAPIKeyRepoStub) ListByUserID(_ context.Context, _ int64, _ pagination.PaginationParams, _ APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) { + return s.keys, &pagination.PaginationResult{Total: int64(len(s.keys))}, nil +} + +type rpmStatusGroupRepoStub struct { + GroupRepository + groups map[int64]*Group +} + +func (s *rpmStatusGroupRepoStub) GetByIDLite(_ context.Context, id int64) (*Group, error) { + return s.groups[id], nil +} + +type rpmStatusRateRepoStub struct { + UserGroupRateRepository + overrides map[int64]*int +} + +func (s *rpmStatusRateRepoStub) GetRPMOverrideByUserAndGroup(_ context.Context, _, groupID int64) (*int, error) { + return s.overrides[groupID], nil +} + +type rpmStatusCacheStub struct { + UserRPMCache + userUsed int + groupUsed map[int64]int +} + +func (s *rpmStatusCacheStub) IncrementUserGroupRPM(context.Context, int64, int64) (int, error) { + return 0, nil +} + +func (s *rpmStatusCacheStub) IncrementUserRPM(context.Context, int64) (int, error) { + return 0, nil +} + +func (s *rpmStatusCacheStub) GetUserGroupRPM(_ context.Context, _, groupID int64) (int, error) { + return s.groupUsed[groupID], nil +} + +func (s *rpmStatusCacheStub) GetUserRPM(context.Context, int64) (int, error) { + return s.userUsed, nil +} + +func TestAdminService_GetUserRPMStatus_AggregatesUserAndGroupLimits(t *testing.T) { + groupOneID := int64(1) + groupTwoID := int64(2) + override := 7 + svc := &adminServiceImpl{ + userRepo: &rpmStatusUserRepoStub{user: &User{ + ID: 42, + RPMLimit: 20, + }}, + apiKeyRepo: &rpmStatusAPIKeyRepoStub{keys: []APIKey{ + {ID: 100, UserID: 42, GroupID: &groupTwoID}, + {ID: 101, UserID: 42, GroupID: &groupOneID}, + {ID: 102, UserID: 42, GroupID: &groupTwoID}, + {ID: 103, UserID: 42}, + }}, + groupRepo: &rpmStatusGroupRepoStub{groups: map[int64]*Group{ + groupOneID: {ID: groupOneID, Name: "group-one", RPMLimit: 10}, + groupTwoID: {ID: groupTwoID, Name: "group-two", RPMLimit: 60}, + }}, + userGroupRateRepo: &rpmStatusRateRepoStub{overrides: map[int64]*int{ + groupTwoID: &override, + }}, + userRPMCache: &rpmStatusCacheStub{ + userUsed: 5, + groupUsed: map[int64]int{ + groupOneID: 3, + groupTwoID: 4, + }, + }, + } + + status, err := svc.GetUserRPMStatus(context.Background(), 42) + require.NoError(t, err) + require.Equal(t, &UserRPMStatus{ + UserRPMUsed: 5, + UserRPMLimit: 20, + PerGroup: []UserGroupRPMStatus{ + {GroupID: groupOneID, GroupName: "group-one", Used: 3, Limit: 10, Source: "group"}, + {GroupID: groupTwoID, GroupName: "group-two", Used: 4, Limit: 7, Source: "override"}, + }, + }, status) +} diff --git a/backend/internal/service/admin_service_update_user_rpm_test.go b/backend/internal/service/admin_service_update_user_rpm_test.go new file mode 100644 index 00000000..cb4c3986 --- /dev/null +++ b/backend/internal/service/admin_service_update_user_rpm_test.go @@ -0,0 +1,69 @@ +//go:build unit + +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +// rpmUserRepoStub 复用 admin_service_update_balance_test.go 的基础 stub 结构, +// 只在 Update 时把入参克隆一份,便于断言修改后的 RPMLimit。 +type rpmUserRepoStub struct { + *userRepoStub + lastUpdated *User +} + +func (s *rpmUserRepoStub) Update(_ context.Context, user *User) error { + if user == nil { + return nil + } + clone := *user + s.lastUpdated = &clone + if s.userRepoStub != nil { + s.userRepoStub.user = &clone + } + return nil +} + +func TestAdminService_UpdateUser_InvalidatesAuthCacheOnRPMLimitChange(t *testing.T) { + base := &userRepoStub{user: &User{ID: 42, Email: "u@example.com", RPMLimit: 10}} + repo := &rpmUserRepoStub{userRepoStub: base} + invalidator := &authCacheInvalidatorStub{} + svc := &adminServiceImpl{ + userRepo: repo, + redeemCodeRepo: &redeemRepoStub{}, + authCacheInvalidator: invalidator, + } + + newRPM := 60 + updated, err := svc.UpdateUser(context.Background(), 42, &UpdateUserInput{ + RPMLimit: &newRPM, + }) + require.NoError(t, err) + require.NotNil(t, updated) + require.Equal(t, 60, updated.RPMLimit) + require.Equal(t, []int64{42}, invalidator.userIDs, "仅修改 RPMLimit 也应失效 API Key 认证缓存") +} + +func TestAdminService_UpdateUser_NoInvalidateWhenRPMLimitUnchanged(t *testing.T) { + base := &userRepoStub{user: &User{ID: 42, Email: "u@example.com", RPMLimit: 10, Username: "old"}} + repo := &rpmUserRepoStub{userRepoStub: base} + invalidator := &authCacheInvalidatorStub{} + svc := &adminServiceImpl{ + userRepo: repo, + redeemCodeRepo: &redeemRepoStub{}, + authCacheInvalidator: invalidator, + } + + newName := "new" + sameRPM := 10 + _, err := svc.UpdateUser(context.Background(), 42, &UpdateUserInput{ + Username: &newName, + RPMLimit: &sameRPM, + }) + require.NoError(t, err) + require.Empty(t, invalidator.userIDs, "只改 username 不应触发认证缓存失效") +} diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index b1660ea7..1a1c78b8 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -43,6 +43,13 @@ type APIKeyAuthUserSnapshot struct { BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"` BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails,omitempty"` TotalRecharged float64 `json:"total_recharged"` + + // RPMLimit 用户级每分钟请求数上限(0 = 不限制);用于 billing_cache_service.checkRPM 兜底判断。 + RPMLimit int `json:"rpm_limit"` + + // UserGroupRPMOverride 该 API Key 对应的 (user, group) 专属 RPM 覆盖值。 + // nil = 无 override(回退到 group/user 级);0 = 不限流;>0 = 专属上限。 + UserGroupRPMOverride *int `json:"user_group_rpm_override,omitempty"` } // APIKeyAuthGroupSnapshot 分组快照 @@ -76,6 +83,9 @@ type APIKeyAuthGroupSnapshot struct { AllowMessagesDispatch bool `json:"allow_messages_dispatch"` DefaultMappedModel string `json:"default_mapped_model,omitempty"` MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig `json:"messages_dispatch_model_config,omitempty"` + + // RPMLimit 分组级每分钟请求数上限(0 = 不限制);用于 billing_cache_service.checkRPM 级联判断。 + RPMLimit int `json:"rpm_limit"` } // APIKeyAuthCacheEntry 缓存条目,支持负缓存 diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index 2bd9a091..974ea66e 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -14,7 +14,7 @@ import ( "github.com/dgraph-io/ristretto" ) -const apiKeyAuthSnapshotVersion = 5 // v5: added TotalRecharged for percentage threshold +const apiKeyAuthSnapshotVersion = 7 // v7: added UserGroupRPMOverride on user snapshot type apiKeyAuthCacheConfig struct { l1Size int @@ -176,7 +176,7 @@ func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey st return nil, fmt.Errorf("get api key: %w", err) } apiKey.Key = key - snapshot := s.snapshotFromAPIKey(apiKey) + snapshot := s.snapshotFromAPIKey(ctx, apiKey) if snapshot == nil { return nil, fmt.Errorf("get api key: %w", ErrAPIKeyNotFound) } @@ -201,7 +201,7 @@ func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEn return s.snapshotToAPIKey(key, entry.Snapshot), true, nil } -func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { +func (s *APIKeyService) snapshotFromAPIKey(ctx context.Context, apiKey *APIKey) *APIKeyAuthSnapshot { if apiKey == nil || apiKey.User == nil { return nil } @@ -232,8 +232,18 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { BalanceNotifyThreshold: apiKey.User.BalanceNotifyThreshold, BalanceNotifyExtraEmails: apiKey.User.BalanceNotifyExtraEmails, TotalRecharged: apiKey.User.TotalRecharged, + RPMLimit: apiKey.User.RPMLimit, }, } + + // 填充 (user, group) RPM override —— snapshot 构建时查一次 DB,后续请求零 DB 往返。 + if apiKey.GroupID != nil && *apiKey.GroupID > 0 && s.userGroupRateRepo != nil { + override, err := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, apiKey.UserID, *apiKey.GroupID) + if err == nil && override != nil { + snapshot.User.UserGroupRPMOverride = override + } + // 查询失败或无 override 时留 nil,checkRPM 会回退到 DB 查询 + } if apiKey.Group != nil { snapshot.Group = &APIKeyAuthGroupSnapshot{ ID: apiKey.Group.ID, @@ -258,6 +268,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch, DefaultMappedModel: apiKey.Group.DefaultMappedModel, MessagesDispatchModelConfig: apiKey.Group.MessagesDispatchModelConfig, + RPMLimit: apiKey.Group.RPMLimit, } } return snapshot @@ -294,6 +305,8 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho BalanceNotifyThreshold: snapshot.User.BalanceNotifyThreshold, BalanceNotifyExtraEmails: snapshot.User.BalanceNotifyExtraEmails, TotalRecharged: snapshot.User.TotalRecharged, + RPMLimit: snapshot.User.RPMLimit, + UserGroupRPMOverride: snapshot.User.UserGroupRPMOverride, }, } if snapshot.Group != nil { @@ -321,6 +334,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch, DefaultMappedModel: snapshot.Group.DefaultMappedModel, MessagesDispatchModelConfig: snapshot.Group.MessagesDispatchModelConfig, + RPMLimit: snapshot.Group.RPMLimit, } } s.compileAPIKeyIPRules(apiKey) diff --git a/backend/internal/service/api_key_service_cache_test.go b/backend/internal/service/api_key_service_cache_test.go index 3c2f7dbb..8cb1b8c4 100644 --- a/backend/internal/service/api_key_service_cache_test.go +++ b/backend/internal/service/api_key_service_cache_test.go @@ -263,7 +263,7 @@ func TestAPIKeyService_SnapshotRoundTrip_PreservesMessagesDispatchModelConfig(t }, } - snapshot := svc.snapshotFromAPIKey(apiKey) + snapshot := svc.snapshotFromAPIKey(context.Background(), apiKey) roundTrip := svc.snapshotToAPIKey(apiKey.Key, snapshot) require.NotNil(t, roundTrip) diff --git a/backend/internal/service/auth_service.go b/backend/internal/service/auth_service.go index 3bf9da3d..e45d8d66 100644 --- a/backend/internal/service/auth_service.go +++ b/backend/internal/service/auth_service.go @@ -196,6 +196,12 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw grantPlan := s.resolveSignupGrantPlan(ctx, "email") + // 新用户默认 RPM(0 = 不限制)。注册时写入,后续作为用户级兜底。 + var defaultRPMLimit int + if s.settingService != nil { + defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx) + } + // 创建用户 user := &User{ Email: email, @@ -203,6 +209,7 @@ func (s *AuthService) RegisterWithVerification(ctx context.Context, email, passw Role: RoleUser, Balance: grantPlan.Balance, Concurrency: grantPlan.Concurrency, + RPMLimit: defaultRPMLimit, Status: StatusActive, } @@ -481,6 +488,10 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username signupSource := inferLegacySignupSource(email) grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) + var defaultRPMLimit int + if s.settingService != nil { + defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx) + } newUser := &User{ Email: email, @@ -489,6 +500,7 @@ func (s *AuthService) LoginOrRegisterOAuth(ctx context.Context, email, username Role: RoleUser, Balance: grantPlan.Balance, Concurrency: grantPlan.Concurrency, + RPMLimit: defaultRPMLimit, Status: StatusActive, SignupSource: signupSource, } @@ -592,6 +604,10 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema signupSource := inferLegacySignupSource(email) grantPlan := s.resolveSignupGrantPlan(ctx, signupSource) + var defaultRPMLimit int + if s.settingService != nil { + defaultRPMLimit = s.settingService.GetDefaultUserRPMLimit(ctx) + } newUser := &User{ Email: email, @@ -600,6 +616,7 @@ func (s *AuthService) LoginOrRegisterOAuthWithTokenPair(ctx context.Context, ema Role: RoleUser, Balance: grantPlan.Balance, Concurrency: grantPlan.Concurrency, + RPMLimit: defaultRPMLimit, Status: StatusActive, SignupSource: signupSource, } diff --git a/backend/internal/service/billing_cache_service.go b/backend/internal/service/billing_cache_service.go index f2ad0a3d..4e695eb9 100644 --- a/backend/internal/service/billing_cache_service.go +++ b/backend/internal/service/billing_cache_service.go @@ -20,6 +20,9 @@ import ( var ( ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired") ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.") + // RPM 超限错误。gateway_handler 负责映射为 HTTP 429。 + ErrGroupRPMExceeded = infraerrors.TooManyRequests("GROUP_RPM_EXCEEDED", "group requests-per-minute limit exceeded") + ErrUserRPMExceeded = infraerrors.TooManyRequests("USER_RPM_EXCEEDED", "user requests-per-minute limit exceeded") ) // subscriptionCacheData 订阅缓存数据结构(内部使用) @@ -87,6 +90,8 @@ type BillingCacheService struct { userRepo UserRepository subRepo UserSubscriptionRepository apiKeyRateLimitLoader apiKeyRateLimitLoader + userRPMCache UserRPMCache + userGroupRateRepo UserGroupRateRepository cfg *config.Config circuitBreaker *billingCircuitBreaker @@ -104,12 +109,22 @@ type BillingCacheService struct { } // NewBillingCacheService 创建计费缓存服务 -func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, apiKeyRepo APIKeyRepository, cfg *config.Config) *BillingCacheService { +func NewBillingCacheService( + cache BillingCache, + userRepo UserRepository, + subRepo UserSubscriptionRepository, + apiKeyRepo APIKeyRepository, + userRPMCache UserRPMCache, + userGroupRateRepo UserGroupRateRepository, + cfg *config.Config, +) *BillingCacheService { svc := &BillingCacheService{ cache: cache, userRepo: userRepo, subRepo: subRepo, apiKeyRateLimitLoader: apiKeyRepo, + userRPMCache: userRPMCache, + userGroupRateRepo: userGroupRateRepo, cfg: cfg, } svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker) @@ -664,6 +679,95 @@ func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user } } + // RPM 限流:级联回落(Override → Group → User),放在最后以避免为注定失败的请求增加计数。 + if err := s.checkRPM(ctx, user, group); err != nil { + return err + } + + return nil +} + +// checkRPM 执行并行 RPM 限流,所有适用的限制同时生效,任一超限即拒绝: +// +// 1. (用户, 分组) rpm_override — 最细粒度:管理员为特定用户在特定分组设定的专属限额。 +// override=0 表示该用户在该分组免检(绿灯),但 user 级全局上限仍然生效。 +// 2. group.rpm_limit — 分组级:该分组的统一 RPM 容量(仅当无 override 时生效)。 +// 3. user.rpm_limit — 用户级全局硬上限:无论 override/group 如何配置,始终生效。 +// +// 与旧版"级联互斥"设计不同,新版确保 user.rpm_limit 作为全局天花板不会被 group 或 override 覆盖。 +// Redis 故障一律 fail-open(打 warning,不阻塞业务)。 +func (s *BillingCacheService) checkRPM(ctx context.Context, user *User, group *Group) error { + if s == nil || s.userRPMCache == nil || user == nil { + return nil + } + + // ── 第一层:分组级检查(override 或 group.rpm_limit) ── + if group != nil { + // 解析 override:优先从 auth cache snapshot,nil 时回退 DB。 + var override *int + if user.UserGroupRPMOverride != nil { + override = user.UserGroupRPMOverride + } else if s.userGroupRateRepo != nil { + dbOverride, err := s.userGroupRateRepo.GetRPMOverrideByUserAndGroup(ctx, user.ID, group.ID) + if err != nil { + logger.LegacyPrintf( + "service.billing_cache", + "Warning: rpm override lookup failed for user=%d group=%d: %v", + user.ID, group.ID, err, + ) + } else { + override = dbOverride + } + } + + if override != nil { + // override=0 → 该用户在该分组免检(但 user 级仍会在下面检查)。 + if *override > 0 { + count, incErr := s.userRPMCache.IncrementUserGroupRPM(ctx, user.ID, group.ID) + if incErr != nil { + logger.LegacyPrintf( + "service.billing_cache", + "Warning: rpm increment (override) failed for user=%d group=%d: %v", + user.ID, group.ID, incErr, + ) + // fail-open + } else if count > *override { + return ErrGroupRPMExceeded + } + } + // override 命中后跳过 group.rpm_limit(override 替代 group),但不 return——继续检查 user 级。 + } else if group.RPMLimit > 0 { + // 无 override,检查 group.rpm_limit。 + count, err := s.userRPMCache.IncrementUserGroupRPM(ctx, user.ID, group.ID) + if err != nil { + logger.LegacyPrintf( + "service.billing_cache", + "Warning: rpm increment (group) failed for user=%d group=%d: %v", + user.ID, group.ID, err, + ) + // fail-open + } else if count > group.RPMLimit { + return ErrGroupRPMExceeded + } + } + } + + // ── 第二层:用户级全局硬上限(始终生效) ── + if user.RPMLimit > 0 { + count, err := s.userRPMCache.IncrementUserRPM(ctx, user.ID) + if err != nil { + logger.LegacyPrintf( + "service.billing_cache", + "Warning: rpm increment (user) failed for user=%d: %v", + user.ID, err, + ) + return nil // fail-open + } + if count > user.RPMLimit { + return ErrUserRPMExceeded + } + } + return nil } diff --git a/backend/internal/service/billing_cache_service_rpm_test.go b/backend/internal/service/billing_cache_service_rpm_test.go new file mode 100644 index 00000000..de66136f --- /dev/null +++ b/backend/internal/service/billing_cache_service_rpm_test.go @@ -0,0 +1,253 @@ +//go:build unit + +package service + +import ( + "context" + "errors" + "sync/atomic" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +// userRPMCacheStub 记录每种计数器被调用的次数,并可注入返回值与错误。 +type userRPMCacheStub struct { + userGroupCalls int32 + userCalls int32 + + userGroupCounts []int // 依次返回的计数值 + userGroupErr error + userCounts []int + userErr error +} + +func (s *userRPMCacheStub) IncrementUserGroupRPM(_ context.Context, _, _ int64) (int, error) { + idx := int(atomic.AddInt32(&s.userGroupCalls, 1)) - 1 + if s.userGroupErr != nil { + return 0, s.userGroupErr + } + if idx < len(s.userGroupCounts) { + return s.userGroupCounts[idx], nil + } + return 1, nil +} + +func (s *userRPMCacheStub) IncrementUserRPM(_ context.Context, _ int64) (int, error) { + idx := int(atomic.AddInt32(&s.userCalls, 1)) - 1 + if s.userErr != nil { + return 0, s.userErr + } + if idx < len(s.userCounts) { + return s.userCounts[idx], nil + } + return 1, nil +} + +func (s *userRPMCacheStub) GetUserGroupRPM(_ context.Context, _, _ int64) (int, error) { + return 0, nil +} + +func (s *userRPMCacheStub) GetUserRPM(_ context.Context, _ int64) (int, error) { + return 0, nil +} + +// rpmOverrideRepoStub 专用于 checkRPM 分支测试,只实现必要方法。 +type rpmOverrideRepoStub struct { + UserGroupRateRepository + + override *int + err error + calls int32 +} + +func (s *rpmOverrideRepoStub) GetRPMOverrideByUserAndGroup(_ context.Context, _, _ int64) (*int, error) { + atomic.AddInt32(&s.calls, 1) + if s.err != nil { + return nil, s.err + } + return s.override, nil +} + +func newBillingServiceForRPM(t *testing.T, cache UserRPMCache, rateRepo UserGroupRateRepository) *BillingCacheService { + t.Helper() + // 用 nil BillingCache 走 "无缓存" 分支,避免 CheckBillingEligibility 副作用。 + // 我们只直接测 checkRPM。 + svc := NewBillingCacheService(nil, nil, nil, nil, cache, rateRepo, &config.Config{}) + t.Cleanup(svc.Stop) + return svc +} + +func TestBillingCacheService_CheckRPM_OverrideTakesPrecedenceOverGroup(t *testing.T) { + override := 2 + // user-group 计数: 1, 2, 3;user 计数: 默认返回 1(远小于 RPMLimit=100,不干扰) + cache := &userRPMCacheStub{userGroupCounts: []int{1, 2, 3}} + repo := &rpmOverrideRepoStub{override: &override} + svc := newBillingServiceForRPM(t, cache, repo) + + user := &User{ID: 1, RPMLimit: 100} // 全局上限设高,不干扰 override 测试 + group := &Group{ID: 10, RPMLimit: 100} + + require.NoError(t, svc.checkRPM(context.Background(), user, group)) + require.NoError(t, svc.checkRPM(context.Background(), user, group)) + require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrGroupRPMExceeded) + + require.EqualValues(t, 3, atomic.LoadInt32(&cache.userGroupCalls), "override 命中分支应走 user-group 计数") + // 并行设计:前 2 次 override 未超→继续检查 user;第 3 次 override 超了→直接 return,不检查 user + require.EqualValues(t, 2, atomic.LoadInt32(&cache.userCalls), "override 超限前 user 计数器应被调用") + require.EqualValues(t, 3, atomic.LoadInt32(&repo.calls)) +} + +func TestBillingCacheService_CheckRPM_UserLimitIsGlobalHardCap(t *testing.T) { + override := 100 // override 很高 + // user-group 计数: 默认返回 1(远小于 override);user 计数: 1, 2, 3 + cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}} + repo := &rpmOverrideRepoStub{override: &override} + svc := newBillingServiceForRPM(t, cache, repo) + + user := &User{ID: 1, RPMLimit: 2} // 全局硬上限=2,应覆盖 override=100 + group := &Group{ID: 10, RPMLimit: 100} + + require.NoError(t, svc.checkRPM(context.Background(), user, group)) + require.NoError(t, svc.checkRPM(context.Background(), user, group)) + require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded, "user 全局硬上限应优先于 override") +} + +func TestBillingCacheService_CheckRPM_OverrideZeroSkipsGroupButUserStillApplies(t *testing.T) { + zero := 0 + // user 计数: 依次返回 1..6 + cache := &userRPMCacheStub{userCounts: []int{1, 2, 3, 4, 5, 6}} + repo := &rpmOverrideRepoStub{override: &zero} + svc := newBillingServiceForRPM(t, cache, repo) + + user := &User{ID: 1, RPMLimit: 5} + group := &Group{ID: 10, RPMLimit: 100} + + // override=0 跳过分组计数,但 user.RPMLimit=5 仍生效 + for i := 0; i < 5; i++ { + require.NoError(t, svc.checkRPM(context.Background(), user, group), "request %d should pass", i+1) + } + require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded, + "override=0 跳过分组但 user 全局上限仍应生效") + require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "override=0 不应触发分组计数器") + require.EqualValues(t, 6, atomic.LoadInt32(&cache.userCalls), "user 计数器应被调用") +} + +func TestBillingCacheService_CheckRPM_OverrideZeroAndUserZeroIsFullyUnlimited(t *testing.T) { + zero := 0 + cache := &userRPMCacheStub{} + repo := &rpmOverrideRepoStub{override: &zero} + svc := newBillingServiceForRPM(t, cache, repo) + + user := &User{ID: 1, RPMLimit: 0} // user 也不限 + group := &Group{ID: 10, RPMLimit: 100} + + for i := 0; i < 50; i++ { + require.NoError(t, svc.checkRPM(context.Background(), user, group)) + } + require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "override=0 不触发分组计数") + require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls), "user.RPMLimit=0 也不触发用户计数") +} + +func TestBillingCacheService_CheckRPM_NilOverrideFallsThroughToGroup(t *testing.T) { + // user-group 计数: 5, 6;user 计数: 默认 1(不干扰) + cache := &userRPMCacheStub{userGroupCounts: []int{5, 6}} + repo := &rpmOverrideRepoStub{override: nil} + svc := newBillingServiceForRPM(t, cache, repo) + + user := &User{ID: 1, RPMLimit: 999} // 全局上限很高,group 先超 + group := &Group{ID: 10, RPMLimit: 5} + + require.NoError(t, svc.checkRPM(context.Background(), user, group)) // ug=5, user=1, 都没超 + require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrGroupRPMExceeded) // ug=6 > 5 + + require.EqualValues(t, 2, atomic.LoadInt32(&cache.userGroupCalls)) + // 并行模式:第 1 次 group 没超 → 继续检查 user;第 2 次 group 超了 → 直接 return,不检查 user + require.EqualValues(t, 1, atomic.LoadInt32(&cache.userCalls), "group 未超时 user 也应检查;group 超时直接返回") +} + +func TestBillingCacheService_CheckRPM_OverrideLookupErrorFallsThroughToGroup(t *testing.T) { + cache := &userRPMCacheStub{userGroupCounts: []int{3}} + repo := &rpmOverrideRepoStub{err: errors.New("db down")} + svc := newBillingServiceForRPM(t, cache, repo) + + user := &User{ID: 1, RPMLimit: 0} + group := &Group{ID: 10, RPMLimit: 10} + + // override 查询失败后应继续尝试 group 分支(不直接拒绝) + require.NoError(t, svc.checkRPM(context.Background(), user, group)) + require.EqualValues(t, 1, atomic.LoadInt32(&cache.userGroupCalls)) + require.EqualValues(t, 1, atomic.LoadInt32(&repo.calls)) +} + +func TestBillingCacheService_CheckRPM_UserLevelFallbackWhenGroupUnlimited(t *testing.T) { + cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}} + repo := &rpmOverrideRepoStub{override: nil} + svc := newBillingServiceForRPM(t, cache, repo) + + user := &User{ID: 1, RPMLimit: 2} + group := &Group{ID: 10, RPMLimit: 0} // 分组未设限 + + require.NoError(t, svc.checkRPM(context.Background(), user, group)) + require.NoError(t, svc.checkRPM(context.Background(), user, group)) + require.ErrorIs(t, svc.checkRPM(context.Background(), user, group), ErrUserRPMExceeded) + + require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls), "group 未设限时不应 INCR user-group 键") + require.EqualValues(t, 3, atomic.LoadInt32(&cache.userCalls)) +} + +func TestBillingCacheService_CheckRPM_NoLimitsConfiguredIsNoop(t *testing.T) { + cache := &userRPMCacheStub{} + repo := &rpmOverrideRepoStub{override: nil} + svc := newBillingServiceForRPM(t, cache, repo) + + user := &User{ID: 1, RPMLimit: 0} + group := &Group{ID: 10, RPMLimit: 0} + + for i := 0; i < 10; i++ { + require.NoError(t, svc.checkRPM(context.Background(), user, group)) + } + require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls)) + require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls)) +} + +func TestBillingCacheService_CheckRPM_RedisErrorFailOpen(t *testing.T) { + cache := &userRPMCacheStub{userGroupErr: errors.New("redis unavailable")} + repo := &rpmOverrideRepoStub{override: nil} + svc := newBillingServiceForRPM(t, cache, repo) + + user := &User{ID: 1, RPMLimit: 0} + group := &Group{ID: 10, RPMLimit: 5} + + // Redis 故障时应 fail-open,不拒绝请求 + require.NoError(t, svc.checkRPM(context.Background(), user, group)) + require.EqualValues(t, 1, atomic.LoadInt32(&cache.userGroupCalls)) +} + +func TestBillingCacheService_CheckRPM_NoGroupUsesUserOnly(t *testing.T) { + cache := &userRPMCacheStub{userCounts: []int{1, 2, 3}} + repo := &rpmOverrideRepoStub{} + svc := newBillingServiceForRPM(t, cache, repo) + + user := &User{ID: 1, RPMLimit: 2} + + // 无 group(纯用户级限流场景),不应查询 rpm_override。 + require.NoError(t, svc.checkRPM(context.Background(), user, nil)) + require.NoError(t, svc.checkRPM(context.Background(), user, nil)) + require.ErrorIs(t, svc.checkRPM(context.Background(), user, nil), ErrUserRPMExceeded) + + require.EqualValues(t, 0, atomic.LoadInt32(&repo.calls), "无 group 时不应查询 rpm_override") + require.EqualValues(t, 3, atomic.LoadInt32(&cache.userCalls)) +} + +func TestBillingCacheService_CheckRPM_NilUserIsNoop(t *testing.T) { + cache := &userRPMCacheStub{} + repo := &rpmOverrideRepoStub{} + svc := newBillingServiceForRPM(t, cache, repo) + + require.NoError(t, svc.checkRPM(context.Background(), nil, &Group{ID: 1, RPMLimit: 10})) + require.EqualValues(t, 0, atomic.LoadInt32(&cache.userGroupCalls)) + require.EqualValues(t, 0, atomic.LoadInt32(&cache.userCalls)) + require.EqualValues(t, 0, atomic.LoadInt32(&repo.calls)) +} diff --git a/backend/internal/service/billing_cache_service_singleflight_test.go b/backend/internal/service/billing_cache_service_singleflight_test.go index 0eaf4570..962becf0 100644 --- a/backend/internal/service/billing_cache_service_singleflight_test.go +++ b/backend/internal/service/billing_cache_service_singleflight_test.go @@ -100,7 +100,7 @@ func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) { delay: 80 * time.Millisecond, balance: 12.34, } - svc := NewBillingCacheService(cache, userRepo, nil, nil, &config.Config{}) + svc := NewBillingCacheService(cache, userRepo, nil, nil, nil, nil, &config.Config{}) t.Cleanup(svc.Stop) const goroutines = 16 diff --git a/backend/internal/service/billing_cache_service_test.go b/backend/internal/service/billing_cache_service_test.go index 7d7045e2..849e24b8 100644 --- a/backend/internal/service/billing_cache_service_test.go +++ b/backend/internal/service/billing_cache_service_test.go @@ -70,7 +70,7 @@ func (b *billingCacheWorkerStub) InvalidateAPIKeyRateLimit(ctx context.Context, func TestBillingCacheServiceQueueHighLoad(t *testing.T) { cache := &billingCacheWorkerStub{} - svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{}) + svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{}) t.Cleanup(svc.Stop) start := time.Now() @@ -92,7 +92,7 @@ func TestBillingCacheServiceQueueHighLoad(t *testing.T) { func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) { cache := &billingCacheWorkerStub{} - svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{}) + svc := NewBillingCacheService(cache, nil, nil, nil, nil, nil, &config.Config{}) svc.Stop() enqueued := svc.enqueueCacheWrite(cacheWriteTask{ diff --git a/backend/internal/service/billing_service.go b/backend/internal/service/billing_service.go index a45203a3..392b3e0b 100644 --- a/backend/internal/service/billing_service.go +++ b/backend/internal/service/billing_service.go @@ -217,6 +217,9 @@ func (s *BillingService) initFallbackPricing() { LongContextInputMultiplier: openAIGPT54LongContextInputMultiplier, LongContextOutputMultiplier: openAIGPT54LongContextOutputMultiplier, } + // GPT-5.5 暂无独立定价,回退到 GPT-5.4 + s.fallbackPrices["gpt-5.5"] = s.fallbackPrices["gpt-5.4"] + s.fallbackPrices["gpt-5.4-mini"] = &ModelPricing{ InputPricePerToken: 7.5e-7, OutputPricePerToken: 4.5e-6, @@ -288,6 +291,8 @@ func (s *BillingService) getFallbackPricing(model string) *ModelPricing { if strings.Contains(modelLower, "gpt-5") || strings.Contains(modelLower, "codex") { normalized := normalizeCodexModel(modelLower) switch normalized { + case "gpt-5.5": + return s.fallbackPrices["gpt-5.5"] case "gpt-5.4-mini": return s.fallbackPrices["gpt-5.4-mini"] case "gpt-5.4": @@ -637,7 +642,8 @@ func isOpenAIGPT54Model(model string) bool { if !strings.Contains(trimmed, "gpt-5") && !strings.Contains(trimmed, "codex") { return false } - return normalizeCodexModel(trimmed) == "gpt-5.4" + normalized := normalizeCodexModel(trimmed) + return normalized == "gpt-5.4" || normalized == "gpt-5.5" } // CalculateCostWithConfig 使用配置中的默认倍率计算费用 diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index f31255b6..cf47b76f 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -170,9 +170,10 @@ const ( SettingKeyCustomEndpoints = "custom_endpoints" // 自定义端点列表(JSON 数组) // 默认配置 - SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 - SettingKeyDefaultBalance = "default_balance" // 新用户默认余额 - SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON) + SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 + SettingKeyDefaultBalance = "default_balance" // 新用户默认余额 + SettingKeyDefaultSubscriptions = "default_subscriptions" // 新用户默认订阅列表(JSON) + SettingKeyDefaultUserRPMLimit = "default_user_rpm_limit" // 新用户默认 RPM 限制(0 = 不限制) // 第三方认证来源默认授予配置 SettingKeyAuthSourceDefaultEmailBalance = "auth_source_default_email_balance" diff --git a/backend/internal/service/group.go b/backend/internal/service/group.go index 64434ae1..bb4c5aa1 100644 --- a/backend/internal/service/group.go +++ b/backend/internal/service/group.go @@ -59,6 +59,10 @@ type Group struct { DefaultMappedModel string MessagesDispatchModelConfig OpenAIMessagesDispatchModelConfig + // RPMLimit 分组级每分钟请求数上限(0 = 不限制)。 + // 一旦设置即接管该分组用户的限流(覆盖用户级 rpm_limit),可被 user-group rpm_override 进一步覆盖。 + RPMLimit int + CreatedAt time.Time UpdatedAt time.Time diff --git a/backend/internal/service/openai_403_counter.go b/backend/internal/service/openai_403_counter.go new file mode 100644 index 00000000..5ba3e195 --- /dev/null +++ b/backend/internal/service/openai_403_counter.go @@ -0,0 +1,11 @@ +package service + +import "context" + +// OpenAI403CounterCache 追踪 OpenAI 账号连续 403 失败次数。 +type OpenAI403CounterCache interface { + // IncrementOpenAI403Count 原子递增 403 计数并返回当前值。 + IncrementOpenAI403Count(ctx context.Context, accountID int64, windowMinutes int) (int64, error) + // ResetOpenAI403Count 成功后清零计数器。 + ResetOpenAI403Count(ctx context.Context, accountID int64) error +} diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index a68c9b67..560db436 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -6,6 +6,7 @@ import ( ) var codexModelMap = map[string]string{ + "gpt-5.5": "gpt-5.5", "gpt-5.4": "gpt-5.4", "gpt-5.4-mini": "gpt-5.4-mini", "gpt-5.4-none": "gpt-5.4", @@ -207,6 +208,9 @@ func normalizeCodexModel(model string) string { normalized := strings.ToLower(modelID) + if strings.Contains(normalized, "gpt-5.5") || strings.Contains(normalized, "gpt 5.5") { + return "gpt-5.5" + } if strings.Contains(normalized, "gpt-5.4-mini") || strings.Contains(normalized, "gpt 5.4 mini") { return "gpt-5.4-mini" } diff --git a/backend/internal/service/openai_gateway_403_reset_test.go b/backend/internal/service/openai_gateway_403_reset_test.go new file mode 100644 index 00000000..c6805464 --- /dev/null +++ b/backend/internal/service/openai_gateway_403_reset_test.go @@ -0,0 +1,39 @@ +package service + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +type openAI403CounterResetStub struct { + resetCalls []int64 +} + +func (s *openAI403CounterResetStub) IncrementOpenAI403Count(context.Context, int64, int) (int64, error) { + return 0, nil +} + +func (s *openAI403CounterResetStub) ResetOpenAI403Count(_ context.Context, accountID int64) error { + s.resetCalls = append(s.resetCalls, accountID) + return nil +} + +func TestOpenAIGatewayServiceRecordUsage_ResetsOpenAI403CounterBeforeZeroUsageReturn(t *testing.T) { + counter := &openAI403CounterResetStub{} + rateLimitSvc := NewRateLimitService(nil, nil, nil, nil, nil) + rateLimitSvc.SetOpenAI403CounterCache(counter) + + svc := &OpenAIGatewayService{ + rateLimitService: rateLimitSvc, + } + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{}, + Account: &Account{ID: 777, Platform: PlatformOpenAI}, + }) + + require.NoError(t, err) + require.Equal(t, []int64{777}, counter.resetCalls) +} diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index 95e1bffa..9665c4c8 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -1098,3 +1098,50 @@ func TestOpenAIGatewayServiceRecordUsage_ImageOnlyUsageStillPersists(t *testing. require.NotNil(t, usageRepo.lastLog.BillingMode) require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode) } + +func TestOpenAIGatewayServiceRecordUsage_ImageUsesPerImageBillingEvenWithUsageTokens(t *testing.T) { + imagePrice := 0.02 + groupID := int64(12) + + usageRepo := &openAIRecordUsageLogRepoStub{inserted: true} + userRepo := &openAIRecordUsageUserRepoStub{} + subRepo := &openAIRecordUsageSubRepoStub{} + svc := newOpenAIRecordUsageServiceForTest(usageRepo, userRepo, subRepo, nil) + + err := svc.RecordUsage(context.Background(), &OpenAIRecordUsageInput{ + Result: &OpenAIForwardResult{ + RequestID: "resp_image_per_request", + Model: "gpt-image-2", + Usage: OpenAIUsage{ + InputTokens: 1110, + OutputTokens: 1756, + ImageOutputTokens: 1756, + }, + ImageCount: 2, + ImageSize: "1K", + Duration: time.Second, + }, + APIKey: &APIKey{ + ID: 1008, + GroupID: i64p(groupID), + Group: &Group{ + ID: groupID, + RateMultiplier: 1.0, + ImagePrice1K: &imagePrice, + }, + }, + User: &User{ID: 2008}, + Account: &Account{ID: 3008}, + }) + + require.NoError(t, err) + require.NotNil(t, usageRepo.lastLog) + require.NotNil(t, usageRepo.lastLog.BillingMode) + require.Equal(t, string(BillingModeImage), *usageRepo.lastLog.BillingMode) + require.Equal(t, 2, usageRepo.lastLog.ImageCount) + require.InDelta(t, 0.04, usageRepo.lastLog.TotalCost, 1e-12) + require.InDelta(t, 0.04, usageRepo.lastLog.ActualCost, 1e-12) + require.InDelta(t, 0.0, usageRepo.lastLog.InputCost, 1e-12) + require.InDelta(t, 0.0, usageRepo.lastLog.OutputCost, 1e-12) + require.InDelta(t, 0.0, usageRepo.lastLog.ImageOutputCost, 1e-12) +} diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 534ffeee..06fd14af 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -4425,6 +4425,9 @@ type OpenAIRecordUsageInput struct { // RecordUsage records usage and deducts balance func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error { result := input.Result + if s.rateLimitService != nil && input != nil && input.Account != nil && input.Account.Platform == PlatformOpenAI { + s.rateLimitService.ResetOpenAI403Counter(ctx, input.Account.ID) + } // 跳过所有 token 均为零的用量记录——上游未返回 usage 时不应写入数据库 if result.Usage.InputTokens == 0 && result.Usage.OutputTokens == 0 && @@ -4622,12 +4625,6 @@ func (s *OpenAIGatewayService) calculateOpenAIRecordUsageCost( serviceTier string, ) (*CostBreakdown, error) { if result != nil && result.ImageCount > 0 { - if hasOpenAIImageUsageTokens(result) { - cost, err := s.calculateOpenAIImageTokenCost(ctx, apiKey, billingModel, multiplier, tokens, serviceTier, result.ImageSize) - if err == nil { - return cost, nil - } - } return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, multiplier), nil } if s.resolver != nil && apiKey.Group != nil { @@ -4646,32 +4643,6 @@ func (s *OpenAIGatewayService) calculateOpenAIRecordUsageCost( return s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier) } -func (s *OpenAIGatewayService) calculateOpenAIImageTokenCost( - ctx context.Context, - apiKey *APIKey, - billingModel string, - multiplier float64, - tokens UsageTokens, - serviceTier string, - sizeTier string, -) (*CostBreakdown, error) { - if s.resolver != nil && apiKey.Group != nil { - gid := apiKey.Group.ID - return s.billingService.CalculateCostUnified(CostInput{ - Ctx: ctx, - Model: billingModel, - GroupID: &gid, - Tokens: tokens, - RequestCount: 1, - SizeTier: sizeTier, - RateMultiplier: multiplier, - ServiceTier: serviceTier, - Resolver: s.resolver, - }) - } - return s.billingService.CalculateCostWithServiceTier(billingModel, tokens, multiplier, serviceTier) -} - func (s *OpenAIGatewayService) calculateOpenAIImageCost( ctx context.Context, billingModel string, @@ -4679,7 +4650,8 @@ func (s *OpenAIGatewayService) calculateOpenAIImageCost( result *OpenAIForwardResult, multiplier float64, ) *CostBreakdown { - if resolved := s.resolveOpenAIChannelPricing(ctx, billingModel, apiKey); resolved != nil { + if resolved := s.resolveOpenAIChannelPricing(ctx, billingModel, apiKey); resolved != nil && + (resolved.Mode == BillingModePerRequest || resolved.Mode == BillingModeImage) { gid := apiKey.Group.ID cost, err := s.billingService.CalculateCostUnified(CostInput{ Ctx: ctx, @@ -4720,17 +4692,6 @@ func (s *OpenAIGatewayService) resolveOpenAIChannelPricing(ctx context.Context, return nil } -func hasOpenAIImageUsageTokens(result *OpenAIForwardResult) bool { - if result == nil { - return false - } - return result.Usage.InputTokens > 0 || - result.Usage.OutputTokens > 0 || - result.Usage.CacheCreationInputTokens > 0 || - result.Usage.CacheReadInputTokens > 0 || - result.Usage.ImageOutputTokens > 0 -} - // ParseCodexRateLimitHeaders extracts Codex usage limits from response headers. // Exported for use in ratelimit_service when handling OpenAI 429 responses. func ParseCodexRateLimitHeaders(headers http.Header) *OpenAICodexUsageSnapshot { diff --git a/backend/internal/service/openai_images.go b/backend/internal/service/openai_images.go index 7935376b..4badcb1c 100644 --- a/backend/internal/service/openai_images.go +++ b/backend/internal/service/openai_images.go @@ -5,27 +5,22 @@ import ( "bytes" "context" "crypto/sha256" - "crypto/sha3" "encoding/base64" "encoding/hex" "encoding/json" - "errors" "fmt" "io" "mime" "mime/multipart" "net/http" "net/textproto" - "sort" "strconv" "strings" "time" "github.com/Wei-Shaw/sub2api/internal/pkg/logger" - "github.com/Wei-Shaw/sub2api/internal/pkg/proxyurl" "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" "github.com/gin-gonic/gin" - "github.com/google/uuid" "github.com/imroc/req/v3" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -38,18 +33,12 @@ const ( openAIImagesGenerationsURL = "https://api.openai.com/v1/images/generations" openAIImagesEditsURL = "https://api.openai.com/v1/images/edits" - openAIChatGPTStartURL = "https://chatgpt.com/" - openAIChatGPTFilesURL = "https://chatgpt.com/backend-api/files" - openAIChatGPTConversationInitURL = "https://chatgpt.com/backend-api/conversation/init" - openAIChatGPTConversationURL = "https://chatgpt.com/backend-api/f/conversation" - openAIChatGPTConversationPrepareURL = "https://chatgpt.com/backend-api/f/conversation/prepare" - openAIChatGPTChatRequirementsURL = "https://chatgpt.com/backend-api/sentinel/chat-requirements" - - openAIImageBackendUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36" - openAIImageRequirementsDiff = "0fffff" - openAIImageLifecycleTimeout = 2 * time.Minute - openAIImageMaxDownloadBytes = 20 << 20 // 20MB per image download - openAIImageMaxUploadPartSize = 20 << 20 // 20MB per multipart upload part + openAIChatGPTStartURL = "https://chatgpt.com/" + openAIChatGPTFilesURL = "https://chatgpt.com/backend-api/files" + openAIImageBackendUserAgent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36" + openAIImageMaxDownloadBytes = 20 << 20 // 20MB per image download + openAIImageMaxUploadPartSize = 20 << 20 // 20MB per multipart upload part + openAIImagesResponsesMainModel = "gpt-5.4-mini" ) type OpenAIImagesCapability string @@ -81,10 +70,21 @@ type OpenAIImagesRequest struct { ExplicitSize bool SizeTier string ResponseFormat string + Quality string + Background string + OutputFormat string + Moderation string + InputFidelity string + Style string + OutputCompression *int + PartialImages *int HasMask bool HasNativeOptions bool RequiredCapability OpenAIImagesCapability + InputImageURLs []string + MaskImageURL string Uploads []OpenAIImagesUpload + MaskUpload *OpenAIImagesUpload Body []byte bodyHash string } @@ -188,7 +188,54 @@ func parseOpenAIImagesJSONRequest(body []byte, req *OpenAIImagesRequest) error { req.ExplicitSize = req.Size != "" } req.ResponseFormat = strings.ToLower(strings.TrimSpace(gjson.GetBytes(body, "response_format").String())) + req.Quality = strings.TrimSpace(gjson.GetBytes(body, "quality").String()) + req.Background = strings.TrimSpace(gjson.GetBytes(body, "background").String()) + req.OutputFormat = strings.TrimSpace(gjson.GetBytes(body, "output_format").String()) + req.Moderation = strings.TrimSpace(gjson.GetBytes(body, "moderation").String()) + req.InputFidelity = strings.TrimSpace(gjson.GetBytes(body, "input_fidelity").String()) + req.Style = strings.TrimSpace(gjson.GetBytes(body, "style").String()) req.HasMask = gjson.GetBytes(body, "mask").Exists() + if outputCompression := gjson.GetBytes(body, "output_compression"); outputCompression.Exists() { + if outputCompression.Type != gjson.Number { + return fmt.Errorf("invalid output_compression field type") + } + v := int(outputCompression.Int()) + req.OutputCompression = &v + } + if partialImages := gjson.GetBytes(body, "partial_images"); partialImages.Exists() { + if partialImages.Type != gjson.Number { + return fmt.Errorf("invalid partial_images field type") + } + v := int(partialImages.Int()) + req.PartialImages = &v + } + if req.IsEdits() { + images := gjson.GetBytes(body, "images") + if images.Exists() { + if !images.IsArray() { + return fmt.Errorf("invalid images field type") + } + for _, item := range images.Array() { + if imageURL := strings.TrimSpace(item.Get("image_url").String()); imageURL != "" { + req.InputImageURLs = append(req.InputImageURLs, imageURL) + continue + } + if item.Get("file_id").Exists() { + return fmt.Errorf("images[].file_id is not supported (use images[].image_url instead)") + } + } + } + if maskImageURL := strings.TrimSpace(gjson.GetBytes(body, "mask.image_url").String()); maskImageURL != "" { + req.MaskImageURL = maskImageURL + req.HasMask = true + } + if gjson.GetBytes(body, "mask.file_id").Exists() { + return fmt.Errorf("mask.file_id is not supported (use mask.image_url instead)") + } + if len(req.InputImageURLs) == 0 { + return fmt.Errorf("images[].image_url is required") + } + } req.HasNativeOptions = hasOpenAINativeImageOptions(func(path string) bool { return gjson.GetBytes(body, path).Exists() }) @@ -231,6 +278,16 @@ func parseOpenAIImagesMultipartRequest(body []byte, contentType string, req *Ope partContentType := strings.TrimSpace(part.Header.Get("Content-Type")) if name == "mask" && len(data) > 0 { req.HasMask = true + width, height := parseOpenAIImageDimensions(part.Header) + maskUpload := OpenAIImagesUpload{ + FieldName: name, + FileName: fileName, + ContentType: partContentType, + Data: data, + Width: width, + Height: height, + } + req.MaskUpload = &maskUpload } if name == "image" || strings.HasPrefix(name, "image[") { width, height := parseOpenAIImageDimensions(part.Header) @@ -270,6 +327,38 @@ func parseOpenAIImagesMultipartRequest(body []byte, contentType string, req *Ope return fmt.Errorf("n must be a positive integer") } req.N = n + case "quality": + req.Quality = value + req.HasNativeOptions = true + case "background": + req.Background = value + req.HasNativeOptions = true + case "output_format": + req.OutputFormat = value + req.HasNativeOptions = true + case "moderation": + req.Moderation = value + req.HasNativeOptions = true + case "input_fidelity": + req.InputFidelity = value + req.HasNativeOptions = true + case "style": + req.Style = value + req.HasNativeOptions = true + case "output_compression": + n, err := strconv.Atoi(value) + if err != nil { + return fmt.Errorf("invalid output_compression field value") + } + req.OutputCompression = &n + req.HasNativeOptions = true + case "partial_images": + n, err := strconv.Atoi(value) + if err != nil { + return fmt.Errorf("invalid partial_images field value") + } + req.PartialImages = &n + req.HasNativeOptions = true default: if isOpenAINativeImageOption(name) && value != "" { req.HasNativeOptions = true @@ -359,6 +448,8 @@ func hasOpenAINativeImageOptions(exists func(path string) bool) bool { "output_format", "output_compression", "moderation", + "input_fidelity", + "partial_images", } { if exists(path) { return true @@ -369,7 +460,7 @@ func hasOpenAINativeImageOptions(exists func(path string) bool) bool { func isOpenAINativeImageOption(name string) bool { switch strings.TrimSpace(strings.ToLower(name)) { - case "background", "quality", "style", "output_format", "output_compression", "moderation": + case "background", "quality", "style", "output_format", "output_compression", "moderation", "input_fidelity", "partial_images": return true default: return false @@ -782,563 +873,6 @@ func extractOpenAIImageCountFromJSONBytes(body []byte) int { return 0 } -func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth( - ctx context.Context, - c *gin.Context, - account *Account, - parsed *OpenAIImagesRequest, - channelMappedModel string, -) (*OpenAIForwardResult, error) { - startTime := time.Now() - requestModel := strings.TrimSpace(parsed.Model) - if mapped := strings.TrimSpace(channelMappedModel); mapped != "" { - requestModel = mapped - } - if err := validateOpenAIImagesModel(requestModel); err != nil { - return nil, err - } - logger.LegacyPrintf( - "service.openai_gateway", - "[OpenAI] Images request routing request_model=%s endpoint=%s account_type=%s uploads=%d", - requestModel, - parsed.Endpoint, - account.Type, - len(parsed.Uploads), - ) - - token, _, err := s.GetAccessToken(ctx, account) - if err != nil { - return nil, err - } - client, err := newOpenAIBackendAPIClient(resolveOpenAIProxyURL(account)) - if err != nil { - return nil, err - } - headers, err := s.buildOpenAIBackendAPIHeaders(account, token) - if err != nil { - return nil, err - } - if bootstrapErr := bootstrapOpenAIBackendAPI(ctx, client, headers); bootstrapErr != nil { - logger.LegacyPrintf("service.openai_gateway", "OpenAI image bootstrap failed: %v", bootstrapErr) - } - - chatReqs, err := fetchOpenAIChatRequirements(ctx, client, headers) - if err != nil { - return nil, s.wrapOpenAIImageBackendError(ctx, c, account, err) - } - if chatReqs.Arkose.Required { - return nil, s.wrapOpenAIImageBackendError( - ctx, - c, - account, - newOpenAIImageSyntheticStatusError( - http.StatusForbidden, - "chat-requirements requires unsupported challenge (arkose)", - openAIChatGPTChatRequirementsURL, - ), - ) - } - - parentMessageID := uuid.NewString() - proofToken := generateOpenAIProofToken(chatReqs.ProofOfWork.Required, chatReqs.ProofOfWork.Seed, chatReqs.ProofOfWork.Difficulty, headers.Get("User-Agent")) - _ = initializeOpenAIImageConversation(ctx, client, headers) - conduitToken, err := prepareOpenAIImageConversation(ctx, client, headers, parsed.Prompt, parentMessageID, chatReqs.Token, proofToken) - if err != nil { - return nil, s.wrapOpenAIImageBackendError(ctx, c, account, err) - } - - uploads, err := uploadOpenAIImageFiles(ctx, client, headers, parsed.Uploads) - if err != nil { - return nil, s.wrapOpenAIImageBackendError(ctx, c, account, err) - } - - convReq := buildOpenAIImageConversationRequest(parsed, parentMessageID, uploads) - if parsedContent, err := json.Marshal(convReq); err == nil { - setOpsUpstreamRequestBody(c, parsedContent) - } - convHeaders := cloneHTTPHeader(headers) - convHeaders.Set("Accept", "text/event-stream") - convHeaders.Set("Content-Type", "application/json") - convHeaders.Set("openai-sentinel-chat-requirements-token", chatReqs.Token) - if conduitToken != "" { - convHeaders.Set("x-conduit-token", conduitToken) - } - if proofToken != "" { - convHeaders.Set("openai-sentinel-proof-token", proofToken) - } - - resp, err := client.R(). - SetContext(ctx). - DisableAutoReadResponse(). - SetHeaders(headerToMap(convHeaders)). - SetBodyJsonMarshal(convReq). - Post(openAIChatGPTConversationURL) - if err != nil { - return nil, fmt.Errorf("openai image conversation request failed: %w", err) - } - defer func() { - if resp != nil && resp.Body != nil { - _ = resp.Body.Close() - } - }() - if resp.StatusCode >= 400 { - return nil, s.wrapOpenAIImageBackendError(ctx, c, account, handleOpenAIImageBackendError(resp)) - } - - conversationID, pointerInfos, usage, firstTokenMs, err := readOpenAIImageConversationStream(resp, startTime) - if err != nil { - return nil, err - } - pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, nil) - logger.LegacyPrintf( - "service.openai_gateway", - "[OpenAI] Image extraction stream conversation_id=%s total_assets=%d file_service_assets=%d direct_assets=%d", - conversationID, - len(pointerInfos), - countOpenAIFileServicePointerInfos(pointerInfos), - countOpenAIDirectImageAssets(pointerInfos), - ) - lifecycleCtx, releaseLifecycleCtx := detachOpenAIImageLifecycleContext(ctx, openAIImageLifecycleTimeout) - defer releaseLifecycleCtx() - if conversationID != "" && !hasOpenAIFileServicePointerInfos(pointerInfos) { - polledPointers, pollErr := pollOpenAIImageConversation(lifecycleCtx, client, headers, conversationID) - if pollErr != nil { - return nil, s.wrapOpenAIImageBackendError(ctx, c, account, pollErr) - } - pointerInfos = mergeOpenAIImagePointerInfos(pointerInfos, polledPointers) - } - pointerInfos = preferOpenAIFileServicePointerInfos(pointerInfos) - if len(pointerInfos) == 0 { - logger.LegacyPrintf("service.openai_gateway", "[OpenAI] Image extraction yielded no assets conversation_id=%s", conversationID) - return nil, fmt.Errorf("openai image conversation returned no downloadable images") - } - - responseBody, imageCount, err := buildOpenAIImageResponse(lifecycleCtx, client, headers, conversationID, pointerInfos) - if err != nil { - return nil, s.wrapOpenAIImageBackendError(ctx, c, account, err) - } - - c.Data(http.StatusOK, "application/json; charset=utf-8", responseBody) - return &OpenAIForwardResult{ - RequestID: resp.Header.Get("x-request-id"), - Usage: usage, - Model: requestModel, - UpstreamModel: requestModel, - Stream: false, - Duration: time.Since(startTime), - FirstTokenMs: firstTokenMs, - ImageCount: imageCount, - ImageSize: parsed.SizeTier, - }, nil -} - -func resolveOpenAIProxyURL(account *Account) string { - if account != nil && account.ProxyID != nil && account.Proxy != nil { - return account.Proxy.URL() - } - return "" -} - -func newOpenAIBackendAPIClient(proxyURL string) (*req.Client, error) { - client := req.C(). - SetTimeout(180 * time.Second). - ImpersonateChrome() - trimmed, _, err := proxyurl.Parse(proxyURL) - if err != nil { - return nil, err - } - if trimmed != "" { - client.SetProxyURL(trimmed) - } - return client, nil -} - -func (s *OpenAIGatewayService) buildOpenAIBackendAPIHeaders(account *Account, token string) (http.Header, error) { - deviceID, sessionID := s.ensureOpenAIImageSessionCredentials(context.Background(), account) - headers := make(http.Header) - headers.Set("Authorization", "Bearer "+token) - headers.Set("Accept", "application/json") - headers.Set("Origin", "https://chatgpt.com") - headers.Set("Referer", "https://chatgpt.com/") - headers.Set("Sec-Fetch-Dest", "empty") - headers.Set("Sec-Fetch-Mode", "cors") - headers.Set("Sec-Fetch-Site", "same-origin") - headers.Set("User-Agent", openAIImageBackendUserAgent) - if customUA := strings.TrimSpace(account.GetOpenAIUserAgent()); customUA != "" { - headers.Set("User-Agent", customUA) - } - if chatgptAccountID := strings.TrimSpace(account.GetChatGPTAccountID()); chatgptAccountID != "" { - headers.Set("chatgpt-account-id", chatgptAccountID) - } - if deviceID != "" { - headers.Set("oai-device-id", deviceID) - headers.Set("Cookie", "oai-did="+deviceID) - } - if sessionID != "" { - headers.Set("oai-session-id", sessionID) - } - return headers, nil -} - -func (s *OpenAIGatewayService) ensureOpenAIImageSessionCredentials(ctx context.Context, account *Account) (string, string) { - if account == nil { - return "", "" - } - deviceID := account.GetOpenAIDeviceID() - sessionID := account.GetOpenAISessionID() - if deviceID != "" && sessionID != "" { - return deviceID, sessionID - } - - updates := map[string]any{} - if deviceID == "" { - deviceID = uuid.NewString() - updates["openai_device_id"] = deviceID - } - if sessionID == "" { - sessionID = uuid.NewString() - updates["openai_session_id"] = sessionID - } - if account.Extra == nil { - account.Extra = map[string]any{} - } - for key, value := range updates { - account.Extra[key] = value - } - if len(updates) == 0 || s == nil || s.accountRepo == nil { - return deviceID, sessionID - } - - updateCtx, cancel := context.WithTimeout(ctx, 5*time.Second) - defer cancel() - if err := s.accountRepo.UpdateExtra(updateCtx, account.ID, updates); err != nil { - logger.LegacyPrintf("service.openai_gateway", "persist openai image session creds failed: account=%d err=%v", account.ID, err) - } - return deviceID, sessionID -} - -func bootstrapOpenAIBackendAPI(ctx context.Context, client *req.Client, headers http.Header) error { - resp, err := client.R(). - SetContext(ctx). - DisableAutoReadResponse(). - SetHeaders(headerToMap(headers)). - Get(openAIChatGPTStartURL) - if err != nil { - return err - } - if resp != nil && resp.Body != nil { - _, _ = io.Copy(io.Discard, resp.Body) - _ = resp.Body.Close() - } - return nil -} - -func initializeOpenAIImageConversation(ctx context.Context, client *req.Client, headers http.Header) error { - payload := map[string]any{ - "gizmo_id": nil, - "requested_default_model": nil, - "conversation_id": nil, - "timezone_offset_min": openAITimezoneOffsetMinutes(), - "system_hints": []string{"picture_v2"}, - } - resp, err := client.R(). - SetContext(ctx). - SetHeaders(headerToMap(headers)). - SetBodyJsonMarshal(payload). - Post(openAIChatGPTConversationInitURL) - if err != nil { - return err - } - if !resp.IsSuccessState() { - return newOpenAIImageStatusError(resp, "conversation init failed") - } - return nil -} - -type openAIChatRequirements struct { - Token string `json:"token"` - Turnstile struct { - Required bool `json:"required"` - } `json:"turnstile"` - Arkose struct { - Required bool `json:"required"` - } `json:"arkose"` - ProofOfWork struct { - Required bool `json:"required"` - Seed string `json:"seed"` - Difficulty string `json:"difficulty"` - } `json:"proofofwork"` -} - -func fetchOpenAIChatRequirements(ctx context.Context, client *req.Client, headers http.Header) (*openAIChatRequirements, error) { - var lastErr error - for _, payload := range []map[string]any{ - {"p": nil}, - {"p": generateOpenAIRequirementsToken(headers.Get("User-Agent"))}, - } { - var result openAIChatRequirements - resp, err := client.R(). - SetContext(ctx). - SetHeaders(headerToMap(headers)). - SetBodyJsonMarshal(payload). - SetSuccessResult(&result). - Post(openAIChatGPTChatRequirementsURL) - if err != nil { - lastErr = err - continue - } - if resp.IsSuccessState() && strings.TrimSpace(result.Token) != "" { - return &result, nil - } - lastErr = newOpenAIImageStatusError(resp, "chat-requirements failed") - } - if lastErr == nil { - lastErr = fmt.Errorf("chat-requirements failed") - } - return nil, lastErr -} - -func prepareOpenAIImageConversation( - ctx context.Context, - client *req.Client, - headers http.Header, - prompt string, - parentMessageID string, - chatToken string, - proofToken string, -) (string, error) { - messageID := uuid.NewString() - payload := map[string]any{ - "action": "next", - "client_prepare_state": "success", - "fork_from_shared_post": false, - "parent_message_id": parentMessageID, - "model": "auto", - "timezone_offset_min": openAITimezoneOffsetMinutes(), - "timezone": openAITimezoneName(), - "conversation_mode": map[string]any{"kind": "primary_assistant"}, - "system_hints": []string{"picture_v2"}, - "supports_buffering": true, - "supported_encodings": []string{"v1"}, - "partial_query": map[string]any{ - "id": messageID, - "author": map[string]any{"role": "user"}, - "content": map[string]any{ - "content_type": "text", - "parts": []string{coalesceOpenAIFileName(prompt, "Generate an image.")}, - }, - }, - "client_contextual_info": map[string]any{ - "app_name": "chatgpt.com", - }, - } - prepareHeaders := cloneHTTPHeader(headers) - prepareHeaders.Set("Accept", "*/*") - prepareHeaders.Set("Content-Type", "application/json") - if strings.TrimSpace(chatToken) != "" { - prepareHeaders.Set("openai-sentinel-chat-requirements-token", strings.TrimSpace(chatToken)) - } - if strings.TrimSpace(proofToken) != "" { - prepareHeaders.Set("openai-sentinel-proof-token", strings.TrimSpace(proofToken)) - } - var result struct { - ConduitToken string `json:"conduit_token"` - } - resp, err := client.R(). - SetContext(ctx). - SetHeaders(headerToMap(prepareHeaders)). - SetBodyJsonMarshal(payload). - SetSuccessResult(&result). - Post(openAIChatGPTConversationPrepareURL) - if err != nil { - return "", err - } - if !resp.IsSuccessState() { - return "", newOpenAIImageStatusError(resp, "conversation prepare failed") - } - return strings.TrimSpace(result.ConduitToken), nil -} - -type openAIUploadedImage struct { - FileID string - FileName string - FileSize int - MimeType string - Width int - Height int -} - -func uploadOpenAIImageFiles(ctx context.Context, client *req.Client, headers http.Header, uploads []OpenAIImagesUpload) ([]openAIUploadedImage, error) { - if len(uploads) == 0 { - return nil, nil - } - results := make([]openAIUploadedImage, 0, len(uploads)) - for i := range uploads { - item := uploads[i] - fileName := coalesceOpenAIFileName(item.FileName, "image.png") - payload := map[string]any{ - "file_name": fileName, - "file_size": len(item.Data), - "use_case": "multimodal", - } - var created struct { - FileID string `json:"file_id"` - UploadURL string `json:"upload_url"` - } - resp, err := client.R(). - SetContext(ctx). - SetHeaders(headerToMap(headers)). - SetBodyJsonMarshal(payload). - SetSuccessResult(&created). - Post(openAIChatGPTFilesURL) - if err != nil { - return nil, err - } - if !resp.IsSuccessState() || strings.TrimSpace(created.FileID) == "" || strings.TrimSpace(created.UploadURL) == "" { - return nil, newOpenAIImageStatusError(resp, "create upload slot failed") - } - - uploadHeaders := map[string]string{ - "Content-Type": coalesceOpenAIFileName(item.ContentType, "application/octet-stream"), - "Origin": "https://chatgpt.com", - "x-ms-blob-type": "BlockBlob", - "x-ms-version": "2020-04-08", - "User-Agent": headers.Get("User-Agent"), - } - putResp, err := client.R(). - SetContext(ctx). - SetHeaders(uploadHeaders). - SetBody(item.Data). - DisableAutoReadResponse(). - Put(created.UploadURL) - if err != nil { - return nil, err - } - if putResp.Response != nil && putResp.Body != nil { - _, _ = io.Copy(io.Discard, putResp.Body) - _ = putResp.Body.Close() - } - if putResp.StatusCode < 200 || putResp.StatusCode >= 300 { - return nil, newOpenAIImageStatusError(putResp, "upload image bytes failed") - } - - uploadedResp, err := client.R(). - SetContext(ctx). - SetHeaders(headerToMap(headers)). - SetBodyJsonMarshal(map[string]any{}). - Post(fmt.Sprintf("%s/%s/uploaded", openAIChatGPTFilesURL, created.FileID)) - if err != nil { - return nil, err - } - if !uploadedResp.IsSuccessState() { - return nil, newOpenAIImageStatusError(uploadedResp, "mark upload complete failed") - } - - results = append(results, openAIUploadedImage{ - FileID: created.FileID, - FileName: fileName, - FileSize: len(item.Data), - MimeType: coalesceOpenAIFileName(item.ContentType, "application/octet-stream"), - Width: item.Width, - Height: item.Height, - }) - } - return results, nil -} - -func coalesceOpenAIFileName(value string, fallback string) string { - value = strings.TrimSpace(value) - if value == "" { - return fallback - } - return value -} - -func buildOpenAIImageConversationRequest(parsed *OpenAIImagesRequest, parentMessageID string, uploads []openAIUploadedImage) map[string]any { - parts := []any{coalesceOpenAIFileName(parsed.Prompt, "Generate an image.")} - attachments := make([]map[string]any, 0, len(uploads)) - if len(uploads) > 0 { - parts = make([]any, 0, len(uploads)+1) - for _, upload := range uploads { - parts = append(parts, map[string]any{ - "content_type": "image_asset_pointer", - "asset_pointer": "file-service://" + upload.FileID, - "size_bytes": upload.FileSize, - "width": upload.Width, - "height": upload.Height, - }) - attachment := map[string]any{ - "id": upload.FileID, - "mimeType": upload.MimeType, - "name": upload.FileName, - "size": upload.FileSize, - } - if upload.Width > 0 { - attachment["width"] = upload.Width - } - if upload.Height > 0 { - attachment["height"] = upload.Height - } - attachments = append(attachments, attachment) - } - parts = append(parts, coalesceOpenAIFileName(parsed.Prompt, "Edit this image.")) - } - - contentType := "text" - if len(uploads) > 0 { - contentType = "multimodal_text" - } - metadata := map[string]any{ - "developer_mode_connector_ids": []any{}, - "selected_github_repos": []any{}, - "selected_all_github_repos": false, - "system_hints": []string{"picture_v2"}, - "serialization_metadata": map[string]any{ - "custom_symbol_offsets": []any{}, - }, - } - message := map[string]any{ - "id": uuid.NewString(), - "author": map[string]any{"role": "user"}, - "content": map[string]any{ - "content_type": contentType, - "parts": parts, - }, - "metadata": metadata, - "create_time": float64(time.Now().UnixMilli()) / 1000, - } - if len(attachments) > 0 { - metadata["attachments"] = attachments - } - - return map[string]any{ - "action": "next", - "client_prepare_state": "sent", - "parent_message_id": parentMessageID, - "model": "auto", - "timezone_offset_min": openAITimezoneOffsetMinutes(), - "timezone": openAITimezoneName(), - "conversation_mode": map[string]any{"kind": "primary_assistant"}, - "enable_message_followups": true, - "system_hints": []string{"picture_v2"}, - "supports_buffering": true, - "supported_encodings": []string{"v1"}, - "paragen_cot_summary_display_override": "allow", - "force_parallel_switch": "auto", - "client_contextual_info": map[string]any{ - "is_dark_mode": false, - "time_since_loaded": 200, - "page_height": 900, - "page_width": 1440, - "pixel_ratio": 1, - "screen_height": 1080, - "screen_width": 1920, - "app_name": "chatgpt.com", - }, - "messages": []any{message}, - } -} - type openAIImagePointerInfo struct { Pointer string DownloadURL string @@ -1347,51 +881,6 @@ type openAIImagePointerInfo struct { Prompt string } -type openAIImageToolMessage struct { - MessageID string - CreateTime float64 - PointerInfos []openAIImagePointerInfo -} - -func readOpenAIImageConversationStream(resp *req.Response, startTime time.Time) (string, []openAIImagePointerInfo, OpenAIUsage, *int, error) { - if resp == nil || resp.Body == nil { - return "", nil, OpenAIUsage{}, nil, fmt.Errorf("empty conversation response") - } - reader := bufio.NewReader(resp.Body) - var ( - conversationID string - firstTokenMs *int - usage OpenAIUsage - pointers []openAIImagePointerInfo - ) - - for { - line, err := reader.ReadString('\n') - if strings.TrimSpace(line) != "" && firstTokenMs == nil { - ms := int(time.Since(startTime).Milliseconds()) - firstTokenMs = &ms - } - if data, ok := extractOpenAISSEDataLine(strings.TrimRight(line, "\r\n")); ok && data != "" && data != "[DONE]" { - dataBytes := []byte(data) - if conversationID == "" { - conversationID = strings.TrimSpace(gjson.GetBytes(dataBytes, "v.conversation_id").String()) - if conversationID == "" { - conversationID = strings.TrimSpace(gjson.GetBytes(dataBytes, "conversation_id").String()) - } - } - mergeOpenAIUsage(&usage, dataBytes) - pointers = mergeOpenAIImagePointerInfos(pointers, collectOpenAIImagePointers(dataBytes)) - } - if err == io.EOF { - break - } - if err != nil { - return "", nil, OpenAIUsage{}, firstTokenMs, err - } - } - return conversationID, pointers, usage, firstTokenMs, nil -} - func collectOpenAIImagePointers(body []byte) []openAIImagePointerInfo { if len(body) == 0 { return nil @@ -1517,222 +1006,6 @@ func mergeOpenAIImagePointerInfo(existing, next openAIImagePointerInfo) openAIIm return merged } -func hasOpenAIFileServicePointerInfos(items []openAIImagePointerInfo) bool { - for _, item := range items { - if strings.HasPrefix(item.Pointer, "file-service://") { - return true - } - } - return false -} - -func countOpenAIFileServicePointerInfos(items []openAIImagePointerInfo) int { - count := 0 - for _, item := range items { - if strings.HasPrefix(item.Pointer, "file-service://") { - count++ - } - } - return count -} - -func countOpenAIDirectImageAssets(items []openAIImagePointerInfo) int { - count := 0 - for _, item := range items { - if strings.TrimSpace(item.DownloadURL) != "" || strings.TrimSpace(item.B64JSON) != "" { - count++ - } - } - return count -} - -func preferOpenAIFileServicePointerInfos(items []openAIImagePointerInfo) []openAIImagePointerInfo { - if !hasOpenAIFileServicePointerInfos(items) { - return items - } - out := make([]openAIImagePointerInfo, 0, len(items)) - for _, item := range items { - if strings.HasPrefix(item.Pointer, "file-service://") { - out = append(out, item) - } - } - return out -} - -func extractOpenAIImageToolMessages(mapping map[string]any) []openAIImageToolMessage { - if len(mapping) == 0 { - return nil - } - out := make([]openAIImageToolMessage, 0, 4) - for messageID, raw := range mapping { - node, _ := raw.(map[string]any) - if node == nil { - continue - } - message, _ := node["message"].(map[string]any) - if message == nil { - continue - } - author, _ := message["author"].(map[string]any) - metadata, _ := message["metadata"].(map[string]any) - content, _ := message["content"].(map[string]any) - if author == nil || metadata == nil || content == nil { - continue - } - if role, _ := author["role"].(string); role != "tool" { - continue - } - if asyncTaskType, _ := metadata["async_task_type"].(string); asyncTaskType != "image_gen" { - continue - } - if contentType, _ := content["content_type"].(string); contentType != "multimodal_text" { - continue - } - prompt := "" - if title, _ := metadata["image_gen_title"].(string); strings.TrimSpace(title) != "" { - prompt = strings.TrimSpace(title) - } - item := openAIImageToolMessage{MessageID: messageID} - if createTime, ok := message["create_time"].(float64); ok { - item.CreateTime = createTime - } - parts, _ := content["parts"].([]any) - for _, part := range parts { - switch value := part.(type) { - case map[string]any: - if assetPointer, _ := value["asset_pointer"].(string); strings.TrimSpace(assetPointer) != "" { - for _, pointer := range openAIImagePointerMatches([]byte(assetPointer)) { - item.PointerInfos = append(item.PointerInfos, openAIImagePointerInfo{ - Pointer: pointer, - Prompt: prompt, - }) - } - } - case string: - for _, pointer := range openAIImagePointerMatches([]byte(value)) { - item.PointerInfos = append(item.PointerInfos, openAIImagePointerInfo{ - Pointer: pointer, - Prompt: prompt, - }) - } - } - } - if len(item.PointerInfos) == 0 { - continue - } - item.PointerInfos = mergeOpenAIImagePointerInfos(nil, item.PointerInfos) - out = append(out, item) - } - sort.Slice(out, func(i, j int) bool { - return out[i].CreateTime < out[j].CreateTime - }) - return out -} - -func pollOpenAIImageConversation(ctx context.Context, client *req.Client, headers http.Header, conversationID string) ([]openAIImagePointerInfo, error) { - conversationID = strings.TrimSpace(conversationID) - if conversationID == "" { - return nil, nil - } - deadline := time.Now().Add(90 * time.Second) - interval := 3 * time.Second - previewWait := 15 * time.Second - var ( - lastErr error - firstToolAt time.Time - ) - for time.Now().Before(deadline) { - resp, err := client.R(). - SetContext(ctx). - SetHeaders(headerToMap(headers)). - DisableAutoReadResponse(). - Get(fmt.Sprintf("https://chatgpt.com/backend-api/conversation/%s", conversationID)) - if err != nil { - lastErr = err - } else { - if resp.StatusCode >= 200 && resp.StatusCode < 300 { - body, readErr := io.ReadAll(resp.Body) - _ = resp.Body.Close() - if readErr != nil { - lastErr = readErr - goto waitNextPoll - } - pointers := mergeOpenAIImagePointerInfos(nil, collectOpenAIImagePointers(body)) - var decoded map[string]any - if err := json.Unmarshal(body, &decoded); err == nil { - if mapping, _ := decoded["mapping"].(map[string]any); len(mapping) > 0 { - toolMessages := extractOpenAIImageToolMessages(mapping) - if len(toolMessages) > 0 && firstToolAt.IsZero() { - firstToolAt = time.Now() - } - for _, msg := range toolMessages { - pointers = mergeOpenAIImagePointerInfos(pointers, msg.PointerInfos) - } - } - } - if hasOpenAIFileServicePointerInfos(pointers) { - return preferOpenAIFileServicePointerInfos(pointers), nil - } - if len(pointers) > 0 && !firstToolAt.IsZero() && time.Since(firstToolAt) >= previewWait { - return pointers, nil - } - } else { - statusErr := newOpenAIImageStatusError(resp, "conversation poll failed") - if isOpenAIImageTransientConversationNotFoundError(statusErr) { - lastErr = statusErr - goto waitNextPoll - } - return nil, statusErr - } - } - - waitNextPoll: - timer := time.NewTimer(interval) - select { - case <-ctx.Done(): - if !timer.Stop() { - <-timer.C - } - return nil, ctx.Err() - case <-timer.C: - } - } - return nil, lastErr -} - -func buildOpenAIImageResponse( - ctx context.Context, - client *req.Client, - headers http.Header, - conversationID string, - pointers []openAIImagePointerInfo, -) ([]byte, int, error) { - type responseItem struct { - B64JSON string `json:"b64_json"` - RevisedPrompt string `json:"revised_prompt,omitempty"` - } - items := make([]responseItem, 0, len(pointers)) - for _, pointer := range pointers { - data, err := resolveOpenAIImageBytes(ctx, client, headers, conversationID, pointer) - if err != nil { - return nil, 0, err - } - items = append(items, responseItem{ - B64JSON: base64.StdEncoding.EncodeToString(data), - RevisedPrompt: pointer.Prompt, - }) - } - payload := map[string]any{ - "created": time.Now().Unix(), - "data": items, - } - body, err := json.Marshal(payload) - if err != nil { - return nil, 0, err - } - return body, len(items), nil -} - func resolveOpenAIImageBytes( ctx context.Context, client *req.Client, @@ -1852,17 +1125,6 @@ func isLikelyOpenAIImageDownloadURL(raw string) bool { strings.Contains(lower, ".webp") } -func detachOpenAIImageLifecycleContext(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { - base := context.Background() - if ctx != nil { - base = context.WithoutCancel(ctx) - } - if timeout <= 0 { - return base, func() {} - } - return context.WithTimeout(base, timeout) -} - func fetchOpenAIImageDownloadURL( ctx context.Context, client *req.Client, @@ -1957,10 +1219,6 @@ func downloadOpenAIImageBytes(ctx context.Context, client *req.Client, headers h return io.ReadAll(io.LimitReader(resp.Body, openAIImageMaxDownloadBytes)) } -func handleOpenAIImageBackendError(resp *req.Response) error { - return newOpenAIImageStatusError(resp, "backend-api request failed") -} - type openAIImageStatusError struct { StatusCode int Message string @@ -2028,23 +1286,6 @@ func newOpenAIImageStatusError(resp *req.Response, fallback string) error { } } -func newOpenAIImageSyntheticStatusError(statusCode int, message string, requestURL string) *openAIImageStatusError { - message = sanitizeUpstreamErrorMessage(strings.TrimSpace(message)) - if message == "" { - message = "openai image backend request failed" - } - var body []byte - if payload, err := json.Marshal(map[string]string{"detail": message}); err == nil { - body = payload - } - return &openAIImageStatusError{ - StatusCode: statusCode, - Message: message, - ResponseBody: body, - URL: strings.TrimSpace(requestURL), - } -} - func isOpenAIImageTransientConversationNotFoundError(err error) bool { statusErr, ok := err.(*openAIImageStatusError) if !ok || statusErr == nil || statusErr.StatusCode != http.StatusNotFound { @@ -2064,58 +1305,6 @@ func isOpenAIImageTransientConversationNotFoundError(err error) bool { return strings.Contains(bodyMsg, "conversation") && strings.Contains(bodyMsg, "not found") } -func (s *OpenAIGatewayService) wrapOpenAIImageBackendError( - ctx context.Context, - c *gin.Context, - account *Account, - err error, -) error { - var statusErr *openAIImageStatusError - if !errors.As(err, &statusErr) || statusErr == nil { - return err - } - - upstreamMsg := sanitizeUpstreamErrorMessage(statusErr.Message) - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: statusErr.StatusCode, - UpstreamRequestID: statusErr.RequestID, - UpstreamURL: safeUpstreamURL(statusErr.URL), - Kind: "request_error", - Message: upstreamMsg, - }) - setOpsUpstreamError(c, statusErr.StatusCode, upstreamMsg, "") - - if s.shouldFailoverOpenAIUpstreamResponse(statusErr.StatusCode, upstreamMsg, statusErr.ResponseBody) { - if s.rateLimitService != nil { - s.rateLimitService.HandleUpstreamError(ctx, account, statusErr.StatusCode, statusErr.ResponseHeaders, statusErr.ResponseBody) - } - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: statusErr.StatusCode, - UpstreamRequestID: statusErr.RequestID, - UpstreamURL: safeUpstreamURL(statusErr.URL), - Kind: "failover", - Message: upstreamMsg, - }) - retryableOnSameAccount := account.IsPoolMode() && isPoolModeRetryableStatus(statusErr.StatusCode) - if strings.Contains(strings.ToLower(statusErr.Message), "unsupported challenge") { - retryableOnSameAccount = false - } - return &UpstreamFailoverError{ - StatusCode: statusErr.StatusCode, - ResponseBody: statusErr.ResponseBody, - RetryableOnSameAccount: retryableOnSameAccount, - } - } - - return statusErr -} - func cloneHTTPHeader(src http.Header) http.Header { dst := make(http.Header, len(src)) for key, values := range src { @@ -2140,110 +1329,6 @@ func headerToMap(header http.Header) map[string]string { return result } -func openAITimezoneOffsetMinutes() int { - _, offset := time.Now().Zone() - return offset / 60 -} - -func openAITimezoneName() string { - return time.Now().Location().String() -} - -func generateOpenAIRequirementsToken(userAgent string) string { - config := []any{ - "core" + strconv.Itoa(3008), - time.Now().UTC().Format(time.RFC1123), - nil, - 0.123456, - coalesceOpenAIFileName(strings.TrimSpace(userAgent), openAIImageBackendUserAgent), - nil, - "prod-openai-images", - "en-US", - "en-US,en", - 0, - "navigator.webdriver", - "location", - "document.body", - float64(time.Now().UnixMilli()) / 1000, - uuid.NewString(), - "", - 8, - time.Now().Unix(), - } - answer, solved := generateOpenAIChallengeAnswer(strconv.FormatInt(time.Now().UnixNano(), 10), openAIImageRequirementsDiff, config) - if solved { - return "gAAAAAC" + answer - } - return "" -} - -func generateOpenAIChallengeAnswer(seed string, difficulty string, config []any) (string, bool) { - diffBytes, err := hex.DecodeString(difficulty) - if err != nil { - return "", false - } - p1 := []byte(jsonCompactSlice(config[:3], true)) - p2 := []byte(jsonCompactSlice(config[4:9], false)) - p3 := []byte(jsonCompactSlice(config[10:], false)) - seedBytes := []byte(seed) - - for i := 0; i < 100000; i++ { - payload := fmt.Sprintf("%s%d,%s,%d,%s", p1, i, p2, i>>1, p3) - encoded := base64.StdEncoding.EncodeToString([]byte(payload)) - sum := sha3.Sum512(append(seedBytes, []byte(encoded)...)) - if bytes.Compare(sum[:len(diffBytes)], diffBytes) <= 0 { - return encoded, true - } - } - return "", false -} - -func jsonCompactSlice(values []any, trimSuffixComma bool) string { - raw, _ := json.Marshal(values) - text := string(raw) - if trimSuffixComma { - return strings.TrimSuffix(text, "]") - } - return strings.TrimPrefix(text, "[") -} - -func generateOpenAIProofToken(required bool, seed string, difficulty string, userAgent string) string { - if !required || strings.TrimSpace(seed) == "" || strings.TrimSpace(difficulty) == "" { - return "" - } - screen := 3008 - if len(seed)%2 == 0 { - screen = 4010 - } - proofToken := []any{ - screen, - time.Now().UTC().Format(time.RFC1123), - nil, - 0, - coalesceOpenAIFileName(strings.TrimSpace(userAgent), openAIImageBackendUserAgent), - "https://chatgpt.com/", - "dpl=openai-images", - "en", - "en-US", - nil, - "plugins[object PluginArray]", - "_reactListening", - "alert", - } - diffLen := len(difficulty) - for i := 0; i < 100000; i++ { - proofToken[3] = i - raw, _ := json.Marshal(proofToken) - encoded := base64.StdEncoding.EncodeToString(raw) - sum := sha3.Sum512([]byte(seed + encoded)) - if strings.Compare(hex.EncodeToString(sum[:])[:diffLen], difficulty) <= 0 { - return "gAAAAAB" + encoded - } - } - fallbackBase := base64.StdEncoding.EncodeToString([]byte(fmt.Sprintf("%q", seed))) - return "gAAAAABwQ8Lk5FbGpA2NcR9dShT6gYjU7VxZ4D" + fallbackBase -} - func dedupeStrings(values []string) []string { if len(values) == 0 { return nil diff --git a/backend/internal/service/openai_images_responses.go b/backend/internal/service/openai_images_responses.go new file mode 100644 index 00000000..64d995e1 --- /dev/null +++ b/backend/internal/service/openai_images_responses.go @@ -0,0 +1,853 @@ +package service + +import ( + "bufio" + "bytes" + "context" + "encoding/base64" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/logger" + "github.com/Wei-Shaw/sub2api/internal/util/responseheaders" + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +type openAIResponsesImageResult struct { + Result string + RevisedPrompt string + OutputFormat string + Size string + Background string + Quality string + Model string +} + +func openAIResponsesImageResultKey(itemID string, result openAIResponsesImageResult) string { + if strings.TrimSpace(result.Result) != "" { + return strings.TrimSpace(result.OutputFormat) + "|" + strings.TrimSpace(result.Result) + } + return "item:" + strings.TrimSpace(itemID) +} + +func appendOpenAIResponsesImageResultDedup(results *[]openAIResponsesImageResult, seen map[string]struct{}, itemID string, result openAIResponsesImageResult) bool { + if results == nil { + return false + } + key := openAIResponsesImageResultKey(itemID, result) + if key != "" { + if _, exists := seen[key]; exists { + return false + } + seen[key] = struct{}{} + } + *results = append(*results, result) + return true +} + +func mergeOpenAIResponsesImageMeta(dst *openAIResponsesImageResult, src openAIResponsesImageResult) { + if dst == nil { + return + } + if trimmed := strings.TrimSpace(src.OutputFormat); trimmed != "" { + dst.OutputFormat = trimmed + } + if trimmed := strings.TrimSpace(src.Size); trimmed != "" { + dst.Size = trimmed + } + if trimmed := strings.TrimSpace(src.Background); trimmed != "" { + dst.Background = trimmed + } + if trimmed := strings.TrimSpace(src.Quality); trimmed != "" { + dst.Quality = trimmed + } + if trimmed := strings.TrimSpace(src.Model); trimmed != "" { + dst.Model = trimmed + } +} + +func extractOpenAIResponsesImageMetaFromLifecycleEvent(payload []byte) (openAIResponsesImageResult, int64, bool) { + switch gjson.GetBytes(payload, "type").String() { + case "response.created", "response.in_progress", "response.completed": + default: + return openAIResponsesImageResult{}, 0, false + } + + response := gjson.GetBytes(payload, "response") + if !response.Exists() { + return openAIResponsesImageResult{}, 0, false + } + + meta := openAIResponsesImageResult{ + OutputFormat: strings.TrimSpace(response.Get("tools.0.output_format").String()), + Size: strings.TrimSpace(response.Get("tools.0.size").String()), + Background: strings.TrimSpace(response.Get("tools.0.background").String()), + Quality: strings.TrimSpace(response.Get("tools.0.quality").String()), + Model: strings.TrimSpace(response.Get("tools.0.model").String()), + } + return meta, response.Get("created_at").Int(), true +} + +func buildOpenAIImagesStreamPartialPayload( + eventType string, + b64 string, + partialImageIndex int64, + responseFormat string, + createdAt int64, + meta openAIResponsesImageResult, +) []byte { + if createdAt <= 0 { + createdAt = time.Now().Unix() + } + + payload := []byte(`{"type":"","created_at":0,"partial_image_index":0,"b64_json":""}`) + payload, _ = sjson.SetBytes(payload, "type", eventType) + payload, _ = sjson.SetBytes(payload, "created_at", createdAt) + payload, _ = sjson.SetBytes(payload, "partial_image_index", partialImageIndex) + payload, _ = sjson.SetBytes(payload, "b64_json", b64) + if strings.EqualFold(strings.TrimSpace(responseFormat), "url") { + payload, _ = sjson.SetBytes(payload, "url", "data:"+openAIImageOutputMIMEType(meta.OutputFormat)+";base64,"+b64) + } + if meta.Background != "" { + payload, _ = sjson.SetBytes(payload, "background", meta.Background) + } + if meta.OutputFormat != "" { + payload, _ = sjson.SetBytes(payload, "output_format", meta.OutputFormat) + } + if meta.Quality != "" { + payload, _ = sjson.SetBytes(payload, "quality", meta.Quality) + } + if meta.Size != "" { + payload, _ = sjson.SetBytes(payload, "size", meta.Size) + } + if meta.Model != "" { + payload, _ = sjson.SetBytes(payload, "model", meta.Model) + } + return payload +} + +func buildOpenAIImagesStreamCompletedPayload( + eventType string, + img openAIResponsesImageResult, + responseFormat string, + createdAt int64, + usageRaw []byte, +) []byte { + if createdAt <= 0 { + createdAt = time.Now().Unix() + } + + payload := []byte(`{"type":"","created_at":0,"b64_json":""}`) + payload, _ = sjson.SetBytes(payload, "type", eventType) + payload, _ = sjson.SetBytes(payload, "created_at", createdAt) + payload, _ = sjson.SetBytes(payload, "b64_json", img.Result) + if strings.EqualFold(strings.TrimSpace(responseFormat), "url") { + payload, _ = sjson.SetBytes(payload, "url", "data:"+openAIImageOutputMIMEType(img.OutputFormat)+";base64,"+img.Result) + } + if img.Background != "" { + payload, _ = sjson.SetBytes(payload, "background", img.Background) + } + if img.OutputFormat != "" { + payload, _ = sjson.SetBytes(payload, "output_format", img.OutputFormat) + } + if img.Quality != "" { + payload, _ = sjson.SetBytes(payload, "quality", img.Quality) + } + if img.Size != "" { + payload, _ = sjson.SetBytes(payload, "size", img.Size) + } + if img.Model != "" { + payload, _ = sjson.SetBytes(payload, "model", img.Model) + } + if len(usageRaw) > 0 && gjson.ValidBytes(usageRaw) { + payload, _ = sjson.SetRawBytes(payload, "usage", usageRaw) + } + return payload +} + +func openAIImageOutputMIMEType(outputFormat string) string { + if outputFormat == "" { + return "image/png" + } + if strings.Contains(outputFormat, "/") { + return outputFormat + } + switch strings.ToLower(strings.TrimSpace(outputFormat)) { + case "png": + return "image/png" + case "jpg", "jpeg": + return "image/jpeg" + case "webp": + return "image/webp" + default: + return "image/png" + } +} + +func openAIImageUploadToDataURL(upload OpenAIImagesUpload) (string, error) { + if len(upload.Data) == 0 { + return "", fmt.Errorf("upload %q is empty", strings.TrimSpace(upload.FileName)) + } + contentType := strings.TrimSpace(upload.ContentType) + if contentType == "" { + contentType = http.DetectContentType(upload.Data) + } + return "data:" + contentType + ";base64," + base64.StdEncoding.EncodeToString(upload.Data), nil +} + +func buildOpenAIImagesResponsesRequest(parsed *OpenAIImagesRequest, toolModel string) ([]byte, error) { + if parsed == nil { + return nil, fmt.Errorf("parsed images request is required") + } + prompt := strings.TrimSpace(parsed.Prompt) + if prompt == "" { + return nil, fmt.Errorf("prompt is required") + } + + inputImages := make([]string, 0, len(parsed.InputImageURLs)+len(parsed.Uploads)) + for _, imageURL := range parsed.InputImageURLs { + if trimmed := strings.TrimSpace(imageURL); trimmed != "" { + inputImages = append(inputImages, trimmed) + } + } + for _, upload := range parsed.Uploads { + dataURL, err := openAIImageUploadToDataURL(upload) + if err != nil { + return nil, err + } + inputImages = append(inputImages, dataURL) + } + if parsed.IsEdits() && len(inputImages) == 0 { + return nil, fmt.Errorf("image input is required") + } + + req := []byte(`{"instructions":"","stream":true,"reasoning":{"effort":"medium","summary":"auto"},"parallel_tool_calls":true,"include":["reasoning.encrypted_content"],"model":"","store":false,"tool_choice":{"type":"image_generation"}}`) + req, _ = sjson.SetBytes(req, "model", openAIImagesResponsesMainModel) + + input := []byte(`[{"type":"message","role":"user","content":[{"type":"input_text","text":""}]}]`) + input, _ = sjson.SetBytes(input, "0.content.0.text", prompt) + for index, imageURL := range inputImages { + part := []byte(`{"type":"input_image","image_url":""}`) + part, _ = sjson.SetBytes(part, "image_url", imageURL) + input, _ = sjson.SetRawBytes(input, fmt.Sprintf("0.content.%d", index+1), part) + } + req, _ = sjson.SetRawBytes(req, "input", input) + + action := "generate" + if parsed.IsEdits() { + action = "edit" + } + tool := []byte(`{"type":"image_generation","action":"","model":""}`) + tool, _ = sjson.SetBytes(tool, "action", action) + tool, _ = sjson.SetBytes(tool, "model", strings.TrimSpace(toolModel)) + + for _, field := range []struct { + path string + value string + }{ + {path: "size", value: parsed.Size}, + {path: "quality", value: parsed.Quality}, + {path: "background", value: parsed.Background}, + {path: "output_format", value: parsed.OutputFormat}, + {path: "moderation", value: parsed.Moderation}, + {path: "style", value: parsed.Style}, + } { + if trimmed := strings.TrimSpace(field.value); trimmed != "" { + tool, _ = sjson.SetBytes(tool, field.path, trimmed) + } + } + if parsed.OutputCompression != nil { + tool, _ = sjson.SetBytes(tool, "output_compression", *parsed.OutputCompression) + } + if parsed.PartialImages != nil { + tool, _ = sjson.SetBytes(tool, "partial_images", *parsed.PartialImages) + } + + maskImageURL := strings.TrimSpace(parsed.MaskImageURL) + if parsed.MaskUpload != nil { + dataURL, err := openAIImageUploadToDataURL(*parsed.MaskUpload) + if err != nil { + return nil, err + } + maskImageURL = dataURL + } + if maskImageURL != "" { + tool, _ = sjson.SetBytes(tool, "input_image_mask.image_url", maskImageURL) + } + + req, _ = sjson.SetRawBytes(req, "tools", []byte(`[]`)) + req, _ = sjson.SetRawBytes(req, "tools.-1", tool) + return req, nil +} + +func extractOpenAIImagesFromResponsesCompleted(payload []byte) ([]openAIResponsesImageResult, int64, []byte, openAIResponsesImageResult, error) { + if gjson.GetBytes(payload, "type").String() != "response.completed" { + return nil, 0, nil, openAIResponsesImageResult{}, fmt.Errorf("unexpected event type") + } + + createdAt := gjson.GetBytes(payload, "response.created_at").Int() + if createdAt <= 0 { + createdAt = time.Now().Unix() + } + + var ( + results []openAIResponsesImageResult + firstMeta openAIResponsesImageResult + ) + output := gjson.GetBytes(payload, "response.output") + if output.IsArray() { + for _, item := range output.Array() { + if item.Get("type").String() != "image_generation_call" { + continue + } + result := strings.TrimSpace(item.Get("result").String()) + if result == "" { + continue + } + entry := openAIResponsesImageResult{ + Result: result, + RevisedPrompt: strings.TrimSpace(item.Get("revised_prompt").String()), + OutputFormat: strings.TrimSpace(item.Get("output_format").String()), + Size: strings.TrimSpace(item.Get("size").String()), + Background: strings.TrimSpace(item.Get("background").String()), + Quality: strings.TrimSpace(item.Get("quality").String()), + } + if len(results) == 0 { + firstMeta = entry + } + results = append(results, entry) + } + } + + var usageRaw []byte + if usage := gjson.GetBytes(payload, "response.tool_usage.image_gen"); usage.Exists() && usage.IsObject() { + usageRaw = []byte(usage.Raw) + } + return results, createdAt, usageRaw, firstMeta, nil +} + +func extractOpenAIImageFromResponsesOutputItemDone(payload []byte) (openAIResponsesImageResult, string, bool, error) { + if gjson.GetBytes(payload, "type").String() != "response.output_item.done" { + return openAIResponsesImageResult{}, "", false, fmt.Errorf("unexpected event type") + } + + item := gjson.GetBytes(payload, "item") + if !item.Exists() || item.Get("type").String() != "image_generation_call" { + return openAIResponsesImageResult{}, "", false, nil + } + + result := strings.TrimSpace(item.Get("result").String()) + if result == "" { + return openAIResponsesImageResult{}, "", false, nil + } + + entry := openAIResponsesImageResult{ + Result: result, + RevisedPrompt: strings.TrimSpace(item.Get("revised_prompt").String()), + OutputFormat: strings.TrimSpace(item.Get("output_format").String()), + Size: strings.TrimSpace(item.Get("size").String()), + Background: strings.TrimSpace(item.Get("background").String()), + Quality: strings.TrimSpace(item.Get("quality").String()), + } + return entry, strings.TrimSpace(item.Get("id").String()), true, nil +} + +func collectOpenAIImagesFromResponsesBody(body []byte) ([]openAIResponsesImageResult, int64, []byte, openAIResponsesImageResult, bool, error) { + var ( + fallbackResults []openAIResponsesImageResult + fallbackSeen = make(map[string]struct{}) + createdAt int64 + usageRaw []byte + foundFinal bool + responseMeta openAIResponsesImageResult + ) + + for _, line := range bytes.Split(body, []byte("\n")) { + line = bytes.TrimRight(line, "\r") + data, ok := extractOpenAISSEDataLine(string(line)) + if !ok || data == "" || data == "[DONE]" { + continue + } + payload := []byte(data) + if !gjson.ValidBytes(payload) { + continue + } + if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(payload); ok { + mergeOpenAIResponsesImageMeta(&responseMeta, meta) + if eventCreatedAt > 0 { + createdAt = eventCreatedAt + } + } + + switch gjson.GetBytes(payload, "type").String() { + case "response.output_item.done": + result, itemID, ok, err := extractOpenAIImageFromResponsesOutputItemDone(payload) + if err != nil { + return nil, 0, nil, openAIResponsesImageResult{}, false, err + } + if ok { + mergeOpenAIResponsesImageMeta(&result, responseMeta) + appendOpenAIResponsesImageResultDedup(&fallbackResults, fallbackSeen, itemID, result) + } + case "response.completed": + results, completedAt, completedUsageRaw, firstMeta, err := extractOpenAIImagesFromResponsesCompleted(payload) + if err != nil { + return nil, 0, nil, openAIResponsesImageResult{}, false, err + } + foundFinal = true + if completedAt > 0 { + createdAt = completedAt + } + if len(completedUsageRaw) > 0 { + usageRaw = completedUsageRaw + } + if len(results) > 0 { + mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta) + return results, createdAt, usageRaw, firstMeta, true, nil + } + if len(fallbackResults) > 0 { + firstMeta = fallbackResults[0] + mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta) + return fallbackResults, createdAt, usageRaw, firstMeta, true, nil + } + } + } + + if len(fallbackResults) > 0 { + firstMeta := fallbackResults[0] + mergeOpenAIResponsesImageMeta(&firstMeta, responseMeta) + return fallbackResults, createdAt, usageRaw, firstMeta, foundFinal, nil + } + return nil, createdAt, usageRaw, openAIResponsesImageResult{}, foundFinal, nil +} + +func buildOpenAIImagesAPIResponse( + results []openAIResponsesImageResult, + createdAt int64, + usageRaw []byte, + firstMeta openAIResponsesImageResult, + responseFormat string, +) ([]byte, error) { + if createdAt <= 0 { + createdAt = time.Now().Unix() + } + out := []byte(`{"created":0,"data":[]}`) + out, _ = sjson.SetBytes(out, "created", createdAt) + + format := strings.ToLower(strings.TrimSpace(responseFormat)) + if format == "" { + format = "b64_json" + } + for _, img := range results { + item := []byte(`{}`) + if format == "url" { + item, _ = sjson.SetBytes(item, "url", "data:"+openAIImageOutputMIMEType(img.OutputFormat)+";base64,"+img.Result) + } else { + item, _ = sjson.SetBytes(item, "b64_json", img.Result) + } + if img.RevisedPrompt != "" { + item, _ = sjson.SetBytes(item, "revised_prompt", img.RevisedPrompt) + } + out, _ = sjson.SetRawBytes(out, "data.-1", item) + } + if firstMeta.Background != "" { + out, _ = sjson.SetBytes(out, "background", firstMeta.Background) + } + if firstMeta.OutputFormat != "" { + out, _ = sjson.SetBytes(out, "output_format", firstMeta.OutputFormat) + } + if firstMeta.Quality != "" { + out, _ = sjson.SetBytes(out, "quality", firstMeta.Quality) + } + if firstMeta.Size != "" { + out, _ = sjson.SetBytes(out, "size", firstMeta.Size) + } + if firstMeta.Model != "" { + out, _ = sjson.SetBytes(out, "model", firstMeta.Model) + } + if len(usageRaw) > 0 && gjson.ValidBytes(usageRaw) { + out, _ = sjson.SetRawBytes(out, "usage", usageRaw) + } + return out, nil +} + +func openAIImagesStreamPrefix(parsed *OpenAIImagesRequest) string { + if parsed != nil && parsed.IsEdits() { + return "image_edit" + } + return "image_generation" +} + +func buildOpenAIImagesStreamErrorBody(message string) []byte { + body := []byte(`{"type":"error","error":{"type":"upstream_error","message":""}}`) + if strings.TrimSpace(message) == "" { + message = "upstream request failed" + } + body, _ = sjson.SetBytes(body, "error.message", message) + return body +} + +func (s *OpenAIGatewayService) writeOpenAIImagesStreamEvent(c *gin.Context, flusher http.Flusher, eventName string, payload []byte) error { + if strings.TrimSpace(eventName) != "" { + if _, err := fmt.Fprintf(c.Writer, "event: %s\n", eventName); err != nil { + return err + } + } + if _, err := fmt.Fprintf(c.Writer, "data: %s\n\n", payload); err != nil { + return err + } + flusher.Flush() + return nil +} + +func (s *OpenAIGatewayService) handleOpenAIImagesOAuthNonStreamingResponse( + resp *http.Response, + c *gin.Context, + responseFormat string, + fallbackModel string, +) (OpenAIUsage, int, error) { + body, err := ReadUpstreamResponseBody(resp.Body, s.cfg, c, openAITooLargeError) + if err != nil { + return OpenAIUsage{}, 0, err + } + + var usage OpenAIUsage + for _, line := range bytes.Split(body, []byte("\n")) { + line = bytes.TrimRight(line, "\r") + data, ok := extractOpenAISSEDataLine(string(line)) + if !ok || data == "" || data == "[DONE]" { + continue + } + dataBytes := []byte(data) + s.parseSSEUsageBytes(dataBytes, &usage) + } + results, createdAt, usageRaw, firstMeta, _, err := collectOpenAIImagesFromResponsesBody(body) + if err != nil { + return OpenAIUsage{}, 0, err + } + if len(results) == 0 { + return OpenAIUsage{}, 0, fmt.Errorf("upstream did not return image output") + } + if strings.TrimSpace(firstMeta.Model) == "" { + firstMeta.Model = strings.TrimSpace(fallbackModel) + } + + responseBody, err := buildOpenAIImagesAPIResponse(results, createdAt, usageRaw, firstMeta, responseFormat) + if err != nil { + return OpenAIUsage{}, 0, err + } + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + c.Data(resp.StatusCode, "application/json; charset=utf-8", responseBody) + return usage, len(results), nil +} + +func (s *OpenAIGatewayService) handleOpenAIImagesOAuthStreamingResponse( + resp *http.Response, + c *gin.Context, + startTime time.Time, + responseFormat string, + streamPrefix string, + fallbackModel string, +) (OpenAIUsage, int, *int, error) { + responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.responseHeaderFilter) + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Status(resp.StatusCode) + + flusher, ok := c.Writer.(http.Flusher) + if !ok { + return OpenAIUsage{}, 0, nil, fmt.Errorf("streaming is not supported by response writer") + } + + format := strings.ToLower(strings.TrimSpace(responseFormat)) + if format == "" { + format = "b64_json" + } + + reader := bufio.NewReader(resp.Body) + usage := OpenAIUsage{} + imageCount := 0 + var firstTokenMs *int + emitted := make(map[string]struct{}) + pendingResults := make([]openAIResponsesImageResult, 0, 1) + pendingSeen := make(map[string]struct{}) + streamMeta := openAIResponsesImageResult{Model: strings.TrimSpace(fallbackModel)} + var createdAt int64 + + for { + line, err := reader.ReadBytes('\n') + if len(line) > 0 { + trimmedLine := strings.TrimRight(string(line), "\r\n") + data, ok := extractOpenAISSEDataLine(trimmedLine) + if ok && data != "" && data != "[DONE]" { + if firstTokenMs == nil { + ms := int(time.Since(startTime).Milliseconds()) + firstTokenMs = &ms + } + dataBytes := []byte(data) + s.parseSSEUsageBytes(dataBytes, &usage) + if gjson.ValidBytes(dataBytes) { + if meta, eventCreatedAt, ok := extractOpenAIResponsesImageMetaFromLifecycleEvent(dataBytes); ok { + mergeOpenAIResponsesImageMeta(&streamMeta, meta) + if eventCreatedAt > 0 { + createdAt = eventCreatedAt + } + } + switch gjson.GetBytes(dataBytes, "type").String() { + case "response.image_generation_call.partial_image": + b64 := strings.TrimSpace(gjson.GetBytes(dataBytes, "partial_image_b64").String()) + if b64 != "" { + eventName := streamPrefix + ".partial_image" + partialMeta := streamMeta + mergeOpenAIResponsesImageMeta(&partialMeta, openAIResponsesImageResult{ + OutputFormat: strings.TrimSpace(gjson.GetBytes(dataBytes, "output_format").String()), + Background: strings.TrimSpace(gjson.GetBytes(dataBytes, "background").String()), + }) + payload := buildOpenAIImagesStreamPartialPayload( + eventName, + b64, + gjson.GetBytes(dataBytes, "partial_image_index").Int(), + format, + createdAt, + partialMeta, + ) + if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil { + return OpenAIUsage{}, imageCount, firstTokenMs, writeErr + } + } + case "response.output_item.done": + img, itemID, ok, extractErr := extractOpenAIImageFromResponsesOutputItemDone(dataBytes) + if extractErr != nil { + _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error())) + return OpenAIUsage{}, imageCount, firstTokenMs, extractErr + } + if !ok { + break + } + mergeOpenAIResponsesImageMeta(&streamMeta, img) + mergeOpenAIResponsesImageMeta(&img, streamMeta) + key := openAIResponsesImageResultKey(itemID, img) + if _, exists := emitted[key]; exists { + break + } + if _, exists := pendingSeen[key]; exists { + break + } + pendingSeen[key] = struct{}{} + pendingResults = append(pendingResults, img) + case "response.completed": + results, _, usageRaw, firstMeta, extractErr := extractOpenAIImagesFromResponsesCompleted(dataBytes) + if extractErr != nil { + _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(extractErr.Error())) + return OpenAIUsage{}, imageCount, firstTokenMs, extractErr + } + mergeOpenAIResponsesImageMeta(&streamMeta, firstMeta) + finalResults := make([]openAIResponsesImageResult, 0, len(results)+len(pendingResults)) + finalSeen := make(map[string]struct{}) + for _, img := range results { + mergeOpenAIResponsesImageMeta(&img, streamMeta) + appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img) + } + for _, img := range pendingResults { + mergeOpenAIResponsesImageMeta(&img, streamMeta) + appendOpenAIResponsesImageResultDedup(&finalResults, finalSeen, "", img) + } + if len(finalResults) == 0 { + err = fmt.Errorf("upstream did not return image output") + _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(err.Error())) + return OpenAIUsage{}, imageCount, firstTokenMs, err + } + eventName := streamPrefix + ".completed" + for _, img := range finalResults { + key := openAIResponsesImageResultKey("", img) + if _, exists := emitted[key]; exists { + continue + } + payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, usageRaw) + if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil { + return OpenAIUsage{}, imageCount, firstTokenMs, writeErr + } + emitted[key] = struct{}{} + } + imageCount = len(emitted) + return usage, imageCount, firstTokenMs, nil + } + } + } + } + if err == io.EOF { + break + } + if err != nil { + _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(err.Error())) + return OpenAIUsage{}, imageCount, firstTokenMs, err + } + } + + if imageCount > 0 { + return usage, imageCount, firstTokenMs, nil + } + if len(pendingResults) > 0 { + eventName := streamPrefix + ".completed" + for _, img := range pendingResults { + mergeOpenAIResponsesImageMeta(&img, streamMeta) + key := openAIResponsesImageResultKey("", img) + if _, exists := emitted[key]; exists { + continue + } + payload := buildOpenAIImagesStreamCompletedPayload(eventName, img, format, createdAt, nil) + if writeErr := s.writeOpenAIImagesStreamEvent(c, flusher, eventName, payload); writeErr != nil { + return OpenAIUsage{}, imageCount, firstTokenMs, writeErr + } + emitted[key] = struct{}{} + } + imageCount = len(emitted) + return usage, imageCount, firstTokenMs, nil + } + + streamErr := fmt.Errorf("stream disconnected before image generation completed") + _ = s.writeOpenAIImagesStreamEvent(c, flusher, "error", buildOpenAIImagesStreamErrorBody(streamErr.Error())) + return OpenAIUsage{}, imageCount, firstTokenMs, streamErr +} + +func (s *OpenAIGatewayService) forwardOpenAIImagesOAuth( + ctx context.Context, + c *gin.Context, + account *Account, + parsed *OpenAIImagesRequest, + channelMappedModel string, +) (*OpenAIForwardResult, error) { + startTime := time.Now() + requestModel := strings.TrimSpace(parsed.Model) + if mapped := strings.TrimSpace(channelMappedModel); mapped != "" { + requestModel = mapped + } + if requestModel == "" { + requestModel = "gpt-image-2" + } + if err := validateOpenAIImagesModel(requestModel); err != nil { + return nil, err + } + logger.LegacyPrintf( + "service.openai_gateway", + "[OpenAI] Images request routing request_model=%s endpoint=%s account_type=%s uploads=%d", + requestModel, + parsed.Endpoint, + account.Type, + len(parsed.Uploads), + ) + if parsed.N > 1 { + logger.LegacyPrintf( + "service.openai_gateway", + "[Warning] Codex /responses image tool requested n=%d; falling back to n=1 request_model=%s endpoint=%s", + parsed.N, + requestModel, + parsed.Endpoint, + ) + } + + token, _, err := s.GetAccessToken(ctx, account) + if err != nil { + return nil, err + } + + responsesBody, err := buildOpenAIImagesResponsesRequest(parsed, requestModel) + if err != nil { + return nil, err + } + setOpsUpstreamRequestBody(c, responsesBody) + + upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, responsesBody, token, true, parsed.StickySessionSeed(), false) + if err != nil { + return nil, err + } + upstreamReq.Header.Set("Content-Type", "application/json") + upstreamReq.Header.Set("Accept", "text/event-stream") + + proxyURL := "" + if account.ProxyID != nil && account.Proxy != nil { + proxyURL = account.Proxy.URL() + } + upstreamStart := time.Now() + resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds()) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + setOpsUpstreamError(c, 0, safeErr, "") + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: 0, + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), + Kind: "request_error", + Message: safeErr, + }) + return nil, fmt.Errorf("upstream request failed: %s", safeErr) + } + if resp.StatusCode >= 400 { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + resp.Body = io.NopCloser(bytes.NewReader(respBody)) + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + if s.shouldFailoverOpenAIUpstreamResponse(resp.StatusCode, upstreamMsg, respBody) { + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Platform: account.Platform, + AccountID: account.ID, + AccountName: account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + UpstreamURL: safeUpstreamURL(upstreamReq.URL.String()), + Kind: "failover", + Message: upstreamMsg, + }) + s.handleFailoverSideEffects(ctx, resp, account) + return nil, &UpstreamFailoverError{ + StatusCode: resp.StatusCode, + ResponseBody: respBody, + RetryableOnSameAccount: account.IsPoolMode() && isPoolModeRetryableStatus(resp.StatusCode), + } + } + return s.handleErrorResponse(ctx, resp, c, account, responsesBody) + } + defer func() { _ = resp.Body.Close() }() + + var ( + usage OpenAIUsage + imageCount int + firstTokenMs *int + ) + if parsed.Stream { + usage, imageCount, firstTokenMs, err = s.handleOpenAIImagesOAuthStreamingResponse(resp, c, startTime, parsed.ResponseFormat, openAIImagesStreamPrefix(parsed), requestModel) + if err != nil { + return nil, err + } + } else { + usage, imageCount, err = s.handleOpenAIImagesOAuthNonStreamingResponse(resp, c, parsed.ResponseFormat, requestModel) + if err != nil { + return nil, err + } + } + if imageCount <= 0 { + imageCount = parsed.N + } + return &OpenAIForwardResult{ + RequestID: resp.Header.Get("x-request-id"), + Usage: usage, + Model: requestModel, + UpstreamModel: requestModel, + Stream: parsed.Stream, + ResponseHeaders: resp.Header.Clone(), + Duration: time.Since(startTime), + FirstTokenMs: firstTokenMs, + ImageCount: imageCount, + ImageSize: parsed.SizeTier, + }, nil +} diff --git a/backend/internal/service/openai_images_test.go b/backend/internal/service/openai_images_test.go index 6aa1d5e5..200547d4 100644 --- a/backend/internal/service/openai_images_test.go +++ b/backend/internal/service/openai_images_test.go @@ -3,13 +3,17 @@ package service import ( "bytes" "context" + "io" "mime/multipart" "net/http" "net/http/httptest" + "net/textproto" + "strings" "testing" "github.com/gin-gonic/gin" "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" ) func TestOpenAIGatewayServiceParseOpenAIImagesRequest_JSON(t *testing.T) { @@ -70,6 +74,58 @@ func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEdit(t *testing.T require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability) } +func TestOpenAIGatewayServiceParseOpenAIImagesRequest_MultipartEditWithMaskAndNativeOptions(t *testing.T) { + gin.SetMode(gin.TestMode) + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + require.NoError(t, writer.WriteField("model", "gpt-image-2")) + require.NoError(t, writer.WriteField("prompt", "replace foreground")) + require.NoError(t, writer.WriteField("output_format", "png")) + require.NoError(t, writer.WriteField("input_fidelity", "high")) + require.NoError(t, writer.WriteField("output_compression", "80")) + require.NoError(t, writer.WriteField("partial_images", "2")) + + imageHeader := make(textproto.MIMEHeader) + imageHeader.Set("Content-Disposition", `form-data; name="image"; filename="source.png"`) + imageHeader.Set("Content-Type", "image/png") + imagePart, err := writer.CreatePart(imageHeader) + require.NoError(t, err) + _, err = imagePart.Write([]byte("source-image-bytes")) + require.NoError(t, err) + + maskHeader := make(textproto.MIMEHeader) + maskHeader.Set("Content-Disposition", `form-data; name="mask"; filename="mask.png"`) + maskHeader.Set("Content-Type", "image/png") + maskPart, err := writer.CreatePart(maskHeader) + require.NoError(t, err) + _, err = maskPart.Write([]byte("mask-image-bytes")) + require.NoError(t, err) + + require.NoError(t, writer.Close()) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body.Bytes())) + req.Header.Set("Content-Type", writer.FormDataContentType()) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body.Bytes()) + require.NoError(t, err) + require.NotNil(t, parsed) + require.Len(t, parsed.Uploads, 1) + require.NotNil(t, parsed.MaskUpload) + require.True(t, parsed.HasMask) + require.Equal(t, "png", parsed.OutputFormat) + require.Equal(t, "high", parsed.InputFidelity) + require.NotNil(t, parsed.OutputCompression) + require.Equal(t, 80, *parsed.OutputCompression) + require.NotNil(t, parsed.PartialImages) + require.Equal(t, 2, *parsed.PartialImages) + require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability) +} + func TestOpenAIGatewayServiceParseOpenAIImagesRequest_PromptOnlyDefaultsRemainBasic(t *testing.T) { gin.SetMode(gin.TestMode) body := []byte(`{"prompt":"draw a cat"}`) @@ -121,6 +177,40 @@ func TestOpenAIGatewayServiceParseOpenAIImagesRequest_RejectsNonImageModel(t *te require.ErrorContains(t, err, `images endpoint requires an image model, got "gpt-5.4"`) } +func TestOpenAIGatewayServiceParseOpenAIImagesRequest_JSONEditURLs(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{ + "model":"gpt-image-2", + "prompt":"replace the background", + "images":[{"image_url":"https://example.com/source.png"}], + "mask":{"image_url":"https://example.com/mask.png"}, + "input_fidelity":"high", + "output_compression":90, + "partial_images":2, + "response_format":"url" + }`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + require.NotNil(t, parsed) + require.Equal(t, []string{"https://example.com/source.png"}, parsed.InputImageURLs) + require.Equal(t, "https://example.com/mask.png", parsed.MaskImageURL) + require.Equal(t, "high", parsed.InputFidelity) + require.NotNil(t, parsed.OutputCompression) + require.Equal(t, 90, *parsed.OutputCompression) + require.NotNil(t, parsed.PartialImages) + require.Equal(t, 2, *parsed.PartialImages) + require.True(t, parsed.HasMask) + require.Equal(t, OpenAIImagesCapabilityNative, parsed.RequiredCapability) +} + func TestCollectOpenAIImagePointers_RecognizesDirectAssets(t *testing.T) { items := collectOpenAIImagePointers([]byte(`{ "revised_prompt": "cat astronaut", @@ -157,3 +247,472 @@ func TestResolveOpenAIImageBytes_PrefersInlineBase64(t *testing.T) { require.NoError(t, err) require.Equal(t, []byte("ABC"), data) } + +func TestAccountSupportsOpenAIImageCapability_OAuthSupportsNative(t *testing.T) { + account := &Account{ + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + } + + require.True(t, account.SupportsOpenAIImageCapability(OpenAIImagesCapabilityBasic)) + require.True(t, account.SupportsOpenAIImageCapability(OpenAIImagesCapabilityNative)) +} + +type openAIImageTestSSEEvent struct { + Name string + Data string +} + +func parseOpenAIImageTestSSEEvents(body string) []openAIImageTestSSEEvent { + chunks := strings.Split(body, "\n\n") + events := make([]openAIImageTestSSEEvent, 0, len(chunks)) + for _, chunk := range chunks { + chunk = strings.TrimSpace(chunk) + if chunk == "" { + continue + } + var event openAIImageTestSSEEvent + for _, line := range strings.Split(chunk, "\n") { + switch { + case strings.HasPrefix(line, "event: "): + event.Name = strings.TrimSpace(strings.TrimPrefix(line, "event: ")) + case strings.HasPrefix(line, "data: "): + event.Data = strings.TrimSpace(strings.TrimPrefix(line, "data: ")) + } + } + if event.Name != "" || event.Data != "" { + events = append(events, event) + } + } + return events +} + +func findOpenAIImageTestSSEEvent(events []openAIImageTestSSEEvent, name string) (openAIImageTestSSEEvent, bool) { + for _, event := range events { + if event.Name == name { + return event, true + } + } + return openAIImageTestSSEEvent{}, false +} + +func TestOpenAIGatewayServiceForwardImages_OAuthUsesResponsesAPI(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","size":"1024x1024","quality":"high","n":2}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + c.Set("api_key", &APIKey{ID: 42}) + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "X-Request-Id": []string{"req_img_123"}, + }, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000000,\"usage\":{\"input_tokens\":11,\"output_tokens\":22,\"input_tokens_details\":{\"cached_tokens\":3},\"output_tokens_details\":{\"image_tokens\":7}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"aGVsbG8=\",\"revised_prompt\":\"draw a cat\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" + + "data: [DONE]\n\n", + )), + }, + } + svc.httpUpstream = upstream + + account := &Account{ + ID: 1, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "token-123", + "chatgpt_account_id": "acct-123", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "gpt-image-2", result.Model) + require.Equal(t, "gpt-image-2", result.UpstreamModel) + require.Equal(t, 1, result.ImageCount) + require.Equal(t, 11, result.Usage.InputTokens) + require.Equal(t, 22, result.Usage.OutputTokens) + require.Equal(t, 7, result.Usage.ImageOutputTokens) + + require.NotNil(t, upstream.lastReq) + require.Equal(t, chatgptCodexURL, upstream.lastReq.URL.String()) + require.Equal(t, "chatgpt.com", upstream.lastReq.Host) + require.Equal(t, "application/json", upstream.lastReq.Header.Get("Content-Type")) + require.Equal(t, "text/event-stream", upstream.lastReq.Header.Get("Accept")) + require.Equal(t, "acct-123", upstream.lastReq.Header.Get("chatgpt-account-id")) + require.Equal(t, "responses=experimental", upstream.lastReq.Header.Get("OpenAI-Beta")) + + require.Equal(t, openAIImagesResponsesMainModel, gjson.GetBytes(upstream.lastBody, "model").String()) + require.True(t, gjson.GetBytes(upstream.lastBody, "stream").Bool()) + require.Equal(t, "image_generation", gjson.GetBytes(upstream.lastBody, "tools.0.type").String()) + require.Equal(t, "generate", gjson.GetBytes(upstream.lastBody, "tools.0.action").String()) + require.Equal(t, "gpt-image-2", gjson.GetBytes(upstream.lastBody, "tools.0.model").String()) + require.Equal(t, "1024x1024", gjson.GetBytes(upstream.lastBody, "tools.0.size").String()) + require.Equal(t, "high", gjson.GetBytes(upstream.lastBody, "tools.0.quality").String()) + require.False(t, gjson.GetBytes(upstream.lastBody, "tools.0.n").Exists()) + require.Equal(t, "draw a cat", gjson.GetBytes(upstream.lastBody, "input.0.content.0.text").String()) + + require.Equal(t, http.StatusOK, rec.Code) + require.Equal(t, "gpt-image-2", gjson.Get(rec.Body.String(), "model").String()) + require.Equal(t, "aGVsbG8=", gjson.Get(rec.Body.String(), "data.0.b64_json").String()) + require.Equal(t, "draw a cat", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String()) +} + +func TestOpenAIGatewayServiceForwardImages_OAuthStreamingTransformsEvents(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "X-Request-Id": []string{"req_img_stream"}, + }, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.created\",\"response\":{\"created_at\":1710000001,\"tools\":[{\"type\":\"image_generation\",\"model\":\"gpt-image-2\",\"background\":\"auto\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" + + "data: {\"type\":\"response.image_generation_call.partial_image\",\"partial_image_b64\":\"cGFydGlhbA==\",\"partial_image_index\":0,\"output_format\":\"png\",\"background\":\"auto\"}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000001,\"usage\":{\"input_tokens\":5,\"output_tokens\":9,\"output_tokens_details\":{\"image_tokens\":4}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"tools\":[{\"type\":\"image_generation\",\"model\":\"gpt-image-2\",\"background\":\"auto\",\"output_format\":\"png\",\"quality\":\"high\",\"size\":\"1024x1024\"}],\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZmluYWw=\",\"output_format\":\"png\"}]}}\n\n" + + "data: [DONE]\n\n", + )), + }, + } + svc.httpUpstream = upstream + + account := &Account{ + ID: 2, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "token-123", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "") + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + require.Equal(t, 1, result.ImageCount) + events := parseOpenAIImageTestSSEEvents(rec.Body.String()) + partial, ok := findOpenAIImageTestSSEEvent(events, "image_generation.partial_image") + require.True(t, ok) + require.Equal(t, "image_generation.partial_image", gjson.Get(partial.Data, "type").String()) + require.Equal(t, int64(1710000001), gjson.Get(partial.Data, "created_at").Int()) + require.Equal(t, "cGFydGlhbA==", gjson.Get(partial.Data, "b64_json").String()) + require.Equal(t, "data:image/png;base64,cGFydGlhbA==", gjson.Get(partial.Data, "url").String()) + require.Equal(t, "gpt-image-2", gjson.Get(partial.Data, "model").String()) + require.Equal(t, "png", gjson.Get(partial.Data, "output_format").String()) + require.Equal(t, "high", gjson.Get(partial.Data, "quality").String()) + require.Equal(t, "1024x1024", gjson.Get(partial.Data, "size").String()) + require.Equal(t, "auto", gjson.Get(partial.Data, "background").String()) + + completed, ok := findOpenAIImageTestSSEEvent(events, "image_generation.completed") + require.True(t, ok) + require.Equal(t, "image_generation.completed", gjson.Get(completed.Data, "type").String()) + require.Equal(t, int64(1710000001), gjson.Get(completed.Data, "created_at").Int()) + require.Equal(t, "ZmluYWw=", gjson.Get(completed.Data, "b64_json").String()) + require.Equal(t, "data:image/png;base64,ZmluYWw=", gjson.Get(completed.Data, "url").String()) + require.Equal(t, "gpt-image-2", gjson.Get(completed.Data, "model").String()) + require.Equal(t, "png", gjson.Get(completed.Data, "output_format").String()) + require.Equal(t, "high", gjson.Get(completed.Data, "quality").String()) + require.Equal(t, "1024x1024", gjson.Get(completed.Data, "size").String()) + require.Equal(t, "auto", gjson.Get(completed.Data, "background").String()) + require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw) + require.False(t, gjson.Get(completed.Data, "revised_prompt").Exists()) +} + +func TestOpenAIGatewayServiceForwardImages_OAuthEditsMultipartUsesResponsesAPI(t *testing.T) { + gin.SetMode(gin.TestMode) + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + require.NoError(t, writer.WriteField("model", "gpt-image-2")) + require.NoError(t, writer.WriteField("prompt", "replace background with aurora")) + require.NoError(t, writer.WriteField("input_fidelity", "high")) + require.NoError(t, writer.WriteField("output_format", "webp")) + require.NoError(t, writer.WriteField("quality", "high")) + + imageHeader := make(textproto.MIMEHeader) + imageHeader.Set("Content-Disposition", `form-data; name="image"; filename="source.png"`) + imageHeader.Set("Content-Type", "image/png") + imagePart, err := writer.CreatePart(imageHeader) + require.NoError(t, err) + _, err = imagePart.Write([]byte("png-image-content")) + require.NoError(t, err) + + maskHeader := make(textproto.MIMEHeader) + maskHeader.Set("Content-Disposition", `form-data; name="mask"; filename="mask.png"`) + maskHeader.Set("Content-Type", "image/png") + maskPart, err := writer.CreatePart(maskHeader) + require.NoError(t, err) + _, err = maskPart.Write([]byte("png-mask-content")) + require.NoError(t, err) + + require.NoError(t, writer.Close()) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body.Bytes())) + req.Header.Set("Content-Type", writer.FormDataContentType()) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + c.Set("api_key", &APIKey{ID: 100}) + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body.Bytes()) + require.NoError(t, err) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "X-Request-Id": []string{"req_img_edit_123"}, + }, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000002,\"usage\":{\"input_tokens\":13,\"output_tokens\":21,\"output_tokens_details\":{\"image_tokens\":8}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZWRpdGVk\",\"revised_prompt\":\"replace background with aurora\",\"output_format\":\"webp\",\"quality\":\"high\"}]}}\n\n" + + "data: [DONE]\n\n", + )), + }, + } + svc.httpUpstream = upstream + + account := &Account{ + ID: 3, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "token-123", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body.Bytes(), parsed, "") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 1, result.ImageCount) + require.Equal(t, "gpt-image-2", gjson.GetBytes(upstream.lastBody, "tools.0.model").String()) + require.Equal(t, "edit", gjson.GetBytes(upstream.lastBody, "tools.0.action").String()) + require.False(t, gjson.GetBytes(upstream.lastBody, "tools.0.input_fidelity").Exists()) + require.Equal(t, "webp", gjson.GetBytes(upstream.lastBody, "tools.0.output_format").String()) + require.True(t, strings.HasPrefix(gjson.GetBytes(upstream.lastBody, "input.0.content.1.image_url").String(), "data:image/png;base64,")) + require.True(t, strings.HasPrefix(gjson.GetBytes(upstream.lastBody, "tools.0.input_image_mask.image_url").String(), "data:image/png;base64,")) + require.Equal(t, "replace background with aurora", gjson.GetBytes(upstream.lastBody, "input.0.content.0.text").String()) + require.Equal(t, "ZWRpdGVk", gjson.Get(rec.Body.String(), "data.0.b64_json").String()) + require.Equal(t, "replace background with aurora", gjson.Get(rec.Body.String(), "data.0.revised_prompt").String()) +} + +func TestOpenAIGatewayServiceForwardImages_OAuthEditsStreamingTransformsEvents(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{ + "model":"gpt-image-2", + "prompt":"replace background with aurora", + "images":[{"image_url":"https://example.com/source.png"}], + "mask":{"image_url":"https://example.com/mask.png"}, + "stream":true, + "response_format":"url" + }`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/edits", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + }, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.created\",\"response\":{\"created_at\":1710000003,\"tools\":[{\"type\":\"image_generation\",\"model\":\"gpt-image-2\",\"background\":\"transparent\",\"output_format\":\"webp\",\"quality\":\"high\",\"size\":\"1024x1024\"}]}}\n\n" + + "data: {\"type\":\"response.image_generation_call.partial_image\",\"partial_image_b64\":\"cGFydGlhbA==\",\"partial_image_index\":0,\"output_format\":\"webp\",\"background\":\"transparent\"}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000003,\"usage\":{\"input_tokens\":7,\"output_tokens\":10,\"output_tokens_details\":{\"image_tokens\":5}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"tools\":[{\"type\":\"image_generation\",\"model\":\"gpt-image-2\",\"background\":\"transparent\",\"output_format\":\"webp\",\"quality\":\"high\",\"size\":\"1024x1024\"}],\"output\":[{\"type\":\"image_generation_call\",\"result\":\"ZWRpdGVk\",\"revised_prompt\":\"replace background with aurora\",\"output_format\":\"webp\"}]}}\n\n" + + "data: [DONE]\n\n", + )), + }, + } + svc.httpUpstream = upstream + + account := &Account{ + ID: 4, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "token-123", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "") + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, 1, result.ImageCount) + require.Equal(t, "edit", gjson.GetBytes(upstream.lastBody, "tools.0.action").String()) + require.Equal(t, "https://example.com/source.png", gjson.GetBytes(upstream.lastBody, "input.0.content.1.image_url").String()) + require.Equal(t, "https://example.com/mask.png", gjson.GetBytes(upstream.lastBody, "tools.0.input_image_mask.image_url").String()) + events := parseOpenAIImageTestSSEEvents(rec.Body.String()) + partial, ok := findOpenAIImageTestSSEEvent(events, "image_edit.partial_image") + require.True(t, ok) + require.Equal(t, "image_edit.partial_image", gjson.Get(partial.Data, "type").String()) + require.Equal(t, int64(1710000003), gjson.Get(partial.Data, "created_at").Int()) + require.Equal(t, "cGFydGlhbA==", gjson.Get(partial.Data, "b64_json").String()) + require.Equal(t, "data:image/webp;base64,cGFydGlhbA==", gjson.Get(partial.Data, "url").String()) + require.Equal(t, "gpt-image-2", gjson.Get(partial.Data, "model").String()) + require.Equal(t, "webp", gjson.Get(partial.Data, "output_format").String()) + require.Equal(t, "high", gjson.Get(partial.Data, "quality").String()) + require.Equal(t, "1024x1024", gjson.Get(partial.Data, "size").String()) + require.Equal(t, "transparent", gjson.Get(partial.Data, "background").String()) + + completed, ok := findOpenAIImageTestSSEEvent(events, "image_edit.completed") + require.True(t, ok) + require.Equal(t, "image_edit.completed", gjson.Get(completed.Data, "type").String()) + require.Equal(t, int64(1710000003), gjson.Get(completed.Data, "created_at").Int()) + require.Equal(t, "ZWRpdGVk", gjson.Get(completed.Data, "b64_json").String()) + require.Equal(t, "data:image/webp;base64,ZWRpdGVk", gjson.Get(completed.Data, "url").String()) + require.Equal(t, "gpt-image-2", gjson.Get(completed.Data, "model").String()) + require.Equal(t, "webp", gjson.Get(completed.Data, "output_format").String()) + require.Equal(t, "high", gjson.Get(completed.Data, "quality").String()) + require.Equal(t, "1024x1024", gjson.Get(completed.Data, "size").String()) + require.Equal(t, "transparent", gjson.Get(completed.Data, "background").String()) + require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw) + require.False(t, gjson.Get(completed.Data, "revised_prompt").Exists()) +} + +func TestBuildOpenAIImagesResponsesRequest_DowngradesMultipleImagesToSingle(t *testing.T) { + parsed := &OpenAIImagesRequest{ + Endpoint: openAIImagesGenerationsEndpoint, + Model: "gpt-image-2", + Prompt: "draw a cat", + N: 2, + } + + body, err := buildOpenAIImagesResponsesRequest(parsed, "gpt-image-2") + require.NoError(t, err) + require.NotNil(t, body) + require.False(t, gjson.GetBytes(body, "tools.0.n").Exists()) + require.Equal(t, "gpt-image-2", gjson.GetBytes(body, "tools.0.model").String()) + require.Equal(t, "draw a cat", gjson.GetBytes(body, "input.0.content.0.text").String()) +} + +func TestBuildOpenAIImagesResponsesRequest_StripsInputFidelity(t *testing.T) { + parsed := &OpenAIImagesRequest{ + Endpoint: openAIImagesEditsEndpoint, + Model: "gpt-image-2", + Prompt: "replace background", + InputFidelity: "high", + InputImageURLs: []string{ + "https://example.com/source.png", + }, + } + + body, err := buildOpenAIImagesResponsesRequest(parsed, "gpt-image-2") + require.NoError(t, err) + require.NotNil(t, body) + require.False(t, gjson.GetBytes(body, "tools.0.input_fidelity").Exists()) + require.Equal(t, "edit", gjson.GetBytes(body, "tools.0.action").String()) +} + +func TestCollectOpenAIImagesFromResponsesBody_FallsBackToOutputItemDone(t *testing.T) { + body := []byte( + "data: {\"type\":\"response.created\",\"response\":{\"created_at\":1710000004}}\n\n" + + "data: {\"type\":\"response.output_item.done\",\"item\":{\"id\":\"ig_123\",\"type\":\"image_generation_call\",\"result\":\"aGVsbG8=\",\"revised_prompt\":\"draw a cat\",\"output_format\":\"png\",\"quality\":\"high\"}}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000004,\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[]}}\n\n" + + "data: [DONE]\n\n", + ) + + results, createdAt, usageRaw, firstMeta, foundFinal, err := collectOpenAIImagesFromResponsesBody(body) + require.NoError(t, err) + require.True(t, foundFinal) + require.Equal(t, int64(1710000004), createdAt) + require.Len(t, results, 1) + require.Equal(t, "aGVsbG8=", results[0].Result) + require.Equal(t, "draw a cat", results[0].RevisedPrompt) + require.Equal(t, "png", firstMeta.OutputFormat) + require.JSONEq(t, `{"images":1}`, string(usageRaw)) +} + +func TestOpenAIGatewayServiceForwardImages_OAuthStreamingHandlesOutputItemDoneFallback(t *testing.T) { + gin.SetMode(gin.TestMode) + body := []byte(`{"model":"gpt-image-2","prompt":"draw a cat","stream":true,"response_format":"url"}`) + + req := httptest.NewRequest(http.MethodPost, "/v1/images/generations", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = req + + svc := &OpenAIGatewayService{} + parsed, err := svc.ParseOpenAIImagesRequest(c, body) + require.NoError(t, err) + + upstream := &httpUpstreamRecorder{ + resp: &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{ + "Content-Type": []string{"text/event-stream"}, + "X-Request-Id": []string{"req_img_stream_output_item_done"}, + }, + Body: io.NopCloser(strings.NewReader( + "data: {\"type\":\"response.output_item.done\",\"item\":{\"id\":\"ig_123\",\"type\":\"image_generation_call\",\"result\":\"ZmluYWw=\",\"revised_prompt\":\"draw a cat\",\"output_format\":\"png\"}}\n\n" + + "data: {\"type\":\"response.completed\",\"response\":{\"created_at\":1710000005,\"usage\":{\"input_tokens\":5,\"output_tokens\":9,\"output_tokens_details\":{\"image_tokens\":4}},\"tool_usage\":{\"image_gen\":{\"images\":1}},\"output\":[]}}\n\n" + + "data: [DONE]\n\n", + )), + }, + } + svc.httpUpstream = upstream + + account := &Account{ + ID: 5, + Name: "openai-oauth", + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "token-123", + }, + } + + result, err := svc.ForwardImages(context.Background(), c, account, body, parsed, "") + require.NoError(t, err) + require.NotNil(t, result) + require.True(t, result.Stream) + require.Equal(t, 1, result.ImageCount) + events := parseOpenAIImageTestSSEEvents(rec.Body.String()) + completed, ok := findOpenAIImageTestSSEEvent(events, "image_generation.completed") + require.True(t, ok) + require.Equal(t, "image_generation.completed", gjson.Get(completed.Data, "type").String()) + require.Equal(t, int64(1710000005), gjson.Get(completed.Data, "created_at").Int()) + require.Equal(t, "ZmluYWw=", gjson.Get(completed.Data, "b64_json").String()) + require.Equal(t, "data:image/png;base64,ZmluYWw=", gjson.Get(completed.Data, "url").String()) + require.Equal(t, "gpt-image-2", gjson.Get(completed.Data, "model").String()) + require.JSONEq(t, `{"images":1}`, gjson.Get(completed.Data, "usage").Raw) + require.NotContains(t, rec.Body.String(), "event: error") +} diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index 106ec9f7..91a02901 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -794,6 +794,13 @@ func (s *PricingService) matchOpenAIModel(model string) *LiteLLMModelPricing { } } + // GPT-5.5 回退到 GPT-5.4 定价 + if strings.HasPrefix(model, "gpt-5.5") { + logger.With(zap.String("component", "service.pricing")). + Info(fmt.Sprintf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.4(static)")) + return openAIGPT54FallbackPricing + } + if strings.HasPrefix(model, "gpt-5.4-mini") { logger.With(zap.String("component", "service.pricing")). Info(fmt.Sprintf("[Pricing] OpenAI fallback matched %s -> %s", model, "gpt-5.4-mini(static)")) diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 53581574..4730303f 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -1,8 +1,10 @@ package service import ( + "bytes" "context" "encoding/json" + "fmt" "log/slog" "net/http" "strconv" @@ -23,6 +25,7 @@ type RateLimitService struct { geminiQuotaService *GeminiQuotaService tempUnschedCache TempUnschedCache timeoutCounterCache TimeoutCounterCache + openAI403CounterCache OpenAI403CounterCache settingService *SettingService tokenCacheInvalidator TokenCacheInvalidator usageCacheMu sync.RWMutex @@ -52,6 +55,12 @@ type geminiUsageTotalsBatchProvider interface { const geminiPrecheckCacheTTL = time.Minute +const ( + openAI403CooldownMinutesDefault = 10 + openAI403DisableThreshold = 3 + openAI403CounterWindowMinutes = 180 +) + // NewRateLimitService 创建RateLimitService实例 func NewRateLimitService(accountRepo AccountRepository, usageRepo UsageLogRepository, cfg *config.Config, geminiQuotaService *GeminiQuotaService, tempUnschedCache TempUnschedCache) *RateLimitService { return &RateLimitService{ @@ -69,6 +78,11 @@ func (s *RateLimitService) SetTimeoutCounterCache(cache TimeoutCounterCache) { s.timeoutCounterCache = cache } +// SetOpenAI403CounterCache 设置 OpenAI 403 连续失败计数器(可选依赖) +func (s *RateLimitService) SetOpenAI403CounterCache(cache OpenAI403CounterCache) { + s.openAI403CounterCache = cache +} + // SetSettingService 设置系统设置服务(可选依赖) func (s *RateLimitService) SetSettingService(settingService *SettingService) { s.settingService = settingService @@ -655,6 +669,30 @@ func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account slog.Warn("account_disabled_auth_error", "account_id", account.ID, "error", errorMsg) } +func buildForbiddenErrorMessage(prefix string, upstreamMsg string, responseBody []byte, fallback string) string { + prefix = strings.TrimSpace(prefix) + if prefix != "" && !strings.HasSuffix(prefix, " ") { + prefix += " " + } + + if msg := strings.TrimSpace(upstreamMsg); msg != "" { + return prefix + msg + } + + rawBody := bytes.TrimSpace(responseBody) + if len(rawBody) > 0 { + if json.Valid(rawBody) { + var compact bytes.Buffer + if err := json.Compact(&compact, rawBody); err == nil { + return prefix + truncateForLog(compact.Bytes(), 512) + } + } + return prefix + truncateForLog(rawBody, 512) + } + + return prefix + fallback +} + // handle403 处理 403 Forbidden 错误 // Antigravity 平台区分 validation/violation/generic 三种类型,均 SetError 永久禁用; // 其他平台保持原有 SetError 行为。 @@ -662,15 +700,64 @@ func (s *RateLimitService) handle403(ctx context.Context, account *Account, upst if account.Platform == PlatformAntigravity { return s.handleAntigravity403(ctx, account, upstreamMsg, responseBody) } - // 非 Antigravity 平台:保持原有行为 - msg := "Access forbidden (403): account may be suspended or lack permissions" - if upstreamMsg != "" { - msg = "Access forbidden (403): " + upstreamMsg + if account.Platform == PlatformOpenAI { + return s.handleOpenAI403(ctx, account, upstreamMsg, responseBody) } + // 非 Antigravity 平台:保持原有行为 + msg := buildForbiddenErrorMessage( + "Access forbidden (403):", + upstreamMsg, + responseBody, + "account may be suspended or lack permissions", + ) s.handleAuthError(ctx, account, msg) return true } +func (s *RateLimitService) handleOpenAI403(ctx context.Context, account *Account, upstreamMsg string, responseBody []byte) (shouldDisable bool) { + msg := buildForbiddenErrorMessage( + "Access forbidden (403):", + upstreamMsg, + responseBody, + "account may be suspended or lack permissions", + ) + + if s.openAI403CounterCache == nil { + s.handleAuthError(ctx, account, msg) + return true + } + + count, err := s.openAI403CounterCache.IncrementOpenAI403Count(ctx, account.ID, openAI403CounterWindowMinutes) + if err != nil { + slog.Warn("openai_403_increment_failed", "account_id", account.ID, "error", err) + s.handleAuthError(ctx, account, msg) + return true + } + + if count >= openAI403DisableThreshold { + msg = fmt.Sprintf("%s | consecutive_403=%d/%d", msg, count, openAI403DisableThreshold) + s.handleAuthError(ctx, account, msg) + return true + } + + until := time.Now().Add(time.Duration(openAI403CooldownMinutesDefault) * time.Minute) + reason := fmt.Sprintf("OpenAI 403 temporary cooldown (%d/%d): %s", count, openAI403DisableThreshold, msg) + if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil { + slog.Warn("openai_403_set_temp_unschedulable_failed", "account_id", account.ID, "error", err) + s.handleAuthError(ctx, account, msg) + return true + } + + slog.Warn( + "openai_403_temp_unschedulable", + "account_id", account.ID, + "until", until, + "count", count, + "threshold", openAI403DisableThreshold, + ) + return true +} + // handleAntigravity403 处理 Antigravity 平台的 403 错误 // validation(需要验证)→ 永久 SetError(需人工去 Google 验证后恢复) // violation(违规封号)→ 永久 SetError(需人工处理) @@ -681,10 +768,12 @@ func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Ac switch fbType { case forbiddenTypeValidation: // VALIDATION_REQUIRED: 永久禁用,需人工去 Google 验证后手动恢复 - msg := "Validation required (403): account needs Google verification" - if upstreamMsg != "" { - msg = "Validation required (403): " + upstreamMsg - } + msg := buildForbiddenErrorMessage( + "Validation required (403):", + upstreamMsg, + responseBody, + "account needs Google verification", + ) if validationURL := extractValidationURL(string(responseBody)); validationURL != "" { msg += " | validation_url: " + validationURL } @@ -693,19 +782,23 @@ func (s *RateLimitService) handleAntigravity403(ctx context.Context, account *Ac case forbiddenTypeViolation: // 违规封号: 永久禁用,需人工处理 - msg := "Account violation (403): terms of service violation" - if upstreamMsg != "" { - msg = "Account violation (403): " + upstreamMsg - } + msg := buildForbiddenErrorMessage( + "Account violation (403):", + upstreamMsg, + responseBody, + "terms of service violation", + ) s.handleAuthError(ctx, account, msg) return true default: // 通用 403: 保持原有行为 - msg := "Access forbidden (403): account may be suspended or lack permissions" - if upstreamMsg != "" { - msg = "Access forbidden (403): " + upstreamMsg - } + msg := buildForbiddenErrorMessage( + "Access forbidden (403):", + upstreamMsg, + responseBody, + "account may be suspended or lack permissions", + ) s.handleAuthError(ctx, account, msg) return true } @@ -1221,9 +1314,19 @@ func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) slog.Warn("temp_unsched_cache_delete_failed", "account_id", accountID, "error", err) } } + s.ResetOpenAI403Counter(ctx, accountID) return nil } +func (s *RateLimitService) ResetOpenAI403Counter(ctx context.Context, accountID int64) { + if s == nil || s.openAI403CounterCache == nil || accountID <= 0 { + return + } + if err := s.openAI403CounterCache.ResetOpenAI403Count(ctx, accountID); err != nil { + slog.Warn("openai_403_reset_failed", "account_id", accountID, "error", err) + } +} + // RecoverAccountState 按需恢复账号的可恢复运行时状态。 func (s *RateLimitService) RecoverAccountState(ctx context.Context, accountID int64, options AccountRecoveryOptions) (*SuccessfulTestRecoveryResult, error) { account, err := s.accountRepo.GetByID(ctx, accountID) @@ -1250,6 +1353,9 @@ func (s *RateLimitService) RecoverAccountState(ctx context.Context, accountID in } result.ClearedRateLimit = true } + if result.ClearedError || result.ClearedRateLimit { + s.ResetOpenAI403Counter(ctx, accountID) + } return result, nil } diff --git a/backend/internal/service/ratelimit_service_401_test.go b/backend/internal/service/ratelimit_service_401_test.go index 9e5e2b0e..73b7849f 100644 --- a/backend/internal/service/ratelimit_service_401_test.go +++ b/backend/internal/service/ratelimit_service_401_test.go @@ -20,6 +20,7 @@ type rateLimitAccountRepoStub struct { updateCredentialsCalls int lastCredentials map[string]any lastErrorMsg string + lastTempReason string } func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error { @@ -30,6 +31,7 @@ func (r *rateLimitAccountRepoStub) SetError(ctx context.Context, id int64, error func (r *rateLimitAccountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error { r.tempCalls++ + r.lastTempReason = reason return nil } @@ -44,6 +46,29 @@ type tokenCacheInvalidatorRecorder struct { err error } +type openAI403CounterCacheStub struct { + counts []int64 + resetCalls []int64 + err error +} + +func (s *openAI403CounterCacheStub) IncrementOpenAI403Count(_ context.Context, _ int64, _ int) (int64, error) { + if s.err != nil { + return 0, s.err + } + if len(s.counts) == 0 { + return 1, nil + } + count := s.counts[0] + s.counts = s.counts[1:] + return count, nil +} + +func (s *openAI403CounterCacheStub) ResetOpenAI403Count(_ context.Context, accountID int64) error { + s.resetCalls = append(s.resetCalls, accountID) + return nil +} + func (r *tokenCacheInvalidatorRecorder) InvalidateToken(ctx context.Context, account *Account) error { r.accounts = append(r.accounts, account) return r.err diff --git a/backend/internal/service/ratelimit_service_403_test.go b/backend/internal/service/ratelimit_service_403_test.go new file mode 100644 index 00000000..2fd11b71 --- /dev/null +++ b/backend/internal/service/ratelimit_service_403_test.go @@ -0,0 +1,64 @@ +//go:build unit + +package service + +import ( + "context" + "net/http" + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" +) + +func TestRateLimitService_HandleUpstreamError_OpenAI403FirstHitTempUnschedulable(t *testing.T) { + repo := &rateLimitAccountRepoStub{} + counter := &openAI403CounterCacheStub{counts: []int64{1}} + service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + service.SetOpenAI403CounterCache(counter) + account := &Account{ + ID: 301, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + } + + shouldDisable := service.HandleUpstreamError( + context.Background(), + account, + http.StatusForbidden, + http.Header{}, + []byte(`{"error":{"message":"temporary edge rejection"}}`), + ) + + require.True(t, shouldDisable) + require.Equal(t, 0, repo.setErrorCalls) + require.Equal(t, 1, repo.tempCalls) + require.Contains(t, repo.lastTempReason, "temporary edge rejection") + require.Contains(t, repo.lastTempReason, "(1/3)") +} + +func TestRateLimitService_HandleUpstreamError_OpenAI403ThresholdDisables(t *testing.T) { + repo := &rateLimitAccountRepoStub{} + counter := &openAI403CounterCacheStub{counts: []int64{3}} + service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + service.SetOpenAI403CounterCache(counter) + account := &Account{ + ID: 302, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + } + + shouldDisable := service.HandleUpstreamError( + context.Background(), + account, + http.StatusForbidden, + http.Header{}, + []byte(`{"error":{"message":"workspace forbidden by policy"}}`), + ) + + require.True(t, shouldDisable) + require.Equal(t, 1, repo.setErrorCalls) + require.Equal(t, 0, repo.tempCalls) + require.Contains(t, repo.lastErrorMsg, "workspace forbidden by policy") + require.Contains(t, repo.lastErrorMsg, "consecutive_403=3/3") +} diff --git a/backend/internal/service/ratelimit_service_openai_test.go b/backend/internal/service/ratelimit_service_openai_test.go index 89c754c8..619bb773 100644 --- a/backend/internal/service/ratelimit_service_openai_test.go +++ b/backend/internal/service/ratelimit_service_openai_test.go @@ -7,6 +7,9 @@ import ( "net/http" "testing" "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/stretchr/testify/require" ) func TestCalculateOpenAI429ResetTime_7dExhausted(t *testing.T) { @@ -259,6 +262,53 @@ func TestNormalizedCodexLimits_OnlyPrimaryData(t *testing.T) { } } +func TestRateLimitService_HandleUpstreamError_403PreservesOriginalUpstreamMessage(t *testing.T) { + repo := &rateLimitAccountRepoStub{} + service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + account := &Account{ + ID: 201, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + } + + shouldDisable := service.HandleUpstreamError( + context.Background(), + account, + 403, + http.Header{}, + []byte(`{"error":{"message":"workspace forbidden by policy","type":"invalid_request_error"}}`), + ) + + require.True(t, shouldDisable) + require.Equal(t, 1, repo.setErrorCalls) + require.Contains(t, repo.lastErrorMsg, "workspace forbidden by policy") + require.NotContains(t, repo.lastErrorMsg, "account may be suspended or lack permissions") +} + +func TestRateLimitService_HandleUpstreamError_403FallsBackToRawBody(t *testing.T) { + repo := &rateLimitAccountRepoStub{} + service := NewRateLimitService(repo, nil, &config.Config{}, nil, nil) + account := &Account{ + ID: 202, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + } + + shouldDisable := service.HandleUpstreamError( + context.Background(), + account, + 403, + http.Header{}, + []byte(`{"error":{"type":"access_denied","details":{"reason":"ip_blocked"}}}`), + ) + + require.True(t, shouldDisable) + require.Equal(t, 1, repo.setErrorCalls) + require.Contains(t, repo.lastErrorMsg, `"access_denied"`) + require.Contains(t, repo.lastErrorMsg, `"ip_blocked"`) + require.NotContains(t, repo.lastErrorMsg, "account may be suspended or lack permissions") +} + func TestNormalizedCodexLimits_OnlySecondaryData(t *testing.T) { // Test when only secondary has data, no window_minutes sUsed := 60.0 diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 757c4025..c79d8949 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -1167,6 +1167,7 @@ func (s *SettingService) buildSystemSettingsUpdates(ctx context.Context, setting // 默认配置 updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency) updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64) + updates[SettingKeyDefaultUserRPMLimit] = strconv.Itoa(settings.DefaultUserRPMLimit) defaultSubsJSON, err := json.Marshal(settings.DefaultSubscriptions) if err != nil { return nil, fmt.Errorf("marshal default subscriptions: %w", err) @@ -1538,6 +1539,18 @@ func (s *SettingService) GetDefaultBalance(ctx context.Context) float64 { return s.cfg.Default.UserBalance } +// GetDefaultUserRPMLimit 获取新用户默认 RPM 限制(0 = 不限制)。未配置则返回 0。 +func (s *SettingService) GetDefaultUserRPMLimit(ctx context.Context) int { + value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultUserRPMLimit) + if err != nil || value == "" { + return 0 + } + if v, err := strconv.Atoi(value); err == nil && v >= 0 { + return v + } + return 0 +} + // GetDefaultSubscriptions 获取新用户默认订阅配置列表。 func (s *SettingService) GetDefaultSubscriptions(ctx context.Context) []DefaultSubscriptionSetting { value, err := s.settingRepo.GetValue(ctx, SettingKeyDefaultSubscriptions) @@ -1706,6 +1719,7 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error { SettingKeyOIDCConnectUserInfoUsernamePath: "", SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency), SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64), + SettingKeyDefaultUserRPMLimit: "0", SettingKeyDefaultSubscriptions: "[]", SettingKeyAuthSourceDefaultEmailBalance: "0", SettingKeyAuthSourceDefaultEmailConcurrency: "5", @@ -1822,6 +1836,10 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin result.DefaultConcurrency = s.cfg.Default.UserConcurrency } + if rpm, err := strconv.Atoi(settings[SettingKeyDefaultUserRPMLimit]); err == nil && rpm >= 0 { + result.DefaultUserRPMLimit = rpm + } + // 解析浮点数类型 if balance, err := strconv.ParseFloat(settings[SettingKeyDefaultBalance], 64); err == nil { result.DefaultBalance = balance diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index 986579d1..ddd4fff6 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -106,6 +106,7 @@ type SystemSettings struct { DefaultConcurrency int DefaultBalance float64 + DefaultUserRPMLimit int DefaultSubscriptions []DefaultSubscriptionSetting // Model fallback configuration diff --git a/backend/internal/service/user.go b/backend/internal/service/user.go index 9dc13381..f9833611 100644 --- a/backend/internal/service/user.go +++ b/backend/internal/service/user.go @@ -49,6 +49,15 @@ type User struct { BalanceNotifyExtraEmails []NotifyEmailEntry TotalRecharged float64 + // RPMLimit 用户级每分钟请求数上限(0 = 不限制)。仅在所用分组未设置 rpm_limit + // 且该 (用户, 分组) 无 rpm_override 时作为全局兜底生效,计数键 rpm:u:{userID}:{min}。 + RPMLimit int + + // UserGroupRPMOverride 来自 auth cache snapshot 的 (user, group) RPM 覆盖值。 + // nil = 该 API Key 对应的 (user, group) 无 override;非 nil 时 checkRPM 直接使用, + // 避免每请求查 DB。字段不持久化到数据库。 + UserGroupRPMOverride *int + APIKeys []APIKey Subscriptions []UserSubscription } diff --git a/backend/internal/service/user_group_rate.go b/backend/internal/service/user_group_rate.go index 3d221a25..f069eb7e 100644 --- a/backend/internal/service/user_group_rate.go +++ b/backend/internal/service/user_group_rate.go @@ -2,14 +2,16 @@ package service import "context" -// UserGroupRateEntry 分组下用户专属倍率条目 +// UserGroupRateEntry 分组下用户专属倍率/RPM 条目。 +// RateMultiplier 与 RPMOverride 均为指针以支持"未设置"语义(NULL)。 type UserGroupRateEntry struct { - UserID int64 `json:"user_id"` - UserName string `json:"user_name"` - UserEmail string `json:"user_email"` - UserNotes string `json:"user_notes"` - UserStatus string `json:"user_status"` - RateMultiplier float64 `json:"rate_multiplier"` + UserID int64 `json:"user_id"` + UserName string `json:"user_name"` + UserEmail string `json:"user_email"` + UserNotes string `json:"user_notes"` + UserStatus string `json:"user_status"` + RateMultiplier *float64 `json:"rate_multiplier,omitempty"` + RPMOverride *int `json:"rpm_override,omitempty"` } // GroupRateMultiplierInput 批量设置分组倍率的输入条目 @@ -18,30 +20,44 @@ type GroupRateMultiplierInput struct { RateMultiplier float64 `json:"rate_multiplier"` } -// UserGroupRateRepository 用户专属分组倍率仓储接口 -// 允许管理员为特定用户设置分组的专属计费倍率,覆盖分组默认倍率 +// GroupRPMOverrideInput 批量设置分组 RPM override 的输入条目。 +// RPMOverride 为 *int 以支持清除(nil)语义。 +type GroupRPMOverrideInput struct { + UserID int64 `json:"user_id"` + RPMOverride *int `json:"rpm_override"` +} + +// UserGroupRateRepository 用户专属分组倍率/RPM 仓储接口。 +// 允许管理员为特定用户设置分组的专属计费倍率与 RPM 上限,覆盖分组默认值。 type UserGroupRateRepository interface { - // GetByUserID 获取用户的所有专属分组倍率 - // 返回 map[groupID]rateMultiplier + // GetByUserID 获取用户所有专属分组 rate_multiplier(仅返回非 NULL 的条目) GetByUserID(ctx context.Context, userID int64) (map[int64]float64, error) - // GetByUserAndGroup 获取用户在特定分组的专属倍率 - // 如果未设置专属倍率,返回 nil + // GetByUserAndGroup 获取用户在特定分组的专属 rate_multiplier(NULL 返回 nil) GetByUserAndGroup(ctx context.Context, userID, groupID int64) (*float64, error) - // GetByGroupID 获取指定分组下所有用户的专属倍率 + // GetRPMOverrideByUserAndGroup 获取用户在特定分组的 rpm_override(NULL 返回 nil) + GetRPMOverrideByUserAndGroup(ctx context.Context, userID, groupID int64) (*int, error) + + // GetByGroupID 获取指定分组下所有用户的专属配置(rate 与 rpm_override 任一非 NULL 即返回) GetByGroupID(ctx context.Context, groupID int64) ([]UserGroupRateEntry, error) - // SyncUserGroupRates 同步用户的分组专属倍率 - // rates: map[groupID]*rateMultiplier,nil 表示删除该分组的专属倍率 + // SyncUserGroupRates 同步用户的分组专属倍率;nil 表示清空该分组的 rate_multiplier SyncUserGroupRates(ctx context.Context, userID int64, rates map[int64]*float64) error - // SyncGroupRateMultipliers 批量同步分组的用户专属倍率(替换整组数据) + // SyncGroupRateMultipliers 批量同步分组的用户专属倍率(替换整组 rate 部分) SyncGroupRateMultipliers(ctx context.Context, groupID int64, entries []GroupRateMultiplierInput) error - // DeleteByGroupID 删除指定分组的所有用户专属倍率(分组删除时调用) + // SyncGroupRPMOverrides 批量同步分组的用户专属 RPM(替换整组 rpm_override 部分)。 + // 条目中 RPMOverride 为 nil 时清空对应行的 rpm_override;非 nil 时 upsert。 + SyncGroupRPMOverrides(ctx context.Context, groupID int64, entries []GroupRPMOverrideInput) error + + // ClearGroupRPMOverrides 清空指定分组的所有 rpm_override(整组 rpm 部分归 NULL) + ClearGroupRPMOverrides(ctx context.Context, groupID int64) error + + // DeleteByGroupID 删除指定分组的所有用户专属条目(分组删除时调用) DeleteByGroupID(ctx context.Context, groupID int64) error - // DeleteByUserID 删除指定用户的所有专属倍率(用户删除时调用) + // DeleteByUserID 删除指定用户的所有专属条目(用户删除时调用) DeleteByUserID(ctx context.Context, userID int64) error } diff --git a/backend/internal/service/user_rpm_cache.go b/backend/internal/service/user_rpm_cache.go new file mode 100644 index 00000000..b8857311 --- /dev/null +++ b/backend/internal/service/user_rpm_cache.go @@ -0,0 +1,25 @@ +package service + +import "context" + +// UserRPMCache 用户/分组级 RPM 计数器接口。 +// +// 与账号级 RPMCache 的区别: +// - RPMCache —— 按外部 AI provider 账号聚合(key: rpm:{accountID}:{min})。 +// - UserRPMCache —— 按用户或 (用户, 分组) 聚合,杜绝"同一用户创建多个 API Key 绕过 RPM"的路径。 +// key 形如 rpm:ug:{userID}:{groupID}:{min} 或 rpm:u:{userID}:{min}。 +type UserRPMCache interface { + // IncrementUserGroupRPM 原子递增 (user, group) 级分钟计数并返回最新值。 + // 用于分组 rpm_limit 与 user-group rpm_override 两种命中分支。 + IncrementUserGroupRPM(ctx context.Context, userID, groupID int64) (count int, err error) + + // IncrementUserRPM 原子递增用户级分钟计数并返回最新值。 + // 用于用户全局 rpm_limit 兜底分支(分组未设且无 override 时)。 + IncrementUserRPM(ctx context.Context, userID int64) (count int, err error) + + // GetUserGroupRPM 获取 (user, group) 当前分钟已用 RPM(只读,不递增)。 + GetUserGroupRPM(ctx context.Context, userID, groupID int64) (count int, err error) + + // GetUserRPM 获取用户当前分钟已用 RPM(只读,不递增)。 + GetUserRPM(ctx context.Context, userID int64) (count int, err error) +} diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index ab2802fd..86bfc327 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -39,6 +39,11 @@ func ProvideEmailQueueService(emailService *EmailService) *EmailQueueService { return NewEmailQueueService(emailService, 3) } +// ProvideOAuthRefreshAPI creates OAuthRefreshAPI with the default lock TTL. +func ProvideOAuthRefreshAPI(accountRepo AccountRepository, tokenCache GeminiTokenCache) *OAuthRefreshAPI { + return NewOAuthRefreshAPI(accountRepo, tokenCache) +} + // ProvideTokenRefreshService creates and starts TokenRefreshService func ProvideTokenRefreshService( accountRepo AccountRepository, @@ -210,11 +215,13 @@ func ProvideRateLimitService( geminiQuotaService *GeminiQuotaService, tempUnschedCache TempUnschedCache, timeoutCounterCache TimeoutCounterCache, + openAI403CounterCache OpenAI403CounterCache, settingService *SettingService, tokenCacheInvalidator TokenCacheInvalidator, ) *RateLimitService { svc := NewRateLimitService(accountRepo, usageRepo, cfg, geminiQuotaService, tempUnschedCache) svc.SetTimeoutCounterCache(timeoutCounterCache) + svc.SetOpenAI403CounterCache(openAI403CounterCache) svc.SetSettingService(settingService) svc.SetTokenCacheInvalidator(tokenCacheInvalidator) return svc @@ -384,6 +391,19 @@ func ProvideSettingService(settingRepo SettingRepository, groupRepo GroupReposit return svc } +// ProvideBillingCacheService wires BillingCacheService with its RPM dependencies. +func ProvideBillingCacheService( + cache BillingCache, + userRepo UserRepository, + subRepo UserSubscriptionRepository, + apiKeyRepo APIKeyRepository, + rpmCache UserRPMCache, + rateRepo UserGroupRateRepository, + cfg *config.Config, +) *BillingCacheService { + return NewBillingCacheService(cache, userRepo, subRepo, apiKeyRepo, rpmCache, rateRepo, cfg) +} + // ProviderSet is the Wire provider set for all services var ProviderSet = wire.NewSet( // Core services @@ -400,7 +420,7 @@ var ProviderSet = wire.NewSet( NewDashboardService, ProvidePricingService, NewBillingService, - NewBillingCacheService, + ProvideBillingCacheService, NewAnnouncementService, NewAdminService, NewGatewayService, @@ -412,7 +432,7 @@ var ProviderSet = wire.NewSet( NewCompositeTokenCacheInvalidator, wire.Bind(new(TokenCacheInvalidator), new(*CompositeTokenCacheInvalidator)), NewAntigravityOAuthService, - NewOAuthRefreshAPI, + ProvideOAuthRefreshAPI, ProvideGeminiTokenProvider, NewGeminiMessagesCompatService, ProvideAntigravityTokenProvider, diff --git a/backend/migrations/125_add_group_rpm_limit.sql b/backend/migrations/125_add_group_rpm_limit.sql new file mode 100644 index 00000000..fbde1b20 --- /dev/null +++ b/backend/migrations/125_add_group_rpm_limit.sql @@ -0,0 +1,7 @@ +-- Add per-group Requests-Per-Minute limit. +-- rpm_limit: 分组统一 RPM 上限(0 = 不限制)。 +-- 一旦配置即接管该用户在该分组的限流,覆盖用户级 users.rpm_limit。 +-- 计数键:rpm:ug:{user_id}:{group_id}:{minute}。 +ALTER TABLE groups ADD COLUMN IF NOT EXISTS rpm_limit integer NOT NULL DEFAULT 0; + +COMMENT ON COLUMN groups.rpm_limit IS '分组 RPM 上限;0 表示不限制;设置后接管该分组用户的限流(覆盖用户级 rpm_limit)。'; diff --git a/backend/migrations/126_add_user_rpm_limit.sql b/backend/migrations/126_add_user_rpm_limit.sql new file mode 100644 index 00000000..64a8b977 --- /dev/null +++ b/backend/migrations/126_add_user_rpm_limit.sql @@ -0,0 +1,7 @@ +-- Add per-user Requests-Per-Minute cap. +-- rpm_limit: 用户全局 RPM 兜底(0 = 不限制)。 +-- 仅当所访问分组未设置 rpm_limit 且无 user-group rpm_override 时作为兜底生效。 +-- 计数键:rpm:u:{user_id}:{minute}。 +ALTER TABLE users ADD COLUMN IF NOT EXISTS rpm_limit integer NOT NULL DEFAULT 0; + +COMMENT ON COLUMN users.rpm_limit IS '用户级 RPM 兜底上限;0 表示不限制;仅当分组未设置 rpm_limit 时生效。'; diff --git a/backend/migrations/127_add_user_group_rpm_override.sql b/backend/migrations/127_add_user_group_rpm_override.sql new file mode 100644 index 00000000..1d674258 --- /dev/null +++ b/backend/migrations/127_add_user_group_rpm_override.sql @@ -0,0 +1,16 @@ +-- 在已有的"用户专属分组倍率表"上扩展 rpm_override 列;同时放宽 rate_multiplier 为可空, +-- 使一行记录可以只覆盖 rate、只覆盖 rpm,或同时覆盖两者。 +-- 语义: +-- - rate_multiplier NULL → 该用户在此分组使用 groups.rate_multiplier 默认值 +-- - rate_multiplier 非 NULL → 覆盖分组默认计费倍率 +-- - rpm_override NULL → 该用户在此分组使用 groups.rpm_limit 默认值 +-- - rpm_override 非 NULL → 覆盖分组默认 RPM(0 = 不限制) +-- 用户级 users.rpm_limit 仍独立生效(跨分组总配额)。 +ALTER TABLE user_group_rate_multipliers + ADD COLUMN IF NOT EXISTS rpm_override integer NULL; + +ALTER TABLE user_group_rate_multipliers + ALTER COLUMN rate_multiplier DROP NOT NULL; + +COMMENT ON COLUMN user_group_rate_multipliers.rate_multiplier IS '专属计费倍率;NULL 表示沿用分组默认倍率。'; +COMMENT ON COLUMN user_group_rate_multipliers.rpm_override IS '专属 RPM 上限;NULL 表示沿用分组默认;0 表示该用户在此分组不受 RPM 限制。'; diff --git a/frontend/src/api/admin/groups.ts b/frontend/src/api/admin/groups.ts index 8739d5cb..6b94b799 100644 --- a/frontend/src/api/admin/groups.ts +++ b/frontend/src/api/admin/groups.ts @@ -164,7 +164,8 @@ export interface GroupRateMultiplierEntry { user_email: string user_notes: string user_status: string - rate_multiplier: number + rate_multiplier?: number | null + rpm_override?: number | null } /** @@ -205,9 +206,7 @@ export async function clearGroupRateMultipliers(id: number): Promise<{ message: /** * Batch set rate multipliers for users in a group - * @param id - Group ID - * @param entries - Array of { user_id, rate_multiplier } - * @returns Success confirmation + * Only touches rate_multiplier column; preserves rpm_override on existing rows. */ export async function batchSetGroupRateMultipliers( id: number, @@ -220,6 +219,60 @@ export async function batchSetGroupRateMultipliers( return data } +/** + * RPM override entry for a user in a group + */ +export interface GroupRPMOverrideEntry { + user_id: number + user_name: string + user_email: string + user_notes: string + user_status: string + rpm_override: number +} + +/** + * Get RPM overrides for users in a group (subset of rate-multipliers endpoint). + */ +export async function getGroupRPMOverrides(id: number): Promise { + const { data } = await apiClient.get( + `/admin/groups/${id}/rate-multipliers` + ) + return data + .filter(e => e.rpm_override != null) + .map(e => ({ + user_id: e.user_id, + user_name: e.user_name, + user_email: e.user_email, + user_notes: e.user_notes, + user_status: e.user_status, + rpm_override: e.rpm_override as number + })) +} + +/** + * Batch set RPM overrides for users in a group. + * Only touches rpm_override column; preserves rate_multiplier on existing rows. + */ +export async function batchSetGroupRPMOverrides( + id: number, + entries: Array<{ user_id: number; rpm_override: number }> +): Promise<{ message: string }> { + const { data } = await apiClient.put<{ message: string }>( + `/admin/groups/${id}/rpm-overrides`, + { entries } + ) + return data +} + +/** + * Clear all RPM overrides for a group (preserves rate_multiplier). + */ +export async function clearGroupRPMOverrides(id: number): Promise<{ message: string }> { + const { data } = await apiClient.delete<{ message: string }>(`/admin/groups/${id}/rpm-overrides`) + return data +} + /** * Get usage summary (today + cumulative cost) for all groups * @param timezone - IANA timezone string (e.g. "Asia/Shanghai") @@ -262,6 +315,9 @@ export const groupsAPI = { getGroupRateMultipliers, clearGroupRateMultipliers, batchSetGroupRateMultipliers, + getGroupRPMOverrides, + clearGroupRPMOverrides, + batchSetGroupRPMOverrides, updateSortOrder, getUsageSummary, getCapacitySummary diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index a9208950..b9f24663 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -309,6 +309,7 @@ export interface SystemSettings { // Default settings default_balance: number; default_concurrency: number; + default_user_rpm_limit: number; default_subscriptions: DefaultSubscriptionSetting[]; auth_source_default_email_balance?: number; auth_source_default_email_concurrency?: number; @@ -489,6 +490,7 @@ export interface UpdateSettingsRequest { totp_enabled?: boolean; // TOTP 双因素认证 default_balance?: number; default_concurrency?: number; + default_user_rpm_limit?: number; default_subscriptions?: DefaultSubscriptionSetting[]; auth_source_default_email_balance?: number; auth_source_default_email_concurrency?: number; diff --git a/frontend/src/components/admin/group/GroupRPMOverridesModal.vue b/frontend/src/components/admin/group/GroupRPMOverridesModal.vue new file mode 100644 index 00000000..a4b4e536 --- /dev/null +++ b/frontend/src/components/admin/group/GroupRPMOverridesModal.vue @@ -0,0 +1,434 @@ + + + + + diff --git a/frontend/src/components/admin/group/GroupRateMultipliersModal.vue b/frontend/src/components/admin/group/GroupRateMultipliersModal.vue index 41b2e63c..d68f3aa5 100644 --- a/frontend/src/components/admin/group/GroupRateMultipliersModal.vue +++ b/frontend/src/components/admin/group/GroupRateMultipliersModal.vue @@ -168,7 +168,8 @@ step="0.001" min="0.001" autocomplete="off" - :value="entry.rate_multiplier" + :value="entry.rate_multiplier ?? ''" + :placeholder="String(props.group?.rate_multiplier ?? 1)" class="hide-spinner w-20 rounded border border-gray-200 bg-white px-2 py-1 text-center text-sm font-medium transition-colors focus:border-primary-500 focus:outline-none focus:ring-1 focus:ring-primary-500/20 dark:border-dark-500 dark:bg-dark-700 dark:focus:border-primary-500" @change="updateLocalRate(entry.user_id, ($event.target as HTMLInputElement).value)" /> @@ -294,19 +295,17 @@ const showFinalRate = computed(() => { }) // 计算最终倍率预览 -const computeFinalRate = (rate: number) => { - if (!batchFactor.value) return rate - return parseFloat((rate * batchFactor.value).toFixed(6)) +const computeFinalRate = (rate: number | null | undefined) => { + const base = rate ?? props.group?.rate_multiplier ?? 1 + if (!batchFactor.value) return base + return parseFloat((base * batchFactor.value).toFixed(6)) } // 检测是否有未保存的修改 const isDirty = computed(() => { if (localEntries.value.length !== serverEntries.value.length) return true - const serverMap = new Map(serverEntries.value.map(e => [e.user_id, e.rate_multiplier])) - return localEntries.value.some(e => { - const serverRate = serverMap.get(e.user_id) - return serverRate === undefined || serverRate !== e.rate_multiplier - }) + const serverMap = new Map(serverEntries.value.map(e => [e.user_id, e.rate_multiplier ?? null])) + return localEntries.value.some(e => serverMap.get(e.user_id) !== (e.rate_multiplier ?? null)) }) const paginatedLocalEntries = computed(() => { @@ -322,7 +321,9 @@ const loadEntries = async () => { if (!props.group) return loading.value = true try { - serverEntries.value = await adminAPI.groups.getGroupRateMultipliers(props.group.id) + const raw = await adminAPI.groups.getGroupRateMultipliers(props.group.id) + // 仅显示已设置 rate_multiplier 的条目;rpm_override 在另一个弹窗管理,保留不动 + serverEntries.value = raw.filter(e => e.rate_multiplier != null) localEntries.value = cloneEntries(serverEntries.value) adjustPage() } catch (error) { @@ -394,7 +395,8 @@ const handleAddLocal = () => { user_email: user.email, user_notes: user.notes || '', user_status: user.status || 'active', - rate_multiplier: newRate.value + rate_multiplier: newRate.value, + rpm_override: null } if (idx >= 0) { localEntries.value[idx] = entry @@ -409,12 +411,15 @@ const handleAddLocal = () => { // 本地修改倍率 const updateLocalRate = (userId: number, value: string) => { + const entry = localEntries.value.find(e => e.user_id === userId) + if (!entry) return + if (value.trim() === '') { + entry.rate_multiplier = null + return + } const num = parseFloat(value) if (isNaN(num)) return - const entry = localEntries.value.find(e => e.user_id === userId) - if (entry) { - entry.rate_multiplier = num - } + entry.rate_multiplier = num } // 本地删除 @@ -427,7 +432,9 @@ const removeLocal = (userId: number) => { const applyBatchFactor = () => { if (!batchFactor.value || batchFactor.value <= 0) return for (const entry of localEntries.value) { - entry.rate_multiplier = parseFloat((entry.rate_multiplier * batchFactor.value).toFixed(6)) + if (entry.rate_multiplier != null) { + entry.rate_multiplier = parseFloat((entry.rate_multiplier * batchFactor.value).toFixed(6)) + } } batchFactor.value = null } @@ -444,15 +451,17 @@ const handleCancel = () => { adjustPage() } -// 保存:一次性提交所有数据 +// 保存:一次性提交所有数据(只提交 rate_multiplier;rpm_override 由独立弹窗管理) const handleSave = async () => { if (!props.group) return saving.value = true try { - const entries = localEntries.value.map(e => ({ - user_id: e.user_id, - rate_multiplier: e.rate_multiplier - })) + const entries = localEntries.value + .filter(e => e.rate_multiplier != null) + .map(e => ({ + user_id: e.user_id, + rate_multiplier: e.rate_multiplier as number + })) await adminAPI.groups.batchSetGroupRateMultipliers(props.group.id, entries) appStore.showSuccess(t('admin.groups.rateSaved')) emit('success') diff --git a/frontend/src/components/admin/user/UserCreateModal.vue b/frontend/src/components/admin/user/UserCreateModal.vue index 0e44d81e..2966a23b 100644 --- a/frontend/src/components/admin/user/UserCreateModal.vue +++ b/frontend/src/components/admin/user/UserCreateModal.vue @@ -35,6 +35,18 @@ +
+ + +

{{ t('admin.users.form.rpmLimitHint') }}

+