diff --git a/backend/cmd/server/main.go b/backend/cmd/server/main.go index c9dc57bb..f8a7d313 100644 --- a/backend/cmd/server/main.go +++ b/backend/cmd/server/main.go @@ -8,6 +8,7 @@ import ( "errors" "flag" "log" + "log/slog" "net/http" "os" "os/signal" @@ -44,7 +45,25 @@ func init() { } } +// initLogger configures the default slog handler based on gin.Mode(). +// In non-release mode, Debug level logs are enabled. +func initLogger() { + var level slog.Level + if gin.Mode() == gin.ReleaseMode { + level = slog.LevelInfo + } else { + level = slog.LevelDebug + } + handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + Level: level, + }) + slog.SetDefault(slog.New(handler)) +} + func main() { + // Initialize slog logger based on gin mode + initLogger() + // Parse command line flags setupMode := flag.Bool("setup", false, "Run setup wizard in CLI mode") showVersion := flag.Bool("version", false, "Show version information") diff --git a/backend/cmd/server/wire.go b/backend/cmd/server/wire.go index 0a5f9744..5ef04a66 100644 --- a/backend/cmd/server/wire.go +++ b/backend/cmd/server/wire.go @@ -70,6 +70,7 @@ func provideCleanup( schedulerSnapshot *service.SchedulerSnapshotService, tokenRefresh *service.TokenRefreshService, accountExpiry *service.AccountExpiryService, + usageCleanup *service.UsageCleanupService, pricing *service.PricingService, emailQueue *service.EmailQueueService, billingCache *service.BillingCacheService, @@ -123,6 +124,12 @@ func provideCleanup( } return nil }}, + {"UsageCleanupService", func() error { + if usageCleanup != nil { + usageCleanup.Stop() + } + return nil + }}, {"TokenRefreshService", func() error { tokenRefresh.Stop() return nil diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 31e47332..a12d3790 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -105,21 +105,22 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { geminiTokenCache := repository.NewGeminiTokenCache(redisClient) compositeTokenCacheInvalidator := service.NewCompositeTokenCacheInvalidator(geminiTokenCache) rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService, compositeTokenCacheInvalidator) - claudeUsageFetcher := repository.NewClaudeUsageFetcher() + httpUpstream := repository.NewHTTPUpstream(configConfig) + claudeUsageFetcher := repository.NewClaudeUsageFetcher(httpUpstream) antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository) usageCache := service.NewUsageCache() - accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache) + identityCache := repository.NewIdentityCache(redisClient) + accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService, antigravityQuotaFetcher, usageCache, identityCache) geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService) gatewayCache := repository.NewGatewayCache(redisClient) antigravityTokenProvider := service.NewAntigravityTokenProvider(accountRepository, geminiTokenCache, antigravityOAuthService) - httpUpstream := repository.NewHTTPUpstream(configConfig) antigravityGatewayService := service.NewAntigravityGatewayService(accountRepository, gatewayCache, antigravityTokenProvider, rateLimitService, httpUpstream, settingService) accountTestService := service.NewAccountTestService(accountRepository, geminiTokenProvider, antigravityGatewayService, httpUpstream, configConfig) concurrencyCache := repository.ProvideConcurrencyCache(redisClient, configConfig) concurrencyService := service.ProvideConcurrencyService(concurrencyCache, accountRepository, configConfig) crsSyncService := service.NewCRSSyncService(accountRepository, proxyRepository, oAuthService, openAIOAuthService, geminiOAuthService, configConfig) sessionLimitCache := repository.ProvideSessionLimitCache(redisClient, configConfig) - accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache) + accountHandler := admin.NewAccountHandler(adminService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, rateLimitService, accountUsageService, accountTestService, concurrencyService, crsSyncService, sessionLimitCache, compositeTokenCacheInvalidator) oAuthHandler := admin.NewOAuthHandler(oAuthService) openAIOAuthHandler := admin.NewOpenAIOAuthHandler(openAIOAuthService, adminService) geminiOAuthHandler := admin.NewGeminiOAuthHandler(geminiOAuthService) @@ -137,7 +138,6 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { return nil, err } billingService := service.NewBillingService(configConfig, pricingService) - identityCache := repository.NewIdentityCache(redisClient) identityService := service.NewIdentityService(identityCache) deferredService := service.ProvideDeferredService(accountRepository, timingWheelService) claudeTokenProvider := service.NewClaudeTokenProvider(accountRepository, geminiTokenCache, oAuthService) @@ -154,7 +154,9 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { updateService := service.ProvideUpdateService(updateCache, gitHubReleaseClient, serviceBuildInfo) systemHandler := handler.ProvideSystemHandler(updateService) adminSubscriptionHandler := admin.NewSubscriptionHandler(subscriptionService) - adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService) + usageCleanupRepository := repository.NewUsageCleanupRepository(client, db) + usageCleanupService := service.ProvideUsageCleanupService(usageCleanupRepository, timingWheelService, dashboardAggregationService, configConfig) + adminUsageHandler := admin.NewUsageHandler(usageService, apiKeyService, adminService, usageCleanupService) userAttributeDefinitionRepository := repository.NewUserAttributeDefinitionRepository(client) userAttributeValueRepository := repository.NewUserAttributeValueRepository(client) userAttributeService := service.NewUserAttributeService(userAttributeDefinitionRepository, userAttributeValueRepository) @@ -176,7 +178,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { opsScheduledReportService := service.ProvideOpsScheduledReportService(opsService, userService, emailService, redisClient, configConfig) tokenRefreshService := service.ProvideTokenRefreshService(accountRepository, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService, compositeTokenCacheInvalidator, configConfig) accountExpiryService := service.ProvideAccountExpiryService(accountRepository) - v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) + v := provideCleanup(client, redisClient, opsMetricsCollector, opsAggregationService, opsAlertEvaluatorService, opsCleanupService, opsScheduledReportService, schedulerSnapshotService, tokenRefreshService, accountExpiryService, usageCleanupService, pricingService, emailQueueService, billingCacheService, oAuthService, openAIOAuthService, geminiOAuthService, antigravityOAuthService) application := &Application{ Server: httpServer, Cleanup: v, @@ -209,6 +211,7 @@ func provideCleanup( schedulerSnapshot *service.SchedulerSnapshotService, tokenRefresh *service.TokenRefreshService, accountExpiry *service.AccountExpiryService, + usageCleanup *service.UsageCleanupService, pricing *service.PricingService, emailQueue *service.EmailQueueService, billingCache *service.BillingCacheService, @@ -261,6 +264,12 @@ func provideCleanup( } return nil }}, + {"UsageCleanupService", func() error { + if usageCleanup != nil { + usageCleanup.Stop() + } + return nil + }}, {"TokenRefreshService", func() error { tokenRefresh.Stop() return nil diff --git a/backend/ent/client.go b/backend/ent/client.go index 35cf644f..f6c13e84 100644 --- a/backend/ent/client.go +++ b/backend/ent/client.go @@ -24,6 +24,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/setting" + "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" @@ -57,6 +58,8 @@ type Client struct { RedeemCode *RedeemCodeClient // Setting is the client for interacting with the Setting builders. Setting *SettingClient + // UsageCleanupTask is the client for interacting with the UsageCleanupTask builders. + UsageCleanupTask *UsageCleanupTaskClient // UsageLog is the client for interacting with the UsageLog builders. UsageLog *UsageLogClient // User is the client for interacting with the User builders. @@ -89,6 +92,7 @@ func (c *Client) init() { c.Proxy = NewProxyClient(c.config) c.RedeemCode = NewRedeemCodeClient(c.config) c.Setting = NewSettingClient(c.config) + c.UsageCleanupTask = NewUsageCleanupTaskClient(c.config) c.UsageLog = NewUsageLogClient(c.config) c.User = NewUserClient(c.config) c.UserAllowedGroup = NewUserAllowedGroupClient(c.config) @@ -196,6 +200,7 @@ func (c *Client) Tx(ctx context.Context) (*Tx, error) { Proxy: NewProxyClient(cfg), RedeemCode: NewRedeemCodeClient(cfg), Setting: NewSettingClient(cfg), + UsageCleanupTask: NewUsageCleanupTaskClient(cfg), UsageLog: NewUsageLogClient(cfg), User: NewUserClient(cfg), UserAllowedGroup: NewUserAllowedGroupClient(cfg), @@ -230,6 +235,7 @@ func (c *Client) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) Proxy: NewProxyClient(cfg), RedeemCode: NewRedeemCodeClient(cfg), Setting: NewSettingClient(cfg), + UsageCleanupTask: NewUsageCleanupTaskClient(cfg), UsageLog: NewUsageLogClient(cfg), User: NewUserClient(cfg), UserAllowedGroup: NewUserAllowedGroupClient(cfg), @@ -266,8 +272,9 @@ func (c *Client) Close() error { func (c *Client) Use(hooks ...Hook) { for _, n := range []interface{ Use(...Hook) }{ c.APIKey, c.Account, c.AccountGroup, c.Group, c.PromoCode, c.PromoCodeUsage, - c.Proxy, c.RedeemCode, c.Setting, c.UsageLog, c.User, c.UserAllowedGroup, - c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, + c.Proxy, c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User, + c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, + c.UserSubscription, } { n.Use(hooks...) } @@ -278,8 +285,9 @@ func (c *Client) Use(hooks ...Hook) { func (c *Client) Intercept(interceptors ...Interceptor) { for _, n := range []interface{ Intercept(...Interceptor) }{ c.APIKey, c.Account, c.AccountGroup, c.Group, c.PromoCode, c.PromoCodeUsage, - c.Proxy, c.RedeemCode, c.Setting, c.UsageLog, c.User, c.UserAllowedGroup, - c.UserAttributeDefinition, c.UserAttributeValue, c.UserSubscription, + c.Proxy, c.RedeemCode, c.Setting, c.UsageCleanupTask, c.UsageLog, c.User, + c.UserAllowedGroup, c.UserAttributeDefinition, c.UserAttributeValue, + c.UserSubscription, } { n.Intercept(interceptors...) } @@ -306,6 +314,8 @@ func (c *Client) Mutate(ctx context.Context, m Mutation) (Value, error) { return c.RedeemCode.mutate(ctx, m) case *SettingMutation: return c.Setting.mutate(ctx, m) + case *UsageCleanupTaskMutation: + return c.UsageCleanupTask.mutate(ctx, m) case *UsageLogMutation: return c.UsageLog.mutate(ctx, m) case *UserMutation: @@ -1847,6 +1857,139 @@ func (c *SettingClient) mutate(ctx context.Context, m *SettingMutation) (Value, } } +// UsageCleanupTaskClient is a client for the UsageCleanupTask schema. +type UsageCleanupTaskClient struct { + config +} + +// NewUsageCleanupTaskClient returns a client for the UsageCleanupTask from the given config. +func NewUsageCleanupTaskClient(c config) *UsageCleanupTaskClient { + return &UsageCleanupTaskClient{config: c} +} + +// Use adds a list of mutation hooks to the hooks stack. +// A call to `Use(f, g, h)` equals to `usagecleanuptask.Hooks(f(g(h())))`. +func (c *UsageCleanupTaskClient) Use(hooks ...Hook) { + c.hooks.UsageCleanupTask = append(c.hooks.UsageCleanupTask, hooks...) +} + +// Intercept adds a list of query interceptors to the interceptors stack. +// A call to `Intercept(f, g, h)` equals to `usagecleanuptask.Intercept(f(g(h())))`. +func (c *UsageCleanupTaskClient) Intercept(interceptors ...Interceptor) { + c.inters.UsageCleanupTask = append(c.inters.UsageCleanupTask, interceptors...) +} + +// Create returns a builder for creating a UsageCleanupTask entity. +func (c *UsageCleanupTaskClient) Create() *UsageCleanupTaskCreate { + mutation := newUsageCleanupTaskMutation(c.config, OpCreate) + return &UsageCleanupTaskCreate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// CreateBulk returns a builder for creating a bulk of UsageCleanupTask entities. +func (c *UsageCleanupTaskClient) CreateBulk(builders ...*UsageCleanupTaskCreate) *UsageCleanupTaskCreateBulk { + return &UsageCleanupTaskCreateBulk{config: c.config, builders: builders} +} + +// MapCreateBulk creates a bulk creation builder from the given slice. For each item in the slice, the function creates +// a builder and applies setFunc on it. +func (c *UsageCleanupTaskClient) MapCreateBulk(slice any, setFunc func(*UsageCleanupTaskCreate, int)) *UsageCleanupTaskCreateBulk { + rv := reflect.ValueOf(slice) + if rv.Kind() != reflect.Slice { + return &UsageCleanupTaskCreateBulk{err: fmt.Errorf("calling to UsageCleanupTaskClient.MapCreateBulk with wrong type %T, need slice", slice)} + } + builders := make([]*UsageCleanupTaskCreate, rv.Len()) + for i := 0; i < rv.Len(); i++ { + builders[i] = c.Create() + setFunc(builders[i], i) + } + return &UsageCleanupTaskCreateBulk{config: c.config, builders: builders} +} + +// Update returns an update builder for UsageCleanupTask. +func (c *UsageCleanupTaskClient) Update() *UsageCleanupTaskUpdate { + mutation := newUsageCleanupTaskMutation(c.config, OpUpdate) + return &UsageCleanupTaskUpdate{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOne returns an update builder for the given entity. +func (c *UsageCleanupTaskClient) UpdateOne(_m *UsageCleanupTask) *UsageCleanupTaskUpdateOne { + mutation := newUsageCleanupTaskMutation(c.config, OpUpdateOne, withUsageCleanupTask(_m)) + return &UsageCleanupTaskUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// UpdateOneID returns an update builder for the given id. +func (c *UsageCleanupTaskClient) UpdateOneID(id int64) *UsageCleanupTaskUpdateOne { + mutation := newUsageCleanupTaskMutation(c.config, OpUpdateOne, withUsageCleanupTaskID(id)) + return &UsageCleanupTaskUpdateOne{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// Delete returns a delete builder for UsageCleanupTask. +func (c *UsageCleanupTaskClient) Delete() *UsageCleanupTaskDelete { + mutation := newUsageCleanupTaskMutation(c.config, OpDelete) + return &UsageCleanupTaskDelete{config: c.config, hooks: c.Hooks(), mutation: mutation} +} + +// DeleteOne returns a builder for deleting the given entity. +func (c *UsageCleanupTaskClient) DeleteOne(_m *UsageCleanupTask) *UsageCleanupTaskDeleteOne { + return c.DeleteOneID(_m.ID) +} + +// DeleteOneID returns a builder for deleting the given entity by its id. +func (c *UsageCleanupTaskClient) DeleteOneID(id int64) *UsageCleanupTaskDeleteOne { + builder := c.Delete().Where(usagecleanuptask.ID(id)) + builder.mutation.id = &id + builder.mutation.op = OpDeleteOne + return &UsageCleanupTaskDeleteOne{builder} +} + +// Query returns a query builder for UsageCleanupTask. +func (c *UsageCleanupTaskClient) Query() *UsageCleanupTaskQuery { + return &UsageCleanupTaskQuery{ + config: c.config, + ctx: &QueryContext{Type: TypeUsageCleanupTask}, + inters: c.Interceptors(), + } +} + +// Get returns a UsageCleanupTask entity by its id. +func (c *UsageCleanupTaskClient) Get(ctx context.Context, id int64) (*UsageCleanupTask, error) { + return c.Query().Where(usagecleanuptask.ID(id)).Only(ctx) +} + +// GetX is like Get, but panics if an error occurs. +func (c *UsageCleanupTaskClient) GetX(ctx context.Context, id int64) *UsageCleanupTask { + obj, err := c.Get(ctx, id) + if err != nil { + panic(err) + } + return obj +} + +// Hooks returns the client hooks. +func (c *UsageCleanupTaskClient) Hooks() []Hook { + return c.hooks.UsageCleanupTask +} + +// Interceptors returns the client interceptors. +func (c *UsageCleanupTaskClient) Interceptors() []Interceptor { + return c.inters.UsageCleanupTask +} + +func (c *UsageCleanupTaskClient) mutate(ctx context.Context, m *UsageCleanupTaskMutation) (Value, error) { + switch m.Op() { + case OpCreate: + return (&UsageCleanupTaskCreate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdate: + return (&UsageCleanupTaskUpdate{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpUpdateOne: + return (&UsageCleanupTaskUpdateOne{config: c.config, hooks: c.Hooks(), mutation: m}).Save(ctx) + case OpDelete, OpDeleteOne: + return (&UsageCleanupTaskDelete{config: c.config, hooks: c.Hooks(), mutation: m}).Exec(ctx) + default: + return nil, fmt.Errorf("ent: unknown UsageCleanupTask mutation op: %q", m.Op()) + } +} + // UsageLogClient is a client for the UsageLog schema. type UsageLogClient struct { config @@ -2974,13 +3117,13 @@ func (c *UserSubscriptionClient) mutate(ctx context.Context, m *UserSubscription type ( hooks struct { APIKey, Account, AccountGroup, Group, PromoCode, PromoCodeUsage, Proxy, - RedeemCode, Setting, UsageLog, User, UserAllowedGroup, UserAttributeDefinition, - UserAttributeValue, UserSubscription []ent.Hook + RedeemCode, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup, + UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Hook } inters struct { APIKey, Account, AccountGroup, Group, PromoCode, PromoCodeUsage, Proxy, - RedeemCode, Setting, UsageLog, User, UserAllowedGroup, UserAttributeDefinition, - UserAttributeValue, UserSubscription []ent.Interceptor + RedeemCode, Setting, UsageCleanupTask, UsageLog, User, UserAllowedGroup, + UserAttributeDefinition, UserAttributeValue, UserSubscription []ent.Interceptor } ) diff --git a/backend/ent/ent.go b/backend/ent/ent.go index 410375a7..4bcc2642 100644 --- a/backend/ent/ent.go +++ b/backend/ent/ent.go @@ -21,6 +21,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/setting" + "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" @@ -96,6 +97,7 @@ func checkColumn(t, c string) error { proxy.Table: proxy.ValidColumn, redeemcode.Table: redeemcode.ValidColumn, setting.Table: setting.ValidColumn, + usagecleanuptask.Table: usagecleanuptask.ValidColumn, usagelog.Table: usagelog.ValidColumn, user.Table: user.ValidColumn, userallowedgroup.Table: userallowedgroup.ValidColumn, diff --git a/backend/ent/hook/hook.go b/backend/ent/hook/hook.go index 532b0d2c..edd84f5e 100644 --- a/backend/ent/hook/hook.go +++ b/backend/ent/hook/hook.go @@ -117,6 +117,18 @@ func (f SettingFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, err return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.SettingMutation", m) } +// The UsageCleanupTaskFunc type is an adapter to allow the use of ordinary +// function as UsageCleanupTask mutator. +type UsageCleanupTaskFunc func(context.Context, *ent.UsageCleanupTaskMutation) (ent.Value, error) + +// Mutate calls f(ctx, m). +func (f UsageCleanupTaskFunc) Mutate(ctx context.Context, m ent.Mutation) (ent.Value, error) { + if mv, ok := m.(*ent.UsageCleanupTaskMutation); ok { + return f(ctx, mv) + } + return nil, fmt.Errorf("unexpected mutation type %T. expect *ent.UsageCleanupTaskMutation", m) +} + // The UsageLogFunc type is an adapter to allow the use of ordinary // function as UsageLog mutator. type UsageLogFunc func(context.Context, *ent.UsageLogMutation) (ent.Value, error) diff --git a/backend/ent/intercept/intercept.go b/backend/ent/intercept/intercept.go index 765d39b4..f18c0624 100644 --- a/backend/ent/intercept/intercept.go +++ b/backend/ent/intercept/intercept.go @@ -18,6 +18,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/setting" + "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" @@ -325,6 +326,33 @@ func (f TraverseSetting) Traverse(ctx context.Context, q ent.Query) error { return fmt.Errorf("unexpected query type %T. expect *ent.SettingQuery", q) } +// The UsageCleanupTaskFunc type is an adapter to allow the use of ordinary function as a Querier. +type UsageCleanupTaskFunc func(context.Context, *ent.UsageCleanupTaskQuery) (ent.Value, error) + +// Query calls f(ctx, q). +func (f UsageCleanupTaskFunc) Query(ctx context.Context, q ent.Query) (ent.Value, error) { + if q, ok := q.(*ent.UsageCleanupTaskQuery); ok { + return f(ctx, q) + } + return nil, fmt.Errorf("unexpected query type %T. expect *ent.UsageCleanupTaskQuery", q) +} + +// The TraverseUsageCleanupTask type is an adapter to allow the use of ordinary function as Traverser. +type TraverseUsageCleanupTask func(context.Context, *ent.UsageCleanupTaskQuery) error + +// Intercept is a dummy implementation of Intercept that returns the next Querier in the pipeline. +func (f TraverseUsageCleanupTask) Intercept(next ent.Querier) ent.Querier { + return next +} + +// Traverse calls f(ctx, q). +func (f TraverseUsageCleanupTask) Traverse(ctx context.Context, q ent.Query) error { + if q, ok := q.(*ent.UsageCleanupTaskQuery); ok { + return f(ctx, q) + } + return fmt.Errorf("unexpected query type %T. expect *ent.UsageCleanupTaskQuery", q) +} + // The UsageLogFunc type is an adapter to allow the use of ordinary function as a Querier. type UsageLogFunc func(context.Context, *ent.UsageLogQuery) (ent.Value, error) @@ -508,6 +536,8 @@ func NewQuery(q ent.Query) (Query, error) { return &query[*ent.RedeemCodeQuery, predicate.RedeemCode, redeemcode.OrderOption]{typ: ent.TypeRedeemCode, tq: q}, nil case *ent.SettingQuery: return &query[*ent.SettingQuery, predicate.Setting, setting.OrderOption]{typ: ent.TypeSetting, tq: q}, nil + case *ent.UsageCleanupTaskQuery: + return &query[*ent.UsageCleanupTaskQuery, predicate.UsageCleanupTask, usagecleanuptask.OrderOption]{typ: ent.TypeUsageCleanupTask, tq: q}, nil case *ent.UsageLogQuery: return &query[*ent.UsageLogQuery, predicate.UsageLog, usagelog.OrderOption]{typ: ent.TypeUsageLog, tq: q}, nil case *ent.UserQuery: diff --git a/backend/ent/migrate/schema.go b/backend/ent/migrate/schema.go index b377804f..d1f05186 100644 --- a/backend/ent/migrate/schema.go +++ b/backend/ent/migrate/schema.go @@ -434,6 +434,44 @@ var ( Columns: SettingsColumns, PrimaryKey: []*schema.Column{SettingsColumns[0]}, } + // UsageCleanupTasksColumns holds the columns for the "usage_cleanup_tasks" table. + UsageCleanupTasksColumns = []*schema.Column{ + {Name: "id", Type: field.TypeInt64, Increment: true}, + {Name: "created_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "updated_at", Type: field.TypeTime, SchemaType: map[string]string{"postgres": "timestamptz"}}, + {Name: "status", Type: field.TypeString, Size: 20}, + {Name: "filters", Type: field.TypeJSON}, + {Name: "created_by", Type: field.TypeInt64}, + {Name: "deleted_rows", Type: field.TypeInt64, Default: 0}, + {Name: "error_message", Type: field.TypeString, Nullable: true}, + {Name: "canceled_by", Type: field.TypeInt64, Nullable: true}, + {Name: "canceled_at", Type: field.TypeTime, Nullable: true}, + {Name: "started_at", Type: field.TypeTime, Nullable: true}, + {Name: "finished_at", Type: field.TypeTime, Nullable: true}, + } + // UsageCleanupTasksTable holds the schema information for the "usage_cleanup_tasks" table. + UsageCleanupTasksTable = &schema.Table{ + Name: "usage_cleanup_tasks", + Columns: UsageCleanupTasksColumns, + PrimaryKey: []*schema.Column{UsageCleanupTasksColumns[0]}, + Indexes: []*schema.Index{ + { + Name: "usagecleanuptask_status_created_at", + Unique: false, + Columns: []*schema.Column{UsageCleanupTasksColumns[3], UsageCleanupTasksColumns[1]}, + }, + { + Name: "usagecleanuptask_created_at", + Unique: false, + Columns: []*schema.Column{UsageCleanupTasksColumns[1]}, + }, + { + Name: "usagecleanuptask_canceled_at", + Unique: false, + Columns: []*schema.Column{UsageCleanupTasksColumns[9]}, + }, + }, + } // UsageLogsColumns holds the columns for the "usage_logs" table. UsageLogsColumns = []*schema.Column{ {Name: "id", Type: field.TypeInt64, Increment: true}, @@ -805,6 +843,7 @@ var ( ProxiesTable, RedeemCodesTable, SettingsTable, + UsageCleanupTasksTable, UsageLogsTable, UsersTable, UserAllowedGroupsTable, @@ -851,6 +890,9 @@ func init() { SettingsTable.Annotation = &entsql.Annotation{ Table: "settings", } + UsageCleanupTasksTable.Annotation = &entsql.Annotation{ + Table: "usage_cleanup_tasks", + } UsageLogsTable.ForeignKeys[0].RefTable = APIKeysTable UsageLogsTable.ForeignKeys[1].RefTable = AccountsTable UsageLogsTable.ForeignKeys[2].RefTable = GroupsTable diff --git a/backend/ent/mutation.go b/backend/ent/mutation.go index cd2fe8e0..9b330616 100644 --- a/backend/ent/mutation.go +++ b/backend/ent/mutation.go @@ -4,6 +4,7 @@ package ent import ( "context" + "encoding/json" "errors" "fmt" "sync" @@ -21,6 +22,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/proxy" "github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/setting" + "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" @@ -47,6 +49,7 @@ const ( TypeProxy = "Proxy" TypeRedeemCode = "RedeemCode" TypeSetting = "Setting" + TypeUsageCleanupTask = "UsageCleanupTask" TypeUsageLog = "UsageLog" TypeUser = "User" TypeUserAllowedGroup = "UserAllowedGroup" @@ -10370,6 +10373,1089 @@ func (m *SettingMutation) ResetEdge(name string) error { return fmt.Errorf("unknown Setting edge %s", name) } +// UsageCleanupTaskMutation represents an operation that mutates the UsageCleanupTask nodes in the graph. +type UsageCleanupTaskMutation struct { + config + op Op + typ string + id *int64 + created_at *time.Time + updated_at *time.Time + status *string + filters *json.RawMessage + appendfilters json.RawMessage + created_by *int64 + addcreated_by *int64 + deleted_rows *int64 + adddeleted_rows *int64 + error_message *string + canceled_by *int64 + addcanceled_by *int64 + canceled_at *time.Time + started_at *time.Time + finished_at *time.Time + clearedFields map[string]struct{} + done bool + oldValue func(context.Context) (*UsageCleanupTask, error) + predicates []predicate.UsageCleanupTask +} + +var _ ent.Mutation = (*UsageCleanupTaskMutation)(nil) + +// usagecleanuptaskOption allows management of the mutation configuration using functional options. +type usagecleanuptaskOption func(*UsageCleanupTaskMutation) + +// newUsageCleanupTaskMutation creates new mutation for the UsageCleanupTask entity. +func newUsageCleanupTaskMutation(c config, op Op, opts ...usagecleanuptaskOption) *UsageCleanupTaskMutation { + m := &UsageCleanupTaskMutation{ + config: c, + op: op, + typ: TypeUsageCleanupTask, + clearedFields: make(map[string]struct{}), + } + for _, opt := range opts { + opt(m) + } + return m +} + +// withUsageCleanupTaskID sets the ID field of the mutation. +func withUsageCleanupTaskID(id int64) usagecleanuptaskOption { + return func(m *UsageCleanupTaskMutation) { + var ( + err error + once sync.Once + value *UsageCleanupTask + ) + m.oldValue = func(ctx context.Context) (*UsageCleanupTask, error) { + once.Do(func() { + if m.done { + err = errors.New("querying old values post mutation is not allowed") + } else { + value, err = m.Client().UsageCleanupTask.Get(ctx, id) + } + }) + return value, err + } + m.id = &id + } +} + +// withUsageCleanupTask sets the old UsageCleanupTask of the mutation. +func withUsageCleanupTask(node *UsageCleanupTask) usagecleanuptaskOption { + return func(m *UsageCleanupTaskMutation) { + m.oldValue = func(context.Context) (*UsageCleanupTask, error) { + return node, nil + } + m.id = &node.ID + } +} + +// Client returns a new `ent.Client` from the mutation. If the mutation was +// executed in a transaction (ent.Tx), a transactional client is returned. +func (m UsageCleanupTaskMutation) Client() *Client { + client := &Client{config: m.config} + client.init() + return client +} + +// Tx returns an `ent.Tx` for mutations that were executed in transactions; +// it returns an error otherwise. +func (m UsageCleanupTaskMutation) Tx() (*Tx, error) { + if _, ok := m.driver.(*txDriver); !ok { + return nil, errors.New("ent: mutation is not running in a transaction") + } + tx := &Tx{config: m.config} + tx.init() + return tx, nil +} + +// ID returns the ID value in the mutation. Note that the ID is only available +// if it was provided to the builder or after it was returned from the database. +func (m *UsageCleanupTaskMutation) ID() (id int64, exists bool) { + if m.id == nil { + return + } + return *m.id, true +} + +// IDs queries the database and returns the entity ids that match the mutation's predicate. +// That means, if the mutation is applied within a transaction with an isolation level such +// as sql.LevelSerializable, the returned ids match the ids of the rows that will be updated +// or updated by the mutation. +func (m *UsageCleanupTaskMutation) IDs(ctx context.Context) ([]int64, error) { + switch { + case m.op.Is(OpUpdateOne | OpDeleteOne): + id, exists := m.ID() + if exists { + return []int64{id}, nil + } + fallthrough + case m.op.Is(OpUpdate | OpDelete): + return m.Client().UsageCleanupTask.Query().Where(m.predicates...).IDs(ctx) + default: + return nil, fmt.Errorf("IDs is not allowed on %s operations", m.op) + } +} + +// SetCreatedAt sets the "created_at" field. +func (m *UsageCleanupTaskMutation) SetCreatedAt(t time.Time) { + m.created_at = &t +} + +// CreatedAt returns the value of the "created_at" field in the mutation. +func (m *UsageCleanupTaskMutation) CreatedAt() (r time.Time, exists bool) { + v := m.created_at + if v == nil { + return + } + return *v, true +} + +// OldCreatedAt returns the old "created_at" field's value of the UsageCleanupTask entity. +// If the UsageCleanupTask object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageCleanupTaskMutation) OldCreatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedAt: %w", err) + } + return oldValue.CreatedAt, nil +} + +// ResetCreatedAt resets all changes to the "created_at" field. +func (m *UsageCleanupTaskMutation) ResetCreatedAt() { + m.created_at = nil +} + +// SetUpdatedAt sets the "updated_at" field. +func (m *UsageCleanupTaskMutation) SetUpdatedAt(t time.Time) { + m.updated_at = &t +} + +// UpdatedAt returns the value of the "updated_at" field in the mutation. +func (m *UsageCleanupTaskMutation) UpdatedAt() (r time.Time, exists bool) { + v := m.updated_at + if v == nil { + return + } + return *v, true +} + +// OldUpdatedAt returns the old "updated_at" field's value of the UsageCleanupTask entity. +// If the UsageCleanupTask object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageCleanupTaskMutation) OldUpdatedAt(ctx context.Context) (v time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldUpdatedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldUpdatedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldUpdatedAt: %w", err) + } + return oldValue.UpdatedAt, nil +} + +// ResetUpdatedAt resets all changes to the "updated_at" field. +func (m *UsageCleanupTaskMutation) ResetUpdatedAt() { + m.updated_at = nil +} + +// SetStatus sets the "status" field. +func (m *UsageCleanupTaskMutation) SetStatus(s string) { + m.status = &s +} + +// Status returns the value of the "status" field in the mutation. +func (m *UsageCleanupTaskMutation) Status() (r string, exists bool) { + v := m.status + if v == nil { + return + } + return *v, true +} + +// OldStatus returns the old "status" field's value of the UsageCleanupTask entity. +// If the UsageCleanupTask object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageCleanupTaskMutation) OldStatus(ctx context.Context) (v string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStatus is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStatus requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStatus: %w", err) + } + return oldValue.Status, nil +} + +// ResetStatus resets all changes to the "status" field. +func (m *UsageCleanupTaskMutation) ResetStatus() { + m.status = nil +} + +// SetFilters sets the "filters" field. +func (m *UsageCleanupTaskMutation) SetFilters(jm json.RawMessage) { + m.filters = &jm + m.appendfilters = nil +} + +// Filters returns the value of the "filters" field in the mutation. +func (m *UsageCleanupTaskMutation) Filters() (r json.RawMessage, exists bool) { + v := m.filters + if v == nil { + return + } + return *v, true +} + +// OldFilters returns the old "filters" field's value of the UsageCleanupTask entity. +// If the UsageCleanupTask object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageCleanupTaskMutation) OldFilters(ctx context.Context) (v json.RawMessage, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFilters is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFilters requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFilters: %w", err) + } + return oldValue.Filters, nil +} + +// AppendFilters adds jm to the "filters" field. +func (m *UsageCleanupTaskMutation) AppendFilters(jm json.RawMessage) { + m.appendfilters = append(m.appendfilters, jm...) +} + +// AppendedFilters returns the list of values that were appended to the "filters" field in this mutation. +func (m *UsageCleanupTaskMutation) AppendedFilters() (json.RawMessage, bool) { + if len(m.appendfilters) == 0 { + return nil, false + } + return m.appendfilters, true +} + +// ResetFilters resets all changes to the "filters" field. +func (m *UsageCleanupTaskMutation) ResetFilters() { + m.filters = nil + m.appendfilters = nil +} + +// SetCreatedBy sets the "created_by" field. +func (m *UsageCleanupTaskMutation) SetCreatedBy(i int64) { + m.created_by = &i + m.addcreated_by = nil +} + +// CreatedBy returns the value of the "created_by" field in the mutation. +func (m *UsageCleanupTaskMutation) CreatedBy() (r int64, exists bool) { + v := m.created_by + if v == nil { + return + } + return *v, true +} + +// OldCreatedBy returns the old "created_by" field's value of the UsageCleanupTask entity. +// If the UsageCleanupTask object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageCleanupTaskMutation) OldCreatedBy(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCreatedBy is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCreatedBy requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCreatedBy: %w", err) + } + return oldValue.CreatedBy, nil +} + +// AddCreatedBy adds i to the "created_by" field. +func (m *UsageCleanupTaskMutation) AddCreatedBy(i int64) { + if m.addcreated_by != nil { + *m.addcreated_by += i + } else { + m.addcreated_by = &i + } +} + +// AddedCreatedBy returns the value that was added to the "created_by" field in this mutation. +func (m *UsageCleanupTaskMutation) AddedCreatedBy() (r int64, exists bool) { + v := m.addcreated_by + if v == nil { + return + } + return *v, true +} + +// ResetCreatedBy resets all changes to the "created_by" field. +func (m *UsageCleanupTaskMutation) ResetCreatedBy() { + m.created_by = nil + m.addcreated_by = nil +} + +// SetDeletedRows sets the "deleted_rows" field. +func (m *UsageCleanupTaskMutation) SetDeletedRows(i int64) { + m.deleted_rows = &i + m.adddeleted_rows = nil +} + +// DeletedRows returns the value of the "deleted_rows" field in the mutation. +func (m *UsageCleanupTaskMutation) DeletedRows() (r int64, exists bool) { + v := m.deleted_rows + if v == nil { + return + } + return *v, true +} + +// OldDeletedRows returns the old "deleted_rows" field's value of the UsageCleanupTask entity. +// If the UsageCleanupTask object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageCleanupTaskMutation) OldDeletedRows(ctx context.Context) (v int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldDeletedRows is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldDeletedRows requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldDeletedRows: %w", err) + } + return oldValue.DeletedRows, nil +} + +// AddDeletedRows adds i to the "deleted_rows" field. +func (m *UsageCleanupTaskMutation) AddDeletedRows(i int64) { + if m.adddeleted_rows != nil { + *m.adddeleted_rows += i + } else { + m.adddeleted_rows = &i + } +} + +// AddedDeletedRows returns the value that was added to the "deleted_rows" field in this mutation. +func (m *UsageCleanupTaskMutation) AddedDeletedRows() (r int64, exists bool) { + v := m.adddeleted_rows + if v == nil { + return + } + return *v, true +} + +// ResetDeletedRows resets all changes to the "deleted_rows" field. +func (m *UsageCleanupTaskMutation) ResetDeletedRows() { + m.deleted_rows = nil + m.adddeleted_rows = nil +} + +// SetErrorMessage sets the "error_message" field. +func (m *UsageCleanupTaskMutation) SetErrorMessage(s string) { + m.error_message = &s +} + +// ErrorMessage returns the value of the "error_message" field in the mutation. +func (m *UsageCleanupTaskMutation) ErrorMessage() (r string, exists bool) { + v := m.error_message + if v == nil { + return + } + return *v, true +} + +// OldErrorMessage returns the old "error_message" field's value of the UsageCleanupTask entity. +// If the UsageCleanupTask object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageCleanupTaskMutation) OldErrorMessage(ctx context.Context) (v *string, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldErrorMessage is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldErrorMessage requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldErrorMessage: %w", err) + } + return oldValue.ErrorMessage, nil +} + +// ClearErrorMessage clears the value of the "error_message" field. +func (m *UsageCleanupTaskMutation) ClearErrorMessage() { + m.error_message = nil + m.clearedFields[usagecleanuptask.FieldErrorMessage] = struct{}{} +} + +// ErrorMessageCleared returns if the "error_message" field was cleared in this mutation. +func (m *UsageCleanupTaskMutation) ErrorMessageCleared() bool { + _, ok := m.clearedFields[usagecleanuptask.FieldErrorMessage] + return ok +} + +// ResetErrorMessage resets all changes to the "error_message" field. +func (m *UsageCleanupTaskMutation) ResetErrorMessage() { + m.error_message = nil + delete(m.clearedFields, usagecleanuptask.FieldErrorMessage) +} + +// SetCanceledBy sets the "canceled_by" field. +func (m *UsageCleanupTaskMutation) SetCanceledBy(i int64) { + m.canceled_by = &i + m.addcanceled_by = nil +} + +// CanceledBy returns the value of the "canceled_by" field in the mutation. +func (m *UsageCleanupTaskMutation) CanceledBy() (r int64, exists bool) { + v := m.canceled_by + if v == nil { + return + } + return *v, true +} + +// OldCanceledBy returns the old "canceled_by" field's value of the UsageCleanupTask entity. +// If the UsageCleanupTask object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageCleanupTaskMutation) OldCanceledBy(ctx context.Context) (v *int64, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCanceledBy is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCanceledBy requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCanceledBy: %w", err) + } + return oldValue.CanceledBy, nil +} + +// AddCanceledBy adds i to the "canceled_by" field. +func (m *UsageCleanupTaskMutation) AddCanceledBy(i int64) { + if m.addcanceled_by != nil { + *m.addcanceled_by += i + } else { + m.addcanceled_by = &i + } +} + +// AddedCanceledBy returns the value that was added to the "canceled_by" field in this mutation. +func (m *UsageCleanupTaskMutation) AddedCanceledBy() (r int64, exists bool) { + v := m.addcanceled_by + if v == nil { + return + } + return *v, true +} + +// ClearCanceledBy clears the value of the "canceled_by" field. +func (m *UsageCleanupTaskMutation) ClearCanceledBy() { + m.canceled_by = nil + m.addcanceled_by = nil + m.clearedFields[usagecleanuptask.FieldCanceledBy] = struct{}{} +} + +// CanceledByCleared returns if the "canceled_by" field was cleared in this mutation. +func (m *UsageCleanupTaskMutation) CanceledByCleared() bool { + _, ok := m.clearedFields[usagecleanuptask.FieldCanceledBy] + return ok +} + +// ResetCanceledBy resets all changes to the "canceled_by" field. +func (m *UsageCleanupTaskMutation) ResetCanceledBy() { + m.canceled_by = nil + m.addcanceled_by = nil + delete(m.clearedFields, usagecleanuptask.FieldCanceledBy) +} + +// SetCanceledAt sets the "canceled_at" field. +func (m *UsageCleanupTaskMutation) SetCanceledAt(t time.Time) { + m.canceled_at = &t +} + +// CanceledAt returns the value of the "canceled_at" field in the mutation. +func (m *UsageCleanupTaskMutation) CanceledAt() (r time.Time, exists bool) { + v := m.canceled_at + if v == nil { + return + } + return *v, true +} + +// OldCanceledAt returns the old "canceled_at" field's value of the UsageCleanupTask entity. +// If the UsageCleanupTask object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageCleanupTaskMutation) OldCanceledAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldCanceledAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldCanceledAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldCanceledAt: %w", err) + } + return oldValue.CanceledAt, nil +} + +// ClearCanceledAt clears the value of the "canceled_at" field. +func (m *UsageCleanupTaskMutation) ClearCanceledAt() { + m.canceled_at = nil + m.clearedFields[usagecleanuptask.FieldCanceledAt] = struct{}{} +} + +// CanceledAtCleared returns if the "canceled_at" field was cleared in this mutation. +func (m *UsageCleanupTaskMutation) CanceledAtCleared() bool { + _, ok := m.clearedFields[usagecleanuptask.FieldCanceledAt] + return ok +} + +// ResetCanceledAt resets all changes to the "canceled_at" field. +func (m *UsageCleanupTaskMutation) ResetCanceledAt() { + m.canceled_at = nil + delete(m.clearedFields, usagecleanuptask.FieldCanceledAt) +} + +// SetStartedAt sets the "started_at" field. +func (m *UsageCleanupTaskMutation) SetStartedAt(t time.Time) { + m.started_at = &t +} + +// StartedAt returns the value of the "started_at" field in the mutation. +func (m *UsageCleanupTaskMutation) StartedAt() (r time.Time, exists bool) { + v := m.started_at + if v == nil { + return + } + return *v, true +} + +// OldStartedAt returns the old "started_at" field's value of the UsageCleanupTask entity. +// If the UsageCleanupTask object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageCleanupTaskMutation) OldStartedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldStartedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldStartedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldStartedAt: %w", err) + } + return oldValue.StartedAt, nil +} + +// ClearStartedAt clears the value of the "started_at" field. +func (m *UsageCleanupTaskMutation) ClearStartedAt() { + m.started_at = nil + m.clearedFields[usagecleanuptask.FieldStartedAt] = struct{}{} +} + +// StartedAtCleared returns if the "started_at" field was cleared in this mutation. +func (m *UsageCleanupTaskMutation) StartedAtCleared() bool { + _, ok := m.clearedFields[usagecleanuptask.FieldStartedAt] + return ok +} + +// ResetStartedAt resets all changes to the "started_at" field. +func (m *UsageCleanupTaskMutation) ResetStartedAt() { + m.started_at = nil + delete(m.clearedFields, usagecleanuptask.FieldStartedAt) +} + +// SetFinishedAt sets the "finished_at" field. +func (m *UsageCleanupTaskMutation) SetFinishedAt(t time.Time) { + m.finished_at = &t +} + +// FinishedAt returns the value of the "finished_at" field in the mutation. +func (m *UsageCleanupTaskMutation) FinishedAt() (r time.Time, exists bool) { + v := m.finished_at + if v == nil { + return + } + return *v, true +} + +// OldFinishedAt returns the old "finished_at" field's value of the UsageCleanupTask entity. +// If the UsageCleanupTask object wasn't provided to the builder, the object is fetched from the database. +// An error is returned if the mutation operation is not UpdateOne, or the database query fails. +func (m *UsageCleanupTaskMutation) OldFinishedAt(ctx context.Context) (v *time.Time, err error) { + if !m.op.Is(OpUpdateOne) { + return v, errors.New("OldFinishedAt is only allowed on UpdateOne operations") + } + if m.id == nil || m.oldValue == nil { + return v, errors.New("OldFinishedAt requires an ID field in the mutation") + } + oldValue, err := m.oldValue(ctx) + if err != nil { + return v, fmt.Errorf("querying old value for OldFinishedAt: %w", err) + } + return oldValue.FinishedAt, nil +} + +// ClearFinishedAt clears the value of the "finished_at" field. +func (m *UsageCleanupTaskMutation) ClearFinishedAt() { + m.finished_at = nil + m.clearedFields[usagecleanuptask.FieldFinishedAt] = struct{}{} +} + +// FinishedAtCleared returns if the "finished_at" field was cleared in this mutation. +func (m *UsageCleanupTaskMutation) FinishedAtCleared() bool { + _, ok := m.clearedFields[usagecleanuptask.FieldFinishedAt] + return ok +} + +// ResetFinishedAt resets all changes to the "finished_at" field. +func (m *UsageCleanupTaskMutation) ResetFinishedAt() { + m.finished_at = nil + delete(m.clearedFields, usagecleanuptask.FieldFinishedAt) +} + +// Where appends a list predicates to the UsageCleanupTaskMutation builder. +func (m *UsageCleanupTaskMutation) Where(ps ...predicate.UsageCleanupTask) { + m.predicates = append(m.predicates, ps...) +} + +// WhereP appends storage-level predicates to the UsageCleanupTaskMutation builder. Using this method, +// users can use type-assertion to append predicates that do not depend on any generated package. +func (m *UsageCleanupTaskMutation) WhereP(ps ...func(*sql.Selector)) { + p := make([]predicate.UsageCleanupTask, len(ps)) + for i := range ps { + p[i] = ps[i] + } + m.Where(p...) +} + +// Op returns the operation name. +func (m *UsageCleanupTaskMutation) Op() Op { + return m.op +} + +// SetOp allows setting the mutation operation. +func (m *UsageCleanupTaskMutation) SetOp(op Op) { + m.op = op +} + +// Type returns the node type of this mutation (UsageCleanupTask). +func (m *UsageCleanupTaskMutation) Type() string { + return m.typ +} + +// Fields returns all fields that were changed during this mutation. Note that in +// order to get all numeric fields that were incremented/decremented, call +// AddedFields(). +func (m *UsageCleanupTaskMutation) Fields() []string { + fields := make([]string, 0, 11) + if m.created_at != nil { + fields = append(fields, usagecleanuptask.FieldCreatedAt) + } + if m.updated_at != nil { + fields = append(fields, usagecleanuptask.FieldUpdatedAt) + } + if m.status != nil { + fields = append(fields, usagecleanuptask.FieldStatus) + } + if m.filters != nil { + fields = append(fields, usagecleanuptask.FieldFilters) + } + if m.created_by != nil { + fields = append(fields, usagecleanuptask.FieldCreatedBy) + } + if m.deleted_rows != nil { + fields = append(fields, usagecleanuptask.FieldDeletedRows) + } + if m.error_message != nil { + fields = append(fields, usagecleanuptask.FieldErrorMessage) + } + if m.canceled_by != nil { + fields = append(fields, usagecleanuptask.FieldCanceledBy) + } + if m.canceled_at != nil { + fields = append(fields, usagecleanuptask.FieldCanceledAt) + } + if m.started_at != nil { + fields = append(fields, usagecleanuptask.FieldStartedAt) + } + if m.finished_at != nil { + fields = append(fields, usagecleanuptask.FieldFinishedAt) + } + return fields +} + +// Field returns the value of a field with the given name. The second boolean +// return value indicates that this field was not set, or was not defined in the +// schema. +func (m *UsageCleanupTaskMutation) Field(name string) (ent.Value, bool) { + switch name { + case usagecleanuptask.FieldCreatedAt: + return m.CreatedAt() + case usagecleanuptask.FieldUpdatedAt: + return m.UpdatedAt() + case usagecleanuptask.FieldStatus: + return m.Status() + case usagecleanuptask.FieldFilters: + return m.Filters() + case usagecleanuptask.FieldCreatedBy: + return m.CreatedBy() + case usagecleanuptask.FieldDeletedRows: + return m.DeletedRows() + case usagecleanuptask.FieldErrorMessage: + return m.ErrorMessage() + case usagecleanuptask.FieldCanceledBy: + return m.CanceledBy() + case usagecleanuptask.FieldCanceledAt: + return m.CanceledAt() + case usagecleanuptask.FieldStartedAt: + return m.StartedAt() + case usagecleanuptask.FieldFinishedAt: + return m.FinishedAt() + } + return nil, false +} + +// OldField returns the old value of the field from the database. An error is +// returned if the mutation operation is not UpdateOne, or the query to the +// database failed. +func (m *UsageCleanupTaskMutation) OldField(ctx context.Context, name string) (ent.Value, error) { + switch name { + case usagecleanuptask.FieldCreatedAt: + return m.OldCreatedAt(ctx) + case usagecleanuptask.FieldUpdatedAt: + return m.OldUpdatedAt(ctx) + case usagecleanuptask.FieldStatus: + return m.OldStatus(ctx) + case usagecleanuptask.FieldFilters: + return m.OldFilters(ctx) + case usagecleanuptask.FieldCreatedBy: + return m.OldCreatedBy(ctx) + case usagecleanuptask.FieldDeletedRows: + return m.OldDeletedRows(ctx) + case usagecleanuptask.FieldErrorMessage: + return m.OldErrorMessage(ctx) + case usagecleanuptask.FieldCanceledBy: + return m.OldCanceledBy(ctx) + case usagecleanuptask.FieldCanceledAt: + return m.OldCanceledAt(ctx) + case usagecleanuptask.FieldStartedAt: + return m.OldStartedAt(ctx) + case usagecleanuptask.FieldFinishedAt: + return m.OldFinishedAt(ctx) + } + return nil, fmt.Errorf("unknown UsageCleanupTask field %s", name) +} + +// SetField sets the value of a field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *UsageCleanupTaskMutation) SetField(name string, value ent.Value) error { + switch name { + case usagecleanuptask.FieldCreatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedAt(v) + return nil + case usagecleanuptask.FieldUpdatedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetUpdatedAt(v) + return nil + case usagecleanuptask.FieldStatus: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStatus(v) + return nil + case usagecleanuptask.FieldFilters: + v, ok := value.(json.RawMessage) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFilters(v) + return nil + case usagecleanuptask.FieldCreatedBy: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCreatedBy(v) + return nil + case usagecleanuptask.FieldDeletedRows: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetDeletedRows(v) + return nil + case usagecleanuptask.FieldErrorMessage: + v, ok := value.(string) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetErrorMessage(v) + return nil + case usagecleanuptask.FieldCanceledBy: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCanceledBy(v) + return nil + case usagecleanuptask.FieldCanceledAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetCanceledAt(v) + return nil + case usagecleanuptask.FieldStartedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetStartedAt(v) + return nil + case usagecleanuptask.FieldFinishedAt: + v, ok := value.(time.Time) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.SetFinishedAt(v) + return nil + } + return fmt.Errorf("unknown UsageCleanupTask field %s", name) +} + +// AddedFields returns all numeric fields that were incremented/decremented during +// this mutation. +func (m *UsageCleanupTaskMutation) AddedFields() []string { + var fields []string + if m.addcreated_by != nil { + fields = append(fields, usagecleanuptask.FieldCreatedBy) + } + if m.adddeleted_rows != nil { + fields = append(fields, usagecleanuptask.FieldDeletedRows) + } + if m.addcanceled_by != nil { + fields = append(fields, usagecleanuptask.FieldCanceledBy) + } + return fields +} + +// AddedField returns the numeric value that was incremented/decremented on a field +// with the given name. The second boolean return value indicates that this field +// was not set, or was not defined in the schema. +func (m *UsageCleanupTaskMutation) AddedField(name string) (ent.Value, bool) { + switch name { + case usagecleanuptask.FieldCreatedBy: + return m.AddedCreatedBy() + case usagecleanuptask.FieldDeletedRows: + return m.AddedDeletedRows() + case usagecleanuptask.FieldCanceledBy: + return m.AddedCanceledBy() + } + return nil, false +} + +// AddField adds the value to the field with the given name. It returns an error if +// the field is not defined in the schema, or if the type mismatched the field +// type. +func (m *UsageCleanupTaskMutation) AddField(name string, value ent.Value) error { + switch name { + case usagecleanuptask.FieldCreatedBy: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddCreatedBy(v) + return nil + case usagecleanuptask.FieldDeletedRows: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddDeletedRows(v) + return nil + case usagecleanuptask.FieldCanceledBy: + v, ok := value.(int64) + if !ok { + return fmt.Errorf("unexpected type %T for field %s", value, name) + } + m.AddCanceledBy(v) + return nil + } + return fmt.Errorf("unknown UsageCleanupTask numeric field %s", name) +} + +// ClearedFields returns all nullable fields that were cleared during this +// mutation. +func (m *UsageCleanupTaskMutation) ClearedFields() []string { + var fields []string + if m.FieldCleared(usagecleanuptask.FieldErrorMessage) { + fields = append(fields, usagecleanuptask.FieldErrorMessage) + } + if m.FieldCleared(usagecleanuptask.FieldCanceledBy) { + fields = append(fields, usagecleanuptask.FieldCanceledBy) + } + if m.FieldCleared(usagecleanuptask.FieldCanceledAt) { + fields = append(fields, usagecleanuptask.FieldCanceledAt) + } + if m.FieldCleared(usagecleanuptask.FieldStartedAt) { + fields = append(fields, usagecleanuptask.FieldStartedAt) + } + if m.FieldCleared(usagecleanuptask.FieldFinishedAt) { + fields = append(fields, usagecleanuptask.FieldFinishedAt) + } + return fields +} + +// FieldCleared returns a boolean indicating if a field with the given name was +// cleared in this mutation. +func (m *UsageCleanupTaskMutation) FieldCleared(name string) bool { + _, ok := m.clearedFields[name] + return ok +} + +// ClearField clears the value of the field with the given name. It returns an +// error if the field is not defined in the schema. +func (m *UsageCleanupTaskMutation) ClearField(name string) error { + switch name { + case usagecleanuptask.FieldErrorMessage: + m.ClearErrorMessage() + return nil + case usagecleanuptask.FieldCanceledBy: + m.ClearCanceledBy() + return nil + case usagecleanuptask.FieldCanceledAt: + m.ClearCanceledAt() + return nil + case usagecleanuptask.FieldStartedAt: + m.ClearStartedAt() + return nil + case usagecleanuptask.FieldFinishedAt: + m.ClearFinishedAt() + return nil + } + return fmt.Errorf("unknown UsageCleanupTask nullable field %s", name) +} + +// ResetField resets all changes in the mutation for the field with the given name. +// It returns an error if the field is not defined in the schema. +func (m *UsageCleanupTaskMutation) ResetField(name string) error { + switch name { + case usagecleanuptask.FieldCreatedAt: + m.ResetCreatedAt() + return nil + case usagecleanuptask.FieldUpdatedAt: + m.ResetUpdatedAt() + return nil + case usagecleanuptask.FieldStatus: + m.ResetStatus() + return nil + case usagecleanuptask.FieldFilters: + m.ResetFilters() + return nil + case usagecleanuptask.FieldCreatedBy: + m.ResetCreatedBy() + return nil + case usagecleanuptask.FieldDeletedRows: + m.ResetDeletedRows() + return nil + case usagecleanuptask.FieldErrorMessage: + m.ResetErrorMessage() + return nil + case usagecleanuptask.FieldCanceledBy: + m.ResetCanceledBy() + return nil + case usagecleanuptask.FieldCanceledAt: + m.ResetCanceledAt() + return nil + case usagecleanuptask.FieldStartedAt: + m.ResetStartedAt() + return nil + case usagecleanuptask.FieldFinishedAt: + m.ResetFinishedAt() + return nil + } + return fmt.Errorf("unknown UsageCleanupTask field %s", name) +} + +// AddedEdges returns all edge names that were set/added in this mutation. +func (m *UsageCleanupTaskMutation) AddedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// AddedIDs returns all IDs (to other nodes) that were added for the given edge +// name in this mutation. +func (m *UsageCleanupTaskMutation) AddedIDs(name string) []ent.Value { + return nil +} + +// RemovedEdges returns all edge names that were removed in this mutation. +func (m *UsageCleanupTaskMutation) RemovedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// RemovedIDs returns all IDs (to other nodes) that were removed for the edge with +// the given name in this mutation. +func (m *UsageCleanupTaskMutation) RemovedIDs(name string) []ent.Value { + return nil +} + +// ClearedEdges returns all edge names that were cleared in this mutation. +func (m *UsageCleanupTaskMutation) ClearedEdges() []string { + edges := make([]string, 0, 0) + return edges +} + +// EdgeCleared returns a boolean which indicates if the edge with the given name +// was cleared in this mutation. +func (m *UsageCleanupTaskMutation) EdgeCleared(name string) bool { + return false +} + +// ClearEdge clears the value of the edge with the given name. It returns an error +// if that edge is not defined in the schema. +func (m *UsageCleanupTaskMutation) ClearEdge(name string) error { + return fmt.Errorf("unknown UsageCleanupTask unique edge %s", name) +} + +// ResetEdge resets all changes to the edge with the given name in this mutation. +// It returns an error if the edge is not defined in the schema. +func (m *UsageCleanupTaskMutation) ResetEdge(name string) error { + return fmt.Errorf("unknown UsageCleanupTask edge %s", name) +} + // UsageLogMutation represents an operation that mutates the UsageLog nodes in the graph. type UsageLogMutation struct { config diff --git a/backend/ent/predicate/predicate.go b/backend/ent/predicate/predicate.go index 7a443c5d..785cb4e6 100644 --- a/backend/ent/predicate/predicate.go +++ b/backend/ent/predicate/predicate.go @@ -33,6 +33,9 @@ type RedeemCode func(*sql.Selector) // Setting is the predicate function for setting builders. type Setting func(*sql.Selector) +// UsageCleanupTask is the predicate function for usagecleanuptask builders. +type UsageCleanupTask func(*sql.Selector) + // UsageLog is the predicate function for usagelog builders. type UsageLog func(*sql.Selector) diff --git a/backend/ent/runtime/runtime.go b/backend/ent/runtime/runtime.go index 0cb10775..1e3f4cbe 100644 --- a/backend/ent/runtime/runtime.go +++ b/backend/ent/runtime/runtime.go @@ -15,6 +15,7 @@ import ( "github.com/Wei-Shaw/sub2api/ent/redeemcode" "github.com/Wei-Shaw/sub2api/ent/schema" "github.com/Wei-Shaw/sub2api/ent/setting" + "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" "github.com/Wei-Shaw/sub2api/ent/usagelog" "github.com/Wei-Shaw/sub2api/ent/user" "github.com/Wei-Shaw/sub2api/ent/userallowedgroup" @@ -495,6 +496,43 @@ func init() { setting.DefaultUpdatedAt = settingDescUpdatedAt.Default.(func() time.Time) // setting.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. setting.UpdateDefaultUpdatedAt = settingDescUpdatedAt.UpdateDefault.(func() time.Time) + usagecleanuptaskMixin := schema.UsageCleanupTask{}.Mixin() + usagecleanuptaskMixinFields0 := usagecleanuptaskMixin[0].Fields() + _ = usagecleanuptaskMixinFields0 + usagecleanuptaskFields := schema.UsageCleanupTask{}.Fields() + _ = usagecleanuptaskFields + // usagecleanuptaskDescCreatedAt is the schema descriptor for created_at field. + usagecleanuptaskDescCreatedAt := usagecleanuptaskMixinFields0[0].Descriptor() + // usagecleanuptask.DefaultCreatedAt holds the default value on creation for the created_at field. + usagecleanuptask.DefaultCreatedAt = usagecleanuptaskDescCreatedAt.Default.(func() time.Time) + // usagecleanuptaskDescUpdatedAt is the schema descriptor for updated_at field. + usagecleanuptaskDescUpdatedAt := usagecleanuptaskMixinFields0[1].Descriptor() + // usagecleanuptask.DefaultUpdatedAt holds the default value on creation for the updated_at field. + usagecleanuptask.DefaultUpdatedAt = usagecleanuptaskDescUpdatedAt.Default.(func() time.Time) + // usagecleanuptask.UpdateDefaultUpdatedAt holds the default value on update for the updated_at field. + usagecleanuptask.UpdateDefaultUpdatedAt = usagecleanuptaskDescUpdatedAt.UpdateDefault.(func() time.Time) + // usagecleanuptaskDescStatus is the schema descriptor for status field. + usagecleanuptaskDescStatus := usagecleanuptaskFields[0].Descriptor() + // usagecleanuptask.StatusValidator is a validator for the "status" field. It is called by the builders before save. + usagecleanuptask.StatusValidator = func() func(string) error { + validators := usagecleanuptaskDescStatus.Validators + fns := [...]func(string) error{ + validators[0].(func(string) error), + validators[1].(func(string) error), + } + return func(status string) error { + for _, fn := range fns { + if err := fn(status); err != nil { + return err + } + } + return nil + } + }() + // usagecleanuptaskDescDeletedRows is the schema descriptor for deleted_rows field. + usagecleanuptaskDescDeletedRows := usagecleanuptaskFields[3].Descriptor() + // usagecleanuptask.DefaultDeletedRows holds the default value on creation for the deleted_rows field. + usagecleanuptask.DefaultDeletedRows = usagecleanuptaskDescDeletedRows.Default.(int64) usagelogFields := schema.UsageLog{}.Fields() _ = usagelogFields // usagelogDescRequestID is the schema descriptor for request_id field. diff --git a/backend/ent/schema/mixins/soft_delete.go b/backend/ent/schema/mixins/soft_delete.go index 9571bc9c..22eded3e 100644 --- a/backend/ent/schema/mixins/soft_delete.go +++ b/backend/ent/schema/mixins/soft_delete.go @@ -5,6 +5,7 @@ package mixins import ( "context" "fmt" + "reflect" "time" "entgo.io/ent" @@ -12,7 +13,6 @@ import ( "entgo.io/ent/dialect/sql" "entgo.io/ent/schema/field" "entgo.io/ent/schema/mixin" - dbent "github.com/Wei-Shaw/sub2api/ent" "github.com/Wei-Shaw/sub2api/ent/intercept" ) @@ -113,7 +113,6 @@ func (d SoftDeleteMixin) Hooks() []ent.Hook { SetOp(ent.Op) SetDeletedAt(time.Time) WhereP(...func(*sql.Selector)) - Client() *dbent.Client }) if !ok { return nil, fmt.Errorf("unexpected mutation type %T", m) @@ -124,7 +123,7 @@ func (d SoftDeleteMixin) Hooks() []ent.Hook { mx.SetOp(ent.OpUpdate) // 设置删除时间为当前时间 mx.SetDeletedAt(time.Now()) - return mx.Client().Mutate(ctx, m) + return mutateWithClient(ctx, m, next) }) }, } @@ -137,3 +136,41 @@ func (d SoftDeleteMixin) applyPredicate(w interface{ WhereP(...func(*sql.Selecto sql.FieldIsNull(d.Fields()[0].Descriptor().Name), ) } + +func mutateWithClient(ctx context.Context, m ent.Mutation, fallback ent.Mutator) (ent.Value, error) { + clientMethod := reflect.ValueOf(m).MethodByName("Client") + if !clientMethod.IsValid() || clientMethod.Type().NumIn() != 0 || clientMethod.Type().NumOut() != 1 { + return nil, fmt.Errorf("soft delete: mutation client method not found for %T", m) + } + client := clientMethod.Call(nil)[0] + mutateMethod := client.MethodByName("Mutate") + if !mutateMethod.IsValid() { + return nil, fmt.Errorf("soft delete: mutation client missing Mutate for %T", m) + } + if mutateMethod.Type().NumIn() != 2 || mutateMethod.Type().NumOut() != 2 { + return nil, fmt.Errorf("soft delete: mutation client signature mismatch for %T", m) + } + + results := mutateMethod.Call([]reflect.Value{reflect.ValueOf(ctx), reflect.ValueOf(m)}) + value := results[0].Interface() + var err error + if !results[1].IsNil() { + errValue := results[1].Interface() + typedErr, ok := errValue.(error) + if !ok { + return nil, fmt.Errorf("soft delete: unexpected error type %T for %T", errValue, m) + } + err = typedErr + } + if err != nil { + return nil, err + } + if value == nil { + return nil, fmt.Errorf("soft delete: mutation client returned nil for %T", m) + } + v, ok := value.(ent.Value) + if !ok { + return nil, fmt.Errorf("soft delete: unexpected value type %T for %T", value, m) + } + return v, nil +} diff --git a/backend/ent/schema/usage_cleanup_task.go b/backend/ent/schema/usage_cleanup_task.go new file mode 100644 index 00000000..753e6410 --- /dev/null +++ b/backend/ent/schema/usage_cleanup_task.go @@ -0,0 +1,75 @@ +package schema + +import ( + "encoding/json" + "fmt" + + "github.com/Wei-Shaw/sub2api/ent/schema/mixins" + + "entgo.io/ent" + "entgo.io/ent/dialect/entsql" + "entgo.io/ent/schema" + "entgo.io/ent/schema/field" + "entgo.io/ent/schema/index" +) + +// UsageCleanupTask 定义使用记录清理任务的 schema。 +type UsageCleanupTask struct { + ent.Schema +} + +func (UsageCleanupTask) Annotations() []schema.Annotation { + return []schema.Annotation{ + entsql.Annotation{Table: "usage_cleanup_tasks"}, + } +} + +func (UsageCleanupTask) Mixin() []ent.Mixin { + return []ent.Mixin{ + mixins.TimeMixin{}, + } +} + +func (UsageCleanupTask) Fields() []ent.Field { + return []ent.Field{ + field.String("status"). + MaxLen(20). + Validate(validateUsageCleanupStatus), + field.JSON("filters", json.RawMessage{}), + field.Int64("created_by"), + field.Int64("deleted_rows"). + Default(0), + field.String("error_message"). + Optional(). + Nillable(), + field.Int64("canceled_by"). + Optional(). + Nillable(), + field.Time("canceled_at"). + Optional(). + Nillable(), + field.Time("started_at"). + Optional(). + Nillable(), + field.Time("finished_at"). + Optional(). + Nillable(), + } +} + +func (UsageCleanupTask) Indexes() []ent.Index { + return []ent.Index{ + index.Fields("status", "created_at"), + index.Fields("created_at"), + index.Fields("canceled_at"), + } +} + +func validateUsageCleanupStatus(status string) error { + switch status { + case "pending", "running", "succeeded", "failed", "canceled": + return nil + default: + return fmt.Errorf("invalid usage cleanup status: %s", status) + } +} diff --git a/backend/ent/tx.go b/backend/ent/tx.go index 56df121a..7ff16ec8 100644 --- a/backend/ent/tx.go +++ b/backend/ent/tx.go @@ -32,6 +32,8 @@ type Tx struct { RedeemCode *RedeemCodeClient // Setting is the client for interacting with the Setting builders. Setting *SettingClient + // UsageCleanupTask is the client for interacting with the UsageCleanupTask builders. + UsageCleanupTask *UsageCleanupTaskClient // UsageLog is the client for interacting with the UsageLog builders. UsageLog *UsageLogClient // User is the client for interacting with the User builders. @@ -184,6 +186,7 @@ func (tx *Tx) init() { tx.Proxy = NewProxyClient(tx.config) tx.RedeemCode = NewRedeemCodeClient(tx.config) tx.Setting = NewSettingClient(tx.config) + tx.UsageCleanupTask = NewUsageCleanupTaskClient(tx.config) tx.UsageLog = NewUsageLogClient(tx.config) tx.User = NewUserClient(tx.config) tx.UserAllowedGroup = NewUserAllowedGroupClient(tx.config) diff --git a/backend/ent/usagecleanuptask.go b/backend/ent/usagecleanuptask.go new file mode 100644 index 00000000..e3a17b5a --- /dev/null +++ b/backend/ent/usagecleanuptask.go @@ -0,0 +1,236 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "encoding/json" + "fmt" + "strings" + "time" + + "entgo.io/ent" + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" +) + +// UsageCleanupTask is the model entity for the UsageCleanupTask schema. +type UsageCleanupTask struct { + config `json:"-"` + // ID of the ent. + ID int64 `json:"id,omitempty"` + // CreatedAt holds the value of the "created_at" field. + CreatedAt time.Time `json:"created_at,omitempty"` + // UpdatedAt holds the value of the "updated_at" field. + UpdatedAt time.Time `json:"updated_at,omitempty"` + // Status holds the value of the "status" field. + Status string `json:"status,omitempty"` + // Filters holds the value of the "filters" field. + Filters json.RawMessage `json:"filters,omitempty"` + // CreatedBy holds the value of the "created_by" field. + CreatedBy int64 `json:"created_by,omitempty"` + // DeletedRows holds the value of the "deleted_rows" field. + DeletedRows int64 `json:"deleted_rows,omitempty"` + // ErrorMessage holds the value of the "error_message" field. + ErrorMessage *string `json:"error_message,omitempty"` + // CanceledBy holds the value of the "canceled_by" field. + CanceledBy *int64 `json:"canceled_by,omitempty"` + // CanceledAt holds the value of the "canceled_at" field. + CanceledAt *time.Time `json:"canceled_at,omitempty"` + // StartedAt holds the value of the "started_at" field. + StartedAt *time.Time `json:"started_at,omitempty"` + // FinishedAt holds the value of the "finished_at" field. + FinishedAt *time.Time `json:"finished_at,omitempty"` + selectValues sql.SelectValues +} + +// scanValues returns the types for scanning values from sql.Rows. +func (*UsageCleanupTask) scanValues(columns []string) ([]any, error) { + values := make([]any, len(columns)) + for i := range columns { + switch columns[i] { + case usagecleanuptask.FieldFilters: + values[i] = new([]byte) + case usagecleanuptask.FieldID, usagecleanuptask.FieldCreatedBy, usagecleanuptask.FieldDeletedRows, usagecleanuptask.FieldCanceledBy: + values[i] = new(sql.NullInt64) + case usagecleanuptask.FieldStatus, usagecleanuptask.FieldErrorMessage: + values[i] = new(sql.NullString) + case usagecleanuptask.FieldCreatedAt, usagecleanuptask.FieldUpdatedAt, usagecleanuptask.FieldCanceledAt, usagecleanuptask.FieldStartedAt, usagecleanuptask.FieldFinishedAt: + values[i] = new(sql.NullTime) + default: + values[i] = new(sql.UnknownType) + } + } + return values, nil +} + +// assignValues assigns the values that were returned from sql.Rows (after scanning) +// to the UsageCleanupTask fields. +func (_m *UsageCleanupTask) assignValues(columns []string, values []any) error { + if m, n := len(values), len(columns); m < n { + return fmt.Errorf("mismatch number of scan values: %d != %d", m, n) + } + for i := range columns { + switch columns[i] { + case usagecleanuptask.FieldID: + value, ok := values[i].(*sql.NullInt64) + if !ok { + return fmt.Errorf("unexpected type %T for field id", value) + } + _m.ID = int64(value.Int64) + case usagecleanuptask.FieldCreatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field created_at", values[i]) + } else if value.Valid { + _m.CreatedAt = value.Time + } + case usagecleanuptask.FieldUpdatedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field updated_at", values[i]) + } else if value.Valid { + _m.UpdatedAt = value.Time + } + case usagecleanuptask.FieldStatus: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field status", values[i]) + } else if value.Valid { + _m.Status = value.String + } + case usagecleanuptask.FieldFilters: + if value, ok := values[i].(*[]byte); !ok { + return fmt.Errorf("unexpected type %T for field filters", values[i]) + } else if value != nil && len(*value) > 0 { + if err := json.Unmarshal(*value, &_m.Filters); err != nil { + return fmt.Errorf("unmarshal field filters: %w", err) + } + } + case usagecleanuptask.FieldCreatedBy: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field created_by", values[i]) + } else if value.Valid { + _m.CreatedBy = value.Int64 + } + case usagecleanuptask.FieldDeletedRows: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field deleted_rows", values[i]) + } else if value.Valid { + _m.DeletedRows = value.Int64 + } + case usagecleanuptask.FieldErrorMessage: + if value, ok := values[i].(*sql.NullString); !ok { + return fmt.Errorf("unexpected type %T for field error_message", values[i]) + } else if value.Valid { + _m.ErrorMessage = new(string) + *_m.ErrorMessage = value.String + } + case usagecleanuptask.FieldCanceledBy: + if value, ok := values[i].(*sql.NullInt64); !ok { + return fmt.Errorf("unexpected type %T for field canceled_by", values[i]) + } else if value.Valid { + _m.CanceledBy = new(int64) + *_m.CanceledBy = value.Int64 + } + case usagecleanuptask.FieldCanceledAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field canceled_at", values[i]) + } else if value.Valid { + _m.CanceledAt = new(time.Time) + *_m.CanceledAt = value.Time + } + case usagecleanuptask.FieldStartedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field started_at", values[i]) + } else if value.Valid { + _m.StartedAt = new(time.Time) + *_m.StartedAt = value.Time + } + case usagecleanuptask.FieldFinishedAt: + if value, ok := values[i].(*sql.NullTime); !ok { + return fmt.Errorf("unexpected type %T for field finished_at", values[i]) + } else if value.Valid { + _m.FinishedAt = new(time.Time) + *_m.FinishedAt = value.Time + } + default: + _m.selectValues.Set(columns[i], values[i]) + } + } + return nil +} + +// Value returns the ent.Value that was dynamically selected and assigned to the UsageCleanupTask. +// This includes values selected through modifiers, order, etc. +func (_m *UsageCleanupTask) Value(name string) (ent.Value, error) { + return _m.selectValues.Get(name) +} + +// Update returns a builder for updating this UsageCleanupTask. +// Note that you need to call UsageCleanupTask.Unwrap() before calling this method if this UsageCleanupTask +// was returned from a transaction, and the transaction was committed or rolled back. +func (_m *UsageCleanupTask) Update() *UsageCleanupTaskUpdateOne { + return NewUsageCleanupTaskClient(_m.config).UpdateOne(_m) +} + +// Unwrap unwraps the UsageCleanupTask entity that was returned from a transaction after it was closed, +// so that all future queries will be executed through the driver which created the transaction. +func (_m *UsageCleanupTask) Unwrap() *UsageCleanupTask { + _tx, ok := _m.config.driver.(*txDriver) + if !ok { + panic("ent: UsageCleanupTask is not a transactional entity") + } + _m.config.driver = _tx.drv + return _m +} + +// String implements the fmt.Stringer. +func (_m *UsageCleanupTask) String() string { + var builder strings.Builder + builder.WriteString("UsageCleanupTask(") + builder.WriteString(fmt.Sprintf("id=%v, ", _m.ID)) + builder.WriteString("created_at=") + builder.WriteString(_m.CreatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("updated_at=") + builder.WriteString(_m.UpdatedAt.Format(time.ANSIC)) + builder.WriteString(", ") + builder.WriteString("status=") + builder.WriteString(_m.Status) + builder.WriteString(", ") + builder.WriteString("filters=") + builder.WriteString(fmt.Sprintf("%v", _m.Filters)) + builder.WriteString(", ") + builder.WriteString("created_by=") + builder.WriteString(fmt.Sprintf("%v", _m.CreatedBy)) + builder.WriteString(", ") + builder.WriteString("deleted_rows=") + builder.WriteString(fmt.Sprintf("%v", _m.DeletedRows)) + builder.WriteString(", ") + if v := _m.ErrorMessage; v != nil { + builder.WriteString("error_message=") + builder.WriteString(*v) + } + builder.WriteString(", ") + if v := _m.CanceledBy; v != nil { + builder.WriteString("canceled_by=") + builder.WriteString(fmt.Sprintf("%v", *v)) + } + builder.WriteString(", ") + if v := _m.CanceledAt; v != nil { + builder.WriteString("canceled_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.StartedAt; v != nil { + builder.WriteString("started_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteString(", ") + if v := _m.FinishedAt; v != nil { + builder.WriteString("finished_at=") + builder.WriteString(v.Format(time.ANSIC)) + } + builder.WriteByte(')') + return builder.String() +} + +// UsageCleanupTasks is a parsable slice of UsageCleanupTask. +type UsageCleanupTasks []*UsageCleanupTask diff --git a/backend/ent/usagecleanuptask/usagecleanuptask.go b/backend/ent/usagecleanuptask/usagecleanuptask.go new file mode 100644 index 00000000..a8ddd9a0 --- /dev/null +++ b/backend/ent/usagecleanuptask/usagecleanuptask.go @@ -0,0 +1,137 @@ +// Code generated by ent, DO NOT EDIT. + +package usagecleanuptask + +import ( + "time" + + "entgo.io/ent/dialect/sql" +) + +const ( + // Label holds the string label denoting the usagecleanuptask type in the database. + Label = "usage_cleanup_task" + // FieldID holds the string denoting the id field in the database. + FieldID = "id" + // FieldCreatedAt holds the string denoting the created_at field in the database. + FieldCreatedAt = "created_at" + // FieldUpdatedAt holds the string denoting the updated_at field in the database. + FieldUpdatedAt = "updated_at" + // FieldStatus holds the string denoting the status field in the database. + FieldStatus = "status" + // FieldFilters holds the string denoting the filters field in the database. + FieldFilters = "filters" + // FieldCreatedBy holds the string denoting the created_by field in the database. + FieldCreatedBy = "created_by" + // FieldDeletedRows holds the string denoting the deleted_rows field in the database. + FieldDeletedRows = "deleted_rows" + // FieldErrorMessage holds the string denoting the error_message field in the database. + FieldErrorMessage = "error_message" + // FieldCanceledBy holds the string denoting the canceled_by field in the database. + FieldCanceledBy = "canceled_by" + // FieldCanceledAt holds the string denoting the canceled_at field in the database. + FieldCanceledAt = "canceled_at" + // FieldStartedAt holds the string denoting the started_at field in the database. + FieldStartedAt = "started_at" + // FieldFinishedAt holds the string denoting the finished_at field in the database. + FieldFinishedAt = "finished_at" + // Table holds the table name of the usagecleanuptask in the database. + Table = "usage_cleanup_tasks" +) + +// Columns holds all SQL columns for usagecleanuptask fields. +var Columns = []string{ + FieldID, + FieldCreatedAt, + FieldUpdatedAt, + FieldStatus, + FieldFilters, + FieldCreatedBy, + FieldDeletedRows, + FieldErrorMessage, + FieldCanceledBy, + FieldCanceledAt, + FieldStartedAt, + FieldFinishedAt, +} + +// ValidColumn reports if the column name is valid (part of the table columns). +func ValidColumn(column string) bool { + for i := range Columns { + if column == Columns[i] { + return true + } + } + return false +} + +var ( + // DefaultCreatedAt holds the default value on creation for the "created_at" field. + DefaultCreatedAt func() time.Time + // DefaultUpdatedAt holds the default value on creation for the "updated_at" field. + DefaultUpdatedAt func() time.Time + // UpdateDefaultUpdatedAt holds the default value on update for the "updated_at" field. + UpdateDefaultUpdatedAt func() time.Time + // StatusValidator is a validator for the "status" field. It is called by the builders before save. + StatusValidator func(string) error + // DefaultDeletedRows holds the default value on creation for the "deleted_rows" field. + DefaultDeletedRows int64 +) + +// OrderOption defines the ordering options for the UsageCleanupTask queries. +type OrderOption func(*sql.Selector) + +// ByID orders the results by the id field. +func ByID(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldID, opts...).ToFunc() +} + +// ByCreatedAt orders the results by the created_at field. +func ByCreatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedAt, opts...).ToFunc() +} + +// ByUpdatedAt orders the results by the updated_at field. +func ByUpdatedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldUpdatedAt, opts...).ToFunc() +} + +// ByStatus orders the results by the status field. +func ByStatus(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStatus, opts...).ToFunc() +} + +// ByCreatedBy orders the results by the created_by field. +func ByCreatedBy(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCreatedBy, opts...).ToFunc() +} + +// ByDeletedRows orders the results by the deleted_rows field. +func ByDeletedRows(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldDeletedRows, opts...).ToFunc() +} + +// ByErrorMessage orders the results by the error_message field. +func ByErrorMessage(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldErrorMessage, opts...).ToFunc() +} + +// ByCanceledBy orders the results by the canceled_by field. +func ByCanceledBy(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCanceledBy, opts...).ToFunc() +} + +// ByCanceledAt orders the results by the canceled_at field. +func ByCanceledAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldCanceledAt, opts...).ToFunc() +} + +// ByStartedAt orders the results by the started_at field. +func ByStartedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldStartedAt, opts...).ToFunc() +} + +// ByFinishedAt orders the results by the finished_at field. +func ByFinishedAt(opts ...sql.OrderTermOption) OrderOption { + return sql.OrderByField(FieldFinishedAt, opts...).ToFunc() +} diff --git a/backend/ent/usagecleanuptask/where.go b/backend/ent/usagecleanuptask/where.go new file mode 100644 index 00000000..99e790ca --- /dev/null +++ b/backend/ent/usagecleanuptask/where.go @@ -0,0 +1,620 @@ +// Code generated by ent, DO NOT EDIT. + +package usagecleanuptask + +import ( + "time" + + "entgo.io/ent/dialect/sql" + "github.com/Wei-Shaw/sub2api/ent/predicate" +) + +// ID filters vertices based on their ID field. +func ID(id int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldEQ(FieldID, id)) +} + +// IDEQ applies the EQ predicate on the ID field. +func IDEQ(id int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldEQ(FieldID, id)) +} + +// IDNEQ applies the NEQ predicate on the ID field. +func IDNEQ(id int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldNEQ(FieldID, id)) +} + +// IDIn applies the In predicate on the ID field. +func IDIn(ids ...int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldIn(FieldID, ids...)) +} + +// IDNotIn applies the NotIn predicate on the ID field. +func IDNotIn(ids ...int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldNotIn(FieldID, ids...)) +} + +// IDGT applies the GT predicate on the ID field. +func IDGT(id int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldGT(FieldID, id)) +} + +// IDGTE applies the GTE predicate on the ID field. +func IDGTE(id int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldGTE(FieldID, id)) +} + +// IDLT applies the LT predicate on the ID field. +func IDLT(id int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldLT(FieldID, id)) +} + +// IDLTE applies the LTE predicate on the ID field. +func IDLTE(id int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldLTE(FieldID, id)) +} + +// CreatedAt applies equality check predicate on the "created_at" field. It's identical to CreatedAtEQ. +func CreatedAt(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldEQ(FieldCreatedAt, v)) +} + +// UpdatedAt applies equality check predicate on the "updated_at" field. It's identical to UpdatedAtEQ. +func UpdatedAt(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// Status applies equality check predicate on the "status" field. It's identical to StatusEQ. +func Status(v string) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldEQ(FieldStatus, v)) +} + +// CreatedBy applies equality check predicate on the "created_by" field. It's identical to CreatedByEQ. +func CreatedBy(v int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldEQ(FieldCreatedBy, v)) +} + +// DeletedRows applies equality check predicate on the "deleted_rows" field. It's identical to DeletedRowsEQ. +func DeletedRows(v int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldEQ(FieldDeletedRows, v)) +} + +// ErrorMessage applies equality check predicate on the "error_message" field. It's identical to ErrorMessageEQ. +func ErrorMessage(v string) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldEQ(FieldErrorMessage, v)) +} + +// CanceledBy applies equality check predicate on the "canceled_by" field. It's identical to CanceledByEQ. +func CanceledBy(v int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldEQ(FieldCanceledBy, v)) +} + +// CanceledAt applies equality check predicate on the "canceled_at" field. It's identical to CanceledAtEQ. +func CanceledAt(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldEQ(FieldCanceledAt, v)) +} + +// StartedAt applies equality check predicate on the "started_at" field. It's identical to StartedAtEQ. +func StartedAt(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldEQ(FieldStartedAt, v)) +} + +// FinishedAt applies equality check predicate on the "finished_at" field. It's identical to FinishedAtEQ. +func FinishedAt(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldEQ(FieldFinishedAt, v)) +} + +// CreatedAtEQ applies the EQ predicate on the "created_at" field. +func CreatedAtEQ(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldEQ(FieldCreatedAt, v)) +} + +// CreatedAtNEQ applies the NEQ predicate on the "created_at" field. +func CreatedAtNEQ(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldNEQ(FieldCreatedAt, v)) +} + +// CreatedAtIn applies the In predicate on the "created_at" field. +func CreatedAtIn(vs ...time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldIn(FieldCreatedAt, vs...)) +} + +// CreatedAtNotIn applies the NotIn predicate on the "created_at" field. +func CreatedAtNotIn(vs ...time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldNotIn(FieldCreatedAt, vs...)) +} + +// CreatedAtGT applies the GT predicate on the "created_at" field. +func CreatedAtGT(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldGT(FieldCreatedAt, v)) +} + +// CreatedAtGTE applies the GTE predicate on the "created_at" field. +func CreatedAtGTE(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldGTE(FieldCreatedAt, v)) +} + +// CreatedAtLT applies the LT predicate on the "created_at" field. +func CreatedAtLT(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldLT(FieldCreatedAt, v)) +} + +// CreatedAtLTE applies the LTE predicate on the "created_at" field. +func CreatedAtLTE(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldLTE(FieldCreatedAt, v)) +} + +// UpdatedAtEQ applies the EQ predicate on the "updated_at" field. +func UpdatedAtEQ(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtNEQ applies the NEQ predicate on the "updated_at" field. +func UpdatedAtNEQ(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldNEQ(FieldUpdatedAt, v)) +} + +// UpdatedAtIn applies the In predicate on the "updated_at" field. +func UpdatedAtIn(vs ...time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtNotIn applies the NotIn predicate on the "updated_at" field. +func UpdatedAtNotIn(vs ...time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldNotIn(FieldUpdatedAt, vs...)) +} + +// UpdatedAtGT applies the GT predicate on the "updated_at" field. +func UpdatedAtGT(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldGT(FieldUpdatedAt, v)) +} + +// UpdatedAtGTE applies the GTE predicate on the "updated_at" field. +func UpdatedAtGTE(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldGTE(FieldUpdatedAt, v)) +} + +// UpdatedAtLT applies the LT predicate on the "updated_at" field. +func UpdatedAtLT(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldLT(FieldUpdatedAt, v)) +} + +// UpdatedAtLTE applies the LTE predicate on the "updated_at" field. +func UpdatedAtLTE(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldLTE(FieldUpdatedAt, v)) +} + +// StatusEQ applies the EQ predicate on the "status" field. +func StatusEQ(v string) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldEQ(FieldStatus, v)) +} + +// StatusNEQ applies the NEQ predicate on the "status" field. +func StatusNEQ(v string) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldNEQ(FieldStatus, v)) +} + +// StatusIn applies the In predicate on the "status" field. +func StatusIn(vs ...string) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldIn(FieldStatus, vs...)) +} + +// StatusNotIn applies the NotIn predicate on the "status" field. +func StatusNotIn(vs ...string) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldNotIn(FieldStatus, vs...)) +} + +// StatusGT applies the GT predicate on the "status" field. +func StatusGT(v string) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldGT(FieldStatus, v)) +} + +// StatusGTE applies the GTE predicate on the "status" field. +func StatusGTE(v string) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldGTE(FieldStatus, v)) +} + +// StatusLT applies the LT predicate on the "status" field. +func StatusLT(v string) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldLT(FieldStatus, v)) +} + +// StatusLTE applies the LTE predicate on the "status" field. +func StatusLTE(v string) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldLTE(FieldStatus, v)) +} + +// StatusContains applies the Contains predicate on the "status" field. +func StatusContains(v string) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldContains(FieldStatus, v)) +} + +// StatusHasPrefix applies the HasPrefix predicate on the "status" field. +func StatusHasPrefix(v string) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldHasPrefix(FieldStatus, v)) +} + +// StatusHasSuffix applies the HasSuffix predicate on the "status" field. +func StatusHasSuffix(v string) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldHasSuffix(FieldStatus, v)) +} + +// StatusEqualFold applies the EqualFold predicate on the "status" field. +func StatusEqualFold(v string) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldEqualFold(FieldStatus, v)) +} + +// StatusContainsFold applies the ContainsFold predicate on the "status" field. +func StatusContainsFold(v string) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldContainsFold(FieldStatus, v)) +} + +// CreatedByEQ applies the EQ predicate on the "created_by" field. +func CreatedByEQ(v int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldEQ(FieldCreatedBy, v)) +} + +// CreatedByNEQ applies the NEQ predicate on the "created_by" field. +func CreatedByNEQ(v int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldNEQ(FieldCreatedBy, v)) +} + +// CreatedByIn applies the In predicate on the "created_by" field. +func CreatedByIn(vs ...int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldIn(FieldCreatedBy, vs...)) +} + +// CreatedByNotIn applies the NotIn predicate on the "created_by" field. +func CreatedByNotIn(vs ...int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldNotIn(FieldCreatedBy, vs...)) +} + +// CreatedByGT applies the GT predicate on the "created_by" field. +func CreatedByGT(v int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldGT(FieldCreatedBy, v)) +} + +// CreatedByGTE applies the GTE predicate on the "created_by" field. +func CreatedByGTE(v int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldGTE(FieldCreatedBy, v)) +} + +// CreatedByLT applies the LT predicate on the "created_by" field. +func CreatedByLT(v int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldLT(FieldCreatedBy, v)) +} + +// CreatedByLTE applies the LTE predicate on the "created_by" field. +func CreatedByLTE(v int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldLTE(FieldCreatedBy, v)) +} + +// DeletedRowsEQ applies the EQ predicate on the "deleted_rows" field. +func DeletedRowsEQ(v int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldEQ(FieldDeletedRows, v)) +} + +// DeletedRowsNEQ applies the NEQ predicate on the "deleted_rows" field. +func DeletedRowsNEQ(v int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldNEQ(FieldDeletedRows, v)) +} + +// DeletedRowsIn applies the In predicate on the "deleted_rows" field. +func DeletedRowsIn(vs ...int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldIn(FieldDeletedRows, vs...)) +} + +// DeletedRowsNotIn applies the NotIn predicate on the "deleted_rows" field. +func DeletedRowsNotIn(vs ...int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldNotIn(FieldDeletedRows, vs...)) +} + +// DeletedRowsGT applies the GT predicate on the "deleted_rows" field. +func DeletedRowsGT(v int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldGT(FieldDeletedRows, v)) +} + +// DeletedRowsGTE applies the GTE predicate on the "deleted_rows" field. +func DeletedRowsGTE(v int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldGTE(FieldDeletedRows, v)) +} + +// DeletedRowsLT applies the LT predicate on the "deleted_rows" field. +func DeletedRowsLT(v int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldLT(FieldDeletedRows, v)) +} + +// DeletedRowsLTE applies the LTE predicate on the "deleted_rows" field. +func DeletedRowsLTE(v int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldLTE(FieldDeletedRows, v)) +} + +// ErrorMessageEQ applies the EQ predicate on the "error_message" field. +func ErrorMessageEQ(v string) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldEQ(FieldErrorMessage, v)) +} + +// ErrorMessageNEQ applies the NEQ predicate on the "error_message" field. +func ErrorMessageNEQ(v string) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldNEQ(FieldErrorMessage, v)) +} + +// ErrorMessageIn applies the In predicate on the "error_message" field. +func ErrorMessageIn(vs ...string) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldIn(FieldErrorMessage, vs...)) +} + +// ErrorMessageNotIn applies the NotIn predicate on the "error_message" field. +func ErrorMessageNotIn(vs ...string) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldNotIn(FieldErrorMessage, vs...)) +} + +// ErrorMessageGT applies the GT predicate on the "error_message" field. +func ErrorMessageGT(v string) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldGT(FieldErrorMessage, v)) +} + +// ErrorMessageGTE applies the GTE predicate on the "error_message" field. +func ErrorMessageGTE(v string) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldGTE(FieldErrorMessage, v)) +} + +// ErrorMessageLT applies the LT predicate on the "error_message" field. +func ErrorMessageLT(v string) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldLT(FieldErrorMessage, v)) +} + +// ErrorMessageLTE applies the LTE predicate on the "error_message" field. +func ErrorMessageLTE(v string) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldLTE(FieldErrorMessage, v)) +} + +// ErrorMessageContains applies the Contains predicate on the "error_message" field. +func ErrorMessageContains(v string) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldContains(FieldErrorMessage, v)) +} + +// ErrorMessageHasPrefix applies the HasPrefix predicate on the "error_message" field. +func ErrorMessageHasPrefix(v string) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldHasPrefix(FieldErrorMessage, v)) +} + +// ErrorMessageHasSuffix applies the HasSuffix predicate on the "error_message" field. +func ErrorMessageHasSuffix(v string) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldHasSuffix(FieldErrorMessage, v)) +} + +// ErrorMessageIsNil applies the IsNil predicate on the "error_message" field. +func ErrorMessageIsNil() predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldIsNull(FieldErrorMessage)) +} + +// ErrorMessageNotNil applies the NotNil predicate on the "error_message" field. +func ErrorMessageNotNil() predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldNotNull(FieldErrorMessage)) +} + +// ErrorMessageEqualFold applies the EqualFold predicate on the "error_message" field. +func ErrorMessageEqualFold(v string) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldEqualFold(FieldErrorMessage, v)) +} + +// ErrorMessageContainsFold applies the ContainsFold predicate on the "error_message" field. +func ErrorMessageContainsFold(v string) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldContainsFold(FieldErrorMessage, v)) +} + +// CanceledByEQ applies the EQ predicate on the "canceled_by" field. +func CanceledByEQ(v int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldEQ(FieldCanceledBy, v)) +} + +// CanceledByNEQ applies the NEQ predicate on the "canceled_by" field. +func CanceledByNEQ(v int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldNEQ(FieldCanceledBy, v)) +} + +// CanceledByIn applies the In predicate on the "canceled_by" field. +func CanceledByIn(vs ...int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldIn(FieldCanceledBy, vs...)) +} + +// CanceledByNotIn applies the NotIn predicate on the "canceled_by" field. +func CanceledByNotIn(vs ...int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldNotIn(FieldCanceledBy, vs...)) +} + +// CanceledByGT applies the GT predicate on the "canceled_by" field. +func CanceledByGT(v int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldGT(FieldCanceledBy, v)) +} + +// CanceledByGTE applies the GTE predicate on the "canceled_by" field. +func CanceledByGTE(v int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldGTE(FieldCanceledBy, v)) +} + +// CanceledByLT applies the LT predicate on the "canceled_by" field. +func CanceledByLT(v int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldLT(FieldCanceledBy, v)) +} + +// CanceledByLTE applies the LTE predicate on the "canceled_by" field. +func CanceledByLTE(v int64) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldLTE(FieldCanceledBy, v)) +} + +// CanceledByIsNil applies the IsNil predicate on the "canceled_by" field. +func CanceledByIsNil() predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldIsNull(FieldCanceledBy)) +} + +// CanceledByNotNil applies the NotNil predicate on the "canceled_by" field. +func CanceledByNotNil() predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldNotNull(FieldCanceledBy)) +} + +// CanceledAtEQ applies the EQ predicate on the "canceled_at" field. +func CanceledAtEQ(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldEQ(FieldCanceledAt, v)) +} + +// CanceledAtNEQ applies the NEQ predicate on the "canceled_at" field. +func CanceledAtNEQ(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldNEQ(FieldCanceledAt, v)) +} + +// CanceledAtIn applies the In predicate on the "canceled_at" field. +func CanceledAtIn(vs ...time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldIn(FieldCanceledAt, vs...)) +} + +// CanceledAtNotIn applies the NotIn predicate on the "canceled_at" field. +func CanceledAtNotIn(vs ...time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldNotIn(FieldCanceledAt, vs...)) +} + +// CanceledAtGT applies the GT predicate on the "canceled_at" field. +func CanceledAtGT(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldGT(FieldCanceledAt, v)) +} + +// CanceledAtGTE applies the GTE predicate on the "canceled_at" field. +func CanceledAtGTE(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldGTE(FieldCanceledAt, v)) +} + +// CanceledAtLT applies the LT predicate on the "canceled_at" field. +func CanceledAtLT(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldLT(FieldCanceledAt, v)) +} + +// CanceledAtLTE applies the LTE predicate on the "canceled_at" field. +func CanceledAtLTE(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldLTE(FieldCanceledAt, v)) +} + +// CanceledAtIsNil applies the IsNil predicate on the "canceled_at" field. +func CanceledAtIsNil() predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldIsNull(FieldCanceledAt)) +} + +// CanceledAtNotNil applies the NotNil predicate on the "canceled_at" field. +func CanceledAtNotNil() predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldNotNull(FieldCanceledAt)) +} + +// StartedAtEQ applies the EQ predicate on the "started_at" field. +func StartedAtEQ(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldEQ(FieldStartedAt, v)) +} + +// StartedAtNEQ applies the NEQ predicate on the "started_at" field. +func StartedAtNEQ(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldNEQ(FieldStartedAt, v)) +} + +// StartedAtIn applies the In predicate on the "started_at" field. +func StartedAtIn(vs ...time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldIn(FieldStartedAt, vs...)) +} + +// StartedAtNotIn applies the NotIn predicate on the "started_at" field. +func StartedAtNotIn(vs ...time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldNotIn(FieldStartedAt, vs...)) +} + +// StartedAtGT applies the GT predicate on the "started_at" field. +func StartedAtGT(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldGT(FieldStartedAt, v)) +} + +// StartedAtGTE applies the GTE predicate on the "started_at" field. +func StartedAtGTE(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldGTE(FieldStartedAt, v)) +} + +// StartedAtLT applies the LT predicate on the "started_at" field. +func StartedAtLT(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldLT(FieldStartedAt, v)) +} + +// StartedAtLTE applies the LTE predicate on the "started_at" field. +func StartedAtLTE(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldLTE(FieldStartedAt, v)) +} + +// StartedAtIsNil applies the IsNil predicate on the "started_at" field. +func StartedAtIsNil() predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldIsNull(FieldStartedAt)) +} + +// StartedAtNotNil applies the NotNil predicate on the "started_at" field. +func StartedAtNotNil() predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldNotNull(FieldStartedAt)) +} + +// FinishedAtEQ applies the EQ predicate on the "finished_at" field. +func FinishedAtEQ(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldEQ(FieldFinishedAt, v)) +} + +// FinishedAtNEQ applies the NEQ predicate on the "finished_at" field. +func FinishedAtNEQ(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldNEQ(FieldFinishedAt, v)) +} + +// FinishedAtIn applies the In predicate on the "finished_at" field. +func FinishedAtIn(vs ...time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldIn(FieldFinishedAt, vs...)) +} + +// FinishedAtNotIn applies the NotIn predicate on the "finished_at" field. +func FinishedAtNotIn(vs ...time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldNotIn(FieldFinishedAt, vs...)) +} + +// FinishedAtGT applies the GT predicate on the "finished_at" field. +func FinishedAtGT(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldGT(FieldFinishedAt, v)) +} + +// FinishedAtGTE applies the GTE predicate on the "finished_at" field. +func FinishedAtGTE(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldGTE(FieldFinishedAt, v)) +} + +// FinishedAtLT applies the LT predicate on the "finished_at" field. +func FinishedAtLT(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldLT(FieldFinishedAt, v)) +} + +// FinishedAtLTE applies the LTE predicate on the "finished_at" field. +func FinishedAtLTE(v time.Time) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldLTE(FieldFinishedAt, v)) +} + +// FinishedAtIsNil applies the IsNil predicate on the "finished_at" field. +func FinishedAtIsNil() predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldIsNull(FieldFinishedAt)) +} + +// FinishedAtNotNil applies the NotNil predicate on the "finished_at" field. +func FinishedAtNotNil() predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.FieldNotNull(FieldFinishedAt)) +} + +// And groups predicates with the AND operator between them. +func And(predicates ...predicate.UsageCleanupTask) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.AndPredicates(predicates...)) +} + +// Or groups predicates with the OR operator between them. +func Or(predicates ...predicate.UsageCleanupTask) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.OrPredicates(predicates...)) +} + +// Not applies the not operator on the given predicate. +func Not(p predicate.UsageCleanupTask) predicate.UsageCleanupTask { + return predicate.UsageCleanupTask(sql.NotPredicates(p)) +} diff --git a/backend/ent/usagecleanuptask_create.go b/backend/ent/usagecleanuptask_create.go new file mode 100644 index 00000000..0b1dcff5 --- /dev/null +++ b/backend/ent/usagecleanuptask_create.go @@ -0,0 +1,1190 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" +) + +// UsageCleanupTaskCreate is the builder for creating a UsageCleanupTask entity. +type UsageCleanupTaskCreate struct { + config + mutation *UsageCleanupTaskMutation + hooks []Hook + conflict []sql.ConflictOption +} + +// SetCreatedAt sets the "created_at" field. +func (_c *UsageCleanupTaskCreate) SetCreatedAt(v time.Time) *UsageCleanupTaskCreate { + _c.mutation.SetCreatedAt(v) + return _c +} + +// SetNillableCreatedAt sets the "created_at" field if the given value is not nil. +func (_c *UsageCleanupTaskCreate) SetNillableCreatedAt(v *time.Time) *UsageCleanupTaskCreate { + if v != nil { + _c.SetCreatedAt(*v) + } + return _c +} + +// SetUpdatedAt sets the "updated_at" field. +func (_c *UsageCleanupTaskCreate) SetUpdatedAt(v time.Time) *UsageCleanupTaskCreate { + _c.mutation.SetUpdatedAt(v) + return _c +} + +// SetNillableUpdatedAt sets the "updated_at" field if the given value is not nil. +func (_c *UsageCleanupTaskCreate) SetNillableUpdatedAt(v *time.Time) *UsageCleanupTaskCreate { + if v != nil { + _c.SetUpdatedAt(*v) + } + return _c +} + +// SetStatus sets the "status" field. +func (_c *UsageCleanupTaskCreate) SetStatus(v string) *UsageCleanupTaskCreate { + _c.mutation.SetStatus(v) + return _c +} + +// SetFilters sets the "filters" field. +func (_c *UsageCleanupTaskCreate) SetFilters(v json.RawMessage) *UsageCleanupTaskCreate { + _c.mutation.SetFilters(v) + return _c +} + +// SetCreatedBy sets the "created_by" field. +func (_c *UsageCleanupTaskCreate) SetCreatedBy(v int64) *UsageCleanupTaskCreate { + _c.mutation.SetCreatedBy(v) + return _c +} + +// SetDeletedRows sets the "deleted_rows" field. +func (_c *UsageCleanupTaskCreate) SetDeletedRows(v int64) *UsageCleanupTaskCreate { + _c.mutation.SetDeletedRows(v) + return _c +} + +// SetNillableDeletedRows sets the "deleted_rows" field if the given value is not nil. +func (_c *UsageCleanupTaskCreate) SetNillableDeletedRows(v *int64) *UsageCleanupTaskCreate { + if v != nil { + _c.SetDeletedRows(*v) + } + return _c +} + +// SetErrorMessage sets the "error_message" field. +func (_c *UsageCleanupTaskCreate) SetErrorMessage(v string) *UsageCleanupTaskCreate { + _c.mutation.SetErrorMessage(v) + return _c +} + +// SetNillableErrorMessage sets the "error_message" field if the given value is not nil. +func (_c *UsageCleanupTaskCreate) SetNillableErrorMessage(v *string) *UsageCleanupTaskCreate { + if v != nil { + _c.SetErrorMessage(*v) + } + return _c +} + +// SetCanceledBy sets the "canceled_by" field. +func (_c *UsageCleanupTaskCreate) SetCanceledBy(v int64) *UsageCleanupTaskCreate { + _c.mutation.SetCanceledBy(v) + return _c +} + +// SetNillableCanceledBy sets the "canceled_by" field if the given value is not nil. +func (_c *UsageCleanupTaskCreate) SetNillableCanceledBy(v *int64) *UsageCleanupTaskCreate { + if v != nil { + _c.SetCanceledBy(*v) + } + return _c +} + +// SetCanceledAt sets the "canceled_at" field. +func (_c *UsageCleanupTaskCreate) SetCanceledAt(v time.Time) *UsageCleanupTaskCreate { + _c.mutation.SetCanceledAt(v) + return _c +} + +// SetNillableCanceledAt sets the "canceled_at" field if the given value is not nil. +func (_c *UsageCleanupTaskCreate) SetNillableCanceledAt(v *time.Time) *UsageCleanupTaskCreate { + if v != nil { + _c.SetCanceledAt(*v) + } + return _c +} + +// SetStartedAt sets the "started_at" field. +func (_c *UsageCleanupTaskCreate) SetStartedAt(v time.Time) *UsageCleanupTaskCreate { + _c.mutation.SetStartedAt(v) + return _c +} + +// SetNillableStartedAt sets the "started_at" field if the given value is not nil. +func (_c *UsageCleanupTaskCreate) SetNillableStartedAt(v *time.Time) *UsageCleanupTaskCreate { + if v != nil { + _c.SetStartedAt(*v) + } + return _c +} + +// SetFinishedAt sets the "finished_at" field. +func (_c *UsageCleanupTaskCreate) SetFinishedAt(v time.Time) *UsageCleanupTaskCreate { + _c.mutation.SetFinishedAt(v) + return _c +} + +// SetNillableFinishedAt sets the "finished_at" field if the given value is not nil. +func (_c *UsageCleanupTaskCreate) SetNillableFinishedAt(v *time.Time) *UsageCleanupTaskCreate { + if v != nil { + _c.SetFinishedAt(*v) + } + return _c +} + +// Mutation returns the UsageCleanupTaskMutation object of the builder. +func (_c *UsageCleanupTaskCreate) Mutation() *UsageCleanupTaskMutation { + return _c.mutation +} + +// Save creates the UsageCleanupTask in the database. +func (_c *UsageCleanupTaskCreate) Save(ctx context.Context) (*UsageCleanupTask, error) { + _c.defaults() + return withHooks(ctx, _c.sqlSave, _c.mutation, _c.hooks) +} + +// SaveX calls Save and panics if Save returns an error. +func (_c *UsageCleanupTaskCreate) SaveX(ctx context.Context) *UsageCleanupTask { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *UsageCleanupTaskCreate) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *UsageCleanupTaskCreate) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_c *UsageCleanupTaskCreate) defaults() { + if _, ok := _c.mutation.CreatedAt(); !ok { + v := usagecleanuptask.DefaultCreatedAt() + _c.mutation.SetCreatedAt(v) + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + v := usagecleanuptask.DefaultUpdatedAt() + _c.mutation.SetUpdatedAt(v) + } + if _, ok := _c.mutation.DeletedRows(); !ok { + v := usagecleanuptask.DefaultDeletedRows + _c.mutation.SetDeletedRows(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_c *UsageCleanupTaskCreate) check() error { + if _, ok := _c.mutation.CreatedAt(); !ok { + return &ValidationError{Name: "created_at", err: errors.New(`ent: missing required field "UsageCleanupTask.created_at"`)} + } + if _, ok := _c.mutation.UpdatedAt(); !ok { + return &ValidationError{Name: "updated_at", err: errors.New(`ent: missing required field "UsageCleanupTask.updated_at"`)} + } + if _, ok := _c.mutation.Status(); !ok { + return &ValidationError{Name: "status", err: errors.New(`ent: missing required field "UsageCleanupTask.status"`)} + } + if v, ok := _c.mutation.Status(); ok { + if err := usagecleanuptask.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "UsageCleanupTask.status": %w`, err)} + } + } + if _, ok := _c.mutation.Filters(); !ok { + return &ValidationError{Name: "filters", err: errors.New(`ent: missing required field "UsageCleanupTask.filters"`)} + } + if _, ok := _c.mutation.CreatedBy(); !ok { + return &ValidationError{Name: "created_by", err: errors.New(`ent: missing required field "UsageCleanupTask.created_by"`)} + } + if _, ok := _c.mutation.DeletedRows(); !ok { + return &ValidationError{Name: "deleted_rows", err: errors.New(`ent: missing required field "UsageCleanupTask.deleted_rows"`)} + } + return nil +} + +func (_c *UsageCleanupTaskCreate) sqlSave(ctx context.Context) (*UsageCleanupTask, error) { + if err := _c.check(); err != nil { + return nil, err + } + _node, _spec := _c.createSpec() + if err := sqlgraph.CreateNode(ctx, _c.driver, _spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + id := _spec.ID.Value.(int64) + _node.ID = int64(id) + _c.mutation.id = &_node.ID + _c.mutation.done = true + return _node, nil +} + +func (_c *UsageCleanupTaskCreate) createSpec() (*UsageCleanupTask, *sqlgraph.CreateSpec) { + var ( + _node = &UsageCleanupTask{config: _c.config} + _spec = sqlgraph.NewCreateSpec(usagecleanuptask.Table, sqlgraph.NewFieldSpec(usagecleanuptask.FieldID, field.TypeInt64)) + ) + _spec.OnConflict = _c.conflict + if value, ok := _c.mutation.CreatedAt(); ok { + _spec.SetField(usagecleanuptask.FieldCreatedAt, field.TypeTime, value) + _node.CreatedAt = value + } + if value, ok := _c.mutation.UpdatedAt(); ok { + _spec.SetField(usagecleanuptask.FieldUpdatedAt, field.TypeTime, value) + _node.UpdatedAt = value + } + if value, ok := _c.mutation.Status(); ok { + _spec.SetField(usagecleanuptask.FieldStatus, field.TypeString, value) + _node.Status = value + } + if value, ok := _c.mutation.Filters(); ok { + _spec.SetField(usagecleanuptask.FieldFilters, field.TypeJSON, value) + _node.Filters = value + } + if value, ok := _c.mutation.CreatedBy(); ok { + _spec.SetField(usagecleanuptask.FieldCreatedBy, field.TypeInt64, value) + _node.CreatedBy = value + } + if value, ok := _c.mutation.DeletedRows(); ok { + _spec.SetField(usagecleanuptask.FieldDeletedRows, field.TypeInt64, value) + _node.DeletedRows = value + } + if value, ok := _c.mutation.ErrorMessage(); ok { + _spec.SetField(usagecleanuptask.FieldErrorMessage, field.TypeString, value) + _node.ErrorMessage = &value + } + if value, ok := _c.mutation.CanceledBy(); ok { + _spec.SetField(usagecleanuptask.FieldCanceledBy, field.TypeInt64, value) + _node.CanceledBy = &value + } + if value, ok := _c.mutation.CanceledAt(); ok { + _spec.SetField(usagecleanuptask.FieldCanceledAt, field.TypeTime, value) + _node.CanceledAt = &value + } + if value, ok := _c.mutation.StartedAt(); ok { + _spec.SetField(usagecleanuptask.FieldStartedAt, field.TypeTime, value) + _node.StartedAt = &value + } + if value, ok := _c.mutation.FinishedAt(); ok { + _spec.SetField(usagecleanuptask.FieldFinishedAt, field.TypeTime, value) + _node.FinishedAt = &value + } + return _node, _spec +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.UsageCleanupTask.Create(). +// SetCreatedAt(v). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.UsageCleanupTaskUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *UsageCleanupTaskCreate) OnConflict(opts ...sql.ConflictOption) *UsageCleanupTaskUpsertOne { + _c.conflict = opts + return &UsageCleanupTaskUpsertOne{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.UsageCleanupTask.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *UsageCleanupTaskCreate) OnConflictColumns(columns ...string) *UsageCleanupTaskUpsertOne { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &UsageCleanupTaskUpsertOne{ + create: _c, + } +} + +type ( + // UsageCleanupTaskUpsertOne is the builder for "upsert"-ing + // one UsageCleanupTask node. + UsageCleanupTaskUpsertOne struct { + create *UsageCleanupTaskCreate + } + + // UsageCleanupTaskUpsert is the "OnConflict" setter. + UsageCleanupTaskUpsert struct { + *sql.UpdateSet + } +) + +// SetUpdatedAt sets the "updated_at" field. +func (u *UsageCleanupTaskUpsert) SetUpdatedAt(v time.Time) *UsageCleanupTaskUpsert { + u.Set(usagecleanuptask.FieldUpdatedAt, v) + return u +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *UsageCleanupTaskUpsert) UpdateUpdatedAt() *UsageCleanupTaskUpsert { + u.SetExcluded(usagecleanuptask.FieldUpdatedAt) + return u +} + +// SetStatus sets the "status" field. +func (u *UsageCleanupTaskUpsert) SetStatus(v string) *UsageCleanupTaskUpsert { + u.Set(usagecleanuptask.FieldStatus, v) + return u +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *UsageCleanupTaskUpsert) UpdateStatus() *UsageCleanupTaskUpsert { + u.SetExcluded(usagecleanuptask.FieldStatus) + return u +} + +// SetFilters sets the "filters" field. +func (u *UsageCleanupTaskUpsert) SetFilters(v json.RawMessage) *UsageCleanupTaskUpsert { + u.Set(usagecleanuptask.FieldFilters, v) + return u +} + +// UpdateFilters sets the "filters" field to the value that was provided on create. +func (u *UsageCleanupTaskUpsert) UpdateFilters() *UsageCleanupTaskUpsert { + u.SetExcluded(usagecleanuptask.FieldFilters) + return u +} + +// SetCreatedBy sets the "created_by" field. +func (u *UsageCleanupTaskUpsert) SetCreatedBy(v int64) *UsageCleanupTaskUpsert { + u.Set(usagecleanuptask.FieldCreatedBy, v) + return u +} + +// UpdateCreatedBy sets the "created_by" field to the value that was provided on create. +func (u *UsageCleanupTaskUpsert) UpdateCreatedBy() *UsageCleanupTaskUpsert { + u.SetExcluded(usagecleanuptask.FieldCreatedBy) + return u +} + +// AddCreatedBy adds v to the "created_by" field. +func (u *UsageCleanupTaskUpsert) AddCreatedBy(v int64) *UsageCleanupTaskUpsert { + u.Add(usagecleanuptask.FieldCreatedBy, v) + return u +} + +// SetDeletedRows sets the "deleted_rows" field. +func (u *UsageCleanupTaskUpsert) SetDeletedRows(v int64) *UsageCleanupTaskUpsert { + u.Set(usagecleanuptask.FieldDeletedRows, v) + return u +} + +// UpdateDeletedRows sets the "deleted_rows" field to the value that was provided on create. +func (u *UsageCleanupTaskUpsert) UpdateDeletedRows() *UsageCleanupTaskUpsert { + u.SetExcluded(usagecleanuptask.FieldDeletedRows) + return u +} + +// AddDeletedRows adds v to the "deleted_rows" field. +func (u *UsageCleanupTaskUpsert) AddDeletedRows(v int64) *UsageCleanupTaskUpsert { + u.Add(usagecleanuptask.FieldDeletedRows, v) + return u +} + +// SetErrorMessage sets the "error_message" field. +func (u *UsageCleanupTaskUpsert) SetErrorMessage(v string) *UsageCleanupTaskUpsert { + u.Set(usagecleanuptask.FieldErrorMessage, v) + return u +} + +// UpdateErrorMessage sets the "error_message" field to the value that was provided on create. +func (u *UsageCleanupTaskUpsert) UpdateErrorMessage() *UsageCleanupTaskUpsert { + u.SetExcluded(usagecleanuptask.FieldErrorMessage) + return u +} + +// ClearErrorMessage clears the value of the "error_message" field. +func (u *UsageCleanupTaskUpsert) ClearErrorMessage() *UsageCleanupTaskUpsert { + u.SetNull(usagecleanuptask.FieldErrorMessage) + return u +} + +// SetCanceledBy sets the "canceled_by" field. +func (u *UsageCleanupTaskUpsert) SetCanceledBy(v int64) *UsageCleanupTaskUpsert { + u.Set(usagecleanuptask.FieldCanceledBy, v) + return u +} + +// UpdateCanceledBy sets the "canceled_by" field to the value that was provided on create. +func (u *UsageCleanupTaskUpsert) UpdateCanceledBy() *UsageCleanupTaskUpsert { + u.SetExcluded(usagecleanuptask.FieldCanceledBy) + return u +} + +// AddCanceledBy adds v to the "canceled_by" field. +func (u *UsageCleanupTaskUpsert) AddCanceledBy(v int64) *UsageCleanupTaskUpsert { + u.Add(usagecleanuptask.FieldCanceledBy, v) + return u +} + +// ClearCanceledBy clears the value of the "canceled_by" field. +func (u *UsageCleanupTaskUpsert) ClearCanceledBy() *UsageCleanupTaskUpsert { + u.SetNull(usagecleanuptask.FieldCanceledBy) + return u +} + +// SetCanceledAt sets the "canceled_at" field. +func (u *UsageCleanupTaskUpsert) SetCanceledAt(v time.Time) *UsageCleanupTaskUpsert { + u.Set(usagecleanuptask.FieldCanceledAt, v) + return u +} + +// UpdateCanceledAt sets the "canceled_at" field to the value that was provided on create. +func (u *UsageCleanupTaskUpsert) UpdateCanceledAt() *UsageCleanupTaskUpsert { + u.SetExcluded(usagecleanuptask.FieldCanceledAt) + return u +} + +// ClearCanceledAt clears the value of the "canceled_at" field. +func (u *UsageCleanupTaskUpsert) ClearCanceledAt() *UsageCleanupTaskUpsert { + u.SetNull(usagecleanuptask.FieldCanceledAt) + return u +} + +// SetStartedAt sets the "started_at" field. +func (u *UsageCleanupTaskUpsert) SetStartedAt(v time.Time) *UsageCleanupTaskUpsert { + u.Set(usagecleanuptask.FieldStartedAt, v) + return u +} + +// UpdateStartedAt sets the "started_at" field to the value that was provided on create. +func (u *UsageCleanupTaskUpsert) UpdateStartedAt() *UsageCleanupTaskUpsert { + u.SetExcluded(usagecleanuptask.FieldStartedAt) + return u +} + +// ClearStartedAt clears the value of the "started_at" field. +func (u *UsageCleanupTaskUpsert) ClearStartedAt() *UsageCleanupTaskUpsert { + u.SetNull(usagecleanuptask.FieldStartedAt) + return u +} + +// SetFinishedAt sets the "finished_at" field. +func (u *UsageCleanupTaskUpsert) SetFinishedAt(v time.Time) *UsageCleanupTaskUpsert { + u.Set(usagecleanuptask.FieldFinishedAt, v) + return u +} + +// UpdateFinishedAt sets the "finished_at" field to the value that was provided on create. +func (u *UsageCleanupTaskUpsert) UpdateFinishedAt() *UsageCleanupTaskUpsert { + u.SetExcluded(usagecleanuptask.FieldFinishedAt) + return u +} + +// ClearFinishedAt clears the value of the "finished_at" field. +func (u *UsageCleanupTaskUpsert) ClearFinishedAt() *UsageCleanupTaskUpsert { + u.SetNull(usagecleanuptask.FieldFinishedAt) + return u +} + +// UpdateNewValues updates the mutable fields using the new values that were set on create. +// Using this option is equivalent to using: +// +// client.UsageCleanupTask.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *UsageCleanupTaskUpsertOne) UpdateNewValues() *UsageCleanupTaskUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + if _, exists := u.create.mutation.CreatedAt(); exists { + s.SetIgnore(usagecleanuptask.FieldCreatedAt) + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.UsageCleanupTask.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *UsageCleanupTaskUpsertOne) Ignore() *UsageCleanupTaskUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *UsageCleanupTaskUpsertOne) DoNothing() *UsageCleanupTaskUpsertOne { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the UsageCleanupTaskCreate.OnConflict +// documentation for more info. +func (u *UsageCleanupTaskUpsertOne) Update(set func(*UsageCleanupTaskUpsert)) *UsageCleanupTaskUpsertOne { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&UsageCleanupTaskUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *UsageCleanupTaskUpsertOne) SetUpdatedAt(v time.Time) *UsageCleanupTaskUpsertOne { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *UsageCleanupTaskUpsertOne) UpdateUpdatedAt() *UsageCleanupTaskUpsertOne { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetStatus sets the "status" field. +func (u *UsageCleanupTaskUpsertOne) SetStatus(v string) *UsageCleanupTaskUpsertOne { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.SetStatus(v) + }) +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *UsageCleanupTaskUpsertOne) UpdateStatus() *UsageCleanupTaskUpsertOne { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.UpdateStatus() + }) +} + +// SetFilters sets the "filters" field. +func (u *UsageCleanupTaskUpsertOne) SetFilters(v json.RawMessage) *UsageCleanupTaskUpsertOne { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.SetFilters(v) + }) +} + +// UpdateFilters sets the "filters" field to the value that was provided on create. +func (u *UsageCleanupTaskUpsertOne) UpdateFilters() *UsageCleanupTaskUpsertOne { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.UpdateFilters() + }) +} + +// SetCreatedBy sets the "created_by" field. +func (u *UsageCleanupTaskUpsertOne) SetCreatedBy(v int64) *UsageCleanupTaskUpsertOne { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.SetCreatedBy(v) + }) +} + +// AddCreatedBy adds v to the "created_by" field. +func (u *UsageCleanupTaskUpsertOne) AddCreatedBy(v int64) *UsageCleanupTaskUpsertOne { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.AddCreatedBy(v) + }) +} + +// UpdateCreatedBy sets the "created_by" field to the value that was provided on create. +func (u *UsageCleanupTaskUpsertOne) UpdateCreatedBy() *UsageCleanupTaskUpsertOne { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.UpdateCreatedBy() + }) +} + +// SetDeletedRows sets the "deleted_rows" field. +func (u *UsageCleanupTaskUpsertOne) SetDeletedRows(v int64) *UsageCleanupTaskUpsertOne { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.SetDeletedRows(v) + }) +} + +// AddDeletedRows adds v to the "deleted_rows" field. +func (u *UsageCleanupTaskUpsertOne) AddDeletedRows(v int64) *UsageCleanupTaskUpsertOne { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.AddDeletedRows(v) + }) +} + +// UpdateDeletedRows sets the "deleted_rows" field to the value that was provided on create. +func (u *UsageCleanupTaskUpsertOne) UpdateDeletedRows() *UsageCleanupTaskUpsertOne { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.UpdateDeletedRows() + }) +} + +// SetErrorMessage sets the "error_message" field. +func (u *UsageCleanupTaskUpsertOne) SetErrorMessage(v string) *UsageCleanupTaskUpsertOne { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.SetErrorMessage(v) + }) +} + +// UpdateErrorMessage sets the "error_message" field to the value that was provided on create. +func (u *UsageCleanupTaskUpsertOne) UpdateErrorMessage() *UsageCleanupTaskUpsertOne { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.UpdateErrorMessage() + }) +} + +// ClearErrorMessage clears the value of the "error_message" field. +func (u *UsageCleanupTaskUpsertOne) ClearErrorMessage() *UsageCleanupTaskUpsertOne { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.ClearErrorMessage() + }) +} + +// SetCanceledBy sets the "canceled_by" field. +func (u *UsageCleanupTaskUpsertOne) SetCanceledBy(v int64) *UsageCleanupTaskUpsertOne { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.SetCanceledBy(v) + }) +} + +// AddCanceledBy adds v to the "canceled_by" field. +func (u *UsageCleanupTaskUpsertOne) AddCanceledBy(v int64) *UsageCleanupTaskUpsertOne { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.AddCanceledBy(v) + }) +} + +// UpdateCanceledBy sets the "canceled_by" field to the value that was provided on create. +func (u *UsageCleanupTaskUpsertOne) UpdateCanceledBy() *UsageCleanupTaskUpsertOne { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.UpdateCanceledBy() + }) +} + +// ClearCanceledBy clears the value of the "canceled_by" field. +func (u *UsageCleanupTaskUpsertOne) ClearCanceledBy() *UsageCleanupTaskUpsertOne { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.ClearCanceledBy() + }) +} + +// SetCanceledAt sets the "canceled_at" field. +func (u *UsageCleanupTaskUpsertOne) SetCanceledAt(v time.Time) *UsageCleanupTaskUpsertOne { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.SetCanceledAt(v) + }) +} + +// UpdateCanceledAt sets the "canceled_at" field to the value that was provided on create. +func (u *UsageCleanupTaskUpsertOne) UpdateCanceledAt() *UsageCleanupTaskUpsertOne { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.UpdateCanceledAt() + }) +} + +// ClearCanceledAt clears the value of the "canceled_at" field. +func (u *UsageCleanupTaskUpsertOne) ClearCanceledAt() *UsageCleanupTaskUpsertOne { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.ClearCanceledAt() + }) +} + +// SetStartedAt sets the "started_at" field. +func (u *UsageCleanupTaskUpsertOne) SetStartedAt(v time.Time) *UsageCleanupTaskUpsertOne { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.SetStartedAt(v) + }) +} + +// UpdateStartedAt sets the "started_at" field to the value that was provided on create. +func (u *UsageCleanupTaskUpsertOne) UpdateStartedAt() *UsageCleanupTaskUpsertOne { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.UpdateStartedAt() + }) +} + +// ClearStartedAt clears the value of the "started_at" field. +func (u *UsageCleanupTaskUpsertOne) ClearStartedAt() *UsageCleanupTaskUpsertOne { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.ClearStartedAt() + }) +} + +// SetFinishedAt sets the "finished_at" field. +func (u *UsageCleanupTaskUpsertOne) SetFinishedAt(v time.Time) *UsageCleanupTaskUpsertOne { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.SetFinishedAt(v) + }) +} + +// UpdateFinishedAt sets the "finished_at" field to the value that was provided on create. +func (u *UsageCleanupTaskUpsertOne) UpdateFinishedAt() *UsageCleanupTaskUpsertOne { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.UpdateFinishedAt() + }) +} + +// ClearFinishedAt clears the value of the "finished_at" field. +func (u *UsageCleanupTaskUpsertOne) ClearFinishedAt() *UsageCleanupTaskUpsertOne { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.ClearFinishedAt() + }) +} + +// Exec executes the query. +func (u *UsageCleanupTaskUpsertOne) Exec(ctx context.Context) error { + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for UsageCleanupTaskCreate.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *UsageCleanupTaskUpsertOne) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} + +// Exec executes the UPSERT query and returns the inserted/updated ID. +func (u *UsageCleanupTaskUpsertOne) ID(ctx context.Context) (id int64, err error) { + node, err := u.create.Save(ctx) + if err != nil { + return id, err + } + return node.ID, nil +} + +// IDX is like ID, but panics if an error occurs. +func (u *UsageCleanupTaskUpsertOne) IDX(ctx context.Context) int64 { + id, err := u.ID(ctx) + if err != nil { + panic(err) + } + return id +} + +// UsageCleanupTaskCreateBulk is the builder for creating many UsageCleanupTask entities in bulk. +type UsageCleanupTaskCreateBulk struct { + config + err error + builders []*UsageCleanupTaskCreate + conflict []sql.ConflictOption +} + +// Save creates the UsageCleanupTask entities in the database. +func (_c *UsageCleanupTaskCreateBulk) Save(ctx context.Context) ([]*UsageCleanupTask, error) { + if _c.err != nil { + return nil, _c.err + } + specs := make([]*sqlgraph.CreateSpec, len(_c.builders)) + nodes := make([]*UsageCleanupTask, len(_c.builders)) + mutators := make([]Mutator, len(_c.builders)) + for i := range _c.builders { + func(i int, root context.Context) { + builder := _c.builders[i] + builder.defaults() + var mut Mutator = MutateFunc(func(ctx context.Context, m Mutation) (Value, error) { + mutation, ok := m.(*UsageCleanupTaskMutation) + if !ok { + return nil, fmt.Errorf("unexpected mutation type %T", m) + } + if err := builder.check(); err != nil { + return nil, err + } + builder.mutation = mutation + var err error + nodes[i], specs[i] = builder.createSpec() + if i < len(mutators)-1 { + _, err = mutators[i+1].Mutate(root, _c.builders[i+1].mutation) + } else { + spec := &sqlgraph.BatchCreateSpec{Nodes: specs} + spec.OnConflict = _c.conflict + // Invoke the actual operation on the latest mutation in the chain. + if err = sqlgraph.BatchCreate(ctx, _c.driver, spec); err != nil { + if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + } + } + if err != nil { + return nil, err + } + mutation.id = &nodes[i].ID + if specs[i].ID.Value != nil { + id := specs[i].ID.Value.(int64) + nodes[i].ID = int64(id) + } + mutation.done = true + return nodes[i], nil + }) + for i := len(builder.hooks) - 1; i >= 0; i-- { + mut = builder.hooks[i](mut) + } + mutators[i] = mut + }(i, ctx) + } + if len(mutators) > 0 { + if _, err := mutators[0].Mutate(ctx, _c.builders[0].mutation); err != nil { + return nil, err + } + } + return nodes, nil +} + +// SaveX is like Save, but panics if an error occurs. +func (_c *UsageCleanupTaskCreateBulk) SaveX(ctx context.Context) []*UsageCleanupTask { + v, err := _c.Save(ctx) + if err != nil { + panic(err) + } + return v +} + +// Exec executes the query. +func (_c *UsageCleanupTaskCreateBulk) Exec(ctx context.Context) error { + _, err := _c.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_c *UsageCleanupTaskCreateBulk) ExecX(ctx context.Context) { + if err := _c.Exec(ctx); err != nil { + panic(err) + } +} + +// OnConflict allows configuring the `ON CONFLICT` / `ON DUPLICATE KEY` clause +// of the `INSERT` statement. For example: +// +// client.UsageCleanupTask.CreateBulk(builders...). +// OnConflict( +// // Update the row with the new values +// // the was proposed for insertion. +// sql.ResolveWithNewValues(), +// ). +// // Override some of the fields with custom +// // update values. +// Update(func(u *ent.UsageCleanupTaskUpsert) { +// SetCreatedAt(v+v). +// }). +// Exec(ctx) +func (_c *UsageCleanupTaskCreateBulk) OnConflict(opts ...sql.ConflictOption) *UsageCleanupTaskUpsertBulk { + _c.conflict = opts + return &UsageCleanupTaskUpsertBulk{ + create: _c, + } +} + +// OnConflictColumns calls `OnConflict` and configures the columns +// as conflict target. Using this option is equivalent to using: +// +// client.UsageCleanupTask.Create(). +// OnConflict(sql.ConflictColumns(columns...)). +// Exec(ctx) +func (_c *UsageCleanupTaskCreateBulk) OnConflictColumns(columns ...string) *UsageCleanupTaskUpsertBulk { + _c.conflict = append(_c.conflict, sql.ConflictColumns(columns...)) + return &UsageCleanupTaskUpsertBulk{ + create: _c, + } +} + +// UsageCleanupTaskUpsertBulk is the builder for "upsert"-ing +// a bulk of UsageCleanupTask nodes. +type UsageCleanupTaskUpsertBulk struct { + create *UsageCleanupTaskCreateBulk +} + +// UpdateNewValues updates the mutable fields using the new values that +// were set on create. Using this option is equivalent to using: +// +// client.UsageCleanupTask.Create(). +// OnConflict( +// sql.ResolveWithNewValues(), +// ). +// Exec(ctx) +func (u *UsageCleanupTaskUpsertBulk) UpdateNewValues() *UsageCleanupTaskUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithNewValues()) + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(s *sql.UpdateSet) { + for _, b := range u.create.builders { + if _, exists := b.mutation.CreatedAt(); exists { + s.SetIgnore(usagecleanuptask.FieldCreatedAt) + } + } + })) + return u +} + +// Ignore sets each column to itself in case of conflict. +// Using this option is equivalent to using: +// +// client.UsageCleanupTask.Create(). +// OnConflict(sql.ResolveWithIgnore()). +// Exec(ctx) +func (u *UsageCleanupTaskUpsertBulk) Ignore() *UsageCleanupTaskUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWithIgnore()) + return u +} + +// DoNothing configures the conflict_action to `DO NOTHING`. +// Supported only by SQLite and PostgreSQL. +func (u *UsageCleanupTaskUpsertBulk) DoNothing() *UsageCleanupTaskUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.DoNothing()) + return u +} + +// Update allows overriding fields `UPDATE` values. See the UsageCleanupTaskCreateBulk.OnConflict +// documentation for more info. +func (u *UsageCleanupTaskUpsertBulk) Update(set func(*UsageCleanupTaskUpsert)) *UsageCleanupTaskUpsertBulk { + u.create.conflict = append(u.create.conflict, sql.ResolveWith(func(update *sql.UpdateSet) { + set(&UsageCleanupTaskUpsert{UpdateSet: update}) + })) + return u +} + +// SetUpdatedAt sets the "updated_at" field. +func (u *UsageCleanupTaskUpsertBulk) SetUpdatedAt(v time.Time) *UsageCleanupTaskUpsertBulk { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.SetUpdatedAt(v) + }) +} + +// UpdateUpdatedAt sets the "updated_at" field to the value that was provided on create. +func (u *UsageCleanupTaskUpsertBulk) UpdateUpdatedAt() *UsageCleanupTaskUpsertBulk { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.UpdateUpdatedAt() + }) +} + +// SetStatus sets the "status" field. +func (u *UsageCleanupTaskUpsertBulk) SetStatus(v string) *UsageCleanupTaskUpsertBulk { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.SetStatus(v) + }) +} + +// UpdateStatus sets the "status" field to the value that was provided on create. +func (u *UsageCleanupTaskUpsertBulk) UpdateStatus() *UsageCleanupTaskUpsertBulk { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.UpdateStatus() + }) +} + +// SetFilters sets the "filters" field. +func (u *UsageCleanupTaskUpsertBulk) SetFilters(v json.RawMessage) *UsageCleanupTaskUpsertBulk { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.SetFilters(v) + }) +} + +// UpdateFilters sets the "filters" field to the value that was provided on create. +func (u *UsageCleanupTaskUpsertBulk) UpdateFilters() *UsageCleanupTaskUpsertBulk { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.UpdateFilters() + }) +} + +// SetCreatedBy sets the "created_by" field. +func (u *UsageCleanupTaskUpsertBulk) SetCreatedBy(v int64) *UsageCleanupTaskUpsertBulk { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.SetCreatedBy(v) + }) +} + +// AddCreatedBy adds v to the "created_by" field. +func (u *UsageCleanupTaskUpsertBulk) AddCreatedBy(v int64) *UsageCleanupTaskUpsertBulk { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.AddCreatedBy(v) + }) +} + +// UpdateCreatedBy sets the "created_by" field to the value that was provided on create. +func (u *UsageCleanupTaskUpsertBulk) UpdateCreatedBy() *UsageCleanupTaskUpsertBulk { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.UpdateCreatedBy() + }) +} + +// SetDeletedRows sets the "deleted_rows" field. +func (u *UsageCleanupTaskUpsertBulk) SetDeletedRows(v int64) *UsageCleanupTaskUpsertBulk { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.SetDeletedRows(v) + }) +} + +// AddDeletedRows adds v to the "deleted_rows" field. +func (u *UsageCleanupTaskUpsertBulk) AddDeletedRows(v int64) *UsageCleanupTaskUpsertBulk { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.AddDeletedRows(v) + }) +} + +// UpdateDeletedRows sets the "deleted_rows" field to the value that was provided on create. +func (u *UsageCleanupTaskUpsertBulk) UpdateDeletedRows() *UsageCleanupTaskUpsertBulk { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.UpdateDeletedRows() + }) +} + +// SetErrorMessage sets the "error_message" field. +func (u *UsageCleanupTaskUpsertBulk) SetErrorMessage(v string) *UsageCleanupTaskUpsertBulk { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.SetErrorMessage(v) + }) +} + +// UpdateErrorMessage sets the "error_message" field to the value that was provided on create. +func (u *UsageCleanupTaskUpsertBulk) UpdateErrorMessage() *UsageCleanupTaskUpsertBulk { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.UpdateErrorMessage() + }) +} + +// ClearErrorMessage clears the value of the "error_message" field. +func (u *UsageCleanupTaskUpsertBulk) ClearErrorMessage() *UsageCleanupTaskUpsertBulk { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.ClearErrorMessage() + }) +} + +// SetCanceledBy sets the "canceled_by" field. +func (u *UsageCleanupTaskUpsertBulk) SetCanceledBy(v int64) *UsageCleanupTaskUpsertBulk { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.SetCanceledBy(v) + }) +} + +// AddCanceledBy adds v to the "canceled_by" field. +func (u *UsageCleanupTaskUpsertBulk) AddCanceledBy(v int64) *UsageCleanupTaskUpsertBulk { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.AddCanceledBy(v) + }) +} + +// UpdateCanceledBy sets the "canceled_by" field to the value that was provided on create. +func (u *UsageCleanupTaskUpsertBulk) UpdateCanceledBy() *UsageCleanupTaskUpsertBulk { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.UpdateCanceledBy() + }) +} + +// ClearCanceledBy clears the value of the "canceled_by" field. +func (u *UsageCleanupTaskUpsertBulk) ClearCanceledBy() *UsageCleanupTaskUpsertBulk { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.ClearCanceledBy() + }) +} + +// SetCanceledAt sets the "canceled_at" field. +func (u *UsageCleanupTaskUpsertBulk) SetCanceledAt(v time.Time) *UsageCleanupTaskUpsertBulk { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.SetCanceledAt(v) + }) +} + +// UpdateCanceledAt sets the "canceled_at" field to the value that was provided on create. +func (u *UsageCleanupTaskUpsertBulk) UpdateCanceledAt() *UsageCleanupTaskUpsertBulk { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.UpdateCanceledAt() + }) +} + +// ClearCanceledAt clears the value of the "canceled_at" field. +func (u *UsageCleanupTaskUpsertBulk) ClearCanceledAt() *UsageCleanupTaskUpsertBulk { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.ClearCanceledAt() + }) +} + +// SetStartedAt sets the "started_at" field. +func (u *UsageCleanupTaskUpsertBulk) SetStartedAt(v time.Time) *UsageCleanupTaskUpsertBulk { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.SetStartedAt(v) + }) +} + +// UpdateStartedAt sets the "started_at" field to the value that was provided on create. +func (u *UsageCleanupTaskUpsertBulk) UpdateStartedAt() *UsageCleanupTaskUpsertBulk { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.UpdateStartedAt() + }) +} + +// ClearStartedAt clears the value of the "started_at" field. +func (u *UsageCleanupTaskUpsertBulk) ClearStartedAt() *UsageCleanupTaskUpsertBulk { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.ClearStartedAt() + }) +} + +// SetFinishedAt sets the "finished_at" field. +func (u *UsageCleanupTaskUpsertBulk) SetFinishedAt(v time.Time) *UsageCleanupTaskUpsertBulk { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.SetFinishedAt(v) + }) +} + +// UpdateFinishedAt sets the "finished_at" field to the value that was provided on create. +func (u *UsageCleanupTaskUpsertBulk) UpdateFinishedAt() *UsageCleanupTaskUpsertBulk { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.UpdateFinishedAt() + }) +} + +// ClearFinishedAt clears the value of the "finished_at" field. +func (u *UsageCleanupTaskUpsertBulk) ClearFinishedAt() *UsageCleanupTaskUpsertBulk { + return u.Update(func(s *UsageCleanupTaskUpsert) { + s.ClearFinishedAt() + }) +} + +// Exec executes the query. +func (u *UsageCleanupTaskUpsertBulk) Exec(ctx context.Context) error { + if u.create.err != nil { + return u.create.err + } + for i, b := range u.create.builders { + if len(b.conflict) != 0 { + return fmt.Errorf("ent: OnConflict was set for builder %d. Set it on the UsageCleanupTaskCreateBulk instead", i) + } + } + if len(u.create.conflict) == 0 { + return errors.New("ent: missing options for UsageCleanupTaskCreateBulk.OnConflict") + } + return u.create.Exec(ctx) +} + +// ExecX is like Exec, but panics if an error occurs. +func (u *UsageCleanupTaskUpsertBulk) ExecX(ctx context.Context) { + if err := u.create.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/usagecleanuptask_delete.go b/backend/ent/usagecleanuptask_delete.go new file mode 100644 index 00000000..158555f7 --- /dev/null +++ b/backend/ent/usagecleanuptask_delete.go @@ -0,0 +1,88 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" +) + +// UsageCleanupTaskDelete is the builder for deleting a UsageCleanupTask entity. +type UsageCleanupTaskDelete struct { + config + hooks []Hook + mutation *UsageCleanupTaskMutation +} + +// Where appends a list predicates to the UsageCleanupTaskDelete builder. +func (_d *UsageCleanupTaskDelete) Where(ps ...predicate.UsageCleanupTask) *UsageCleanupTaskDelete { + _d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query and returns how many vertices were deleted. +func (_d *UsageCleanupTaskDelete) Exec(ctx context.Context) (int, error) { + return withHooks(ctx, _d.sqlExec, _d.mutation, _d.hooks) +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *UsageCleanupTaskDelete) ExecX(ctx context.Context) int { + n, err := _d.Exec(ctx) + if err != nil { + panic(err) + } + return n +} + +func (_d *UsageCleanupTaskDelete) sqlExec(ctx context.Context) (int, error) { + _spec := sqlgraph.NewDeleteSpec(usagecleanuptask.Table, sqlgraph.NewFieldSpec(usagecleanuptask.FieldID, field.TypeInt64)) + if ps := _d.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + affected, err := sqlgraph.DeleteNodes(ctx, _d.driver, _spec) + if err != nil && sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + _d.mutation.done = true + return affected, err +} + +// UsageCleanupTaskDeleteOne is the builder for deleting a single UsageCleanupTask entity. +type UsageCleanupTaskDeleteOne struct { + _d *UsageCleanupTaskDelete +} + +// Where appends a list predicates to the UsageCleanupTaskDelete builder. +func (_d *UsageCleanupTaskDeleteOne) Where(ps ...predicate.UsageCleanupTask) *UsageCleanupTaskDeleteOne { + _d._d.mutation.Where(ps...) + return _d +} + +// Exec executes the deletion query. +func (_d *UsageCleanupTaskDeleteOne) Exec(ctx context.Context) error { + n, err := _d._d.Exec(ctx) + switch { + case err != nil: + return err + case n == 0: + return &NotFoundError{usagecleanuptask.Label} + default: + return nil + } +} + +// ExecX is like Exec, but panics if an error occurs. +func (_d *UsageCleanupTaskDeleteOne) ExecX(ctx context.Context) { + if err := _d.Exec(ctx); err != nil { + panic(err) + } +} diff --git a/backend/ent/usagecleanuptask_query.go b/backend/ent/usagecleanuptask_query.go new file mode 100644 index 00000000..9d8d5410 --- /dev/null +++ b/backend/ent/usagecleanuptask_query.go @@ -0,0 +1,564 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "fmt" + "math" + + "entgo.io/ent" + "entgo.io/ent/dialect" + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" +) + +// UsageCleanupTaskQuery is the builder for querying UsageCleanupTask entities. +type UsageCleanupTaskQuery struct { + config + ctx *QueryContext + order []usagecleanuptask.OrderOption + inters []Interceptor + predicates []predicate.UsageCleanupTask + modifiers []func(*sql.Selector) + // intermediate query (i.e. traversal path). + sql *sql.Selector + path func(context.Context) (*sql.Selector, error) +} + +// Where adds a new predicate for the UsageCleanupTaskQuery builder. +func (_q *UsageCleanupTaskQuery) Where(ps ...predicate.UsageCleanupTask) *UsageCleanupTaskQuery { + _q.predicates = append(_q.predicates, ps...) + return _q +} + +// Limit the number of records to be returned by this query. +func (_q *UsageCleanupTaskQuery) Limit(limit int) *UsageCleanupTaskQuery { + _q.ctx.Limit = &limit + return _q +} + +// Offset to start from. +func (_q *UsageCleanupTaskQuery) Offset(offset int) *UsageCleanupTaskQuery { + _q.ctx.Offset = &offset + return _q +} + +// Unique configures the query builder to filter duplicate records on query. +// By default, unique is set to true, and can be disabled using this method. +func (_q *UsageCleanupTaskQuery) Unique(unique bool) *UsageCleanupTaskQuery { + _q.ctx.Unique = &unique + return _q +} + +// Order specifies how the records should be ordered. +func (_q *UsageCleanupTaskQuery) Order(o ...usagecleanuptask.OrderOption) *UsageCleanupTaskQuery { + _q.order = append(_q.order, o...) + return _q +} + +// First returns the first UsageCleanupTask entity from the query. +// Returns a *NotFoundError when no UsageCleanupTask was found. +func (_q *UsageCleanupTaskQuery) First(ctx context.Context) (*UsageCleanupTask, error) { + nodes, err := _q.Limit(1).All(setContextOp(ctx, _q.ctx, ent.OpQueryFirst)) + if err != nil { + return nil, err + } + if len(nodes) == 0 { + return nil, &NotFoundError{usagecleanuptask.Label} + } + return nodes[0], nil +} + +// FirstX is like First, but panics if an error occurs. +func (_q *UsageCleanupTaskQuery) FirstX(ctx context.Context) *UsageCleanupTask { + node, err := _q.First(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return node +} + +// FirstID returns the first UsageCleanupTask ID from the query. +// Returns a *NotFoundError when no UsageCleanupTask ID was found. +func (_q *UsageCleanupTaskQuery) FirstID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(1).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryFirstID)); err != nil { + return + } + if len(ids) == 0 { + err = &NotFoundError{usagecleanuptask.Label} + return + } + return ids[0], nil +} + +// FirstIDX is like FirstID, but panics if an error occurs. +func (_q *UsageCleanupTaskQuery) FirstIDX(ctx context.Context) int64 { + id, err := _q.FirstID(ctx) + if err != nil && !IsNotFound(err) { + panic(err) + } + return id +} + +// Only returns a single UsageCleanupTask entity found by the query, ensuring it only returns one. +// Returns a *NotSingularError when more than one UsageCleanupTask entity is found. +// Returns a *NotFoundError when no UsageCleanupTask entities are found. +func (_q *UsageCleanupTaskQuery) Only(ctx context.Context) (*UsageCleanupTask, error) { + nodes, err := _q.Limit(2).All(setContextOp(ctx, _q.ctx, ent.OpQueryOnly)) + if err != nil { + return nil, err + } + switch len(nodes) { + case 1: + return nodes[0], nil + case 0: + return nil, &NotFoundError{usagecleanuptask.Label} + default: + return nil, &NotSingularError{usagecleanuptask.Label} + } +} + +// OnlyX is like Only, but panics if an error occurs. +func (_q *UsageCleanupTaskQuery) OnlyX(ctx context.Context) *UsageCleanupTask { + node, err := _q.Only(ctx) + if err != nil { + panic(err) + } + return node +} + +// OnlyID is like Only, but returns the only UsageCleanupTask ID in the query. +// Returns a *NotSingularError when more than one UsageCleanupTask ID is found. +// Returns a *NotFoundError when no entities are found. +func (_q *UsageCleanupTaskQuery) OnlyID(ctx context.Context) (id int64, err error) { + var ids []int64 + if ids, err = _q.Limit(2).IDs(setContextOp(ctx, _q.ctx, ent.OpQueryOnlyID)); err != nil { + return + } + switch len(ids) { + case 1: + id = ids[0] + case 0: + err = &NotFoundError{usagecleanuptask.Label} + default: + err = &NotSingularError{usagecleanuptask.Label} + } + return +} + +// OnlyIDX is like OnlyID, but panics if an error occurs. +func (_q *UsageCleanupTaskQuery) OnlyIDX(ctx context.Context) int64 { + id, err := _q.OnlyID(ctx) + if err != nil { + panic(err) + } + return id +} + +// All executes the query and returns a list of UsageCleanupTasks. +func (_q *UsageCleanupTaskQuery) All(ctx context.Context) ([]*UsageCleanupTask, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryAll) + if err := _q.prepareQuery(ctx); err != nil { + return nil, err + } + qr := querierAll[[]*UsageCleanupTask, *UsageCleanupTaskQuery]() + return withInterceptors[[]*UsageCleanupTask](ctx, _q, qr, _q.inters) +} + +// AllX is like All, but panics if an error occurs. +func (_q *UsageCleanupTaskQuery) AllX(ctx context.Context) []*UsageCleanupTask { + nodes, err := _q.All(ctx) + if err != nil { + panic(err) + } + return nodes +} + +// IDs executes the query and returns a list of UsageCleanupTask IDs. +func (_q *UsageCleanupTaskQuery) IDs(ctx context.Context) (ids []int64, err error) { + if _q.ctx.Unique == nil && _q.path != nil { + _q.Unique(true) + } + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryIDs) + if err = _q.Select(usagecleanuptask.FieldID).Scan(ctx, &ids); err != nil { + return nil, err + } + return ids, nil +} + +// IDsX is like IDs, but panics if an error occurs. +func (_q *UsageCleanupTaskQuery) IDsX(ctx context.Context) []int64 { + ids, err := _q.IDs(ctx) + if err != nil { + panic(err) + } + return ids +} + +// Count returns the count of the given query. +func (_q *UsageCleanupTaskQuery) Count(ctx context.Context) (int, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryCount) + if err := _q.prepareQuery(ctx); err != nil { + return 0, err + } + return withInterceptors[int](ctx, _q, querierCount[*UsageCleanupTaskQuery](), _q.inters) +} + +// CountX is like Count, but panics if an error occurs. +func (_q *UsageCleanupTaskQuery) CountX(ctx context.Context) int { + count, err := _q.Count(ctx) + if err != nil { + panic(err) + } + return count +} + +// Exist returns true if the query has elements in the graph. +func (_q *UsageCleanupTaskQuery) Exist(ctx context.Context) (bool, error) { + ctx = setContextOp(ctx, _q.ctx, ent.OpQueryExist) + switch _, err := _q.FirstID(ctx); { + case IsNotFound(err): + return false, nil + case err != nil: + return false, fmt.Errorf("ent: check existence: %w", err) + default: + return true, nil + } +} + +// ExistX is like Exist, but panics if an error occurs. +func (_q *UsageCleanupTaskQuery) ExistX(ctx context.Context) bool { + exist, err := _q.Exist(ctx) + if err != nil { + panic(err) + } + return exist +} + +// Clone returns a duplicate of the UsageCleanupTaskQuery builder, including all associated steps. It can be +// used to prepare common query builders and use them differently after the clone is made. +func (_q *UsageCleanupTaskQuery) Clone() *UsageCleanupTaskQuery { + if _q == nil { + return nil + } + return &UsageCleanupTaskQuery{ + config: _q.config, + ctx: _q.ctx.Clone(), + order: append([]usagecleanuptask.OrderOption{}, _q.order...), + inters: append([]Interceptor{}, _q.inters...), + predicates: append([]predicate.UsageCleanupTask{}, _q.predicates...), + // clone intermediate query. + sql: _q.sql.Clone(), + path: _q.path, + } +} + +// GroupBy is used to group vertices by one or more fields/columns. +// It is often used with aggregate functions, like: count, max, mean, min, sum. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// Count int `json:"count,omitempty"` +// } +// +// client.UsageCleanupTask.Query(). +// GroupBy(usagecleanuptask.FieldCreatedAt). +// Aggregate(ent.Count()). +// Scan(ctx, &v) +func (_q *UsageCleanupTaskQuery) GroupBy(field string, fields ...string) *UsageCleanupTaskGroupBy { + _q.ctx.Fields = append([]string{field}, fields...) + grbuild := &UsageCleanupTaskGroupBy{build: _q} + grbuild.flds = &_q.ctx.Fields + grbuild.label = usagecleanuptask.Label + grbuild.scan = grbuild.Scan + return grbuild +} + +// Select allows the selection one or more fields/columns for the given query, +// instead of selecting all fields in the entity. +// +// Example: +// +// var v []struct { +// CreatedAt time.Time `json:"created_at,omitempty"` +// } +// +// client.UsageCleanupTask.Query(). +// Select(usagecleanuptask.FieldCreatedAt). +// Scan(ctx, &v) +func (_q *UsageCleanupTaskQuery) Select(fields ...string) *UsageCleanupTaskSelect { + _q.ctx.Fields = append(_q.ctx.Fields, fields...) + sbuild := &UsageCleanupTaskSelect{UsageCleanupTaskQuery: _q} + sbuild.label = usagecleanuptask.Label + sbuild.flds, sbuild.scan = &_q.ctx.Fields, sbuild.Scan + return sbuild +} + +// Aggregate returns a UsageCleanupTaskSelect configured with the given aggregations. +func (_q *UsageCleanupTaskQuery) Aggregate(fns ...AggregateFunc) *UsageCleanupTaskSelect { + return _q.Select().Aggregate(fns...) +} + +func (_q *UsageCleanupTaskQuery) prepareQuery(ctx context.Context) error { + for _, inter := range _q.inters { + if inter == nil { + return fmt.Errorf("ent: uninitialized interceptor (forgotten import ent/runtime?)") + } + if trv, ok := inter.(Traverser); ok { + if err := trv.Traverse(ctx, _q); err != nil { + return err + } + } + } + for _, f := range _q.ctx.Fields { + if !usagecleanuptask.ValidColumn(f) { + return &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + } + if _q.path != nil { + prev, err := _q.path(ctx) + if err != nil { + return err + } + _q.sql = prev + } + return nil +} + +func (_q *UsageCleanupTaskQuery) sqlAll(ctx context.Context, hooks ...queryHook) ([]*UsageCleanupTask, error) { + var ( + nodes = []*UsageCleanupTask{} + _spec = _q.querySpec() + ) + _spec.ScanValues = func(columns []string) ([]any, error) { + return (*UsageCleanupTask).scanValues(nil, columns) + } + _spec.Assign = func(columns []string, values []any) error { + node := &UsageCleanupTask{config: _q.config} + nodes = append(nodes, node) + return node.assignValues(columns, values) + } + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + for i := range hooks { + hooks[i](ctx, _spec) + } + if err := sqlgraph.QueryNodes(ctx, _q.driver, _spec); err != nil { + return nil, err + } + if len(nodes) == 0 { + return nodes, nil + } + return nodes, nil +} + +func (_q *UsageCleanupTaskQuery) sqlCount(ctx context.Context) (int, error) { + _spec := _q.querySpec() + if len(_q.modifiers) > 0 { + _spec.Modifiers = _q.modifiers + } + _spec.Node.Columns = _q.ctx.Fields + if len(_q.ctx.Fields) > 0 { + _spec.Unique = _q.ctx.Unique != nil && *_q.ctx.Unique + } + return sqlgraph.CountNodes(ctx, _q.driver, _spec) +} + +func (_q *UsageCleanupTaskQuery) querySpec() *sqlgraph.QuerySpec { + _spec := sqlgraph.NewQuerySpec(usagecleanuptask.Table, usagecleanuptask.Columns, sqlgraph.NewFieldSpec(usagecleanuptask.FieldID, field.TypeInt64)) + _spec.From = _q.sql + if unique := _q.ctx.Unique; unique != nil { + _spec.Unique = *unique + } else if _q.path != nil { + _spec.Unique = true + } + if fields := _q.ctx.Fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, usagecleanuptask.FieldID) + for i := range fields { + if fields[i] != usagecleanuptask.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, fields[i]) + } + } + } + if ps := _q.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if limit := _q.ctx.Limit; limit != nil { + _spec.Limit = *limit + } + if offset := _q.ctx.Offset; offset != nil { + _spec.Offset = *offset + } + if ps := _q.order; len(ps) > 0 { + _spec.Order = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + return _spec +} + +func (_q *UsageCleanupTaskQuery) sqlQuery(ctx context.Context) *sql.Selector { + builder := sql.Dialect(_q.driver.Dialect()) + t1 := builder.Table(usagecleanuptask.Table) + columns := _q.ctx.Fields + if len(columns) == 0 { + columns = usagecleanuptask.Columns + } + selector := builder.Select(t1.Columns(columns...)...).From(t1) + if _q.sql != nil { + selector = _q.sql + selector.Select(selector.Columns(columns...)...) + } + if _q.ctx.Unique != nil && *_q.ctx.Unique { + selector.Distinct() + } + for _, m := range _q.modifiers { + m(selector) + } + for _, p := range _q.predicates { + p(selector) + } + for _, p := range _q.order { + p(selector) + } + if offset := _q.ctx.Offset; offset != nil { + // limit is mandatory for offset clause. We start + // with default value, and override it below if needed. + selector.Offset(*offset).Limit(math.MaxInt32) + } + if limit := _q.ctx.Limit; limit != nil { + selector.Limit(*limit) + } + return selector +} + +// ForUpdate locks the selected rows against concurrent updates, and prevent them from being +// updated, deleted or "selected ... for update" by other sessions, until the transaction is +// either committed or rolled-back. +func (_q *UsageCleanupTaskQuery) ForUpdate(opts ...sql.LockOption) *UsageCleanupTaskQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForUpdate(opts...) + }) + return _q +} + +// ForShare behaves similarly to ForUpdate, except that it acquires a shared mode lock +// on any rows that are read. Other sessions can read the rows, but cannot modify them +// until your transaction commits. +func (_q *UsageCleanupTaskQuery) ForShare(opts ...sql.LockOption) *UsageCleanupTaskQuery { + if _q.driver.Dialect() == dialect.Postgres { + _q.Unique(false) + } + _q.modifiers = append(_q.modifiers, func(s *sql.Selector) { + s.ForShare(opts...) + }) + return _q +} + +// UsageCleanupTaskGroupBy is the group-by builder for UsageCleanupTask entities. +type UsageCleanupTaskGroupBy struct { + selector + build *UsageCleanupTaskQuery +} + +// Aggregate adds the given aggregation functions to the group-by query. +func (_g *UsageCleanupTaskGroupBy) Aggregate(fns ...AggregateFunc) *UsageCleanupTaskGroupBy { + _g.fns = append(_g.fns, fns...) + return _g +} + +// Scan applies the selector query and scans the result into the given value. +func (_g *UsageCleanupTaskGroupBy) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _g.build.ctx, ent.OpQueryGroupBy) + if err := _g.build.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*UsageCleanupTaskQuery, *UsageCleanupTaskGroupBy](ctx, _g.build, _g, _g.build.inters, v) +} + +func (_g *UsageCleanupTaskGroupBy) sqlScan(ctx context.Context, root *UsageCleanupTaskQuery, v any) error { + selector := root.sqlQuery(ctx).Select() + aggregation := make([]string, 0, len(_g.fns)) + for _, fn := range _g.fns { + aggregation = append(aggregation, fn(selector)) + } + if len(selector.SelectedColumns()) == 0 { + columns := make([]string, 0, len(*_g.flds)+len(_g.fns)) + for _, f := range *_g.flds { + columns = append(columns, selector.C(f)) + } + columns = append(columns, aggregation...) + selector.Select(columns...) + } + selector.GroupBy(selector.Columns(*_g.flds...)...) + if err := selector.Err(); err != nil { + return err + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _g.build.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} + +// UsageCleanupTaskSelect is the builder for selecting fields of UsageCleanupTask entities. +type UsageCleanupTaskSelect struct { + *UsageCleanupTaskQuery + selector +} + +// Aggregate adds the given aggregation functions to the selector query. +func (_s *UsageCleanupTaskSelect) Aggregate(fns ...AggregateFunc) *UsageCleanupTaskSelect { + _s.fns = append(_s.fns, fns...) + return _s +} + +// Scan applies the selector query and scans the result into the given value. +func (_s *UsageCleanupTaskSelect) Scan(ctx context.Context, v any) error { + ctx = setContextOp(ctx, _s.ctx, ent.OpQuerySelect) + if err := _s.prepareQuery(ctx); err != nil { + return err + } + return scanWithInterceptors[*UsageCleanupTaskQuery, *UsageCleanupTaskSelect](ctx, _s.UsageCleanupTaskQuery, _s, _s.inters, v) +} + +func (_s *UsageCleanupTaskSelect) sqlScan(ctx context.Context, root *UsageCleanupTaskQuery, v any) error { + selector := root.sqlQuery(ctx) + aggregation := make([]string, 0, len(_s.fns)) + for _, fn := range _s.fns { + aggregation = append(aggregation, fn(selector)) + } + switch n := len(*_s.selector.flds); { + case n == 0 && len(aggregation) > 0: + selector.Select(aggregation...) + case n != 0 && len(aggregation) > 0: + selector.AppendSelect(aggregation...) + } + rows := &sql.Rows{} + query, args := selector.Query() + if err := _s.driver.Query(ctx, query, args, rows); err != nil { + return err + } + defer rows.Close() + return sql.ScanSlice(rows, v) +} diff --git a/backend/ent/usagecleanuptask_update.go b/backend/ent/usagecleanuptask_update.go new file mode 100644 index 00000000..604202c6 --- /dev/null +++ b/backend/ent/usagecleanuptask_update.go @@ -0,0 +1,702 @@ +// Code generated by ent, DO NOT EDIT. + +package ent + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "time" + + "entgo.io/ent/dialect/sql" + "entgo.io/ent/dialect/sql/sqlgraph" + "entgo.io/ent/dialect/sql/sqljson" + "entgo.io/ent/schema/field" + "github.com/Wei-Shaw/sub2api/ent/predicate" + "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" +) + +// UsageCleanupTaskUpdate is the builder for updating UsageCleanupTask entities. +type UsageCleanupTaskUpdate struct { + config + hooks []Hook + mutation *UsageCleanupTaskMutation +} + +// Where appends a list predicates to the UsageCleanupTaskUpdate builder. +func (_u *UsageCleanupTaskUpdate) Where(ps ...predicate.UsageCleanupTask) *UsageCleanupTaskUpdate { + _u.mutation.Where(ps...) + return _u +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *UsageCleanupTaskUpdate) SetUpdatedAt(v time.Time) *UsageCleanupTaskUpdate { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetStatus sets the "status" field. +func (_u *UsageCleanupTaskUpdate) SetStatus(v string) *UsageCleanupTaskUpdate { + _u.mutation.SetStatus(v) + return _u +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (_u *UsageCleanupTaskUpdate) SetNillableStatus(v *string) *UsageCleanupTaskUpdate { + if v != nil { + _u.SetStatus(*v) + } + return _u +} + +// SetFilters sets the "filters" field. +func (_u *UsageCleanupTaskUpdate) SetFilters(v json.RawMessage) *UsageCleanupTaskUpdate { + _u.mutation.SetFilters(v) + return _u +} + +// AppendFilters appends value to the "filters" field. +func (_u *UsageCleanupTaskUpdate) AppendFilters(v json.RawMessage) *UsageCleanupTaskUpdate { + _u.mutation.AppendFilters(v) + return _u +} + +// SetCreatedBy sets the "created_by" field. +func (_u *UsageCleanupTaskUpdate) SetCreatedBy(v int64) *UsageCleanupTaskUpdate { + _u.mutation.ResetCreatedBy() + _u.mutation.SetCreatedBy(v) + return _u +} + +// SetNillableCreatedBy sets the "created_by" field if the given value is not nil. +func (_u *UsageCleanupTaskUpdate) SetNillableCreatedBy(v *int64) *UsageCleanupTaskUpdate { + if v != nil { + _u.SetCreatedBy(*v) + } + return _u +} + +// AddCreatedBy adds value to the "created_by" field. +func (_u *UsageCleanupTaskUpdate) AddCreatedBy(v int64) *UsageCleanupTaskUpdate { + _u.mutation.AddCreatedBy(v) + return _u +} + +// SetDeletedRows sets the "deleted_rows" field. +func (_u *UsageCleanupTaskUpdate) SetDeletedRows(v int64) *UsageCleanupTaskUpdate { + _u.mutation.ResetDeletedRows() + _u.mutation.SetDeletedRows(v) + return _u +} + +// SetNillableDeletedRows sets the "deleted_rows" field if the given value is not nil. +func (_u *UsageCleanupTaskUpdate) SetNillableDeletedRows(v *int64) *UsageCleanupTaskUpdate { + if v != nil { + _u.SetDeletedRows(*v) + } + return _u +} + +// AddDeletedRows adds value to the "deleted_rows" field. +func (_u *UsageCleanupTaskUpdate) AddDeletedRows(v int64) *UsageCleanupTaskUpdate { + _u.mutation.AddDeletedRows(v) + return _u +} + +// SetErrorMessage sets the "error_message" field. +func (_u *UsageCleanupTaskUpdate) SetErrorMessage(v string) *UsageCleanupTaskUpdate { + _u.mutation.SetErrorMessage(v) + return _u +} + +// SetNillableErrorMessage sets the "error_message" field if the given value is not nil. +func (_u *UsageCleanupTaskUpdate) SetNillableErrorMessage(v *string) *UsageCleanupTaskUpdate { + if v != nil { + _u.SetErrorMessage(*v) + } + return _u +} + +// ClearErrorMessage clears the value of the "error_message" field. +func (_u *UsageCleanupTaskUpdate) ClearErrorMessage() *UsageCleanupTaskUpdate { + _u.mutation.ClearErrorMessage() + return _u +} + +// SetCanceledBy sets the "canceled_by" field. +func (_u *UsageCleanupTaskUpdate) SetCanceledBy(v int64) *UsageCleanupTaskUpdate { + _u.mutation.ResetCanceledBy() + _u.mutation.SetCanceledBy(v) + return _u +} + +// SetNillableCanceledBy sets the "canceled_by" field if the given value is not nil. +func (_u *UsageCleanupTaskUpdate) SetNillableCanceledBy(v *int64) *UsageCleanupTaskUpdate { + if v != nil { + _u.SetCanceledBy(*v) + } + return _u +} + +// AddCanceledBy adds value to the "canceled_by" field. +func (_u *UsageCleanupTaskUpdate) AddCanceledBy(v int64) *UsageCleanupTaskUpdate { + _u.mutation.AddCanceledBy(v) + return _u +} + +// ClearCanceledBy clears the value of the "canceled_by" field. +func (_u *UsageCleanupTaskUpdate) ClearCanceledBy() *UsageCleanupTaskUpdate { + _u.mutation.ClearCanceledBy() + return _u +} + +// SetCanceledAt sets the "canceled_at" field. +func (_u *UsageCleanupTaskUpdate) SetCanceledAt(v time.Time) *UsageCleanupTaskUpdate { + _u.mutation.SetCanceledAt(v) + return _u +} + +// SetNillableCanceledAt sets the "canceled_at" field if the given value is not nil. +func (_u *UsageCleanupTaskUpdate) SetNillableCanceledAt(v *time.Time) *UsageCleanupTaskUpdate { + if v != nil { + _u.SetCanceledAt(*v) + } + return _u +} + +// ClearCanceledAt clears the value of the "canceled_at" field. +func (_u *UsageCleanupTaskUpdate) ClearCanceledAt() *UsageCleanupTaskUpdate { + _u.mutation.ClearCanceledAt() + return _u +} + +// SetStartedAt sets the "started_at" field. +func (_u *UsageCleanupTaskUpdate) SetStartedAt(v time.Time) *UsageCleanupTaskUpdate { + _u.mutation.SetStartedAt(v) + return _u +} + +// SetNillableStartedAt sets the "started_at" field if the given value is not nil. +func (_u *UsageCleanupTaskUpdate) SetNillableStartedAt(v *time.Time) *UsageCleanupTaskUpdate { + if v != nil { + _u.SetStartedAt(*v) + } + return _u +} + +// ClearStartedAt clears the value of the "started_at" field. +func (_u *UsageCleanupTaskUpdate) ClearStartedAt() *UsageCleanupTaskUpdate { + _u.mutation.ClearStartedAt() + return _u +} + +// SetFinishedAt sets the "finished_at" field. +func (_u *UsageCleanupTaskUpdate) SetFinishedAt(v time.Time) *UsageCleanupTaskUpdate { + _u.mutation.SetFinishedAt(v) + return _u +} + +// SetNillableFinishedAt sets the "finished_at" field if the given value is not nil. +func (_u *UsageCleanupTaskUpdate) SetNillableFinishedAt(v *time.Time) *UsageCleanupTaskUpdate { + if v != nil { + _u.SetFinishedAt(*v) + } + return _u +} + +// ClearFinishedAt clears the value of the "finished_at" field. +func (_u *UsageCleanupTaskUpdate) ClearFinishedAt() *UsageCleanupTaskUpdate { + _u.mutation.ClearFinishedAt() + return _u +} + +// Mutation returns the UsageCleanupTaskMutation object of the builder. +func (_u *UsageCleanupTaskUpdate) Mutation() *UsageCleanupTaskMutation { + return _u.mutation +} + +// Save executes the query and returns the number of nodes affected by the update operation. +func (_u *UsageCleanupTaskUpdate) Save(ctx context.Context) (int, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *UsageCleanupTaskUpdate) SaveX(ctx context.Context) int { + affected, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return affected +} + +// Exec executes the query. +func (_u *UsageCleanupTaskUpdate) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *UsageCleanupTaskUpdate) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *UsageCleanupTaskUpdate) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := usagecleanuptask.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *UsageCleanupTaskUpdate) check() error { + if v, ok := _u.mutation.Status(); ok { + if err := usagecleanuptask.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "UsageCleanupTask.status": %w`, err)} + } + } + return nil +} + +func (_u *UsageCleanupTaskUpdate) sqlSave(ctx context.Context) (_node int, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(usagecleanuptask.Table, usagecleanuptask.Columns, sqlgraph.NewFieldSpec(usagecleanuptask.FieldID, field.TypeInt64)) + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(usagecleanuptask.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.Status(); ok { + _spec.SetField(usagecleanuptask.FieldStatus, field.TypeString, value) + } + if value, ok := _u.mutation.Filters(); ok { + _spec.SetField(usagecleanuptask.FieldFilters, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedFilters(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, usagecleanuptask.FieldFilters, value) + }) + } + if value, ok := _u.mutation.CreatedBy(); ok { + _spec.SetField(usagecleanuptask.FieldCreatedBy, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedCreatedBy(); ok { + _spec.AddField(usagecleanuptask.FieldCreatedBy, field.TypeInt64, value) + } + if value, ok := _u.mutation.DeletedRows(); ok { + _spec.SetField(usagecleanuptask.FieldDeletedRows, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedDeletedRows(); ok { + _spec.AddField(usagecleanuptask.FieldDeletedRows, field.TypeInt64, value) + } + if value, ok := _u.mutation.ErrorMessage(); ok { + _spec.SetField(usagecleanuptask.FieldErrorMessage, field.TypeString, value) + } + if _u.mutation.ErrorMessageCleared() { + _spec.ClearField(usagecleanuptask.FieldErrorMessage, field.TypeString) + } + if value, ok := _u.mutation.CanceledBy(); ok { + _spec.SetField(usagecleanuptask.FieldCanceledBy, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedCanceledBy(); ok { + _spec.AddField(usagecleanuptask.FieldCanceledBy, field.TypeInt64, value) + } + if _u.mutation.CanceledByCleared() { + _spec.ClearField(usagecleanuptask.FieldCanceledBy, field.TypeInt64) + } + if value, ok := _u.mutation.CanceledAt(); ok { + _spec.SetField(usagecleanuptask.FieldCanceledAt, field.TypeTime, value) + } + if _u.mutation.CanceledAtCleared() { + _spec.ClearField(usagecleanuptask.FieldCanceledAt, field.TypeTime) + } + if value, ok := _u.mutation.StartedAt(); ok { + _spec.SetField(usagecleanuptask.FieldStartedAt, field.TypeTime, value) + } + if _u.mutation.StartedAtCleared() { + _spec.ClearField(usagecleanuptask.FieldStartedAt, field.TypeTime) + } + if value, ok := _u.mutation.FinishedAt(); ok { + _spec.SetField(usagecleanuptask.FieldFinishedAt, field.TypeTime, value) + } + if _u.mutation.FinishedAtCleared() { + _spec.ClearField(usagecleanuptask.FieldFinishedAt, field.TypeTime) + } + if _node, err = sqlgraph.UpdateNodes(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{usagecleanuptask.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return 0, err + } + _u.mutation.done = true + return _node, nil +} + +// UsageCleanupTaskUpdateOne is the builder for updating a single UsageCleanupTask entity. +type UsageCleanupTaskUpdateOne struct { + config + fields []string + hooks []Hook + mutation *UsageCleanupTaskMutation +} + +// SetUpdatedAt sets the "updated_at" field. +func (_u *UsageCleanupTaskUpdateOne) SetUpdatedAt(v time.Time) *UsageCleanupTaskUpdateOne { + _u.mutation.SetUpdatedAt(v) + return _u +} + +// SetStatus sets the "status" field. +func (_u *UsageCleanupTaskUpdateOne) SetStatus(v string) *UsageCleanupTaskUpdateOne { + _u.mutation.SetStatus(v) + return _u +} + +// SetNillableStatus sets the "status" field if the given value is not nil. +func (_u *UsageCleanupTaskUpdateOne) SetNillableStatus(v *string) *UsageCleanupTaskUpdateOne { + if v != nil { + _u.SetStatus(*v) + } + return _u +} + +// SetFilters sets the "filters" field. +func (_u *UsageCleanupTaskUpdateOne) SetFilters(v json.RawMessage) *UsageCleanupTaskUpdateOne { + _u.mutation.SetFilters(v) + return _u +} + +// AppendFilters appends value to the "filters" field. +func (_u *UsageCleanupTaskUpdateOne) AppendFilters(v json.RawMessage) *UsageCleanupTaskUpdateOne { + _u.mutation.AppendFilters(v) + return _u +} + +// SetCreatedBy sets the "created_by" field. +func (_u *UsageCleanupTaskUpdateOne) SetCreatedBy(v int64) *UsageCleanupTaskUpdateOne { + _u.mutation.ResetCreatedBy() + _u.mutation.SetCreatedBy(v) + return _u +} + +// SetNillableCreatedBy sets the "created_by" field if the given value is not nil. +func (_u *UsageCleanupTaskUpdateOne) SetNillableCreatedBy(v *int64) *UsageCleanupTaskUpdateOne { + if v != nil { + _u.SetCreatedBy(*v) + } + return _u +} + +// AddCreatedBy adds value to the "created_by" field. +func (_u *UsageCleanupTaskUpdateOne) AddCreatedBy(v int64) *UsageCleanupTaskUpdateOne { + _u.mutation.AddCreatedBy(v) + return _u +} + +// SetDeletedRows sets the "deleted_rows" field. +func (_u *UsageCleanupTaskUpdateOne) SetDeletedRows(v int64) *UsageCleanupTaskUpdateOne { + _u.mutation.ResetDeletedRows() + _u.mutation.SetDeletedRows(v) + return _u +} + +// SetNillableDeletedRows sets the "deleted_rows" field if the given value is not nil. +func (_u *UsageCleanupTaskUpdateOne) SetNillableDeletedRows(v *int64) *UsageCleanupTaskUpdateOne { + if v != nil { + _u.SetDeletedRows(*v) + } + return _u +} + +// AddDeletedRows adds value to the "deleted_rows" field. +func (_u *UsageCleanupTaskUpdateOne) AddDeletedRows(v int64) *UsageCleanupTaskUpdateOne { + _u.mutation.AddDeletedRows(v) + return _u +} + +// SetErrorMessage sets the "error_message" field. +func (_u *UsageCleanupTaskUpdateOne) SetErrorMessage(v string) *UsageCleanupTaskUpdateOne { + _u.mutation.SetErrorMessage(v) + return _u +} + +// SetNillableErrorMessage sets the "error_message" field if the given value is not nil. +func (_u *UsageCleanupTaskUpdateOne) SetNillableErrorMessage(v *string) *UsageCleanupTaskUpdateOne { + if v != nil { + _u.SetErrorMessage(*v) + } + return _u +} + +// ClearErrorMessage clears the value of the "error_message" field. +func (_u *UsageCleanupTaskUpdateOne) ClearErrorMessage() *UsageCleanupTaskUpdateOne { + _u.mutation.ClearErrorMessage() + return _u +} + +// SetCanceledBy sets the "canceled_by" field. +func (_u *UsageCleanupTaskUpdateOne) SetCanceledBy(v int64) *UsageCleanupTaskUpdateOne { + _u.mutation.ResetCanceledBy() + _u.mutation.SetCanceledBy(v) + return _u +} + +// SetNillableCanceledBy sets the "canceled_by" field if the given value is not nil. +func (_u *UsageCleanupTaskUpdateOne) SetNillableCanceledBy(v *int64) *UsageCleanupTaskUpdateOne { + if v != nil { + _u.SetCanceledBy(*v) + } + return _u +} + +// AddCanceledBy adds value to the "canceled_by" field. +func (_u *UsageCleanupTaskUpdateOne) AddCanceledBy(v int64) *UsageCleanupTaskUpdateOne { + _u.mutation.AddCanceledBy(v) + return _u +} + +// ClearCanceledBy clears the value of the "canceled_by" field. +func (_u *UsageCleanupTaskUpdateOne) ClearCanceledBy() *UsageCleanupTaskUpdateOne { + _u.mutation.ClearCanceledBy() + return _u +} + +// SetCanceledAt sets the "canceled_at" field. +func (_u *UsageCleanupTaskUpdateOne) SetCanceledAt(v time.Time) *UsageCleanupTaskUpdateOne { + _u.mutation.SetCanceledAt(v) + return _u +} + +// SetNillableCanceledAt sets the "canceled_at" field if the given value is not nil. +func (_u *UsageCleanupTaskUpdateOne) SetNillableCanceledAt(v *time.Time) *UsageCleanupTaskUpdateOne { + if v != nil { + _u.SetCanceledAt(*v) + } + return _u +} + +// ClearCanceledAt clears the value of the "canceled_at" field. +func (_u *UsageCleanupTaskUpdateOne) ClearCanceledAt() *UsageCleanupTaskUpdateOne { + _u.mutation.ClearCanceledAt() + return _u +} + +// SetStartedAt sets the "started_at" field. +func (_u *UsageCleanupTaskUpdateOne) SetStartedAt(v time.Time) *UsageCleanupTaskUpdateOne { + _u.mutation.SetStartedAt(v) + return _u +} + +// SetNillableStartedAt sets the "started_at" field if the given value is not nil. +func (_u *UsageCleanupTaskUpdateOne) SetNillableStartedAt(v *time.Time) *UsageCleanupTaskUpdateOne { + if v != nil { + _u.SetStartedAt(*v) + } + return _u +} + +// ClearStartedAt clears the value of the "started_at" field. +func (_u *UsageCleanupTaskUpdateOne) ClearStartedAt() *UsageCleanupTaskUpdateOne { + _u.mutation.ClearStartedAt() + return _u +} + +// SetFinishedAt sets the "finished_at" field. +func (_u *UsageCleanupTaskUpdateOne) SetFinishedAt(v time.Time) *UsageCleanupTaskUpdateOne { + _u.mutation.SetFinishedAt(v) + return _u +} + +// SetNillableFinishedAt sets the "finished_at" field if the given value is not nil. +func (_u *UsageCleanupTaskUpdateOne) SetNillableFinishedAt(v *time.Time) *UsageCleanupTaskUpdateOne { + if v != nil { + _u.SetFinishedAt(*v) + } + return _u +} + +// ClearFinishedAt clears the value of the "finished_at" field. +func (_u *UsageCleanupTaskUpdateOne) ClearFinishedAt() *UsageCleanupTaskUpdateOne { + _u.mutation.ClearFinishedAt() + return _u +} + +// Mutation returns the UsageCleanupTaskMutation object of the builder. +func (_u *UsageCleanupTaskUpdateOne) Mutation() *UsageCleanupTaskMutation { + return _u.mutation +} + +// Where appends a list predicates to the UsageCleanupTaskUpdate builder. +func (_u *UsageCleanupTaskUpdateOne) Where(ps ...predicate.UsageCleanupTask) *UsageCleanupTaskUpdateOne { + _u.mutation.Where(ps...) + return _u +} + +// Select allows selecting one or more fields (columns) of the returned entity. +// The default is selecting all fields defined in the entity schema. +func (_u *UsageCleanupTaskUpdateOne) Select(field string, fields ...string) *UsageCleanupTaskUpdateOne { + _u.fields = append([]string{field}, fields...) + return _u +} + +// Save executes the query and returns the updated UsageCleanupTask entity. +func (_u *UsageCleanupTaskUpdateOne) Save(ctx context.Context) (*UsageCleanupTask, error) { + _u.defaults() + return withHooks(ctx, _u.sqlSave, _u.mutation, _u.hooks) +} + +// SaveX is like Save, but panics if an error occurs. +func (_u *UsageCleanupTaskUpdateOne) SaveX(ctx context.Context) *UsageCleanupTask { + node, err := _u.Save(ctx) + if err != nil { + panic(err) + } + return node +} + +// Exec executes the query on the entity. +func (_u *UsageCleanupTaskUpdateOne) Exec(ctx context.Context) error { + _, err := _u.Save(ctx) + return err +} + +// ExecX is like Exec, but panics if an error occurs. +func (_u *UsageCleanupTaskUpdateOne) ExecX(ctx context.Context) { + if err := _u.Exec(ctx); err != nil { + panic(err) + } +} + +// defaults sets the default values of the builder before save. +func (_u *UsageCleanupTaskUpdateOne) defaults() { + if _, ok := _u.mutation.UpdatedAt(); !ok { + v := usagecleanuptask.UpdateDefaultUpdatedAt() + _u.mutation.SetUpdatedAt(v) + } +} + +// check runs all checks and user-defined validators on the builder. +func (_u *UsageCleanupTaskUpdateOne) check() error { + if v, ok := _u.mutation.Status(); ok { + if err := usagecleanuptask.StatusValidator(v); err != nil { + return &ValidationError{Name: "status", err: fmt.Errorf(`ent: validator failed for field "UsageCleanupTask.status": %w`, err)} + } + } + return nil +} + +func (_u *UsageCleanupTaskUpdateOne) sqlSave(ctx context.Context) (_node *UsageCleanupTask, err error) { + if err := _u.check(); err != nil { + return _node, err + } + _spec := sqlgraph.NewUpdateSpec(usagecleanuptask.Table, usagecleanuptask.Columns, sqlgraph.NewFieldSpec(usagecleanuptask.FieldID, field.TypeInt64)) + id, ok := _u.mutation.ID() + if !ok { + return nil, &ValidationError{Name: "id", err: errors.New(`ent: missing "UsageCleanupTask.id" for update`)} + } + _spec.Node.ID.Value = id + if fields := _u.fields; len(fields) > 0 { + _spec.Node.Columns = make([]string, 0, len(fields)) + _spec.Node.Columns = append(_spec.Node.Columns, usagecleanuptask.FieldID) + for _, f := range fields { + if !usagecleanuptask.ValidColumn(f) { + return nil, &ValidationError{Name: f, err: fmt.Errorf("ent: invalid field %q for query", f)} + } + if f != usagecleanuptask.FieldID { + _spec.Node.Columns = append(_spec.Node.Columns, f) + } + } + } + if ps := _u.mutation.predicates; len(ps) > 0 { + _spec.Predicate = func(selector *sql.Selector) { + for i := range ps { + ps[i](selector) + } + } + } + if value, ok := _u.mutation.UpdatedAt(); ok { + _spec.SetField(usagecleanuptask.FieldUpdatedAt, field.TypeTime, value) + } + if value, ok := _u.mutation.Status(); ok { + _spec.SetField(usagecleanuptask.FieldStatus, field.TypeString, value) + } + if value, ok := _u.mutation.Filters(); ok { + _spec.SetField(usagecleanuptask.FieldFilters, field.TypeJSON, value) + } + if value, ok := _u.mutation.AppendedFilters(); ok { + _spec.AddModifier(func(u *sql.UpdateBuilder) { + sqljson.Append(u, usagecleanuptask.FieldFilters, value) + }) + } + if value, ok := _u.mutation.CreatedBy(); ok { + _spec.SetField(usagecleanuptask.FieldCreatedBy, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedCreatedBy(); ok { + _spec.AddField(usagecleanuptask.FieldCreatedBy, field.TypeInt64, value) + } + if value, ok := _u.mutation.DeletedRows(); ok { + _spec.SetField(usagecleanuptask.FieldDeletedRows, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedDeletedRows(); ok { + _spec.AddField(usagecleanuptask.FieldDeletedRows, field.TypeInt64, value) + } + if value, ok := _u.mutation.ErrorMessage(); ok { + _spec.SetField(usagecleanuptask.FieldErrorMessage, field.TypeString, value) + } + if _u.mutation.ErrorMessageCleared() { + _spec.ClearField(usagecleanuptask.FieldErrorMessage, field.TypeString) + } + if value, ok := _u.mutation.CanceledBy(); ok { + _spec.SetField(usagecleanuptask.FieldCanceledBy, field.TypeInt64, value) + } + if value, ok := _u.mutation.AddedCanceledBy(); ok { + _spec.AddField(usagecleanuptask.FieldCanceledBy, field.TypeInt64, value) + } + if _u.mutation.CanceledByCleared() { + _spec.ClearField(usagecleanuptask.FieldCanceledBy, field.TypeInt64) + } + if value, ok := _u.mutation.CanceledAt(); ok { + _spec.SetField(usagecleanuptask.FieldCanceledAt, field.TypeTime, value) + } + if _u.mutation.CanceledAtCleared() { + _spec.ClearField(usagecleanuptask.FieldCanceledAt, field.TypeTime) + } + if value, ok := _u.mutation.StartedAt(); ok { + _spec.SetField(usagecleanuptask.FieldStartedAt, field.TypeTime, value) + } + if _u.mutation.StartedAtCleared() { + _spec.ClearField(usagecleanuptask.FieldStartedAt, field.TypeTime) + } + if value, ok := _u.mutation.FinishedAt(); ok { + _spec.SetField(usagecleanuptask.FieldFinishedAt, field.TypeTime, value) + } + if _u.mutation.FinishedAtCleared() { + _spec.ClearField(usagecleanuptask.FieldFinishedAt, field.TypeTime) + } + _node = &UsageCleanupTask{config: _u.config} + _spec.Assign = _node.assignValues + _spec.ScanValues = _node.scanValues + if err = sqlgraph.UpdateNode(ctx, _u.driver, _spec); err != nil { + if _, ok := err.(*sqlgraph.NotFoundError); ok { + err = &NotFoundError{usagecleanuptask.Label} + } else if sqlgraph.IsConstraintError(err) { + err = &ConstraintError{msg: err.Error(), wrap: err} + } + return nil, err + } + _u.mutation.done = true + return _node, nil +} diff --git a/backend/go.mod b/backend/go.mod index 4ac6ba14..fd429b07 100644 --- a/backend/go.mod +++ b/backend/go.mod @@ -31,6 +31,7 @@ require ( ariga.io/atlas v0.32.1-0.20250325101103-175b25e1c1b9 // indirect dario.cat/mergo v1.0.2 // indirect github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect + github.com/DATA-DOG/go-sqlmock v1.5.2 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect github.com/agext/levenshtein v1.2.3 // indirect github.com/andybalholm/brotli v1.2.0 // indirect @@ -97,6 +98,7 @@ require ( github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect github.com/modern-go/reflect2 v1.0.2 // indirect github.com/morikuni/aec v1.0.0 // indirect + github.com/ncruces/go-strftime v1.0.0 // indirect github.com/olekukonko/tablewriter v0.0.5 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect github.com/opencontainers/image-spec v1.1.1 // indirect @@ -107,6 +109,7 @@ require ( github.com/quic-go/qpack v0.6.0 // indirect github.com/quic-go/quic-go v0.57.1 // indirect github.com/refraction-networking/utls v1.8.1 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rivo/uniseg v0.2.0 // indirect github.com/robfig/cron/v3 v3.0.1 // indirect github.com/sagikazarmark/locafero v0.4.0 // indirect @@ -139,7 +142,7 @@ require ( go.uber.org/automaxprocs v1.6.0 // indirect go.uber.org/multierr v1.9.0 // indirect golang.org/x/arch v0.3.0 // indirect - golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect + golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 // indirect golang.org/x/mod v0.30.0 // indirect golang.org/x/sys v0.39.0 // indirect golang.org/x/text v0.32.0 // indirect @@ -148,4 +151,8 @@ require ( google.golang.org/grpc v1.75.1 // indirect google.golang.org/protobuf v1.36.10 // indirect gopkg.in/ini.v1 v1.67.0 // indirect + modernc.org/libc v1.67.6 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.11.0 // indirect + modernc.org/sqlite v1.44.1 // indirect ) diff --git a/backend/go.sum b/backend/go.sum index 415e73a7..aa10718c 100644 --- a/backend/go.sum +++ b/backend/go.sum @@ -141,6 +141,7 @@ github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= @@ -199,6 +200,8 @@ github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9G github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= +github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= +github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec= github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= @@ -224,6 +227,8 @@ github.com/redis/go-redis/v9 v9.17.2 h1:P2EGsA4qVIM3Pp+aPocCJ7DguDHhqrXNhVcEp4Vi github.com/redis/go-redis/v9 v9.17.2/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370= github.com/refraction-networking/utls v1.8.1 h1:yNY1kapmQU8JeM1sSw2H2asfTIwWxIkrMJI0pRUOCAo= github.com/refraction-networking/utls v1.8.1/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= @@ -338,6 +343,8 @@ golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk= golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc= golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= @@ -365,6 +372,7 @@ golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= golang.org/x/tools/go/expect v0.1.0-deprecated h1:jY2C5HGYR5lqex3gEniOQL0r7Dq5+VGVgY1nudX5lXY= golang.org/x/tools/go/expect v0.1.0-deprecated/go.mod h1:eihoPOH+FgIqa3FpoTwguz/bVUSGBlGQU67vpBeOrBY= +golang.org/x/tools/go/expect v0.1.1-deprecated h1:jpBZDwmgPhXsKZC6WhL20P4b/wmnpsEAGHaNy0n/rJM= golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated h1:1h2MnaIAIXISqTFKdENegdpAgUXz6NrPEsbIeWaBRvM= golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated/go.mod h1:RVAQXBGNv1ib0J382/DPCRS/BPnsGebyM1Gj5VSDpG8= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -387,4 +395,12 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q= gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA= +modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI= +modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/sqlite v1.44.1 h1:qybx/rNpfQipX/t47OxbHmkkJuv2JWifCMH8SVUiDas= +modernc.org/sqlite v1.44.1/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index 655169cc..00a78480 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -55,6 +55,7 @@ type Config struct { APIKeyAuth APIKeyAuthCacheConfig `mapstructure:"api_key_auth_cache"` Dashboard DashboardCacheConfig `mapstructure:"dashboard_cache"` DashboardAgg DashboardAggregationConfig `mapstructure:"dashboard_aggregation"` + UsageCleanup UsageCleanupConfig `mapstructure:"usage_cleanup"` Concurrency ConcurrencyConfig `mapstructure:"concurrency"` TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` RunMode string `mapstructure:"run_mode" yaml:"run_mode"` @@ -257,8 +258,43 @@ type GatewayConfig struct { // 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义) FailoverOn400 bool `mapstructure:"failover_on_400"` + // 账户切换最大次数(遇到上游错误时切换到其他账户的次数上限) + MaxAccountSwitches int `mapstructure:"max_account_switches"` + // Gemini 账户切换最大次数(Gemini 平台单独配置,因 API 限制更严格) + MaxAccountSwitchesGemini int `mapstructure:"max_account_switches_gemini"` + + // Antigravity 429 fallback 限流时间(分钟),解析重置时间失败时使用 + AntigravityFallbackCooldownMinutes int `mapstructure:"antigravity_fallback_cooldown_minutes"` + // Scheduling: 账号调度相关配置 Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"` + + // TLSFingerprint: TLS指纹伪装配置 + TLSFingerprint TLSFingerprintConfig `mapstructure:"tls_fingerprint"` +} + +// TLSFingerprintConfig TLS指纹伪装配置 +// 用于模拟 Claude CLI (Node.js) 的 TLS 握手特征,避免被识别为非官方客户端 +type TLSFingerprintConfig struct { + // Enabled: 是否全局启用TLS指纹功能 + Enabled bool `mapstructure:"enabled"` + // Profiles: 预定义的TLS指纹配置模板 + // key 为模板名称,如 "claude_cli_v2", "chrome_120" 等 + Profiles map[string]TLSProfileConfig `mapstructure:"profiles"` +} + +// TLSProfileConfig 单个TLS指纹模板的配置 +type TLSProfileConfig struct { + // Name: 模板显示名称 + Name string `mapstructure:"name"` + // EnableGREASE: 是否启用GREASE扩展(Chrome使用,Node.js不使用) + EnableGREASE bool `mapstructure:"enable_grease"` + // CipherSuites: TLS加密套件列表(空则使用内置默认值) + CipherSuites []uint16 `mapstructure:"cipher_suites"` + // Curves: 椭圆曲线列表(空则使用内置默认值) + Curves []uint16 `mapstructure:"curves"` + // PointFormats: 点格式列表(空则使用内置默认值) + PointFormats []uint8 `mapstructure:"point_formats"` } // GatewaySchedulingConfig accounts scheduling configuration. @@ -271,6 +307,9 @@ type GatewaySchedulingConfig struct { FallbackWaitTimeout time.Duration `mapstructure:"fallback_wait_timeout"` FallbackMaxWaiting int `mapstructure:"fallback_max_waiting"` + // 兜底层账户选择策略: "last_used"(按最后使用时间排序,默认) 或 "random"(随机) + FallbackSelectionMode string `mapstructure:"fallback_selection_mode"` + // 负载计算 LoadBatchEnabled bool `mapstructure:"load_batch_enabled"` @@ -493,6 +532,20 @@ type DashboardAggregationRetentionConfig struct { DailyDays int `mapstructure:"daily_days"` } +// UsageCleanupConfig 使用记录清理任务配置 +type UsageCleanupConfig struct { + // Enabled: 是否启用清理任务执行器 + Enabled bool `mapstructure:"enabled"` + // MaxRangeDays: 单次任务允许的最大时间跨度(天) + MaxRangeDays int `mapstructure:"max_range_days"` + // BatchSize: 单批删除数量 + BatchSize int `mapstructure:"batch_size"` + // WorkerIntervalSeconds: 后台任务轮询间隔(秒) + WorkerIntervalSeconds int `mapstructure:"worker_interval_seconds"` + // TaskTimeoutSeconds: 单次任务最大执行时长(秒) + TaskTimeoutSeconds int `mapstructure:"task_timeout_seconds"` +} + func NormalizeRunMode(value string) string { normalized := strings.ToLower(strings.TrimSpace(value)) switch normalized { @@ -753,12 +806,22 @@ func setDefaults() { viper.SetDefault("dashboard_aggregation.retention.daily_days", 730) viper.SetDefault("dashboard_aggregation.recompute_days", 2) + // Usage cleanup task + viper.SetDefault("usage_cleanup.enabled", true) + viper.SetDefault("usage_cleanup.max_range_days", 31) + viper.SetDefault("usage_cleanup.batch_size", 5000) + viper.SetDefault("usage_cleanup.worker_interval_seconds", 10) + viper.SetDefault("usage_cleanup.task_timeout_seconds", 1800) + // Gateway viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头,LLM高负载时可能排队较久 viper.SetDefault("gateway.log_upstream_error_body", true) viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048) viper.SetDefault("gateway.inject_beta_for_apikey", false) viper.SetDefault("gateway.failover_on_400", false) + viper.SetDefault("gateway.max_account_switches", 10) + viper.SetDefault("gateway.max_account_switches_gemini", 3) + viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1) viper.SetDefault("gateway.max_body_size", int64(100*1024*1024)) viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) // HTTP 上游连接池配置(针对 5000+ 并发用户优化) @@ -771,11 +834,12 @@ func setDefaults() { viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求) viper.SetDefault("gateway.stream_data_interval_timeout", 180) viper.SetDefault("gateway.stream_keepalive_interval", 10) - viper.SetDefault("gateway.max_line_size", 10*1024*1024) + viper.SetDefault("gateway.max_line_size", 40*1024*1024) viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3) viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second) viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second) viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100) + viper.SetDefault("gateway.scheduling.fallback_selection_mode", "last_used") viper.SetDefault("gateway.scheduling.load_batch_enabled", true) viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second) viper.SetDefault("gateway.scheduling.db_fallback_enabled", true) @@ -787,6 +851,8 @@ func setDefaults() { viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_failures", 3) viper.SetDefault("gateway.scheduling.outbox_backlog_rebuild_rows", 10000) viper.SetDefault("gateway.scheduling.full_rebuild_interval_seconds", 300) + // TLS指纹伪装配置(默认关闭,需要账号级别单独启用) + viper.SetDefault("gateway.tls_fingerprint.enabled", true) viper.SetDefault("concurrency.ping_interval", 10) // TokenRefresh @@ -989,6 +1055,33 @@ func (c *Config) Validate() error { return fmt.Errorf("dashboard_aggregation.recompute_days must be non-negative") } } + if c.UsageCleanup.Enabled { + if c.UsageCleanup.MaxRangeDays <= 0 { + return fmt.Errorf("usage_cleanup.max_range_days must be positive") + } + if c.UsageCleanup.BatchSize <= 0 { + return fmt.Errorf("usage_cleanup.batch_size must be positive") + } + if c.UsageCleanup.WorkerIntervalSeconds <= 0 { + return fmt.Errorf("usage_cleanup.worker_interval_seconds must be positive") + } + if c.UsageCleanup.TaskTimeoutSeconds <= 0 { + return fmt.Errorf("usage_cleanup.task_timeout_seconds must be positive") + } + } else { + if c.UsageCleanup.MaxRangeDays < 0 { + return fmt.Errorf("usage_cleanup.max_range_days must be non-negative") + } + if c.UsageCleanup.BatchSize < 0 { + return fmt.Errorf("usage_cleanup.batch_size must be non-negative") + } + if c.UsageCleanup.WorkerIntervalSeconds < 0 { + return fmt.Errorf("usage_cleanup.worker_interval_seconds must be non-negative") + } + if c.UsageCleanup.TaskTimeoutSeconds < 0 { + return fmt.Errorf("usage_cleanup.task_timeout_seconds must be non-negative") + } + } if c.Gateway.MaxBodySize <= 0 { return fmt.Errorf("gateway.max_body_size must be positive") } diff --git a/backend/internal/config/config_test.go b/backend/internal/config/config_test.go index 4637989e..f734619f 100644 --- a/backend/internal/config/config_test.go +++ b/backend/internal/config/config_test.go @@ -280,3 +280,573 @@ func TestValidateDashboardAggregationBackfillMaxDays(t *testing.T) { t.Fatalf("Validate() expected backfill_max_days error, got: %v", err) } } + +func TestLoadDefaultUsageCleanupConfig(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if !cfg.UsageCleanup.Enabled { + t.Fatalf("UsageCleanup.Enabled = false, want true") + } + if cfg.UsageCleanup.MaxRangeDays != 31 { + t.Fatalf("UsageCleanup.MaxRangeDays = %d, want 31", cfg.UsageCleanup.MaxRangeDays) + } + if cfg.UsageCleanup.BatchSize != 5000 { + t.Fatalf("UsageCleanup.BatchSize = %d, want 5000", cfg.UsageCleanup.BatchSize) + } + if cfg.UsageCleanup.WorkerIntervalSeconds != 10 { + t.Fatalf("UsageCleanup.WorkerIntervalSeconds = %d, want 10", cfg.UsageCleanup.WorkerIntervalSeconds) + } + if cfg.UsageCleanup.TaskTimeoutSeconds != 1800 { + t.Fatalf("UsageCleanup.TaskTimeoutSeconds = %d, want 1800", cfg.UsageCleanup.TaskTimeoutSeconds) + } +} + +func TestValidateUsageCleanupConfigEnabled(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.UsageCleanup.Enabled = true + cfg.UsageCleanup.MaxRangeDays = 0 + err = cfg.Validate() + if err == nil { + t.Fatalf("Validate() expected error for usage_cleanup.max_range_days, got nil") + } + if !strings.Contains(err.Error(), "usage_cleanup.max_range_days") { + t.Fatalf("Validate() expected max_range_days error, got: %v", err) + } +} + +func TestValidateUsageCleanupConfigDisabled(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.UsageCleanup.Enabled = false + cfg.UsageCleanup.BatchSize = -1 + err = cfg.Validate() + if err == nil { + t.Fatalf("Validate() expected error for usage_cleanup.batch_size, got nil") + } + if !strings.Contains(err.Error(), "usage_cleanup.batch_size") { + t.Fatalf("Validate() expected batch_size error, got: %v", err) + } +} + +func TestConfigAddressHelpers(t *testing.T) { + server := ServerConfig{Host: "127.0.0.1", Port: 9000} + if server.Address() != "127.0.0.1:9000" { + t.Fatalf("ServerConfig.Address() = %q", server.Address()) + } + + dbCfg := DatabaseConfig{ + Host: "localhost", + Port: 5432, + User: "postgres", + Password: "", + DBName: "sub2api", + SSLMode: "disable", + } + if !strings.Contains(dbCfg.DSN(), "password=") { + } else { + t.Fatalf("DatabaseConfig.DSN() should not include password when empty") + } + + dbCfg.Password = "secret" + if !strings.Contains(dbCfg.DSN(), "password=secret") { + t.Fatalf("DatabaseConfig.DSN() missing password") + } + + dbCfg.Password = "" + if strings.Contains(dbCfg.DSNWithTimezone("UTC"), "password=") { + t.Fatalf("DatabaseConfig.DSNWithTimezone() should omit password when empty") + } + + if !strings.Contains(dbCfg.DSNWithTimezone(""), "TimeZone=Asia/Shanghai") { + t.Fatalf("DatabaseConfig.DSNWithTimezone() should use default timezone") + } + if !strings.Contains(dbCfg.DSNWithTimezone("UTC"), "TimeZone=UTC") { + t.Fatalf("DatabaseConfig.DSNWithTimezone() should use provided timezone") + } + + redis := RedisConfig{Host: "redis", Port: 6379} + if redis.Address() != "redis:6379" { + t.Fatalf("RedisConfig.Address() = %q", redis.Address()) + } +} + +func TestNormalizeStringSlice(t *testing.T) { + values := normalizeStringSlice([]string{" a ", "", "b", " ", "c"}) + if len(values) != 3 || values[0] != "a" || values[1] != "b" || values[2] != "c" { + t.Fatalf("normalizeStringSlice() unexpected result: %#v", values) + } + if normalizeStringSlice(nil) != nil { + t.Fatalf("normalizeStringSlice(nil) expected nil slice") + } +} + +func TestGetServerAddressFromEnv(t *testing.T) { + t.Setenv("SERVER_HOST", "127.0.0.1") + t.Setenv("SERVER_PORT", "9090") + + address := GetServerAddress() + if address != "127.0.0.1:9090" { + t.Fatalf("GetServerAddress() = %q", address) + } +} + +func TestValidateAbsoluteHTTPURL(t *testing.T) { + if err := ValidateAbsoluteHTTPURL("https://example.com/path"); err != nil { + t.Fatalf("ValidateAbsoluteHTTPURL valid url error: %v", err) + } + if err := ValidateAbsoluteHTTPURL(""); err == nil { + t.Fatalf("ValidateAbsoluteHTTPURL should reject empty url") + } + if err := ValidateAbsoluteHTTPURL("/relative"); err == nil { + t.Fatalf("ValidateAbsoluteHTTPURL should reject relative url") + } + if err := ValidateAbsoluteHTTPURL("ftp://example.com"); err == nil { + t.Fatalf("ValidateAbsoluteHTTPURL should reject ftp scheme") + } + if err := ValidateAbsoluteHTTPURL("https://example.com/#frag"); err == nil { + t.Fatalf("ValidateAbsoluteHTTPURL should reject fragment") + } +} + +func TestValidateFrontendRedirectURL(t *testing.T) { + if err := ValidateFrontendRedirectURL("/auth/callback"); err != nil { + t.Fatalf("ValidateFrontendRedirectURL relative error: %v", err) + } + if err := ValidateFrontendRedirectURL("https://example.com/auth"); err != nil { + t.Fatalf("ValidateFrontendRedirectURL absolute error: %v", err) + } + if err := ValidateFrontendRedirectURL("example.com/path"); err == nil { + t.Fatalf("ValidateFrontendRedirectURL should reject non-absolute url") + } + if err := ValidateFrontendRedirectURL("//evil.com"); err == nil { + t.Fatalf("ValidateFrontendRedirectURL should reject // prefix") + } + if err := ValidateFrontendRedirectURL("javascript:alert(1)"); err == nil { + t.Fatalf("ValidateFrontendRedirectURL should reject javascript scheme") + } +} + +func TestWarnIfInsecureURL(t *testing.T) { + warnIfInsecureURL("test", "http://example.com") + warnIfInsecureURL("test", "bad://url") +} + +func TestGenerateJWTSecretDefaultLength(t *testing.T) { + secret, err := generateJWTSecret(0) + if err != nil { + t.Fatalf("generateJWTSecret error: %v", err) + } + if len(secret) == 0 { + t.Fatalf("generateJWTSecret returned empty string") + } +} + +func TestValidateOpsCleanupScheduleRequired(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + cfg.Ops.Cleanup.Enabled = true + cfg.Ops.Cleanup.Schedule = "" + err = cfg.Validate() + if err == nil { + t.Fatalf("Validate() expected error for ops.cleanup.schedule") + } + if !strings.Contains(err.Error(), "ops.cleanup.schedule") { + t.Fatalf("Validate() expected ops.cleanup.schedule error, got: %v", err) + } +} + +func TestValidateConcurrencyPingInterval(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + cfg.Concurrency.PingInterval = 3 + err = cfg.Validate() + if err == nil { + t.Fatalf("Validate() expected error for concurrency.ping_interval") + } + if !strings.Contains(err.Error(), "concurrency.ping_interval") { + t.Fatalf("Validate() expected concurrency.ping_interval error, got: %v", err) + } +} + +func TestProvideConfig(t *testing.T) { + viper.Reset() + if _, err := ProvideConfig(); err != nil { + t.Fatalf("ProvideConfig() error: %v", err) + } +} + +func TestValidateConfigWithLinuxDoEnabled(t *testing.T) { + viper.Reset() + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + cfg.Security.CSP.Enabled = true + cfg.Security.CSP.Policy = "default-src 'self'" + + cfg.LinuxDo.Enabled = true + cfg.LinuxDo.ClientID = "client" + cfg.LinuxDo.ClientSecret = "secret" + cfg.LinuxDo.AuthorizeURL = "https://example.com/oauth2/authorize" + cfg.LinuxDo.TokenURL = "https://example.com/oauth2/token" + cfg.LinuxDo.UserInfoURL = "https://example.com/oauth2/userinfo" + cfg.LinuxDo.RedirectURL = "https://example.com/api/v1/auth/oauth/linuxdo/callback" + cfg.LinuxDo.FrontendRedirectURL = "/auth/linuxdo/callback" + cfg.LinuxDo.TokenAuthMethod = "client_secret_post" + + if err := cfg.Validate(); err != nil { + t.Fatalf("Validate() unexpected error: %v", err) + } +} + +func TestValidateJWTSecretStrength(t *testing.T) { + if !isWeakJWTSecret("change-me-in-production") { + t.Fatalf("isWeakJWTSecret should detect weak secret") + } + if isWeakJWTSecret("StrongSecretValue") { + t.Fatalf("isWeakJWTSecret should accept strong secret") + } +} + +func TestGenerateJWTSecretWithLength(t *testing.T) { + secret, err := generateJWTSecret(16) + if err != nil { + t.Fatalf("generateJWTSecret error: %v", err) + } + if len(secret) == 0 { + t.Fatalf("generateJWTSecret returned empty string") + } +} + +func TestValidateAbsoluteHTTPURLMissingHost(t *testing.T) { + if err := ValidateAbsoluteHTTPURL("https://"); err == nil { + t.Fatalf("ValidateAbsoluteHTTPURL should reject missing host") + } +} + +func TestValidateFrontendRedirectURLInvalidChars(t *testing.T) { + if err := ValidateFrontendRedirectURL("/auth/\ncallback"); err == nil { + t.Fatalf("ValidateFrontendRedirectURL should reject invalid chars") + } + if err := ValidateFrontendRedirectURL("http://"); err == nil { + t.Fatalf("ValidateFrontendRedirectURL should reject missing host") + } + if err := ValidateFrontendRedirectURL("mailto:user@example.com"); err == nil { + t.Fatalf("ValidateFrontendRedirectURL should reject mailto") + } +} + +func TestWarnIfInsecureURLHTTPS(t *testing.T) { + warnIfInsecureURL("secure", "https://example.com") +} + +func TestValidateConfigErrors(t *testing.T) { + buildValid := func(t *testing.T) *Config { + t.Helper() + viper.Reset() + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + return cfg + } + + cases := []struct { + name string + mutate func(*Config) + wantErr string + }{ + { + name: "jwt expire hour positive", + mutate: func(c *Config) { c.JWT.ExpireHour = 0 }, + wantErr: "jwt.expire_hour must be positive", + }, + { + name: "jwt expire hour max", + mutate: func(c *Config) { c.JWT.ExpireHour = 200 }, + wantErr: "jwt.expire_hour must be <= 168", + }, + { + name: "csp policy required", + mutate: func(c *Config) { c.Security.CSP.Enabled = true; c.Security.CSP.Policy = "" }, + wantErr: "security.csp.policy", + }, + { + name: "linuxdo client id required", + mutate: func(c *Config) { + c.LinuxDo.Enabled = true + c.LinuxDo.ClientID = "" + }, + wantErr: "linuxdo_connect.client_id", + }, + { + name: "linuxdo token auth method", + mutate: func(c *Config) { + c.LinuxDo.Enabled = true + c.LinuxDo.ClientID = "client" + c.LinuxDo.ClientSecret = "secret" + c.LinuxDo.AuthorizeURL = "https://example.com/authorize" + c.LinuxDo.TokenURL = "https://example.com/token" + c.LinuxDo.UserInfoURL = "https://example.com/userinfo" + c.LinuxDo.RedirectURL = "https://example.com/callback" + c.LinuxDo.FrontendRedirectURL = "/auth/callback" + c.LinuxDo.TokenAuthMethod = "invalid" + }, + wantErr: "linuxdo_connect.token_auth_method", + }, + { + name: "billing circuit breaker threshold", + mutate: func(c *Config) { c.Billing.CircuitBreaker.FailureThreshold = 0 }, + wantErr: "billing.circuit_breaker.failure_threshold", + }, + { + name: "billing circuit breaker reset", + mutate: func(c *Config) { c.Billing.CircuitBreaker.ResetTimeoutSeconds = 0 }, + wantErr: "billing.circuit_breaker.reset_timeout_seconds", + }, + { + name: "billing circuit breaker half open", + mutate: func(c *Config) { c.Billing.CircuitBreaker.HalfOpenRequests = 0 }, + wantErr: "billing.circuit_breaker.half_open_requests", + }, + { + name: "database max open conns", + mutate: func(c *Config) { c.Database.MaxOpenConns = 0 }, + wantErr: "database.max_open_conns", + }, + { + name: "database max lifetime", + mutate: func(c *Config) { c.Database.ConnMaxLifetimeMinutes = -1 }, + wantErr: "database.conn_max_lifetime_minutes", + }, + { + name: "database idle exceeds open", + mutate: func(c *Config) { c.Database.MaxIdleConns = c.Database.MaxOpenConns + 1 }, + wantErr: "database.max_idle_conns cannot exceed", + }, + { + name: "redis dial timeout", + mutate: func(c *Config) { c.Redis.DialTimeoutSeconds = 0 }, + wantErr: "redis.dial_timeout_seconds", + }, + { + name: "redis read timeout", + mutate: func(c *Config) { c.Redis.ReadTimeoutSeconds = 0 }, + wantErr: "redis.read_timeout_seconds", + }, + { + name: "redis write timeout", + mutate: func(c *Config) { c.Redis.WriteTimeoutSeconds = 0 }, + wantErr: "redis.write_timeout_seconds", + }, + { + name: "redis pool size", + mutate: func(c *Config) { c.Redis.PoolSize = 0 }, + wantErr: "redis.pool_size", + }, + { + name: "redis idle exceeds pool", + mutate: func(c *Config) { c.Redis.MinIdleConns = c.Redis.PoolSize + 1 }, + wantErr: "redis.min_idle_conns cannot exceed", + }, + { + name: "dashboard cache disabled negative", + mutate: func(c *Config) { c.Dashboard.Enabled = false; c.Dashboard.StatsTTLSeconds = -1 }, + wantErr: "dashboard_cache.stats_ttl_seconds", + }, + { + name: "dashboard cache fresh ttl positive", + mutate: func(c *Config) { c.Dashboard.Enabled = true; c.Dashboard.StatsFreshTTLSeconds = 0 }, + wantErr: "dashboard_cache.stats_fresh_ttl_seconds", + }, + { + name: "dashboard aggregation enabled interval", + mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.IntervalSeconds = 0 }, + wantErr: "dashboard_aggregation.interval_seconds", + }, + { + name: "dashboard aggregation backfill positive", + mutate: func(c *Config) { + c.DashboardAgg.Enabled = true + c.DashboardAgg.BackfillEnabled = true + c.DashboardAgg.BackfillMaxDays = 0 + }, + wantErr: "dashboard_aggregation.backfill_max_days", + }, + { + name: "dashboard aggregation retention", + mutate: func(c *Config) { c.DashboardAgg.Enabled = true; c.DashboardAgg.Retention.UsageLogsDays = 0 }, + wantErr: "dashboard_aggregation.retention.usage_logs_days", + }, + { + name: "dashboard aggregation disabled interval", + mutate: func(c *Config) { c.DashboardAgg.Enabled = false; c.DashboardAgg.IntervalSeconds = -1 }, + wantErr: "dashboard_aggregation.interval_seconds", + }, + { + name: "usage cleanup max range", + mutate: func(c *Config) { c.UsageCleanup.Enabled = true; c.UsageCleanup.MaxRangeDays = 0 }, + wantErr: "usage_cleanup.max_range_days", + }, + { + name: "usage cleanup worker interval", + mutate: func(c *Config) { c.UsageCleanup.Enabled = true; c.UsageCleanup.WorkerIntervalSeconds = 0 }, + wantErr: "usage_cleanup.worker_interval_seconds", + }, + { + name: "usage cleanup batch size", + mutate: func(c *Config) { c.UsageCleanup.Enabled = true; c.UsageCleanup.BatchSize = 0 }, + wantErr: "usage_cleanup.batch_size", + }, + { + name: "usage cleanup disabled negative", + mutate: func(c *Config) { c.UsageCleanup.Enabled = false; c.UsageCleanup.BatchSize = -1 }, + wantErr: "usage_cleanup.batch_size", + }, + { + name: "gateway max body size", + mutate: func(c *Config) { c.Gateway.MaxBodySize = 0 }, + wantErr: "gateway.max_body_size", + }, + { + name: "gateway max idle conns", + mutate: func(c *Config) { c.Gateway.MaxIdleConns = 0 }, + wantErr: "gateway.max_idle_conns", + }, + { + name: "gateway max idle conns per host", + mutate: func(c *Config) { c.Gateway.MaxIdleConnsPerHost = 0 }, + wantErr: "gateway.max_idle_conns_per_host", + }, + { + name: "gateway idle timeout", + mutate: func(c *Config) { c.Gateway.IdleConnTimeoutSeconds = 0 }, + wantErr: "gateway.idle_conn_timeout_seconds", + }, + { + name: "gateway max upstream clients", + mutate: func(c *Config) { c.Gateway.MaxUpstreamClients = 0 }, + wantErr: "gateway.max_upstream_clients", + }, + { + name: "gateway client idle ttl", + mutate: func(c *Config) { c.Gateway.ClientIdleTTLSeconds = 0 }, + wantErr: "gateway.client_idle_ttl_seconds", + }, + { + name: "gateway concurrency slot ttl", + mutate: func(c *Config) { c.Gateway.ConcurrencySlotTTLMinutes = 0 }, + wantErr: "gateway.concurrency_slot_ttl_minutes", + }, + { + name: "gateway max conns per host", + mutate: func(c *Config) { c.Gateway.MaxConnsPerHost = -1 }, + wantErr: "gateway.max_conns_per_host", + }, + { + name: "gateway connection isolation", + mutate: func(c *Config) { c.Gateway.ConnectionPoolIsolation = "invalid" }, + wantErr: "gateway.connection_pool_isolation", + }, + { + name: "gateway stream keepalive range", + mutate: func(c *Config) { c.Gateway.StreamKeepaliveInterval = 4 }, + wantErr: "gateway.stream_keepalive_interval", + }, + { + name: "gateway stream data interval range", + mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = 5 }, + wantErr: "gateway.stream_data_interval_timeout", + }, + { + name: "gateway stream data interval negative", + mutate: func(c *Config) { c.Gateway.StreamDataIntervalTimeout = -1 }, + wantErr: "gateway.stream_data_interval_timeout must be non-negative", + }, + { + name: "gateway max line size", + mutate: func(c *Config) { c.Gateway.MaxLineSize = 1024 }, + wantErr: "gateway.max_line_size must be at least", + }, + { + name: "gateway max line size negative", + mutate: func(c *Config) { c.Gateway.MaxLineSize = -1 }, + wantErr: "gateway.max_line_size must be non-negative", + }, + { + name: "gateway scheduling sticky waiting", + mutate: func(c *Config) { c.Gateway.Scheduling.StickySessionMaxWaiting = 0 }, + wantErr: "gateway.scheduling.sticky_session_max_waiting", + }, + { + name: "gateway scheduling outbox poll", + mutate: func(c *Config) { c.Gateway.Scheduling.OutboxPollIntervalSeconds = 0 }, + wantErr: "gateway.scheduling.outbox_poll_interval_seconds", + }, + { + name: "gateway scheduling outbox failures", + mutate: func(c *Config) { c.Gateway.Scheduling.OutboxLagRebuildFailures = 0 }, + wantErr: "gateway.scheduling.outbox_lag_rebuild_failures", + }, + { + name: "gateway outbox lag rebuild", + mutate: func(c *Config) { + c.Gateway.Scheduling.OutboxLagWarnSeconds = 10 + c.Gateway.Scheduling.OutboxLagRebuildSeconds = 5 + }, + wantErr: "gateway.scheduling.outbox_lag_rebuild_seconds", + }, + { + name: "ops metrics collector ttl", + mutate: func(c *Config) { c.Ops.MetricsCollectorCache.TTL = -1 }, + wantErr: "ops.metrics_collector_cache.ttl", + }, + { + name: "ops cleanup retention", + mutate: func(c *Config) { c.Ops.Cleanup.ErrorLogRetentionDays = -1 }, + wantErr: "ops.cleanup.error_log_retention_days", + }, + { + name: "ops cleanup minute retention", + mutate: func(c *Config) { c.Ops.Cleanup.MinuteMetricsRetentionDays = -1 }, + wantErr: "ops.cleanup.minute_metrics_retention_days", + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + cfg := buildValid(t) + tt.mutate(cfg) + err := cfg.Validate() + if err == nil || !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("Validate() error = %v, want %q", err, tt.wantErr) + } + }) + } +} diff --git a/backend/internal/handler/admin/account_handler.go b/backend/internal/handler/admin/account_handler.go index 33c91dae..9cc2540d 100644 --- a/backend/internal/handler/admin/account_handler.go +++ b/backend/internal/handler/admin/account_handler.go @@ -45,6 +45,7 @@ type AccountHandler struct { concurrencyService *service.ConcurrencyService crsSyncService *service.CRSSyncService sessionLimitCache service.SessionLimitCache + tokenCacheInvalidator service.TokenCacheInvalidator } // NewAccountHandler creates a new admin account handler @@ -60,6 +61,7 @@ func NewAccountHandler( concurrencyService *service.ConcurrencyService, crsSyncService *service.CRSSyncService, sessionLimitCache service.SessionLimitCache, + tokenCacheInvalidator service.TokenCacheInvalidator, ) *AccountHandler { return &AccountHandler{ adminService: adminService, @@ -73,6 +75,7 @@ func NewAccountHandler( concurrencyService: concurrencyService, crsSyncService: crsSyncService, sessionLimitCache: sessionLimitCache, + tokenCacheInvalidator: tokenCacheInvalidator, } } @@ -173,6 +176,7 @@ func (h *AccountHandler) List(c *gin.Context) { // 识别需要查询窗口费用和会话数的账号(Anthropic OAuth/SetupToken 且启用了相应功能) windowCostAccountIDs := make([]int64, 0) sessionLimitAccountIDs := make([]int64, 0) + sessionIdleTimeouts := make(map[int64]time.Duration) // 各账号的会话空闲超时配置 for i := range accounts { acc := &accounts[i] if acc.IsAnthropicOAuthOrSetupToken() { @@ -181,6 +185,7 @@ func (h *AccountHandler) List(c *gin.Context) { } if acc.GetMaxSessions() > 0 { sessionLimitAccountIDs = append(sessionLimitAccountIDs, acc.ID) + sessionIdleTimeouts[acc.ID] = time.Duration(acc.GetSessionIdleTimeoutMinutes()) * time.Minute } } } @@ -189,9 +194,9 @@ func (h *AccountHandler) List(c *gin.Context) { var windowCosts map[int64]float64 var activeSessions map[int64]int - // 获取活跃会话数(批量查询) + // 获取活跃会话数(批量查询,传入各账号的 idleTimeout 配置) if len(sessionLimitAccountIDs) > 0 && h.sessionLimitCache != nil { - activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs) + activeSessions, _ = h.sessionLimitCache.GetActiveSessionCountBatch(c.Request.Context(), sessionLimitAccountIDs, sessionIdleTimeouts) if activeSessions == nil { activeSessions = make(map[int64]int) } @@ -211,12 +216,8 @@ func (h *AccountHandler) List(c *gin.Context) { } accCopy := acc // 闭包捕获 g.Go(func() error { - var startTime time.Time - if accCopy.SessionWindowStart != nil { - startTime = *accCopy.SessionWindowStart - } else { - startTime = time.Now().Add(-5 * time.Hour) - } + // 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况) + startTime := accCopy.GetCurrentWindowStartTime() stats, err := h.accountUsageService.GetAccountWindowStats(gctx, accCopy.ID, startTime) if err == nil && stats != nil { mu.Lock() @@ -545,6 +546,36 @@ func (h *AccountHandler) Refresh(c *gin.Context) { newCredentials[k] = v } } + + // 如果 project_id 获取失败,先更新凭证,再标记账户为 error + if tokenInfo.ProjectIDMissing { + // 先更新凭证 + _, updateErr := h.adminService.UpdateAccount(c.Request.Context(), accountID, &service.UpdateAccountInput{ + Credentials: newCredentials, + }) + if updateErr != nil { + response.InternalError(c, "Failed to update credentials: "+updateErr.Error()) + return + } + // 标记账户为 error + if setErr := h.adminService.SetAccountError(c.Request.Context(), accountID, "missing_project_id: 账户缺少project id,可能无法使用Antigravity"); setErr != nil { + response.InternalError(c, "Failed to set account error: "+setErr.Error()) + return + } + response.Success(c, gin.H{ + "message": "Token refreshed but project_id is missing, account marked as error", + "warning": "missing_project_id", + }) + return + } + + // 成功获取到 project_id,如果之前是 missing_project_id 错误则清除 + if account.Status == service.StatusError && strings.Contains(account.ErrorMessage, "missing_project_id:") { + if _, clearErr := h.adminService.ClearAccountError(c.Request.Context(), accountID); clearErr != nil { + response.InternalError(c, "Failed to clear account error: "+clearErr.Error()) + return + } + } } else { // Use Anthropic/Claude OAuth service to refresh token tokenInfo, err := h.oauthService.RefreshAccountToken(c.Request.Context(), account) @@ -580,6 +611,14 @@ func (h *AccountHandler) Refresh(c *gin.Context) { return } + // 刷新成功后,清除 token 缓存,确保下次请求使用新 token + if h.tokenCacheInvalidator != nil { + if invalidateErr := h.tokenCacheInvalidator.InvalidateToken(c.Request.Context(), updatedAccount); invalidateErr != nil { + // 缓存失效失败只记录日志,不影响主流程 + _ = c.Error(invalidateErr) + } + } + response.Success(c, dto.AccountFromService(updatedAccount)) } diff --git a/backend/internal/handler/admin/admin_basic_handlers_test.go b/backend/internal/handler/admin/admin_basic_handlers_test.go new file mode 100644 index 00000000..e0f731e1 --- /dev/null +++ b/backend/internal/handler/admin/admin_basic_handlers_test.go @@ -0,0 +1,262 @@ +package admin + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func setupAdminRouter() (*gin.Engine, *stubAdminService) { + gin.SetMode(gin.TestMode) + router := gin.New() + adminSvc := newStubAdminService() + + userHandler := NewUserHandler(adminSvc) + groupHandler := NewGroupHandler(adminSvc) + proxyHandler := NewProxyHandler(adminSvc) + redeemHandler := NewRedeemHandler(adminSvc) + + router.GET("/api/v1/admin/users", userHandler.List) + router.GET("/api/v1/admin/users/:id", userHandler.GetByID) + router.POST("/api/v1/admin/users", userHandler.Create) + router.PUT("/api/v1/admin/users/:id", userHandler.Update) + router.DELETE("/api/v1/admin/users/:id", userHandler.Delete) + router.POST("/api/v1/admin/users/:id/balance", userHandler.UpdateBalance) + router.GET("/api/v1/admin/users/:id/api-keys", userHandler.GetUserAPIKeys) + router.GET("/api/v1/admin/users/:id/usage", userHandler.GetUserUsage) + + router.GET("/api/v1/admin/groups", groupHandler.List) + router.GET("/api/v1/admin/groups/all", groupHandler.GetAll) + router.GET("/api/v1/admin/groups/:id", groupHandler.GetByID) + router.POST("/api/v1/admin/groups", groupHandler.Create) + router.PUT("/api/v1/admin/groups/:id", groupHandler.Update) + router.DELETE("/api/v1/admin/groups/:id", groupHandler.Delete) + router.GET("/api/v1/admin/groups/:id/stats", groupHandler.GetStats) + router.GET("/api/v1/admin/groups/:id/api-keys", groupHandler.GetGroupAPIKeys) + + router.GET("/api/v1/admin/proxies", proxyHandler.List) + router.GET("/api/v1/admin/proxies/all", proxyHandler.GetAll) + router.GET("/api/v1/admin/proxies/:id", proxyHandler.GetByID) + router.POST("/api/v1/admin/proxies", proxyHandler.Create) + router.PUT("/api/v1/admin/proxies/:id", proxyHandler.Update) + router.DELETE("/api/v1/admin/proxies/:id", proxyHandler.Delete) + router.POST("/api/v1/admin/proxies/batch-delete", proxyHandler.BatchDelete) + router.POST("/api/v1/admin/proxies/:id/test", proxyHandler.Test) + router.GET("/api/v1/admin/proxies/:id/stats", proxyHandler.GetStats) + router.GET("/api/v1/admin/proxies/:id/accounts", proxyHandler.GetProxyAccounts) + + router.GET("/api/v1/admin/redeem-codes", redeemHandler.List) + router.GET("/api/v1/admin/redeem-codes/:id", redeemHandler.GetByID) + router.POST("/api/v1/admin/redeem-codes", redeemHandler.Generate) + router.DELETE("/api/v1/admin/redeem-codes/:id", redeemHandler.Delete) + router.POST("/api/v1/admin/redeem-codes/batch-delete", redeemHandler.BatchDelete) + router.POST("/api/v1/admin/redeem-codes/:id/expire", redeemHandler.Expire) + router.GET("/api/v1/admin/redeem-codes/:id/stats", redeemHandler.GetStats) + + return router, adminSvc +} + +func TestUserHandlerEndpoints(t *testing.T) { + router, _ := setupAdminRouter() + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/users?page=1&page_size=20", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + createBody := map[string]any{"email": "new@example.com", "password": "pass123", "balance": 1, "concurrency": 2} + body, _ := json.Marshal(createBody) + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + updateBody := map[string]any{"email": "updated@example.com"} + body, _ = json.Marshal(updateBody) + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPut, "/api/v1/admin/users/1", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/users/1", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/users/1/balance", bytes.NewBufferString(`{"balance":1,"operation":"add"}`)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1/api-keys", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/users/1/usage?period=today", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) +} + +func TestGroupHandlerEndpoints(t *testing.T) { + router, _ := setupAdminRouter() + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/all", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/2", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + body, _ := json.Marshal(map[string]any{"name": "new", "platform": "anthropic", "subscription_type": "standard"}) + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/groups", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + body, _ = json.Marshal(map[string]any{"name": "update"}) + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPut, "/api/v1/admin/groups/2", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/groups/2", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/2/stats", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/groups/2/api-keys", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) +} + +func TestProxyHandlerEndpoints(t *testing.T) { + router, _ := setupAdminRouter() + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/all", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + body, _ := json.Marshal(map[string]any{"name": "proxy", "protocol": "http", "host": "localhost", "port": 8080}) + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + body, _ = json.Marshal(map[string]any{"name": "proxy2"}) + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPut, "/api/v1/admin/proxies/4", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/proxies/4", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/batch-delete", bytes.NewBufferString(`{"ids":[1,2]}`)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/proxies/4/test", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/stats", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/proxies/4/accounts", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) +} + +func TestRedeemHandlerEndpoints(t *testing.T) { + router, _ := setupAdminRouter() + + rec := httptest.NewRecorder() + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes/5", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + body, _ := json.Marshal(map[string]any{"count": 1, "type": "balance", "value": 10}) + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodDelete, "/api/v1/admin/redeem-codes/5", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/batch-delete", bytes.NewBufferString(`{"ids":[1,2]}`)) + req.Header.Set("Content-Type", "application/json") + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodPost, "/api/v1/admin/redeem-codes/5/expire", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) + + rec = httptest.NewRecorder() + req = httptest.NewRequest(http.MethodGet, "/api/v1/admin/redeem-codes/5/stats", nil) + router.ServeHTTP(rec, req) + require.Equal(t, http.StatusOK, rec.Code) +} diff --git a/backend/internal/handler/admin/admin_helpers_test.go b/backend/internal/handler/admin/admin_helpers_test.go new file mode 100644 index 00000000..863c755c --- /dev/null +++ b/backend/internal/handler/admin/admin_helpers_test.go @@ -0,0 +1,134 @@ +package admin + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "net/netip" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestParseTimeRange(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + req := httptest.NewRequest(http.MethodGet, "/?start_date=2024-01-01&end_date=2024-01-02&timezone=UTC", nil) + c.Request = req + + start, end := parseTimeRange(c) + require.Equal(t, time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), start) + require.Equal(t, time.Date(2024, 1, 3, 0, 0, 0, 0, time.UTC), end) + + req = httptest.NewRequest(http.MethodGet, "/?start_date=bad&timezone=UTC", nil) + c.Request = req + start, end = parseTimeRange(c) + require.False(t, start.IsZero()) + require.False(t, end.IsZero()) +} + +func TestParseOpsViewParam(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/?view=excluded", nil) + require.Equal(t, opsListViewExcluded, parseOpsViewParam(c)) + + c2, _ := gin.CreateTestContext(w) + c2.Request = httptest.NewRequest(http.MethodGet, "/?view=all", nil) + require.Equal(t, opsListViewAll, parseOpsViewParam(c2)) + + c3, _ := gin.CreateTestContext(w) + c3.Request = httptest.NewRequest(http.MethodGet, "/?view=unknown", nil) + require.Equal(t, opsListViewErrors, parseOpsViewParam(c3)) + + require.Equal(t, "", parseOpsViewParam(nil)) +} + +func TestParseOpsDuration(t *testing.T) { + dur, ok := parseOpsDuration("1h") + require.True(t, ok) + require.Equal(t, time.Hour, dur) + + _, ok = parseOpsDuration("invalid") + require.False(t, ok) +} + +func TestParseOpsTimeRange(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + now := time.Now().UTC() + startStr := now.Add(-time.Hour).Format(time.RFC3339) + endStr := now.Format(time.RFC3339) + c.Request = httptest.NewRequest(http.MethodGet, "/?start_time="+startStr+"&end_time="+endStr, nil) + start, end, err := parseOpsTimeRange(c, "1h") + require.NoError(t, err) + require.True(t, start.Before(end)) + + c2, _ := gin.CreateTestContext(w) + c2.Request = httptest.NewRequest(http.MethodGet, "/?start_time=bad", nil) + _, _, err = parseOpsTimeRange(c2, "1h") + require.Error(t, err) +} + +func TestParseOpsRealtimeWindow(t *testing.T) { + dur, label, ok := parseOpsRealtimeWindow("5m") + require.True(t, ok) + require.Equal(t, 5*time.Minute, dur) + require.Equal(t, "5min", label) + + _, _, ok = parseOpsRealtimeWindow("invalid") + require.False(t, ok) +} + +func TestPickThroughputBucketSeconds(t *testing.T) { + require.Equal(t, 60, pickThroughputBucketSeconds(30*time.Minute)) + require.Equal(t, 300, pickThroughputBucketSeconds(6*time.Hour)) + require.Equal(t, 3600, pickThroughputBucketSeconds(48*time.Hour)) +} + +func TestParseOpsQueryMode(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/?mode=raw", nil) + require.Equal(t, service.ParseOpsQueryMode("raw"), parseOpsQueryMode(c)) + require.Equal(t, service.OpsQueryMode(""), parseOpsQueryMode(nil)) +} + +func TestOpsAlertRuleValidation(t *testing.T) { + raw := map[string]json.RawMessage{ + "name": json.RawMessage(`"High error rate"`), + "metric_type": json.RawMessage(`"error_rate"`), + "operator": json.RawMessage(`">"`), + "threshold": json.RawMessage(`90`), + } + + validated, err := validateOpsAlertRulePayload(raw) + require.NoError(t, err) + require.Equal(t, "High error rate", validated.Name) + + _, err = validateOpsAlertRulePayload(map[string]json.RawMessage{}) + require.Error(t, err) + + require.True(t, isPercentOrRateMetric("error_rate")) + require.False(t, isPercentOrRateMetric("concurrency_queue_depth")) +} + +func TestOpsWSHelpers(t *testing.T) { + prefixes, invalid := parseTrustedProxyList("10.0.0.0/8,invalid") + require.Len(t, prefixes, 1) + require.Len(t, invalid, 1) + + host := hostWithoutPort("example.com:443") + require.Equal(t, "example.com", host) + + addr := netip.MustParseAddr("10.0.0.1") + require.True(t, isAddrInTrustedProxies(addr, prefixes)) + require.False(t, isAddrInTrustedProxies(netip.MustParseAddr("192.168.0.1"), prefixes)) +} diff --git a/backend/internal/handler/admin/admin_service_stub_test.go b/backend/internal/handler/admin/admin_service_stub_test.go new file mode 100644 index 00000000..b820a3fb --- /dev/null +++ b/backend/internal/handler/admin/admin_service_stub_test.go @@ -0,0 +1,294 @@ +package admin + +import ( + "context" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" +) + +type stubAdminService struct { + users []service.User + apiKeys []service.APIKey + groups []service.Group + accounts []service.Account + proxies []service.Proxy + proxyCounts []service.ProxyWithAccountCount + redeems []service.RedeemCode +} + +func newStubAdminService() *stubAdminService { + now := time.Now().UTC() + user := service.User{ + ID: 1, + Email: "user@example.com", + Role: service.RoleUser, + Status: service.StatusActive, + CreatedAt: now, + UpdatedAt: now, + } + apiKey := service.APIKey{ + ID: 10, + UserID: user.ID, + Key: "sk-test", + Name: "test", + Status: service.StatusActive, + CreatedAt: now, + UpdatedAt: now, + } + group := service.Group{ + ID: 2, + Name: "group", + Platform: service.PlatformAnthropic, + Status: service.StatusActive, + CreatedAt: now, + UpdatedAt: now, + } + account := service.Account{ + ID: 3, + Name: "account", + Platform: service.PlatformAnthropic, + Type: service.AccountTypeOAuth, + Status: service.StatusActive, + CreatedAt: now, + UpdatedAt: now, + } + proxy := service.Proxy{ + ID: 4, + Name: "proxy", + Protocol: "http", + Host: "127.0.0.1", + Port: 8080, + Status: service.StatusActive, + CreatedAt: now, + UpdatedAt: now, + } + redeem := service.RedeemCode{ + ID: 5, + Code: "R-TEST", + Type: service.RedeemTypeBalance, + Value: 10, + Status: service.StatusUnused, + CreatedAt: now, + } + return &stubAdminService{ + users: []service.User{user}, + apiKeys: []service.APIKey{apiKey}, + groups: []service.Group{group}, + accounts: []service.Account{account}, + proxies: []service.Proxy{proxy}, + proxyCounts: []service.ProxyWithAccountCount{{Proxy: proxy, AccountCount: 1}}, + redeems: []service.RedeemCode{redeem}, + } +} + +func (s *stubAdminService) ListUsers(ctx context.Context, page, pageSize int, filters service.UserListFilters) ([]service.User, int64, error) { + return s.users, int64(len(s.users)), nil +} + +func (s *stubAdminService) GetUser(ctx context.Context, id int64) (*service.User, error) { + for i := range s.users { + if s.users[i].ID == id { + return &s.users[i], nil + } + } + user := service.User{ID: id, Email: "user@example.com", Status: service.StatusActive} + return &user, nil +} + +func (s *stubAdminService) CreateUser(ctx context.Context, input *service.CreateUserInput) (*service.User, error) { + user := service.User{ID: 100, Email: input.Email, Status: service.StatusActive} + return &user, nil +} + +func (s *stubAdminService) UpdateUser(ctx context.Context, id int64, input *service.UpdateUserInput) (*service.User, error) { + user := service.User{ID: id, Email: "updated@example.com", Status: service.StatusActive} + return &user, nil +} + +func (s *stubAdminService) DeleteUser(ctx context.Context, id int64) error { + return nil +} + +func (s *stubAdminService) UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*service.User, error) { + user := service.User{ID: userID, Balance: balance, Status: service.StatusActive} + return &user, nil +} + +func (s *stubAdminService) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]service.APIKey, int64, error) { + return s.apiKeys, int64(len(s.apiKeys)), nil +} + +func (s *stubAdminService) GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error) { + return map[string]any{"user_id": userID}, nil +} + +func (s *stubAdminService) ListGroups(ctx context.Context, page, pageSize int, platform, status, search string, isExclusive *bool) ([]service.Group, int64, error) { + return s.groups, int64(len(s.groups)), nil +} + +func (s *stubAdminService) GetAllGroups(ctx context.Context) ([]service.Group, error) { + return s.groups, nil +} + +func (s *stubAdminService) GetAllGroupsByPlatform(ctx context.Context, platform string) ([]service.Group, error) { + return s.groups, nil +} + +func (s *stubAdminService) GetGroup(ctx context.Context, id int64) (*service.Group, error) { + group := service.Group{ID: id, Name: "group", Status: service.StatusActive} + return &group, nil +} + +func (s *stubAdminService) CreateGroup(ctx context.Context, input *service.CreateGroupInput) (*service.Group, error) { + group := service.Group{ID: 200, Name: input.Name, Status: service.StatusActive} + return &group, nil +} + +func (s *stubAdminService) UpdateGroup(ctx context.Context, id int64, input *service.UpdateGroupInput) (*service.Group, error) { + group := service.Group{ID: id, Name: input.Name, Status: service.StatusActive} + return &group, nil +} + +func (s *stubAdminService) DeleteGroup(ctx context.Context, id int64) error { + return nil +} + +func (s *stubAdminService) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]service.APIKey, int64, error) { + return s.apiKeys, int64(len(s.apiKeys)), nil +} + +func (s *stubAdminService) ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]service.Account, int64, error) { + return s.accounts, int64(len(s.accounts)), nil +} + +func (s *stubAdminService) GetAccount(ctx context.Context, id int64) (*service.Account, error) { + account := service.Account{ID: id, Name: "account", Status: service.StatusActive} + return &account, nil +} + +func (s *stubAdminService) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*service.Account, error) { + out := make([]*service.Account, 0, len(ids)) + for _, id := range ids { + account := service.Account{ID: id, Name: "account", Status: service.StatusActive} + out = append(out, &account) + } + return out, nil +} + +func (s *stubAdminService) CreateAccount(ctx context.Context, input *service.CreateAccountInput) (*service.Account, error) { + account := service.Account{ID: 300, Name: input.Name, Status: service.StatusActive} + return &account, nil +} + +func (s *stubAdminService) UpdateAccount(ctx context.Context, id int64, input *service.UpdateAccountInput) (*service.Account, error) { + account := service.Account{ID: id, Name: input.Name, Status: service.StatusActive} + return &account, nil +} + +func (s *stubAdminService) DeleteAccount(ctx context.Context, id int64) error { + return nil +} + +func (s *stubAdminService) RefreshAccountCredentials(ctx context.Context, id int64) (*service.Account, error) { + account := service.Account{ID: id, Name: "account", Status: service.StatusActive} + return &account, nil +} + +func (s *stubAdminService) ClearAccountError(ctx context.Context, id int64) (*service.Account, error) { + account := service.Account{ID: id, Name: "account", Status: service.StatusActive} + return &account, nil +} + +func (s *stubAdminService) SetAccountError(ctx context.Context, id int64, errorMsg string) error { + return nil +} + +func (s *stubAdminService) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*service.Account, error) { + account := service.Account{ID: id, Name: "account", Status: service.StatusActive, Schedulable: schedulable} + return &account, nil +} + +func (s *stubAdminService) BulkUpdateAccounts(ctx context.Context, input *service.BulkUpdateAccountsInput) (*service.BulkUpdateAccountsResult, error) { + return &service.BulkUpdateAccountsResult{Success: 1, Failed: 0, SuccessIDs: []int64{1}}, nil +} + +func (s *stubAdminService) ListProxies(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.Proxy, int64, error) { + return s.proxies, int64(len(s.proxies)), nil +} + +func (s *stubAdminService) ListProxiesWithAccountCount(ctx context.Context, page, pageSize int, protocol, status, search string) ([]service.ProxyWithAccountCount, int64, error) { + return s.proxyCounts, int64(len(s.proxyCounts)), nil +} + +func (s *stubAdminService) GetAllProxies(ctx context.Context) ([]service.Proxy, error) { + return s.proxies, nil +} + +func (s *stubAdminService) GetAllProxiesWithAccountCount(ctx context.Context) ([]service.ProxyWithAccountCount, error) { + return s.proxyCounts, nil +} + +func (s *stubAdminService) GetProxy(ctx context.Context, id int64) (*service.Proxy, error) { + proxy := service.Proxy{ID: id, Name: "proxy", Status: service.StatusActive} + return &proxy, nil +} + +func (s *stubAdminService) CreateProxy(ctx context.Context, input *service.CreateProxyInput) (*service.Proxy, error) { + proxy := service.Proxy{ID: 400, Name: input.Name, Status: service.StatusActive} + return &proxy, nil +} + +func (s *stubAdminService) UpdateProxy(ctx context.Context, id int64, input *service.UpdateProxyInput) (*service.Proxy, error) { + proxy := service.Proxy{ID: id, Name: input.Name, Status: service.StatusActive} + return &proxy, nil +} + +func (s *stubAdminService) DeleteProxy(ctx context.Context, id int64) error { + return nil +} + +func (s *stubAdminService) BatchDeleteProxies(ctx context.Context, ids []int64) (*service.ProxyBatchDeleteResult, error) { + return &service.ProxyBatchDeleteResult{DeletedIDs: ids}, nil +} + +func (s *stubAdminService) GetProxyAccounts(ctx context.Context, proxyID int64) ([]service.ProxyAccountSummary, error) { + return []service.ProxyAccountSummary{{ID: 1, Name: "account"}}, nil +} + +func (s *stubAdminService) CheckProxyExists(ctx context.Context, host string, port int, username, password string) (bool, error) { + return false, nil +} + +func (s *stubAdminService) TestProxy(ctx context.Context, id int64) (*service.ProxyTestResult, error) { + return &service.ProxyTestResult{Success: true, Message: "ok"}, nil +} + +func (s *stubAdminService) ListRedeemCodes(ctx context.Context, page, pageSize int, codeType, status, search string) ([]service.RedeemCode, int64, error) { + return s.redeems, int64(len(s.redeems)), nil +} + +func (s *stubAdminService) GetRedeemCode(ctx context.Context, id int64) (*service.RedeemCode, error) { + code := service.RedeemCode{ID: id, Code: "R-TEST", Status: service.StatusUnused} + return &code, nil +} + +func (s *stubAdminService) GenerateRedeemCodes(ctx context.Context, input *service.GenerateRedeemCodesInput) ([]service.RedeemCode, error) { + return s.redeems, nil +} + +func (s *stubAdminService) DeleteRedeemCode(ctx context.Context, id int64) error { + return nil +} + +func (s *stubAdminService) BatchDeleteRedeemCodes(ctx context.Context, ids []int64) (int64, error) { + return int64(len(ids)), nil +} + +func (s *stubAdminService) ExpireRedeemCode(ctx context.Context, id int64) (*service.RedeemCode, error) { + code := service.RedeemCode{ID: id, Code: "R-TEST", Status: service.StatusUsed} + return &code, nil +} + +// Ensure stub implements interface. +var _ service.AdminService = (*stubAdminService)(nil) diff --git a/backend/internal/handler/admin/dashboard_handler.go b/backend/internal/handler/admin/dashboard_handler.go index 3f07403d..18365186 100644 --- a/backend/internal/handler/admin/dashboard_handler.go +++ b/backend/internal/handler/admin/dashboard_handler.go @@ -186,7 +186,7 @@ func (h *DashboardHandler) GetRealtimeMetrics(c *gin.Context) { // GetUsageTrend handles getting usage trend data // GET /api/v1/admin/dashboard/trend -// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream +// Query params: start_date, end_date (YYYY-MM-DD), granularity (day/hour), user_id, api_key_id, model, account_id, group_id, stream, billing_type func (h *DashboardHandler) GetUsageTrend(c *gin.Context) { startTime, endTime := parseTimeRange(c) granularity := c.DefaultQuery("granularity", "day") @@ -195,6 +195,7 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) { var userID, apiKeyID, accountID, groupID int64 var model string var stream *bool + var billingType *int8 if userIDStr := c.Query("user_id"); userIDStr != "" { if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil { @@ -224,8 +225,17 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) { stream = &streamVal } } + if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" { + if v, err := strconv.ParseInt(billingTypeStr, 10, 8); err == nil { + bt := int8(v) + billingType = &bt + } else { + response.BadRequest(c, "Invalid billing_type") + return + } + } - trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream) + trend, err := h.dashboardService.GetUsageTrendWithFilters(c.Request.Context(), startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType) if err != nil { response.Error(c, 500, "Failed to get usage trend") return @@ -241,13 +251,14 @@ func (h *DashboardHandler) GetUsageTrend(c *gin.Context) { // GetModelStats handles getting model usage statistics // GET /api/v1/admin/dashboard/models -// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream +// Query params: start_date, end_date (YYYY-MM-DD), user_id, api_key_id, account_id, group_id, stream, billing_type func (h *DashboardHandler) GetModelStats(c *gin.Context) { startTime, endTime := parseTimeRange(c) // Parse optional filter params var userID, apiKeyID, accountID, groupID int64 var stream *bool + var billingType *int8 if userIDStr := c.Query("user_id"); userIDStr != "" { if id, err := strconv.ParseInt(userIDStr, 10, 64); err == nil { @@ -274,8 +285,17 @@ func (h *DashboardHandler) GetModelStats(c *gin.Context) { stream = &streamVal } } + if billingTypeStr := c.Query("billing_type"); billingTypeStr != "" { + if v, err := strconv.ParseInt(billingTypeStr, 10, 8); err == nil { + bt := int8(v) + billingType = &bt + } else { + response.BadRequest(c, "Invalid billing_type") + return + } + } - stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream) + stats, err := h.dashboardService.GetModelStatsWithFilters(c.Request.Context(), startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType) if err != nil { response.Error(c, 500, "Failed to get model statistics") return diff --git a/backend/internal/handler/admin/group_handler.go b/backend/internal/handler/admin/group_handler.go index f6780dee..926624d2 100644 --- a/backend/internal/handler/admin/group_handler.go +++ b/backend/internal/handler/admin/group_handler.go @@ -94,9 +94,9 @@ func (h *GroupHandler) List(c *gin.Context) { return } - outGroups := make([]dto.Group, 0, len(groups)) + outGroups := make([]dto.AdminGroup, 0, len(groups)) for i := range groups { - outGroups = append(outGroups, *dto.GroupFromService(&groups[i])) + outGroups = append(outGroups, *dto.GroupFromServiceAdmin(&groups[i])) } response.Paginated(c, outGroups, total, page, pageSize) } @@ -120,9 +120,9 @@ func (h *GroupHandler) GetAll(c *gin.Context) { return } - outGroups := make([]dto.Group, 0, len(groups)) + outGroups := make([]dto.AdminGroup, 0, len(groups)) for i := range groups { - outGroups = append(outGroups, *dto.GroupFromService(&groups[i])) + outGroups = append(outGroups, *dto.GroupFromServiceAdmin(&groups[i])) } response.Success(c, outGroups) } @@ -142,7 +142,7 @@ func (h *GroupHandler) GetByID(c *gin.Context) { return } - response.Success(c, dto.GroupFromService(group)) + response.Success(c, dto.GroupFromServiceAdmin(group)) } // Create handles creating a new group @@ -177,7 +177,7 @@ func (h *GroupHandler) Create(c *gin.Context) { return } - response.Success(c, dto.GroupFromService(group)) + response.Success(c, dto.GroupFromServiceAdmin(group)) } // Update handles updating a group @@ -219,7 +219,7 @@ func (h *GroupHandler) Update(c *gin.Context) { return } - response.Success(c, dto.GroupFromService(group)) + response.Success(c, dto.GroupFromServiceAdmin(group)) } // Delete handles deleting a group diff --git a/backend/internal/handler/admin/redeem_handler.go b/backend/internal/handler/admin/redeem_handler.go index 5b3229b6..f1b68334 100644 --- a/backend/internal/handler/admin/redeem_handler.go +++ b/backend/internal/handler/admin/redeem_handler.go @@ -54,9 +54,9 @@ func (h *RedeemHandler) List(c *gin.Context) { return } - out := make([]dto.RedeemCode, 0, len(codes)) + out := make([]dto.AdminRedeemCode, 0, len(codes)) for i := range codes { - out = append(out, *dto.RedeemCodeFromService(&codes[i])) + out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i])) } response.Paginated(c, out, total, page, pageSize) } @@ -76,7 +76,7 @@ func (h *RedeemHandler) GetByID(c *gin.Context) { return } - response.Success(c, dto.RedeemCodeFromService(code)) + response.Success(c, dto.RedeemCodeFromServiceAdmin(code)) } // Generate handles generating new redeem codes @@ -100,9 +100,9 @@ func (h *RedeemHandler) Generate(c *gin.Context) { return } - out := make([]dto.RedeemCode, 0, len(codes)) + out := make([]dto.AdminRedeemCode, 0, len(codes)) for i := range codes { - out = append(out, *dto.RedeemCodeFromService(&codes[i])) + out = append(out, *dto.RedeemCodeFromServiceAdmin(&codes[i])) } response.Success(c, out) } @@ -163,7 +163,7 @@ func (h *RedeemHandler) Expire(c *gin.Context) { return } - response.Success(c, dto.RedeemCodeFromService(code)) + response.Success(c, dto.RedeemCodeFromServiceAdmin(code)) } // GetStats handles getting redeem code statistics diff --git a/backend/internal/handler/admin/setting_handler.go b/backend/internal/handler/admin/setting_handler.go index 6666ce4e..5a543d6c 100644 --- a/backend/internal/handler/admin/setting_handler.go +++ b/backend/internal/handler/admin/setting_handler.go @@ -68,6 +68,7 @@ func (h *SettingHandler) GetSettings(c *gin.Context) { ContactInfo: settings.ContactInfo, DocURL: settings.DocURL, HomeContent: settings.HomeContent, + HideCcsImportButton: settings.HideCcsImportButton, DefaultConcurrency: settings.DefaultConcurrency, DefaultBalance: settings.DefaultBalance, EnableModelFallback: settings.EnableModelFallback, @@ -111,13 +112,14 @@ type UpdateSettingsRequest struct { LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` // OEM设置 - SiteName string `json:"site_name"` - SiteLogo string `json:"site_logo"` - SiteSubtitle string `json:"site_subtitle"` - APIBaseURL string `json:"api_base_url"` - ContactInfo string `json:"contact_info"` - DocURL string `json:"doc_url"` - HomeContent string `json:"home_content"` + SiteName string `json:"site_name"` + SiteLogo string `json:"site_logo"` + SiteSubtitle string `json:"site_subtitle"` + APIBaseURL string `json:"api_base_url"` + ContactInfo string `json:"contact_info"` + DocURL string `json:"doc_url"` + HomeContent string `json:"home_content"` + HideCcsImportButton bool `json:"hide_ccs_import_button"` // 默认配置 DefaultConcurrency int `json:"default_concurrency"` @@ -259,6 +261,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ContactInfo: req.ContactInfo, DocURL: req.DocURL, HomeContent: req.HomeContent, + HideCcsImportButton: req.HideCcsImportButton, DefaultConcurrency: req.DefaultConcurrency, DefaultBalance: req.DefaultBalance, EnableModelFallback: req.EnableModelFallback, @@ -332,6 +335,7 @@ func (h *SettingHandler) UpdateSettings(c *gin.Context) { ContactInfo: updatedSettings.ContactInfo, DocURL: updatedSettings.DocURL, HomeContent: updatedSettings.HomeContent, + HideCcsImportButton: updatedSettings.HideCcsImportButton, DefaultConcurrency: updatedSettings.DefaultConcurrency, DefaultBalance: updatedSettings.DefaultBalance, EnableModelFallback: updatedSettings.EnableModelFallback, @@ -439,6 +443,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings, if before.HomeContent != after.HomeContent { changed = append(changed, "home_content") } + if before.HideCcsImportButton != after.HideCcsImportButton { + changed = append(changed, "hide_ccs_import_button") + } if before.DefaultConcurrency != after.DefaultConcurrency { changed = append(changed, "default_concurrency") } diff --git a/backend/internal/handler/admin/subscription_handler.go b/backend/internal/handler/admin/subscription_handler.go index 08db999a..a0d1456f 100644 --- a/backend/internal/handler/admin/subscription_handler.go +++ b/backend/internal/handler/admin/subscription_handler.go @@ -53,9 +53,9 @@ type BulkAssignSubscriptionRequest struct { Notes string `json:"notes"` } -// ExtendSubscriptionRequest represents extend subscription request -type ExtendSubscriptionRequest struct { - Days int `json:"days" binding:"required,min=1,max=36500"` // max 100 years +// AdjustSubscriptionRequest represents adjust subscription request (extend or shorten) +type AdjustSubscriptionRequest struct { + Days int `json:"days" binding:"required,min=-36500,max=36500"` // negative to shorten, positive to extend } // List handles listing all subscriptions with pagination and filters @@ -83,9 +83,9 @@ func (h *SubscriptionHandler) List(c *gin.Context) { return } - out := make([]dto.UserSubscription, 0, len(subscriptions)) + out := make([]dto.AdminUserSubscription, 0, len(subscriptions)) for i := range subscriptions { - out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i])) + out = append(out, *dto.UserSubscriptionFromServiceAdmin(&subscriptions[i])) } response.PaginatedWithResult(c, out, toResponsePagination(pagination)) } @@ -105,7 +105,7 @@ func (h *SubscriptionHandler) GetByID(c *gin.Context) { return } - response.Success(c, dto.UserSubscriptionFromService(subscription)) + response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription)) } // GetProgress handles getting subscription usage progress @@ -150,7 +150,7 @@ func (h *SubscriptionHandler) Assign(c *gin.Context) { return } - response.Success(c, dto.UserSubscriptionFromService(subscription)) + response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription)) } // BulkAssign handles bulk assigning subscriptions to multiple users @@ -180,7 +180,7 @@ func (h *SubscriptionHandler) BulkAssign(c *gin.Context) { response.Success(c, dto.BulkAssignResultFromService(result)) } -// Extend handles extending a subscription +// Extend handles adjusting a subscription (extend or shorten) // POST /api/v1/admin/subscriptions/:id/extend func (h *SubscriptionHandler) Extend(c *gin.Context) { subscriptionID, err := strconv.ParseInt(c.Param("id"), 10, 64) @@ -189,7 +189,7 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) { return } - var req ExtendSubscriptionRequest + var req AdjustSubscriptionRequest if err := c.ShouldBindJSON(&req); err != nil { response.BadRequest(c, "Invalid request: "+err.Error()) return @@ -201,7 +201,7 @@ func (h *SubscriptionHandler) Extend(c *gin.Context) { return } - response.Success(c, dto.UserSubscriptionFromService(subscription)) + response.Success(c, dto.UserSubscriptionFromServiceAdmin(subscription)) } // Revoke handles revoking a subscription @@ -239,9 +239,9 @@ func (h *SubscriptionHandler) ListByGroup(c *gin.Context) { return } - out := make([]dto.UserSubscription, 0, len(subscriptions)) + out := make([]dto.AdminUserSubscription, 0, len(subscriptions)) for i := range subscriptions { - out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i])) + out = append(out, *dto.UserSubscriptionFromServiceAdmin(&subscriptions[i])) } response.PaginatedWithResult(c, out, toResponsePagination(pagination)) } @@ -261,9 +261,9 @@ func (h *SubscriptionHandler) ListByUser(c *gin.Context) { return } - out := make([]dto.UserSubscription, 0, len(subscriptions)) + out := make([]dto.AdminUserSubscription, 0, len(subscriptions)) for i := range subscriptions { - out = append(out, *dto.UserSubscriptionFromService(&subscriptions[i])) + out = append(out, *dto.UserSubscriptionFromServiceAdmin(&subscriptions[i])) } response.Success(c, out) } diff --git a/backend/internal/handler/admin/usage_cleanup_handler_test.go b/backend/internal/handler/admin/usage_cleanup_handler_test.go new file mode 100644 index 00000000..ed1c7cc2 --- /dev/null +++ b/backend/internal/handler/admin/usage_cleanup_handler_test.go @@ -0,0 +1,377 @@ +package admin + +import ( + "bytes" + "context" + "database/sql" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/handler/dto" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/pkg/response" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +type cleanupRepoStub struct { + mu sync.Mutex + created []*service.UsageCleanupTask + listTasks []service.UsageCleanupTask + listResult *pagination.PaginationResult + listErr error + statusByID map[int64]string +} + +func (s *cleanupRepoStub) CreateTask(ctx context.Context, task *service.UsageCleanupTask) error { + if task == nil { + return nil + } + s.mu.Lock() + defer s.mu.Unlock() + if task.ID == 0 { + task.ID = int64(len(s.created) + 1) + } + if task.CreatedAt.IsZero() { + task.CreatedAt = time.Now().UTC() + } + task.UpdatedAt = task.CreatedAt + clone := *task + s.created = append(s.created, &clone) + return nil +} + +func (s *cleanupRepoStub) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]service.UsageCleanupTask, *pagination.PaginationResult, error) { + s.mu.Lock() + defer s.mu.Unlock() + return s.listTasks, s.listResult, s.listErr +} + +func (s *cleanupRepoStub) ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*service.UsageCleanupTask, error) { + return nil, nil +} + +func (s *cleanupRepoStub) GetTaskStatus(ctx context.Context, taskID int64) (string, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.statusByID == nil { + return "", sql.ErrNoRows + } + status, ok := s.statusByID[taskID] + if !ok { + return "", sql.ErrNoRows + } + return status, nil +} + +func (s *cleanupRepoStub) UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error { + return nil +} + +func (s *cleanupRepoStub) CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.statusByID == nil { + s.statusByID = map[int64]string{} + } + status := s.statusByID[taskID] + if status != service.UsageCleanupStatusPending && status != service.UsageCleanupStatusRunning { + return false, nil + } + s.statusByID[taskID] = service.UsageCleanupStatusCanceled + return true, nil +} + +func (s *cleanupRepoStub) MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error { + return nil +} + +func (s *cleanupRepoStub) MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error { + return nil +} + +func (s *cleanupRepoStub) DeleteUsageLogsBatch(ctx context.Context, filters service.UsageCleanupFilters, limit int) (int64, error) { + return 0, nil +} + +var _ service.UsageCleanupRepository = (*cleanupRepoStub)(nil) + +func setupCleanupRouter(cleanupService *service.UsageCleanupService, userID int64) *gin.Engine { + gin.SetMode(gin.TestMode) + router := gin.New() + if userID > 0 { + router.Use(func(c *gin.Context) { + c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{UserID: userID}) + c.Next() + }) + } + + handler := NewUsageHandler(nil, nil, nil, cleanupService) + router.POST("/api/v1/admin/usage/cleanup-tasks", handler.CreateCleanupTask) + router.GET("/api/v1/admin/usage/cleanup-tasks", handler.ListCleanupTasks) + router.POST("/api/v1/admin/usage/cleanup-tasks/:id/cancel", handler.CancelCleanupTask) + return router +} + +func TestUsageHandlerCreateCleanupTaskUnauthorized(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 0) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewBufferString(`{}`)) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusUnauthorized, recorder.Code) +} + +func TestUsageHandlerCreateCleanupTaskUnavailable(t *testing.T) { + router := setupCleanupRouter(nil, 1) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewBufferString(`{}`)) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusServiceUnavailable, recorder.Code) +} + +func TestUsageHandlerCreateCleanupTaskBindError(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 88) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewBufferString("{bad-json")) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusBadRequest, recorder.Code) +} + +func TestUsageHandlerCreateCleanupTaskMissingRange(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 88) + + payload := map[string]any{ + "start_date": "2024-01-01", + "timezone": "UTC", + } + body, err := json.Marshal(payload) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusBadRequest, recorder.Code) +} + +func TestUsageHandlerCreateCleanupTaskInvalidDate(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 88) + + payload := map[string]any{ + "start_date": "2024-13-01", + "end_date": "2024-01-02", + "timezone": "UTC", + } + body, err := json.Marshal(payload) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusBadRequest, recorder.Code) +} + +func TestUsageHandlerCreateCleanupTaskInvalidEndDate(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 88) + + payload := map[string]any{ + "start_date": "2024-01-01", + "end_date": "2024-02-40", + "timezone": "UTC", + } + body, err := json.Marshal(payload) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusBadRequest, recorder.Code) +} + +func TestUsageHandlerCreateCleanupTaskSuccess(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 99) + + payload := map[string]any{ + "start_date": " 2024-01-01 ", + "end_date": "2024-01-02", + "timezone": "UTC", + "model": "gpt-4", + } + body, err := json.Marshal(payload) + require.NoError(t, err) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks", bytes.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp response.Response + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + + repo.mu.Lock() + defer repo.mu.Unlock() + require.Len(t, repo.created, 1) + created := repo.created[0] + require.Equal(t, int64(99), created.CreatedBy) + require.NotNil(t, created.Filters.Model) + require.Equal(t, "gpt-4", *created.Filters.Model) + + start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + end := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC).Add(24*time.Hour - time.Nanosecond) + require.True(t, created.Filters.StartTime.Equal(start)) + require.True(t, created.Filters.EndTime.Equal(end)) +} + +func TestUsageHandlerListCleanupTasksUnavailable(t *testing.T) { + router := setupCleanupRouter(nil, 0) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/usage/cleanup-tasks", nil) + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusServiceUnavailable, recorder.Code) +} + +func TestUsageHandlerListCleanupTasksSuccess(t *testing.T) { + repo := &cleanupRepoStub{} + repo.listTasks = []service.UsageCleanupTask{ + { + ID: 7, + Status: service.UsageCleanupStatusSucceeded, + CreatedBy: 4, + }, + } + repo.listResult = &pagination.PaginationResult{Total: 1, Page: 1, PageSize: 20, Pages: 1} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 1) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/usage/cleanup-tasks", nil) + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusOK, recorder.Code) + + var resp struct { + Code int `json:"code"` + Data struct { + Items []dto.UsageCleanupTask `json:"items"` + Total int64 `json:"total"` + Page int `json:"page"` + } `json:"data"` + } + require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &resp)) + require.Equal(t, 0, resp.Code) + require.Len(t, resp.Data.Items, 1) + require.Equal(t, int64(7), resp.Data.Items[0].ID) + require.Equal(t, int64(1), resp.Data.Total) + require.Equal(t, 1, resp.Data.Page) +} + +func TestUsageHandlerListCleanupTasksError(t *testing.T) { + repo := &cleanupRepoStub{listErr: errors.New("boom")} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 1) + + req := httptest.NewRequest(http.MethodGet, "/api/v1/admin/usage/cleanup-tasks", nil) + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + require.Equal(t, http.StatusInternalServerError, recorder.Code) +} + +func TestUsageHandlerCancelCleanupTaskUnauthorized(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 0) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/1/cancel", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusUnauthorized, rec.Code) +} + +func TestUsageHandlerCancelCleanupTaskNotFound(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 1) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/999/cancel", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusNotFound, rec.Code) +} + +func TestUsageHandlerCancelCleanupTaskConflict(t *testing.T) { + repo := &cleanupRepoStub{statusByID: map[int64]string{2: service.UsageCleanupStatusSucceeded}} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 1) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/2/cancel", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusConflict, rec.Code) +} + +func TestUsageHandlerCancelCleanupTaskSuccess(t *testing.T) { + repo := &cleanupRepoStub{statusByID: map[int64]string{3: service.UsageCleanupStatusPending}} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}} + cleanupService := service.NewUsageCleanupService(repo, nil, nil, cfg) + router := setupCleanupRouter(cleanupService, 1) + + req := httptest.NewRequest(http.MethodPost, "/api/v1/admin/usage/cleanup-tasks/3/cancel", nil) + rec := httptest.NewRecorder() + router.ServeHTTP(rec, req) + + require.Equal(t, http.StatusOK, rec.Code) +} diff --git a/backend/internal/handler/admin/usage_handler.go b/backend/internal/handler/admin/usage_handler.go index c7b983f1..3f3238dd 100644 --- a/backend/internal/handler/admin/usage_handler.go +++ b/backend/internal/handler/admin/usage_handler.go @@ -1,7 +1,10 @@ package admin import ( + "log" + "net/http" "strconv" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/handler/dto" @@ -9,6 +12,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/pkg/response" "github.com/Wei-Shaw/sub2api/internal/pkg/timezone" "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" + "github.com/Wei-Shaw/sub2api/internal/server/middleware" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" @@ -16,9 +20,10 @@ import ( // UsageHandler handles admin usage-related requests type UsageHandler struct { - usageService *service.UsageService - apiKeyService *service.APIKeyService - adminService service.AdminService + usageService *service.UsageService + apiKeyService *service.APIKeyService + adminService service.AdminService + cleanupService *service.UsageCleanupService } // NewUsageHandler creates a new admin usage handler @@ -26,14 +31,30 @@ func NewUsageHandler( usageService *service.UsageService, apiKeyService *service.APIKeyService, adminService service.AdminService, + cleanupService *service.UsageCleanupService, ) *UsageHandler { return &UsageHandler{ - usageService: usageService, - apiKeyService: apiKeyService, - adminService: adminService, + usageService: usageService, + apiKeyService: apiKeyService, + adminService: adminService, + cleanupService: cleanupService, } } +// CreateUsageCleanupTaskRequest represents cleanup task creation request +type CreateUsageCleanupTaskRequest struct { + StartDate string `json:"start_date"` + EndDate string `json:"end_date"` + UserID *int64 `json:"user_id"` + APIKeyID *int64 `json:"api_key_id"` + AccountID *int64 `json:"account_id"` + GroupID *int64 `json:"group_id"` + Model *string `json:"model"` + Stream *bool `json:"stream"` + BillingType *int8 `json:"billing_type"` + Timezone string `json:"timezone"` +} + // List handles listing all usage records with filters // GET /api/v1/admin/usage func (h *UsageHandler) List(c *gin.Context) { @@ -142,7 +163,7 @@ func (h *UsageHandler) List(c *gin.Context) { return } - out := make([]dto.UsageLog, 0, len(records)) + out := make([]dto.AdminUsageLog, 0, len(records)) for i := range records { out = append(out, *dto.UsageLogFromServiceAdmin(&records[i])) } @@ -344,3 +365,162 @@ func (h *UsageHandler) SearchAPIKeys(c *gin.Context) { response.Success(c, result) } + +// ListCleanupTasks handles listing usage cleanup tasks +// GET /api/v1/admin/usage/cleanup-tasks +func (h *UsageHandler) ListCleanupTasks(c *gin.Context) { + if h.cleanupService == nil { + response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable") + return + } + operator := int64(0) + if subject, ok := middleware.GetAuthSubjectFromContext(c); ok { + operator = subject.UserID + } + page, pageSize := response.ParsePagination(c) + log.Printf("[UsageCleanup] 请求清理任务列表: operator=%d page=%d page_size=%d", operator, page, pageSize) + params := pagination.PaginationParams{Page: page, PageSize: pageSize} + tasks, result, err := h.cleanupService.ListTasks(c.Request.Context(), params) + if err != nil { + log.Printf("[UsageCleanup] 查询清理任务列表失败: operator=%d page=%d page_size=%d err=%v", operator, page, pageSize, err) + response.ErrorFrom(c, err) + return + } + out := make([]dto.UsageCleanupTask, 0, len(tasks)) + for i := range tasks { + out = append(out, *dto.UsageCleanupTaskFromService(&tasks[i])) + } + log.Printf("[UsageCleanup] 返回清理任务列表: operator=%d total=%d items=%d page=%d page_size=%d", operator, result.Total, len(out), page, pageSize) + response.Paginated(c, out, result.Total, page, pageSize) +} + +// CreateCleanupTask handles creating a usage cleanup task +// POST /api/v1/admin/usage/cleanup-tasks +func (h *UsageHandler) CreateCleanupTask(c *gin.Context) { + if h.cleanupService == nil { + response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable") + return + } + subject, ok := middleware.GetAuthSubjectFromContext(c) + if !ok || subject.UserID <= 0 { + response.Unauthorized(c, "Unauthorized") + return + } + + var req CreateUsageCleanupTaskRequest + if err := c.ShouldBindJSON(&req); err != nil { + response.BadRequest(c, "Invalid request: "+err.Error()) + return + } + req.StartDate = strings.TrimSpace(req.StartDate) + req.EndDate = strings.TrimSpace(req.EndDate) + if req.StartDate == "" || req.EndDate == "" { + response.BadRequest(c, "start_date and end_date are required") + return + } + + startTime, err := timezone.ParseInUserLocation("2006-01-02", req.StartDate, req.Timezone) + if err != nil { + response.BadRequest(c, "Invalid start_date format, use YYYY-MM-DD") + return + } + endTime, err := timezone.ParseInUserLocation("2006-01-02", req.EndDate, req.Timezone) + if err != nil { + response.BadRequest(c, "Invalid end_date format, use YYYY-MM-DD") + return + } + endTime = endTime.Add(24*time.Hour - time.Nanosecond) + + filters := service.UsageCleanupFilters{ + StartTime: startTime, + EndTime: endTime, + UserID: req.UserID, + APIKeyID: req.APIKeyID, + AccountID: req.AccountID, + GroupID: req.GroupID, + Model: req.Model, + Stream: req.Stream, + BillingType: req.BillingType, + } + + var userID any + if filters.UserID != nil { + userID = *filters.UserID + } + var apiKeyID any + if filters.APIKeyID != nil { + apiKeyID = *filters.APIKeyID + } + var accountID any + if filters.AccountID != nil { + accountID = *filters.AccountID + } + var groupID any + if filters.GroupID != nil { + groupID = *filters.GroupID + } + var model any + if filters.Model != nil { + model = *filters.Model + } + var stream any + if filters.Stream != nil { + stream = *filters.Stream + } + var billingType any + if filters.BillingType != nil { + billingType = *filters.BillingType + } + + log.Printf("[UsageCleanup] 请求创建清理任务: operator=%d start=%s end=%s user_id=%v api_key_id=%v account_id=%v group_id=%v model=%v stream=%v billing_type=%v tz=%q", + subject.UserID, + filters.StartTime.Format(time.RFC3339), + filters.EndTime.Format(time.RFC3339), + userID, + apiKeyID, + accountID, + groupID, + model, + stream, + billingType, + req.Timezone, + ) + + task, err := h.cleanupService.CreateTask(c.Request.Context(), filters, subject.UserID) + if err != nil { + log.Printf("[UsageCleanup] 创建清理任务失败: operator=%d err=%v", subject.UserID, err) + response.ErrorFrom(c, err) + return + } + + log.Printf("[UsageCleanup] 清理任务已创建: task=%d operator=%d status=%s", task.ID, subject.UserID, task.Status) + response.Success(c, dto.UsageCleanupTaskFromService(task)) +} + +// CancelCleanupTask handles canceling a usage cleanup task +// POST /api/v1/admin/usage/cleanup-tasks/:id/cancel +func (h *UsageHandler) CancelCleanupTask(c *gin.Context) { + if h.cleanupService == nil { + response.Error(c, http.StatusServiceUnavailable, "Usage cleanup service unavailable") + return + } + subject, ok := middleware.GetAuthSubjectFromContext(c) + if !ok || subject.UserID <= 0 { + response.Unauthorized(c, "Unauthorized") + return + } + idStr := strings.TrimSpace(c.Param("id")) + taskID, err := strconv.ParseInt(idStr, 10, 64) + if err != nil || taskID <= 0 { + response.BadRequest(c, "Invalid task id") + return + } + log.Printf("[UsageCleanup] 请求取消清理任务: task=%d operator=%d", taskID, subject.UserID) + if err := h.cleanupService.CancelTask(c.Request.Context(), taskID, subject.UserID); err != nil { + log.Printf("[UsageCleanup] 取消清理任务失败: task=%d operator=%d err=%v", taskID, subject.UserID, err) + response.ErrorFrom(c, err) + return + } + log.Printf("[UsageCleanup] 清理任务已取消: task=%d operator=%d", taskID, subject.UserID) + response.Success(c, gin.H{"id": taskID, "status": service.UsageCleanupStatusCanceled}) +} diff --git a/backend/internal/handler/admin/user_handler.go b/backend/internal/handler/admin/user_handler.go index 38cc8acd..9a5a691f 100644 --- a/backend/internal/handler/admin/user_handler.go +++ b/backend/internal/handler/admin/user_handler.go @@ -84,9 +84,9 @@ func (h *UserHandler) List(c *gin.Context) { return } - out := make([]dto.User, 0, len(users)) + out := make([]dto.AdminUser, 0, len(users)) for i := range users { - out = append(out, *dto.UserFromService(&users[i])) + out = append(out, *dto.UserFromServiceAdmin(&users[i])) } response.Paginated(c, out, total, page, pageSize) } @@ -129,7 +129,7 @@ func (h *UserHandler) GetByID(c *gin.Context) { return } - response.Success(c, dto.UserFromService(user)) + response.Success(c, dto.UserFromServiceAdmin(user)) } // Create handles creating a new user @@ -155,7 +155,7 @@ func (h *UserHandler) Create(c *gin.Context) { return } - response.Success(c, dto.UserFromService(user)) + response.Success(c, dto.UserFromServiceAdmin(user)) } // Update handles updating a user @@ -189,7 +189,7 @@ func (h *UserHandler) Update(c *gin.Context) { return } - response.Success(c, dto.UserFromService(user)) + response.Success(c, dto.UserFromServiceAdmin(user)) } // Delete handles deleting a user @@ -231,7 +231,7 @@ func (h *UserHandler) UpdateBalance(c *gin.Context) { return } - response.Success(c, dto.UserFromService(user)) + response.Success(c, dto.UserFromServiceAdmin(user)) } // GetUserAPIKeys handles getting user's API keys diff --git a/backend/internal/handler/dto/mappers.go b/backend/internal/handler/dto/mappers.go index f5bdd008..d58a8a29 100644 --- a/backend/internal/handler/dto/mappers.go +++ b/backend/internal/handler/dto/mappers.go @@ -15,7 +15,6 @@ func UserFromServiceShallow(u *service.User) *User { ID: u.ID, Email: u.Email, Username: u.Username, - Notes: u.Notes, Role: u.Role, Balance: u.Balance, Concurrency: u.Concurrency, @@ -48,6 +47,22 @@ func UserFromService(u *service.User) *User { return out } +// UserFromServiceAdmin converts a service User to DTO for admin users. +// It includes notes - user-facing endpoints must not use this. +func UserFromServiceAdmin(u *service.User) *AdminUser { + if u == nil { + return nil + } + base := UserFromService(u) + if base == nil { + return nil + } + return &AdminUser{ + User: *base, + Notes: u.Notes, + } +} + func APIKeyFromService(k *service.APIKey) *APIKey { if k == nil { return nil @@ -72,36 +87,29 @@ func GroupFromServiceShallow(g *service.Group) *Group { if g == nil { return nil } - return &Group{ - ID: g.ID, - Name: g.Name, - Description: g.Description, - Platform: g.Platform, - RateMultiplier: g.RateMultiplier, - IsExclusive: g.IsExclusive, - Status: g.Status, - SubscriptionType: g.SubscriptionType, - DailyLimitUSD: g.DailyLimitUSD, - WeeklyLimitUSD: g.WeeklyLimitUSD, - MonthlyLimitUSD: g.MonthlyLimitUSD, - ImagePrice1K: g.ImagePrice1K, - ImagePrice2K: g.ImagePrice2K, - ImagePrice4K: g.ImagePrice4K, - ClaudeCodeOnly: g.ClaudeCodeOnly, - FallbackGroupID: g.FallbackGroupID, - ModelRouting: g.ModelRouting, - ModelRoutingEnabled: g.ModelRoutingEnabled, - CreatedAt: g.CreatedAt, - UpdatedAt: g.UpdatedAt, - AccountCount: g.AccountCount, - } + out := groupFromServiceBase(g) + return &out } func GroupFromService(g *service.Group) *Group { if g == nil { return nil } - out := GroupFromServiceShallow(g) + return GroupFromServiceShallow(g) +} + +// GroupFromServiceAdmin converts a service Group to DTO for admin users. +// It includes internal fields like model_routing and account_count. +func GroupFromServiceAdmin(g *service.Group) *AdminGroup { + if g == nil { + return nil + } + out := &AdminGroup{ + Group: groupFromServiceBase(g), + ModelRouting: g.ModelRouting, + ModelRoutingEnabled: g.ModelRoutingEnabled, + AccountCount: g.AccountCount, + } if len(g.AccountGroups) > 0 { out.AccountGroups = make([]AccountGroup, 0, len(g.AccountGroups)) for i := range g.AccountGroups { @@ -112,6 +120,29 @@ func GroupFromService(g *service.Group) *Group { return out } +func groupFromServiceBase(g *service.Group) Group { + return Group{ + ID: g.ID, + Name: g.Name, + Description: g.Description, + Platform: g.Platform, + RateMultiplier: g.RateMultiplier, + IsExclusive: g.IsExclusive, + Status: g.Status, + SubscriptionType: g.SubscriptionType, + DailyLimitUSD: g.DailyLimitUSD, + WeeklyLimitUSD: g.WeeklyLimitUSD, + MonthlyLimitUSD: g.MonthlyLimitUSD, + ImagePrice1K: g.ImagePrice1K, + ImagePrice2K: g.ImagePrice2K, + ImagePrice4K: g.ImagePrice4K, + ClaudeCodeOnly: g.ClaudeCodeOnly, + FallbackGroupID: g.FallbackGroupID, + CreatedAt: g.CreatedAt, + UpdatedAt: g.UpdatedAt, + } +} + func AccountFromServiceShallow(a *service.Account) *Account { if a == nil { return nil @@ -161,6 +192,16 @@ func AccountFromServiceShallow(a *service.Account) *Account { if idleTimeout := a.GetSessionIdleTimeoutMinutes(); idleTimeout > 0 { out.SessionIdleTimeoutMin = &idleTimeout } + // TLS指纹伪装开关 + if a.IsTLSFingerprintEnabled() { + enabled := true + out.EnableTLSFingerprint = &enabled + } + // 会话ID伪装开关 + if a.IsSessionIDMaskingEnabled() { + enabled := true + out.EnableSessionIDMasking = &enabled + } } return out @@ -263,7 +304,24 @@ func RedeemCodeFromService(rc *service.RedeemCode) *RedeemCode { if rc == nil { return nil } - return &RedeemCode{ + out := redeemCodeFromServiceBase(rc) + return &out +} + +// RedeemCodeFromServiceAdmin converts a service RedeemCode to DTO for admin users. +// It includes notes - user-facing endpoints must not use this. +func RedeemCodeFromServiceAdmin(rc *service.RedeemCode) *AdminRedeemCode { + if rc == nil { + return nil + } + return &AdminRedeemCode{ + RedeemCode: redeemCodeFromServiceBase(rc), + Notes: rc.Notes, + } +} + +func redeemCodeFromServiceBase(rc *service.RedeemCode) RedeemCode { + return RedeemCode{ ID: rc.ID, Code: rc.Code, Type: rc.Type, @@ -271,7 +329,6 @@ func RedeemCodeFromService(rc *service.RedeemCode) *RedeemCode { Status: rc.Status, UsedBy: rc.UsedBy, UsedAt: rc.UsedAt, - Notes: rc.Notes, CreatedAt: rc.CreatedAt, GroupID: rc.GroupID, ValidityDays: rc.ValidityDays, @@ -292,14 +349,9 @@ func AccountSummaryFromService(a *service.Account) *AccountSummary { } } -// usageLogFromServiceBase is a helper that converts service UsageLog to DTO. -// The account parameter allows caller to control what Account info is included. -// The includeIPAddress parameter controls whether to include the IP address (admin-only). -func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, includeIPAddress bool) *UsageLog { - if l == nil { - return nil - } - result := &UsageLog{ +func usageLogFromServiceUser(l *service.UsageLog) UsageLog { + // 普通用户 DTO:严禁包含管理员字段(例如 account_rate_multiplier、ip_address、account)。 + return UsageLog{ ID: l.ID, UserID: l.UserID, APIKeyID: l.APIKeyID, @@ -321,7 +373,6 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, inclu TotalCost: l.TotalCost, ActualCost: l.ActualCost, RateMultiplier: l.RateMultiplier, - AccountRateMultiplier: l.AccountRateMultiplier, BillingType: l.BillingType, Stream: l.Stream, DurationMs: l.DurationMs, @@ -332,30 +383,63 @@ func usageLogFromServiceBase(l *service.UsageLog, account *AccountSummary, inclu CreatedAt: l.CreatedAt, User: UserFromServiceShallow(l.User), APIKey: APIKeyFromService(l.APIKey), - Account: account, Group: GroupFromServiceShallow(l.Group), Subscription: UserSubscriptionFromService(l.Subscription), } - // IP 地址仅对管理员可见 - if includeIPAddress { - result.IPAddress = l.IPAddress - } - return result } // UsageLogFromService converts a service UsageLog to DTO for regular users. // It excludes Account details and IP address - users should not see these. func UsageLogFromService(l *service.UsageLog) *UsageLog { - return usageLogFromServiceBase(l, nil, false) + if l == nil { + return nil + } + u := usageLogFromServiceUser(l) + return &u } // UsageLogFromServiceAdmin converts a service UsageLog to DTO for admin users. // It includes minimal Account info (ID, Name only) and IP address. -func UsageLogFromServiceAdmin(l *service.UsageLog) *UsageLog { +func UsageLogFromServiceAdmin(l *service.UsageLog) *AdminUsageLog { if l == nil { return nil } - return usageLogFromServiceBase(l, AccountSummaryFromService(l.Account), true) + return &AdminUsageLog{ + UsageLog: usageLogFromServiceUser(l), + AccountRateMultiplier: l.AccountRateMultiplier, + IPAddress: l.IPAddress, + Account: AccountSummaryFromService(l.Account), + } +} + +func UsageCleanupTaskFromService(task *service.UsageCleanupTask) *UsageCleanupTask { + if task == nil { + return nil + } + return &UsageCleanupTask{ + ID: task.ID, + Status: task.Status, + Filters: UsageCleanupFilters{ + StartTime: task.Filters.StartTime, + EndTime: task.Filters.EndTime, + UserID: task.Filters.UserID, + APIKeyID: task.Filters.APIKeyID, + AccountID: task.Filters.AccountID, + GroupID: task.Filters.GroupID, + Model: task.Filters.Model, + Stream: task.Filters.Stream, + BillingType: task.Filters.BillingType, + }, + CreatedBy: task.CreatedBy, + DeletedRows: task.DeletedRows, + ErrorMessage: task.ErrorMsg, + CanceledBy: task.CanceledBy, + CanceledAt: task.CanceledAt, + StartedAt: task.StartedAt, + FinishedAt: task.FinishedAt, + CreatedAt: task.CreatedAt, + UpdatedAt: task.UpdatedAt, + } } func SettingFromService(s *service.Setting) *Setting { @@ -374,7 +458,27 @@ func UserSubscriptionFromService(sub *service.UserSubscription) *UserSubscriptio if sub == nil { return nil } - return &UserSubscription{ + out := userSubscriptionFromServiceBase(sub) + return &out +} + +// UserSubscriptionFromServiceAdmin converts a service UserSubscription to DTO for admin users. +// It includes assignment metadata and notes. +func UserSubscriptionFromServiceAdmin(sub *service.UserSubscription) *AdminUserSubscription { + if sub == nil { + return nil + } + return &AdminUserSubscription{ + UserSubscription: userSubscriptionFromServiceBase(sub), + AssignedBy: sub.AssignedBy, + AssignedAt: sub.AssignedAt, + Notes: sub.Notes, + AssignedByUser: UserFromServiceShallow(sub.AssignedByUser), + } +} + +func userSubscriptionFromServiceBase(sub *service.UserSubscription) UserSubscription { + return UserSubscription{ ID: sub.ID, UserID: sub.UserID, GroupID: sub.GroupID, @@ -387,14 +491,10 @@ func UserSubscriptionFromService(sub *service.UserSubscription) *UserSubscriptio DailyUsageUSD: sub.DailyUsageUSD, WeeklyUsageUSD: sub.WeeklyUsageUSD, MonthlyUsageUSD: sub.MonthlyUsageUSD, - AssignedBy: sub.AssignedBy, - AssignedAt: sub.AssignedAt, - Notes: sub.Notes, CreatedAt: sub.CreatedAt, UpdatedAt: sub.UpdatedAt, User: UserFromServiceShallow(sub.User), Group: GroupFromServiceShallow(sub.Group), - AssignedByUser: UserFromServiceShallow(sub.AssignedByUser), } } @@ -402,9 +502,9 @@ func BulkAssignResultFromService(r *service.BulkAssignResult) *BulkAssignResult if r == nil { return nil } - subs := make([]UserSubscription, 0, len(r.Subscriptions)) + subs := make([]AdminUserSubscription, 0, len(r.Subscriptions)) for i := range r.Subscriptions { - subs = append(subs, *UserSubscriptionFromService(&r.Subscriptions[i])) + subs = append(subs, *UserSubscriptionFromServiceAdmin(&r.Subscriptions[i])) } return &BulkAssignResult{ SuccessCount: r.SuccessCount, diff --git a/backend/internal/handler/dto/settings.go b/backend/internal/handler/dto/settings.go index 81206def..19356e46 100644 --- a/backend/internal/handler/dto/settings.go +++ b/backend/internal/handler/dto/settings.go @@ -22,13 +22,14 @@ type SystemSettings struct { LinuxDoConnectClientSecretConfigured bool `json:"linuxdo_connect_client_secret_configured"` LinuxDoConnectRedirectURL string `json:"linuxdo_connect_redirect_url"` - SiteName string `json:"site_name"` - SiteLogo string `json:"site_logo"` - SiteSubtitle string `json:"site_subtitle"` - APIBaseURL string `json:"api_base_url"` - ContactInfo string `json:"contact_info"` - DocURL string `json:"doc_url"` - HomeContent string `json:"home_content"` + SiteName string `json:"site_name"` + SiteLogo string `json:"site_logo"` + SiteSubtitle string `json:"site_subtitle"` + APIBaseURL string `json:"api_base_url"` + ContactInfo string `json:"contact_info"` + DocURL string `json:"doc_url"` + HomeContent string `json:"home_content"` + HideCcsImportButton bool `json:"hide_ccs_import_button"` DefaultConcurrency int `json:"default_concurrency"` DefaultBalance float64 `json:"default_balance"` @@ -63,6 +64,7 @@ type PublicSettings struct { ContactInfo string `json:"contact_info"` DocURL string `json:"doc_url"` HomeContent string `json:"home_content"` + HideCcsImportButton bool `json:"hide_ccs_import_button"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` Version string `json:"version"` } diff --git a/backend/internal/handler/dto/types.go b/backend/internal/handler/dto/types.go index 4519143c..938d707c 100644 --- a/backend/internal/handler/dto/types.go +++ b/backend/internal/handler/dto/types.go @@ -6,7 +6,6 @@ type User struct { ID int64 `json:"id"` Email string `json:"email"` Username string `json:"username"` - Notes string `json:"notes"` Role string `json:"role"` Balance float64 `json:"balance"` Concurrency int `json:"concurrency"` @@ -19,6 +18,14 @@ type User struct { Subscriptions []UserSubscription `json:"subscriptions,omitempty"` } +// AdminUser 是管理员接口使用的 user DTO(包含敏感/内部字段)。 +// 注意:普通用户接口不得返回 notes 等管理员备注信息。 +type AdminUser struct { + User + + Notes string `json:"notes"` +} + type APIKey struct { ID int64 `json:"id"` UserID int64 `json:"user_id"` @@ -58,13 +65,19 @@ type Group struct { ClaudeCodeOnly bool `json:"claude_code_only"` FallbackGroupID *int64 `json:"fallback_group_id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +// AdminGroup 是管理员接口使用的 group DTO(包含敏感/内部字段)。 +// 注意:普通用户接口不得返回 model_routing/account_count/account_groups 等内部信息。 +type AdminGroup struct { + Group + // 模型路由配置(仅 anthropic 平台使用) ModelRouting map[string][]int64 `json:"model_routing"` ModelRoutingEnabled bool `json:"model_routing_enabled"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` - AccountGroups []AccountGroup `json:"account_groups,omitempty"` AccountCount int64 `json:"account_count,omitempty"` } @@ -112,6 +125,15 @@ type Account struct { MaxSessions *int `json:"max_sessions,omitempty"` SessionIdleTimeoutMin *int `json:"session_idle_timeout_minutes,omitempty"` + // TLS指纹伪装(仅 Anthropic OAuth/SetupToken 账号有效) + // 从 extra 字段提取,方便前端显示和编辑 + EnableTLSFingerprint *bool `json:"enable_tls_fingerprint,omitempty"` + + // 会话ID伪装(仅 Anthropic OAuth/SetupToken 账号有效) + // 启用后将在15分钟内固定 metadata.user_id 中的 session ID + // 从 extra 字段提取,方便前端显示和编辑 + EnableSessionIDMasking *bool `json:"session_id_masking_enabled,omitempty"` + Proxy *Proxy `json:"proxy,omitempty"` AccountGroups []AccountGroup `json:"account_groups,omitempty"` @@ -171,7 +193,6 @@ type RedeemCode struct { Status string `json:"status"` UsedBy *int64 `json:"used_by"` UsedAt *time.Time `json:"used_at"` - Notes string `json:"notes"` CreatedAt time.Time `json:"created_at"` GroupID *int64 `json:"group_id"` @@ -181,6 +202,15 @@ type RedeemCode struct { Group *Group `json:"group,omitempty"` } +// AdminRedeemCode 是管理员接口使用的 redeem code DTO(包含 notes 等字段)。 +// 注意:普通用户接口不得返回 notes 等内部信息。 +type AdminRedeemCode struct { + RedeemCode + + Notes string `json:"notes"` +} + +// UsageLog 是普通用户接口使用的 usage log DTO(不包含管理员字段)。 type UsageLog struct { ID int64 `json:"id"` UserID int64 `json:"user_id"` @@ -200,14 +230,13 @@ type UsageLog struct { CacheCreation5mTokens int `json:"cache_creation_5m_tokens"` CacheCreation1hTokens int `json:"cache_creation_1h_tokens"` - InputCost float64 `json:"input_cost"` - OutputCost float64 `json:"output_cost"` - CacheCreationCost float64 `json:"cache_creation_cost"` - CacheReadCost float64 `json:"cache_read_cost"` - TotalCost float64 `json:"total_cost"` - ActualCost float64 `json:"actual_cost"` - RateMultiplier float64 `json:"rate_multiplier"` - AccountRateMultiplier *float64 `json:"account_rate_multiplier"` + InputCost float64 `json:"input_cost"` + OutputCost float64 `json:"output_cost"` + CacheCreationCost float64 `json:"cache_creation_cost"` + CacheReadCost float64 `json:"cache_read_cost"` + TotalCost float64 `json:"total_cost"` + ActualCost float64 `json:"actual_cost"` + RateMultiplier float64 `json:"rate_multiplier"` BillingType int8 `json:"billing_type"` Stream bool `json:"stream"` @@ -221,18 +250,55 @@ type UsageLog struct { // User-Agent UserAgent *string `json:"user_agent"` - // IP 地址(仅管理员可见) - IPAddress *string `json:"ip_address,omitempty"` - CreatedAt time.Time `json:"created_at"` User *User `json:"user,omitempty"` APIKey *APIKey `json:"api_key,omitempty"` - Account *AccountSummary `json:"account,omitempty"` // Use minimal AccountSummary to prevent data leakage Group *Group `json:"group,omitempty"` Subscription *UserSubscription `json:"subscription,omitempty"` } +// AdminUsageLog 是管理员接口使用的 usage log DTO(包含管理员字段)。 +type AdminUsageLog struct { + UsageLog + + // AccountRateMultiplier 账号计费倍率快照(nil 表示按 1.0 处理) + AccountRateMultiplier *float64 `json:"account_rate_multiplier"` + + // IPAddress 用户请求 IP(仅管理员可见) + IPAddress *string `json:"ip_address,omitempty"` + + // Account 最小账号信息(避免泄露敏感字段) + Account *AccountSummary `json:"account,omitempty"` +} + +type UsageCleanupFilters struct { + StartTime time.Time `json:"start_time"` + EndTime time.Time `json:"end_time"` + UserID *int64 `json:"user_id,omitempty"` + APIKeyID *int64 `json:"api_key_id,omitempty"` + AccountID *int64 `json:"account_id,omitempty"` + GroupID *int64 `json:"group_id,omitempty"` + Model *string `json:"model,omitempty"` + Stream *bool `json:"stream,omitempty"` + BillingType *int8 `json:"billing_type,omitempty"` +} + +type UsageCleanupTask struct { + ID int64 `json:"id"` + Status string `json:"status"` + Filters UsageCleanupFilters `json:"filters"` + CreatedBy int64 `json:"created_by"` + DeletedRows int64 `json:"deleted_rows"` + ErrorMessage *string `json:"error_message,omitempty"` + CanceledBy *int64 `json:"canceled_by,omitempty"` + CanceledAt *time.Time `json:"canceled_at,omitempty"` + StartedAt *time.Time `json:"started_at,omitempty"` + FinishedAt *time.Time `json:"finished_at,omitempty"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + // AccountSummary is a minimal account info for usage log display. // It intentionally excludes sensitive fields like Credentials, Proxy, etc. type AccountSummary struct { @@ -264,23 +330,30 @@ type UserSubscription struct { WeeklyUsageUSD float64 `json:"weekly_usage_usd"` MonthlyUsageUSD float64 `json:"monthly_usage_usd"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + + User *User `json:"user,omitempty"` + Group *Group `json:"group,omitempty"` +} + +// AdminUserSubscription 是管理员接口使用的订阅 DTO(包含分配信息/备注等字段)。 +// 注意:普通用户接口不得返回 assigned_by/assigned_at/notes/assigned_by_user 等管理员字段。 +type AdminUserSubscription struct { + UserSubscription + AssignedBy *int64 `json:"assigned_by"` AssignedAt time.Time `json:"assigned_at"` Notes string `json:"notes"` - CreatedAt time.Time `json:"created_at"` - UpdatedAt time.Time `json:"updated_at"` - - User *User `json:"user,omitempty"` - Group *Group `json:"group,omitempty"` - AssignedByUser *User `json:"assigned_by_user,omitempty"` + AssignedByUser *User `json:"assigned_by_user,omitempty"` } type BulkAssignResult struct { - SuccessCount int `json:"success_count"` - FailedCount int `json:"failed_count"` - Subscriptions []UserSubscription `json:"subscriptions"` - Errors []string `json:"errors"` + SuccessCount int `json:"success_count"` + FailedCount int `json:"failed_count"` + Subscriptions []AdminUserSubscription `json:"subscriptions"` + Errors []string `json:"errors"` } // PromoCode 注册优惠码 diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 7605805a..f38fea39 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -31,6 +31,8 @@ type GatewayHandler struct { userService *service.UserService billingCacheService *service.BillingCacheService concurrencyHelper *ConcurrencyHelper + maxAccountSwitches int + maxAccountSwitchesGemini int } // NewGatewayHandler creates a new GatewayHandler @@ -44,8 +46,16 @@ func NewGatewayHandler( cfg *config.Config, ) *GatewayHandler { pingInterval := time.Duration(0) + maxAccountSwitches := 10 + maxAccountSwitchesGemini := 3 if cfg != nil { pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second + if cfg.Gateway.MaxAccountSwitches > 0 { + maxAccountSwitches = cfg.Gateway.MaxAccountSwitches + } + if cfg.Gateway.MaxAccountSwitchesGemini > 0 { + maxAccountSwitchesGemini = cfg.Gateway.MaxAccountSwitchesGemini + } } return &GatewayHandler{ gatewayService: gatewayService, @@ -54,6 +64,8 @@ func NewGatewayHandler( userService: userService, billingCacheService: billingCacheService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatClaude, pingInterval), + maxAccountSwitches: maxAccountSwitches, + maxAccountSwitchesGemini: maxAccountSwitchesGemini, } } @@ -179,7 +191,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } if platform == service.PlatformGemini { - const maxAccountSwitches = 3 + maxAccountSwitches := h.maxAccountSwitchesGemini switchCount := 0 failedAccountIDs := make(map[int64]struct{}) lastFailoverStatus := 0 @@ -313,7 +325,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } } - const maxAccountSwitches = 10 + maxAccountSwitches := h.maxAccountSwitches switchCount := 0 failedAccountIDs := make(map[int64]struct{}) lastFailoverStatus := 0 diff --git a/backend/internal/handler/gemini_v1beta_handler.go b/backend/internal/handler/gemini_v1beta_handler.go index ec943e61..c7646b38 100644 --- a/backend/internal/handler/gemini_v1beta_handler.go +++ b/backend/internal/handler/gemini_v1beta_handler.go @@ -220,7 +220,7 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) { if sessionHash != "" { sessionKey = "gemini:" + sessionHash } - const maxAccountSwitches = 3 + maxAccountSwitches := h.maxAccountSwitchesGemini switchCount := 0 failedAccountIDs := make(map[int64]struct{}) lastFailoverStatus := 0 diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index c4cfabc3..4c9dd8b9 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -25,6 +25,7 @@ type OpenAIGatewayHandler struct { gatewayService *service.OpenAIGatewayService billingCacheService *service.BillingCacheService concurrencyHelper *ConcurrencyHelper + maxAccountSwitches int } // NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler @@ -35,13 +36,18 @@ func NewOpenAIGatewayHandler( cfg *config.Config, ) *OpenAIGatewayHandler { pingInterval := time.Duration(0) + maxAccountSwitches := 3 if cfg != nil { pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second + if cfg.Gateway.MaxAccountSwitches > 0 { + maxAccountSwitches = cfg.Gateway.MaxAccountSwitches + } } return &OpenAIGatewayHandler{ gatewayService: gatewayService, billingCacheService: billingCacheService, concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval), + maxAccountSwitches: maxAccountSwitches, } } @@ -186,10 +192,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { return } - // Generate session hash (from header for OpenAI) - sessionHash := h.gatewayService.GenerateSessionHash(c) + // Generate session hash (header first; fallback to prompt_cache_key) + sessionHash := h.gatewayService.GenerateSessionHash(c, reqBody) - const maxAccountSwitches = 3 + maxAccountSwitches := h.maxAccountSwitches switchCount := 0 failedAccountIDs := make(map[int64]struct{}) lastFailoverStatus := 0 diff --git a/backend/internal/handler/setting_handler.go b/backend/internal/handler/setting_handler.go index cac79e9c..0fc61144 100644 --- a/backend/internal/handler/setting_handler.go +++ b/backend/internal/handler/setting_handler.go @@ -43,6 +43,7 @@ func (h *SettingHandler) GetPublicSettings(c *gin.Context) { ContactInfo: settings.ContactInfo, DocURL: settings.DocURL, HomeContent: settings.HomeContent, + HideCcsImportButton: settings.HideCcsImportButton, LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, Version: h.version, }) diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index d968951c..35862f1c 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -47,9 +47,6 @@ func (h *UserHandler) GetProfile(c *gin.Context) { return } - // 清空notes字段,普通用户不应看到备注 - userData.Notes = "" - response.Success(c, dto.UserFromService(userData)) } @@ -105,8 +102,5 @@ func (h *UserHandler) UpdateProfile(c *gin.Context) { return } - // 清空notes字段,普通用户不应看到备注 - updatedUser.Notes = "" - response.Success(c, dto.UserFromService(updatedUser)) } diff --git a/backend/internal/middleware/rate_limiter_integration_test.go b/backend/internal/middleware/rate_limiter_integration_test.go index 4759a988..1161364b 100644 --- a/backend/internal/middleware/rate_limiter_integration_test.go +++ b/backend/internal/middleware/rate_limiter_integration_test.go @@ -7,6 +7,9 @@ import ( "fmt" "net/http" "net/http/httptest" + "os" + "path/filepath" + "strconv" "testing" "time" @@ -88,6 +91,7 @@ func performRequest(router *gin.Engine) *httptest.ResponseRecorder { func startRedis(t *testing.T, ctx context.Context) *redis.Client { t.Helper() + ensureDockerAvailable(t) redisContainer, err := tcredis.Run(ctx, redisImageTag) require.NoError(t, err) @@ -112,3 +116,43 @@ func startRedis(t *testing.T, ctx context.Context) *redis.Client { return rdb } + +func ensureDockerAvailable(t *testing.T) { + t.Helper() + if dockerAvailable() { + return + } + t.Skip("Docker 未启用,跳过依赖 testcontainers 的集成测试") +} + +func dockerAvailable() bool { + if os.Getenv("DOCKER_HOST") != "" { + return true + } + + socketCandidates := []string{ + "/var/run/docker.sock", + filepath.Join(os.Getenv("XDG_RUNTIME_DIR"), "docker.sock"), + filepath.Join(userHomeDir(), ".docker", "run", "docker.sock"), + filepath.Join(userHomeDir(), ".docker", "desktop", "docker.sock"), + filepath.Join("/run/user", strconv.Itoa(os.Getuid()), "docker.sock"), + } + + for _, socket := range socketCandidates { + if socket == "" { + continue + } + if _, err := os.Stat(socket); err == nil { + return true + } + } + return false +} + +func userHomeDir() string { + home, err := os.UserHomeDir() + if err != nil { + return "" + } + return home +} diff --git a/backend/internal/pkg/antigravity/client.go b/backend/internal/pkg/antigravity/client.go index 1248be95..a6279b11 100644 --- a/backend/internal/pkg/antigravity/client.go +++ b/backend/internal/pkg/antigravity/client.go @@ -16,15 +16,6 @@ import ( "time" ) -// resolveHost 从 URL 解析 host -func resolveHost(urlStr string) string { - parsed, err := url.Parse(urlStr) - if err != nil { - return "" - } - return parsed.Host -} - // NewAPIRequestWithURL 使用指定的 base URL 创建 Antigravity API 请求(v1internal 端点) func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken string, body []byte) (*http.Request, error) { // 构建 URL,流式请求添加 ?alt=sse 参数 @@ -39,23 +30,11 @@ func NewAPIRequestWithURL(ctx context.Context, baseURL, action, accessToken stri return nil, err } - // 基础 Headers + // 基础 Headers(与 Antigravity-Manager 保持一致,只设置这 3 个) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+accessToken) req.Header.Set("User-Agent", UserAgent) - // Accept Header 根据请求类型设置 - if isStream { - req.Header.Set("Accept", "text/event-stream") - } else { - req.Header.Set("Accept", "application/json") - } - - // 显式设置 Host Header - if host := resolveHost(apiURL); host != "" { - req.Host = host - } - return req, nil } @@ -195,12 +174,15 @@ func isConnectionError(err error) bool { } // shouldFallbackToNextURL 判断是否应切换到下一个 URL -// 仅连接错误和 HTTP 429 触发 URL 降级 +// 与 Antigravity-Manager 保持一致:连接错误、429、408、404、5xx 触发 URL 降级 func shouldFallbackToNextURL(err error, statusCode int) bool { if isConnectionError(err) { return true } - return statusCode == http.StatusTooManyRequests + return statusCode == http.StatusTooManyRequests || + statusCode == http.StatusRequestTimeout || + statusCode == http.StatusNotFound || + statusCode >= 500 } // ExchangeCode 用 authorization code 交换 token @@ -321,11 +303,8 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC return nil, nil, fmt.Errorf("序列化请求失败: %w", err) } - // 获取可用的 URL 列表 - availableURLs := DefaultURLAvailability.GetAvailableURLs() - if len(availableURLs) == 0 { - availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有 - } + // 固定顺序:prod -> daily + availableURLs := BaseURLs var lastErr error for urlIdx, baseURL := range availableURLs { @@ -343,7 +322,6 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC if err != nil { lastErr = fmt.Errorf("loadCodeAssist 请求失败: %w", err) if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { - DefaultURLAvailability.MarkUnavailable(baseURL) log.Printf("[antigravity] loadCodeAssist URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1]) continue } @@ -358,7 +336,6 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC // 检查是否需要 URL 降级 if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 { - DefaultURLAvailability.MarkUnavailable(baseURL) log.Printf("[antigravity] loadCodeAssist URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1]) continue } @@ -376,6 +353,8 @@ func (c *Client) LoadCodeAssist(ctx context.Context, accessToken string) (*LoadC var rawResp map[string]any _ = json.Unmarshal(respBodyBytes, &rawResp) + // 标记成功的 URL,下次优先使用 + DefaultURLAvailability.MarkSuccess(baseURL) return &loadResp, rawResp, nil } @@ -412,11 +391,8 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI return nil, nil, fmt.Errorf("序列化请求失败: %w", err) } - // 获取可用的 URL 列表 - availableURLs := DefaultURLAvailability.GetAvailableURLs() - if len(availableURLs) == 0 { - availableURLs = BaseURLs // 所有 URL 都不可用时,重试所有 - } + // 固定顺序:prod -> daily + availableURLs := BaseURLs var lastErr error for urlIdx, baseURL := range availableURLs { @@ -434,7 +410,6 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI if err != nil { lastErr = fmt.Errorf("fetchAvailableModels 请求失败: %w", err) if shouldFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { - DefaultURLAvailability.MarkUnavailable(baseURL) log.Printf("[antigravity] fetchAvailableModels URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1]) continue } @@ -449,7 +424,6 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI // 检查是否需要 URL 降级 if shouldFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 { - DefaultURLAvailability.MarkUnavailable(baseURL) log.Printf("[antigravity] fetchAvailableModels URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1]) continue } @@ -467,6 +441,8 @@ func (c *Client) FetchAvailableModels(ctx context.Context, accessToken, projectI var rawResp map[string]any _ = json.Unmarshal(respBodyBytes, &rawResp) + // 标记成功的 URL,下次优先使用 + DefaultURLAvailability.MarkSuccess(baseURL) return &modelsResp, rawResp, nil } diff --git a/backend/internal/pkg/antigravity/gemini_types.go b/backend/internal/pkg/antigravity/gemini_types.go index f688332f..c1cc998c 100644 --- a/backend/internal/pkg/antigravity/gemini_types.go +++ b/backend/internal/pkg/antigravity/gemini_types.go @@ -143,9 +143,10 @@ type GeminiResponse struct { // GeminiCandidate Gemini 候选响应 type GeminiCandidate struct { - Content *GeminiContent `json:"content,omitempty"` - FinishReason string `json:"finishReason,omitempty"` - Index int `json:"index,omitempty"` + Content *GeminiContent `json:"content,omitempty"` + FinishReason string `json:"finishReason,omitempty"` + Index int `json:"index,omitempty"` + GroundingMetadata *GeminiGroundingMetadata `json:"groundingMetadata,omitempty"` } // GeminiUsageMetadata Gemini 用量元数据 @@ -156,6 +157,23 @@ type GeminiUsageMetadata struct { TotalTokenCount int `json:"totalTokenCount,omitempty"` } +// GeminiGroundingMetadata Gemini grounding 元数据(Web Search) +type GeminiGroundingMetadata struct { + WebSearchQueries []string `json:"webSearchQueries,omitempty"` + GroundingChunks []GeminiGroundingChunk `json:"groundingChunks,omitempty"` +} + +// GeminiGroundingChunk Gemini grounding chunk +type GeminiGroundingChunk struct { + Web *GeminiGroundingWeb `json:"web,omitempty"` +} + +// GeminiGroundingWeb Gemini grounding web 信息 +type GeminiGroundingWeb struct { + Title string `json:"title,omitempty"` + URI string `json:"uri,omitempty"` +} + // DefaultSafetySettings 默认安全设置(关闭所有过滤) var DefaultSafetySettings = []GeminiSafetySetting{ {Category: "HARM_CATEGORY_HARASSMENT", Threshold: "OFF"}, diff --git a/backend/internal/pkg/antigravity/oauth.go b/backend/internal/pkg/antigravity/oauth.go index 736c45df..ee2a6c1a 100644 --- a/backend/internal/pkg/antigravity/oauth.go +++ b/backend/internal/pkg/antigravity/oauth.go @@ -32,8 +32,8 @@ const ( "https://www.googleapis.com/auth/cclog " + "https://www.googleapis.com/auth/experimentsandconfigs" - // User-Agent(模拟官方客户端) - UserAgent = "antigravity/1.104.0 darwin/arm64" + // User-Agent(与 Antigravity-Manager 保持一致) + UserAgent = "antigravity/1.11.9 windows/amd64" // Session 过期时间 SessionTTL = 30 * time.Minute @@ -42,22 +42,21 @@ const ( URLAvailabilityTTL = 5 * time.Minute ) -// BaseURLs 定义 Antigravity API 端点,按优先级排序 -// fallback 顺序: sandbox → daily → prod +// BaseURLs 定义 Antigravity API 端点(与 Antigravity-Manager 保持一致) var BaseURLs = []string{ - "https://daily-cloudcode-pa.sandbox.googleapis.com", // sandbox - "https://daily-cloudcode-pa.googleapis.com", // daily - "https://cloudcode-pa.googleapis.com", // prod + "https://cloudcode-pa.googleapis.com", // prod (优先) + "https://daily-cloudcode-pa.sandbox.googleapis.com", // daily sandbox (备用) } // BaseURL 默认 URL(保持向后兼容) var BaseURL = BaseURLs[0] -// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复) +// URLAvailability 管理 URL 可用性状态(带 TTL 自动恢复和动态优先级) type URLAvailability struct { mu sync.RWMutex unavailable map[string]time.Time // URL -> 恢复时间 ttl time.Duration + lastSuccess string // 最近成功请求的 URL,优先使用 } // DefaultURLAvailability 全局 URL 可用性管理器 @@ -78,6 +77,15 @@ func (u *URLAvailability) MarkUnavailable(url string) { u.unavailable[url] = time.Now().Add(u.ttl) } +// MarkSuccess 标记 URL 请求成功,将其设为优先使用 +func (u *URLAvailability) MarkSuccess(url string) { + u.mu.Lock() + defer u.mu.Unlock() + u.lastSuccess = url + // 成功后清除该 URL 的不可用标记 + delete(u.unavailable, url) +} + // IsAvailable 检查 URL 是否可用 func (u *URLAvailability) IsAvailable(url string) bool { u.mu.RLock() @@ -89,14 +97,29 @@ func (u *URLAvailability) IsAvailable(url string) bool { return time.Now().After(expiry) } -// GetAvailableURLs 返回可用的 URL 列表(保持优先级顺序) +// GetAvailableURLs 返回可用的 URL 列表 +// 最近成功的 URL 优先,其他按默认顺序 func (u *URLAvailability) GetAvailableURLs() []string { u.mu.RLock() defer u.mu.RUnlock() now := time.Now() result := make([]string, 0, len(BaseURLs)) + + // 如果有最近成功的 URL 且可用,放在最前面 + if u.lastSuccess != "" { + expiry, exists := u.unavailable[u.lastSuccess] + if !exists || now.After(expiry) { + result = append(result, u.lastSuccess) + } + } + + // 添加其他可用的 URL(按默认顺序) for _, url := range BaseURLs { + // 跳过已添加的 lastSuccess + if url == u.lastSuccess { + continue + } expiry, exists := u.unavailable[url] if !exists || now.After(expiry) { result = append(result, url) @@ -240,24 +263,3 @@ func BuildAuthorizationURL(state, codeChallenge string) string { return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode()) } - -// GenerateMockProjectID 生成随机 project_id(当 API 不返回时使用) -// 格式:{形容词}-{名词}-{5位随机字符} -func GenerateMockProjectID() string { - adjectives := []string{"useful", "bright", "swift", "calm", "bold"} - nouns := []string{"fuze", "wave", "spark", "flow", "core"} - - randBytes, _ := GenerateRandomBytes(7) - - adj := adjectives[int(randBytes[0])%len(adjectives)] - noun := nouns[int(randBytes[1])%len(nouns)] - - // 生成 5 位随机字符(a-z0-9) - const charset = "abcdefghijklmnopqrstuvwxyz0123456789" - suffix := make([]byte, 5) - for i := 0; i < 5; i++ { - suffix[i] = charset[int(randBytes[i+2])%len(charset)] - } - - return fmt.Sprintf("%s-%s-%s", adj, noun, string(suffix)) -} diff --git a/backend/internal/pkg/antigravity/request_transformer.go b/backend/internal/pkg/antigravity/request_transformer.go index a8474576..637a4ea8 100644 --- a/backend/internal/pkg/antigravity/request_transformer.go +++ b/backend/internal/pkg/antigravity/request_transformer.go @@ -54,6 +54,9 @@ func DefaultTransformOptions() TransformOptions { } } +// webSearchFallbackModel web_search 请求使用的降级模型 +const webSearchFallbackModel = "gemini-2.5-flash" + // TransformClaudeToGemini 将 Claude 请求转换为 v1internal Gemini 格式 func TransformClaudeToGemini(claudeReq *ClaudeRequest, projectID, mappedModel string) ([]byte, error) { return TransformClaudeToGeminiWithOptions(claudeReq, projectID, mappedModel, DefaultTransformOptions()) @@ -64,12 +67,23 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map // 用于存储 tool_use id -> name 映射 toolIDToName := make(map[string]string) + // 检测是否有 web_search 工具 + hasWebSearchTool := hasWebSearchTool(claudeReq.Tools) + requestType := "agent" + targetModel := mappedModel + if hasWebSearchTool { + requestType = "web_search" + if targetModel != webSearchFallbackModel { + targetModel = webSearchFallbackModel + } + } + // 检测是否启用 thinking isThinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled" // 只有 Gemini 模型支持 dummy thought workaround // Claude 模型通过 Vertex/Google API 需要有效的 thought signatures - allowDummyThought := strings.HasPrefix(mappedModel, "gemini-") + allowDummyThought := strings.HasPrefix(targetModel, "gemini-") // 1. 构建 contents contents, strippedThinking, err := buildContents(claudeReq.Messages, toolIDToName, isThinkingEnabled, allowDummyThought) @@ -78,7 +92,7 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map } // 2. 构建 systemInstruction - systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model, opts) + systemInstruction := buildSystemInstruction(claudeReq.System, claudeReq.Model, opts, claudeReq.Tools) // 3. 构建 generationConfig reqForConfig := claudeReq @@ -89,6 +103,11 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map reqCopy.Thinking = nil reqForConfig = &reqCopy } + if targetModel != "" && targetModel != reqForConfig.Model { + reqCopy := *reqForConfig + reqCopy.Model = targetModel + reqForConfig = &reqCopy + } generationConfig := buildGenerationConfig(reqForConfig) // 4. 构建 tools @@ -127,8 +146,8 @@ func TransformClaudeToGeminiWithOptions(claudeReq *ClaudeRequest, projectID, map Project: projectID, RequestID: "agent-" + uuid.New().String(), UserAgent: "antigravity", // 固定值,与官方客户端一致 - RequestType: "agent", - Model: mappedModel, + RequestType: requestType, + Model: targetModel, Request: innerRequest, } @@ -154,8 +173,40 @@ func GetDefaultIdentityPatch() string { return antigravityIdentity } -// buildSystemInstruction 构建 systemInstruction -func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions) *GeminiContent { +// mcpXMLProtocol MCP XML 工具调用协议(与 Antigravity-Manager 保持一致) +const mcpXMLProtocol = ` +==== MCP XML 工具调用协议 (Workaround) ==== +当你需要调用名称以 ` + "`mcp__`" + ` 开头的 MCP 工具时: +1) 优先尝试 XML 格式调用:输出 ` + "`{\"arg\":\"value\"}`" + `。 +2) 必须直接输出 XML 块,无需 markdown 包装,内容为 JSON 格式的入参。 +3) 这种方式具有更高的连通性和容错性,适用于大型结果返回场景。 +===========================================` + +// hasMCPTools 检测是否有 mcp__ 前缀的工具 +func hasMCPTools(tools []ClaudeTool) bool { + for _, tool := range tools { + if strings.HasPrefix(tool.Name, "mcp__") { + return true + } + } + return false +} + +// filterOpenCodePrompt 过滤 OpenCode 默认提示词,只保留用户自定义指令 +func filterOpenCodePrompt(text string) string { + if !strings.Contains(text, "You are an interactive CLI tool") { + return text + } + // 提取 "Instructions from:" 及之后的部分 + if idx := strings.Index(text, "Instructions from:"); idx >= 0 { + return text[idx:] + } + // 如果没有自定义指令,返回空 + return "" +} + +// buildSystemInstruction 构建 systemInstruction(与 Antigravity-Manager 保持一致) +func buildSystemInstruction(system json.RawMessage, modelName string, opts TransformOptions, tools []ClaudeTool) *GeminiContent { var parts []GeminiPart // 先解析用户的 system prompt,检测是否已包含 Antigravity identity @@ -167,10 +218,14 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans var sysStr string if err := json.Unmarshal(system, &sysStr); err == nil { if strings.TrimSpace(sysStr) != "" { - userSystemParts = append(userSystemParts, GeminiPart{Text: sysStr}) if strings.Contains(sysStr, "You are Antigravity") { userHasAntigravityIdentity = true } + // 过滤 OpenCode 默认提示词 + filtered := filterOpenCodePrompt(sysStr) + if filtered != "" { + userSystemParts = append(userSystemParts, GeminiPart{Text: filtered}) + } } } else { // 尝试解析为数组 @@ -178,10 +233,14 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans if err := json.Unmarshal(system, &sysBlocks); err == nil { for _, block := range sysBlocks { if block.Type == "text" && strings.TrimSpace(block.Text) != "" { - userSystemParts = append(userSystemParts, GeminiPart{Text: block.Text}) if strings.Contains(block.Text, "You are Antigravity") { userHasAntigravityIdentity = true } + // 过滤 OpenCode 默认提示词 + filtered := filterOpenCodePrompt(block.Text) + if filtered != "" { + userSystemParts = append(userSystemParts, GeminiPart{Text: filtered}) + } } } } @@ -200,6 +259,16 @@ func buildSystemInstruction(system json.RawMessage, modelName string, opts Trans // 添加用户的 system prompt parts = append(parts, userSystemParts...) + // 检测是否有 MCP 工具,如有则注入 XML 调用协议 + if hasMCPTools(tools) { + parts = append(parts, GeminiPart{Text: mcpXMLProtocol}) + } + + // 如果用户没有提供 Antigravity 身份,添加结束标记 + if !userHasAntigravityIdentity { + parts = append(parts, GeminiPart{Text: "\n--- [SYSTEM_PROMPT_END] ---"}) + } + if len(parts) == 0 { return nil } @@ -429,6 +498,11 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig { StopSequences: DefaultStopSequences, } + // 如果请求中指定了 MaxTokens,使用请求值 + if req.MaxTokens > 0 { + config.MaxOutputTokens = req.MaxTokens + } + // Thinking 配置 if req.Thinking != nil && req.Thinking.Type == "enabled" { config.ThinkingConfig = &GeminiThinkingConfig{ @@ -458,37 +532,43 @@ func buildGenerationConfig(req *ClaudeRequest) *GeminiGenerationConfig { return config } +func hasWebSearchTool(tools []ClaudeTool) bool { + for _, tool := range tools { + if isWebSearchTool(tool) { + return true + } + } + return false +} + +func isWebSearchTool(tool ClaudeTool) bool { + if strings.HasPrefix(tool.Type, "web_search") || tool.Type == "google_search" { + return true + } + + name := strings.TrimSpace(tool.Name) + switch name { + case "web_search", "google_search", "web_search_20250305": + return true + default: + return false + } +} + // buildTools 构建 tools func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { if len(tools) == 0 { return nil } - // 检查是否有 web_search 工具 - hasWebSearch := false - for _, tool := range tools { - if tool.Name == "web_search" { - hasWebSearch = true - break - } - } - - if hasWebSearch { - // Web Search 工具映射 - return []GeminiToolDeclaration{{ - GoogleSearch: &GeminiGoogleSearch{ - EnhancedContent: &GeminiEnhancedContent{ - ImageSearch: &GeminiImageSearch{ - MaxResultCount: 5, - }, - }, - }, - }} - } + hasWebSearch := hasWebSearchTool(tools) // 普通工具 var funcDecls []GeminiFunctionDecl for _, tool := range tools { + if isWebSearchTool(tool) { + continue + } // 跳过无效工具名称 if strings.TrimSpace(tool.Name) == "" { log.Printf("Warning: skipping tool with empty name") @@ -531,7 +611,20 @@ func buildTools(tools []ClaudeTool) []GeminiToolDeclaration { } if len(funcDecls) == 0 { - return nil + if !hasWebSearch { + return nil + } + + // Web Search 工具映射 + return []GeminiToolDeclaration{{ + GoogleSearch: &GeminiGoogleSearch{ + EnhancedContent: &GeminiEnhancedContent{ + ImageSearch: &GeminiImageSearch{ + MaxResultCount: 5, + }, + }, + }, + }} } return []GeminiToolDeclaration{{ diff --git a/backend/internal/pkg/antigravity/response_transformer.go b/backend/internal/pkg/antigravity/response_transformer.go index cd7f5f80..04424c03 100644 --- a/backend/internal/pkg/antigravity/response_transformer.go +++ b/backend/internal/pkg/antigravity/response_transformer.go @@ -3,6 +3,7 @@ package antigravity import ( "encoding/json" "fmt" + "strings" ) // TransformGeminiToClaude 将 Gemini 响应转换为 Claude 格式(非流式) @@ -63,6 +64,12 @@ func (p *NonStreamingProcessor) Process(geminiResp *GeminiResponse, responseID, p.processPart(&part) } + if len(geminiResp.Candidates) > 0 { + if grounding := geminiResp.Candidates[0].GroundingMetadata; grounding != nil { + p.processGrounding(grounding) + } + } + // 刷新剩余内容 p.flushThinking() p.flushText() @@ -190,6 +197,18 @@ func (p *NonStreamingProcessor) processPart(part *GeminiPart) { } } +func (p *NonStreamingProcessor) processGrounding(grounding *GeminiGroundingMetadata) { + groundingText := buildGroundingText(grounding) + if groundingText == "" { + return + } + + p.flushThinking() + p.flushText() + p.textBuilder += groundingText + p.flushText() +} + // flushText 刷新 text builder func (p *NonStreamingProcessor) flushText() { if p.textBuilder == "" { @@ -262,6 +281,44 @@ func (p *NonStreamingProcessor) buildResponse(geminiResp *GeminiResponse, respon } } +func buildGroundingText(grounding *GeminiGroundingMetadata) string { + if grounding == nil { + return "" + } + + var builder strings.Builder + + if len(grounding.WebSearchQueries) > 0 { + _, _ = builder.WriteString("\n\n---\nWeb search queries: ") + _, _ = builder.WriteString(strings.Join(grounding.WebSearchQueries, ", ")) + } + + if len(grounding.GroundingChunks) > 0 { + var links []string + for i, chunk := range grounding.GroundingChunks { + if chunk.Web == nil { + continue + } + title := strings.TrimSpace(chunk.Web.Title) + if title == "" { + title = "Source" + } + uri := strings.TrimSpace(chunk.Web.URI) + if uri == "" { + uri = "#" + } + links = append(links, fmt.Sprintf("[%d] [%s](%s)", i+1, title, uri)) + } + + if len(links) > 0 { + _, _ = builder.WriteString("\n\nSources:\n") + _, _ = builder.WriteString(strings.Join(links, "\n")) + } + } + + return builder.String() +} + // generateRandomID 生成随机 ID func generateRandomID() string { const chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" diff --git a/backend/internal/pkg/antigravity/stream_transformer.go b/backend/internal/pkg/antigravity/stream_transformer.go index 9fe68a11..da0c6f97 100644 --- a/backend/internal/pkg/antigravity/stream_transformer.go +++ b/backend/internal/pkg/antigravity/stream_transformer.go @@ -27,6 +27,8 @@ type StreamingProcessor struct { pendingSignature string trailingSignature string originalModel string + webSearchQueries []string + groundingChunks []GeminiGroundingChunk // 累计 usage inputTokens int @@ -93,6 +95,10 @@ func (p *StreamingProcessor) ProcessLine(line string) []byte { } } + if len(geminiResp.Candidates) > 0 { + p.captureGrounding(geminiResp.Candidates[0].GroundingMetadata) + } + // 检查是否结束 if len(geminiResp.Candidates) > 0 { finishReason := geminiResp.Candidates[0].FinishReason @@ -200,6 +206,20 @@ func (p *StreamingProcessor) processPart(part *GeminiPart) []byte { return result.Bytes() } +func (p *StreamingProcessor) captureGrounding(grounding *GeminiGroundingMetadata) { + if grounding == nil { + return + } + + if len(grounding.WebSearchQueries) > 0 && len(p.webSearchQueries) == 0 { + p.webSearchQueries = append([]string(nil), grounding.WebSearchQueries...) + } + + if len(grounding.GroundingChunks) > 0 && len(p.groundingChunks) == 0 { + p.groundingChunks = append([]GeminiGroundingChunk(nil), grounding.GroundingChunks...) + } +} + // processThinking 处理 thinking func (p *StreamingProcessor) processThinking(text, signature string) []byte { var result bytes.Buffer @@ -417,6 +437,23 @@ func (p *StreamingProcessor) emitFinish(finishReason string) []byte { p.trailingSignature = "" } + if len(p.webSearchQueries) > 0 || len(p.groundingChunks) > 0 { + groundingText := buildGroundingText(&GeminiGroundingMetadata{ + WebSearchQueries: p.webSearchQueries, + GroundingChunks: p.groundingChunks, + }) + if groundingText != "" { + _, _ = result.Write(p.startBlock(BlockTypeText, map[string]any{ + "type": "text", + "text": "", + })) + _, _ = result.Write(p.emitDelta("text_delta", map[string]any{ + "text": groundingText, + })) + _, _ = result.Write(p.endBlock()) + } + } + // 确定 stop_reason stopReason := "end_turn" if p.usedTool { diff --git a/backend/internal/pkg/gemini/models.go b/backend/internal/pkg/gemini/models.go index e251c8d8..424e8ddb 100644 --- a/backend/internal/pkg/gemini/models.go +++ b/backend/internal/pkg/gemini/models.go @@ -16,14 +16,11 @@ type ModelsListResponse struct { func DefaultModels() []Model { methods := []string{"generateContent", "streamGenerateContent"} return []Model{ - {Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods}, - {Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods}, - {Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods}, - {Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods}, {Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods}, - {Name: "models/gemini-1.5-pro", SupportedGenerationMethods: methods}, - {Name: "models/gemini-1.5-flash", SupportedGenerationMethods: methods}, - {Name: "models/gemini-1.5-flash-8b", SupportedGenerationMethods: methods}, + {Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods}, + {Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods}, + {Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods}, + {Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods}, } } diff --git a/backend/internal/pkg/geminicli/models.go b/backend/internal/pkg/geminicli/models.go index 922988c7..08e69886 100644 --- a/backend/internal/pkg/geminicli/models.go +++ b/backend/internal/pkg/geminicli/models.go @@ -12,10 +12,10 @@ type Model struct { // DefaultModels is the curated Gemini model list used by the admin UI "test account" flow. var DefaultModels = []Model{ {ID: "gemini-2.0-flash", Type: "model", DisplayName: "Gemini 2.0 Flash", CreatedAt: ""}, - {ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""}, {ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""}, - {ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""}, + {ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""}, {ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""}, + {ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""}, } // DefaultTestModel is the default model to preselect in test flows. diff --git a/backend/internal/pkg/oauth/oauth.go b/backend/internal/pkg/oauth/oauth.go index d29c2422..0a607dfb 100644 --- a/backend/internal/pkg/oauth/oauth.go +++ b/backend/internal/pkg/oauth/oauth.go @@ -13,20 +13,26 @@ import ( "time" ) -// Claude OAuth Constants (from CRS project) +// Claude OAuth Constants const ( // OAuth Client ID for Claude ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e" // OAuth endpoints AuthorizeURL = "https://claude.ai/oauth/authorize" - TokenURL = "https://console.anthropic.com/v1/oauth/token" - RedirectURI = "https://console.anthropic.com/oauth/code/callback" + TokenURL = "https://platform.claude.com/v1/oauth/token" + RedirectURI = "https://platform.claude.com/oauth/code/callback" - // Scopes - ScopeProfile = "user:profile" + // Scopes - Browser URL (includes org:create_api_key for user authorization) + ScopeOAuth = "org:create_api_key user:profile user:inference user:sessions:claude_code" + // Scopes - Internal API call (org:create_api_key not supported in API) + ScopeAPI = "user:profile user:inference user:sessions:claude_code" + // Scopes - Setup token (inference only) ScopeInference = "user:inference" + // Code Verifier character set (RFC 7636 compliant) + codeVerifierCharset = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-._~" + // Session TTL SessionTTL = 30 * time.Minute ) @@ -53,7 +59,6 @@ func NewSessionStore() *SessionStore { sessions: make(map[string]*OAuthSession), stopCh: make(chan struct{}), } - // Start cleanup goroutine go store.cleanup() return store } @@ -78,7 +83,6 @@ func (s *SessionStore) Get(sessionID string) (*OAuthSession, bool) { if !ok { return nil, false } - // Check if expired if time.Since(session.CreatedAt) > SessionTTL { return nil, false } @@ -122,13 +126,13 @@ func GenerateRandomBytes(n int) ([]byte, error) { return b, nil } -// GenerateState generates a random state string for OAuth +// GenerateState generates a random state string for OAuth (base64url encoded) func GenerateState() (string, error) { bytes, err := GenerateRandomBytes(32) if err != nil { return "", err } - return hex.EncodeToString(bytes), nil + return base64URLEncode(bytes), nil } // GenerateSessionID generates a unique session ID @@ -140,13 +144,30 @@ func GenerateSessionID() (string, error) { return hex.EncodeToString(bytes), nil } -// GenerateCodeVerifier generates a PKCE code verifier (32 bytes -> base64url) +// GenerateCodeVerifier generates a PKCE code verifier using character set method func GenerateCodeVerifier() (string, error) { - bytes, err := GenerateRandomBytes(32) - if err != nil { - return "", err + const targetLen = 32 + charsetLen := len(codeVerifierCharset) + limit := 256 - (256 % charsetLen) + + result := make([]byte, 0, targetLen) + randBuf := make([]byte, targetLen*2) + + for len(result) < targetLen { + if _, err := rand.Read(randBuf); err != nil { + return "", err + } + for _, b := range randBuf { + if int(b) < limit { + result = append(result, codeVerifierCharset[int(b)%charsetLen]) + if len(result) >= targetLen { + break + } + } + } } - return base64URLEncode(bytes), nil + + return base64URLEncode(result), nil } // GenerateCodeChallenge generates a PKCE code challenge using S256 method @@ -158,42 +179,31 @@ func GenerateCodeChallenge(verifier string) string { // base64URLEncode encodes bytes to base64url without padding func base64URLEncode(data []byte) string { encoded := base64.URLEncoding.EncodeToString(data) - // Remove padding return strings.TrimRight(encoded, "=") } -// BuildAuthorizationURL builds the OAuth authorization URL +// BuildAuthorizationURL builds the OAuth authorization URL with correct parameter order func BuildAuthorizationURL(state, codeChallenge, scope string) string { - params := url.Values{} - params.Set("response_type", "code") - params.Set("client_id", ClientID) - params.Set("redirect_uri", RedirectURI) - params.Set("scope", scope) - params.Set("state", state) - params.Set("code_challenge", codeChallenge) - params.Set("code_challenge_method", "S256") + encodedRedirectURI := url.QueryEscape(RedirectURI) + encodedScope := strings.ReplaceAll(url.QueryEscape(scope), "%20", "+") - return fmt.Sprintf("%s?%s", AuthorizeURL, params.Encode()) -} - -// TokenRequest represents the token exchange request body -type TokenRequest struct { - GrantType string `json:"grant_type"` - ClientID string `json:"client_id"` - Code string `json:"code"` - RedirectURI string `json:"redirect_uri"` - CodeVerifier string `json:"code_verifier"` - State string `json:"state"` + return fmt.Sprintf("%s?code=true&client_id=%s&response_type=code&redirect_uri=%s&scope=%s&code_challenge=%s&code_challenge_method=S256&state=%s", + AuthorizeURL, + ClientID, + encodedRedirectURI, + encodedScope, + codeChallenge, + state, + ) } // TokenResponse represents the token response from OAuth provider type TokenResponse struct { - AccessToken string `json:"access_token"` - TokenType string `json:"token_type"` - ExpiresIn int64 `json:"expires_in"` - RefreshToken string `json:"refresh_token,omitempty"` - Scope string `json:"scope,omitempty"` - // Organization and Account info from OAuth response + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int64 `json:"expires_in"` + RefreshToken string `json:"refresh_token,omitempty"` + Scope string `json:"scope,omitempty"` Organization *OrgInfo `json:"organization,omitempty"` Account *AccountInfo `json:"account,omitempty"` } @@ -207,31 +217,3 @@ type OrgInfo struct { type AccountInfo struct { UUID string `json:"uuid"` } - -// RefreshTokenRequest represents the refresh token request -type RefreshTokenRequest struct { - GrantType string `json:"grant_type"` - RefreshToken string `json:"refresh_token"` - ClientID string `json:"client_id"` -} - -// BuildTokenRequest creates a token exchange request -func BuildTokenRequest(code, codeVerifier, state string) *TokenRequest { - return &TokenRequest{ - GrantType: "authorization_code", - ClientID: ClientID, - Code: code, - RedirectURI: RedirectURI, - CodeVerifier: codeVerifier, - State: state, - } -} - -// BuildRefreshTokenRequest creates a refresh token request -func BuildRefreshTokenRequest(refreshToken string) *RefreshTokenRequest { - return &RefreshTokenRequest{ - GrantType: "refresh_token", - RefreshToken: refreshToken, - ClientID: ClientID, - } -} diff --git a/backend/internal/pkg/response/response.go b/backend/internal/pkg/response/response.go index a92ff9e8..43fe12d4 100644 --- a/backend/internal/pkg/response/response.go +++ b/backend/internal/pkg/response/response.go @@ -162,11 +162,11 @@ func ParsePagination(c *gin.Context) (page, pageSize int) { // 支持 page_size 和 limit 两种参数名 if ps := c.Query("page_size"); ps != "" { - if val, err := parseInt(ps); err == nil && val > 0 && val <= 100 { + if val, err := parseInt(ps); err == nil && val > 0 && val <= 1000 { pageSize = val } } else if l := c.Query("limit"); l != "" { - if val, err := parseInt(l); err == nil && val > 0 && val <= 100 { + if val, err := parseInt(l); err == nil && val > 0 && val <= 1000 { pageSize = val } } diff --git a/backend/internal/pkg/tlsfingerprint/dialer.go b/backend/internal/pkg/tlsfingerprint/dialer.go new file mode 100644 index 00000000..42510986 --- /dev/null +++ b/backend/internal/pkg/tlsfingerprint/dialer.go @@ -0,0 +1,568 @@ +// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients. +// It uses the utls library to create TLS connections that mimic Node.js/Claude Code clients. +package tlsfingerprint + +import ( + "bufio" + "context" + "encoding/base64" + "fmt" + "log/slog" + "net" + "net/http" + "net/url" + + utls "github.com/refraction-networking/utls" + "golang.org/x/net/proxy" +) + +// Profile contains TLS fingerprint configuration. +type Profile struct { + Name string // Profile name for identification + CipherSuites []uint16 + Curves []uint16 + PointFormats []uint8 + EnableGREASE bool +} + +// Dialer creates TLS connections with custom fingerprints. +type Dialer struct { + profile *Profile + baseDialer func(ctx context.Context, network, addr string) (net.Conn, error) +} + +// HTTPProxyDialer creates TLS connections through HTTP/HTTPS proxies with custom fingerprints. +// It handles the CONNECT tunnel establishment before performing TLS handshake. +type HTTPProxyDialer struct { + profile *Profile + proxyURL *url.URL +} + +// SOCKS5ProxyDialer creates TLS connections through SOCKS5 proxies with custom fingerprints. +// It uses golang.org/x/net/proxy to establish the SOCKS5 tunnel. +type SOCKS5ProxyDialer struct { + profile *Profile + proxyURL *url.URL +} + +// Default TLS fingerprint values captured from Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x) +// Captured using: tshark -i lo -f "tcp port 8443" -Y "tls.handshake.type == 1" -V +// JA3 Hash: 1a28e69016765d92e3b381168d68922c +// +// Note: JA3/JA4 may have slight variations due to: +// - Session ticket presence/absence +// - Extension negotiation state +var ( + // defaultCipherSuites contains all 59 cipher suites from Claude CLI + // Order is critical for JA3 fingerprint matching + defaultCipherSuites = []uint16{ + // TLS 1.3 cipher suites (MUST be first) + 0x1302, // TLS_AES_256_GCM_SHA384 + 0x1303, // TLS_CHACHA20_POLY1305_SHA256 + 0x1301, // TLS_AES_128_GCM_SHA256 + + // ECDHE + AES-GCM + 0xc02f, // TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 + 0xc02b, // TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 + 0xc030, // TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 + 0xc02c, // TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 + + // DHE + AES-GCM + 0x009e, // TLS_DHE_RSA_WITH_AES_128_GCM_SHA256 + + // ECDHE/DHE + AES-CBC-SHA256/384 + 0xc027, // TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 + 0x0067, // TLS_DHE_RSA_WITH_AES_128_CBC_SHA256 + 0xc028, // TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384 + 0x006b, // TLS_DHE_RSA_WITH_AES_256_CBC_SHA256 + + // DHE-DSS/RSA + AES-GCM + 0x00a3, // TLS_DHE_DSS_WITH_AES_256_GCM_SHA384 + 0x009f, // TLS_DHE_RSA_WITH_AES_256_GCM_SHA384 + + // ChaCha20-Poly1305 + 0xcca9, // TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 + 0xcca8, // TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 + 0xccaa, // TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256 + + // AES-CCM (256-bit) + 0xc0af, // TLS_ECDHE_ECDSA_WITH_AES_256_CCM_8 + 0xc0ad, // TLS_ECDHE_ECDSA_WITH_AES_256_CCM + 0xc0a3, // TLS_DHE_RSA_WITH_AES_256_CCM_8 + 0xc09f, // TLS_DHE_RSA_WITH_AES_256_CCM + + // ARIA (256-bit) + 0xc05d, // TLS_ECDHE_ECDSA_WITH_ARIA_256_GCM_SHA384 + 0xc061, // TLS_ECDHE_RSA_WITH_ARIA_256_GCM_SHA384 + 0xc057, // TLS_DHE_DSS_WITH_ARIA_256_GCM_SHA384 + 0xc053, // TLS_DHE_RSA_WITH_ARIA_256_GCM_SHA384 + + // DHE-DSS + AES-GCM (128-bit) + 0x00a2, // TLS_DHE_DSS_WITH_AES_128_GCM_SHA256 + + // AES-CCM (128-bit) + 0xc0ae, // TLS_ECDHE_ECDSA_WITH_AES_128_CCM_8 + 0xc0ac, // TLS_ECDHE_ECDSA_WITH_AES_128_CCM + 0xc0a2, // TLS_DHE_RSA_WITH_AES_128_CCM_8 + 0xc09e, // TLS_DHE_RSA_WITH_AES_128_CCM + + // ARIA (128-bit) + 0xc05c, // TLS_ECDHE_ECDSA_WITH_ARIA_128_GCM_SHA256 + 0xc060, // TLS_ECDHE_RSA_WITH_ARIA_128_GCM_SHA256 + 0xc056, // TLS_DHE_DSS_WITH_ARIA_128_GCM_SHA256 + 0xc052, // TLS_DHE_RSA_WITH_ARIA_128_GCM_SHA256 + + // ECDHE/DHE + AES-CBC-SHA384/256 (more) + 0xc024, // TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 + 0x006a, // TLS_DHE_DSS_WITH_AES_256_CBC_SHA256 + 0xc023, // TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 + 0x0040, // TLS_DHE_DSS_WITH_AES_128_CBC_SHA256 + + // ECDHE/DHE + AES-CBC-SHA (legacy) + 0xc00a, // TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA + 0xc014, // TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA + 0x0039, // TLS_DHE_RSA_WITH_AES_256_CBC_SHA + 0x0038, // TLS_DHE_DSS_WITH_AES_256_CBC_SHA + 0xc009, // TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA + 0xc013, // TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA + 0x0033, // TLS_DHE_RSA_WITH_AES_128_CBC_SHA + 0x0032, // TLS_DHE_DSS_WITH_AES_128_CBC_SHA + + // RSA + AES-GCM/CCM/ARIA (non-PFS, 256-bit) + 0x009d, // TLS_RSA_WITH_AES_256_GCM_SHA384 + 0xc0a1, // TLS_RSA_WITH_AES_256_CCM_8 + 0xc09d, // TLS_RSA_WITH_AES_256_CCM + 0xc051, // TLS_RSA_WITH_ARIA_256_GCM_SHA384 + + // RSA + AES-GCM/CCM/ARIA (non-PFS, 128-bit) + 0x009c, // TLS_RSA_WITH_AES_128_GCM_SHA256 + 0xc0a0, // TLS_RSA_WITH_AES_128_CCM_8 + 0xc09c, // TLS_RSA_WITH_AES_128_CCM + 0xc050, // TLS_RSA_WITH_ARIA_128_GCM_SHA256 + + // RSA + AES-CBC (non-PFS, legacy) + 0x003d, // TLS_RSA_WITH_AES_256_CBC_SHA256 + 0x003c, // TLS_RSA_WITH_AES_128_CBC_SHA256 + 0x0035, // TLS_RSA_WITH_AES_256_CBC_SHA + 0x002f, // TLS_RSA_WITH_AES_128_CBC_SHA + + // Renegotiation indication + 0x00ff, // TLS_EMPTY_RENEGOTIATION_INFO_SCSV + } + + // defaultCurves contains the 10 supported groups from Claude CLI (including FFDHE) + defaultCurves = []utls.CurveID{ + utls.X25519, // 0x001d + utls.CurveP256, // 0x0017 (secp256r1) + utls.CurveID(0x001e), // x448 + utls.CurveP521, // 0x0019 (secp521r1) + utls.CurveP384, // 0x0018 (secp384r1) + utls.CurveID(0x0100), // ffdhe2048 + utls.CurveID(0x0101), // ffdhe3072 + utls.CurveID(0x0102), // ffdhe4096 + utls.CurveID(0x0103), // ffdhe6144 + utls.CurveID(0x0104), // ffdhe8192 + } + + // defaultPointFormats contains all 3 point formats from Claude CLI + defaultPointFormats = []uint8{ + 0, // uncompressed + 1, // ansiX962_compressed_prime + 2, // ansiX962_compressed_char2 + } + + // defaultSignatureAlgorithms contains the 20 signature algorithms from Claude CLI + defaultSignatureAlgorithms = []utls.SignatureScheme{ + 0x0403, // ecdsa_secp256r1_sha256 + 0x0503, // ecdsa_secp384r1_sha384 + 0x0603, // ecdsa_secp521r1_sha512 + 0x0807, // ed25519 + 0x0808, // ed448 + 0x0809, // rsa_pss_pss_sha256 + 0x080a, // rsa_pss_pss_sha384 + 0x080b, // rsa_pss_pss_sha512 + 0x0804, // rsa_pss_rsae_sha256 + 0x0805, // rsa_pss_rsae_sha384 + 0x0806, // rsa_pss_rsae_sha512 + 0x0401, // rsa_pkcs1_sha256 + 0x0501, // rsa_pkcs1_sha384 + 0x0601, // rsa_pkcs1_sha512 + 0x0303, // ecdsa_sha224 + 0x0301, // rsa_pkcs1_sha224 + 0x0302, // dsa_sha224 + 0x0402, // dsa_sha256 + 0x0502, // dsa_sha384 + 0x0602, // dsa_sha512 + } +) + +// NewDialer creates a new TLS fingerprint dialer. +// baseDialer is used for TCP connection establishment (supports proxy scenarios). +// If baseDialer is nil, direct TCP dial is used. +func NewDialer(profile *Profile, baseDialer func(ctx context.Context, network, addr string) (net.Conn, error)) *Dialer { + if baseDialer == nil { + baseDialer = (&net.Dialer{}).DialContext + } + return &Dialer{profile: profile, baseDialer: baseDialer} +} + +// NewHTTPProxyDialer creates a new TLS fingerprint dialer that works through HTTP/HTTPS proxies. +// It establishes a CONNECT tunnel before performing TLS handshake with custom fingerprint. +func NewHTTPProxyDialer(profile *Profile, proxyURL *url.URL) *HTTPProxyDialer { + return &HTTPProxyDialer{profile: profile, proxyURL: proxyURL} +} + +// NewSOCKS5ProxyDialer creates a new TLS fingerprint dialer that works through SOCKS5 proxies. +// It establishes a SOCKS5 tunnel before performing TLS handshake with custom fingerprint. +func NewSOCKS5ProxyDialer(profile *Profile, proxyURL *url.URL) *SOCKS5ProxyDialer { + return &SOCKS5ProxyDialer{profile: profile, proxyURL: proxyURL} +} + +// DialTLSContext establishes a TLS connection through SOCKS5 proxy with the configured fingerprint. +// Flow: SOCKS5 CONNECT to target -> TLS handshake with utls on the tunnel +func (d *SOCKS5ProxyDialer) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) { + slog.Debug("tls_fingerprint_socks5_connecting", "proxy", d.proxyURL.Host, "target", addr) + + // Step 1: Create SOCKS5 dialer + var auth *proxy.Auth + if d.proxyURL.User != nil { + username := d.proxyURL.User.Username() + password, _ := d.proxyURL.User.Password() + auth = &proxy.Auth{ + User: username, + Password: password, + } + } + + // Determine proxy address + proxyAddr := d.proxyURL.Host + if d.proxyURL.Port() == "" { + proxyAddr = net.JoinHostPort(d.proxyURL.Hostname(), "1080") // Default SOCKS5 port + } + + socksDialer, err := proxy.SOCKS5("tcp", proxyAddr, auth, proxy.Direct) + if err != nil { + slog.Debug("tls_fingerprint_socks5_dialer_failed", "error", err) + return nil, fmt.Errorf("create SOCKS5 dialer: %w", err) + } + + // Step 2: Establish SOCKS5 tunnel to target + slog.Debug("tls_fingerprint_socks5_establishing_tunnel", "target", addr) + conn, err := socksDialer.Dial("tcp", addr) + if err != nil { + slog.Debug("tls_fingerprint_socks5_connect_failed", "error", err) + return nil, fmt.Errorf("SOCKS5 connect: %w", err) + } + slog.Debug("tls_fingerprint_socks5_tunnel_established") + + // Step 3: Perform TLS handshake on the tunnel with utls fingerprint + host, _, err := net.SplitHostPort(addr) + if err != nil { + host = addr + } + slog.Debug("tls_fingerprint_socks5_starting_handshake", "host", host) + + // Build ClientHello specification from profile (Node.js/Claude CLI fingerprint) + spec := buildClientHelloSpecFromProfile(d.profile) + slog.Debug("tls_fingerprint_socks5_clienthello_spec", + "cipher_suites", len(spec.CipherSuites), + "extensions", len(spec.Extensions), + "compression_methods", spec.CompressionMethods, + "tls_vers_max", fmt.Sprintf("0x%04x", spec.TLSVersMax), + "tls_vers_min", fmt.Sprintf("0x%04x", spec.TLSVersMin)) + + if d.profile != nil { + slog.Debug("tls_fingerprint_socks5_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE) + } + + // Create uTLS connection on the tunnel + tlsConn := utls.UClient(conn, &utls.Config{ + ServerName: host, + }, utls.HelloCustom) + + if err := tlsConn.ApplyPreset(spec); err != nil { + slog.Debug("tls_fingerprint_socks5_apply_preset_failed", "error", err) + _ = conn.Close() + return nil, fmt.Errorf("apply TLS preset: %w", err) + } + + if err := tlsConn.Handshake(); err != nil { + slog.Debug("tls_fingerprint_socks5_handshake_failed", "error", err) + _ = conn.Close() + return nil, fmt.Errorf("TLS handshake failed: %w", err) + } + + state := tlsConn.ConnectionState() + slog.Debug("tls_fingerprint_socks5_handshake_success", + "version", fmt.Sprintf("0x%04x", state.Version), + "cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite), + "alpn", state.NegotiatedProtocol) + + return tlsConn, nil +} + +// DialTLSContext establishes a TLS connection through HTTP proxy with the configured fingerprint. +// Flow: TCP connect to proxy -> CONNECT tunnel -> TLS handshake with utls +func (d *HTTPProxyDialer) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) { + slog.Debug("tls_fingerprint_http_proxy_connecting", "proxy", d.proxyURL.Host, "target", addr) + + // Step 1: TCP connect to proxy server + var proxyAddr string + if d.proxyURL.Port() != "" { + proxyAddr = d.proxyURL.Host + } else { + // Default ports + if d.proxyURL.Scheme == "https" { + proxyAddr = net.JoinHostPort(d.proxyURL.Hostname(), "443") + } else { + proxyAddr = net.JoinHostPort(d.proxyURL.Hostname(), "80") + } + } + + dialer := &net.Dialer{} + conn, err := dialer.DialContext(ctx, "tcp", proxyAddr) + if err != nil { + slog.Debug("tls_fingerprint_http_proxy_connect_failed", "error", err) + return nil, fmt.Errorf("connect to proxy: %w", err) + } + slog.Debug("tls_fingerprint_http_proxy_connected", "proxy_addr", proxyAddr) + + // Step 2: Send CONNECT request to establish tunnel + req := &http.Request{ + Method: "CONNECT", + URL: &url.URL{Opaque: addr}, + Host: addr, + Header: make(http.Header), + } + + // Add proxy authentication if present + if d.proxyURL.User != nil { + username := d.proxyURL.User.Username() + password, _ := d.proxyURL.User.Password() + auth := base64.StdEncoding.EncodeToString([]byte(username + ":" + password)) + req.Header.Set("Proxy-Authorization", "Basic "+auth) + } + + slog.Debug("tls_fingerprint_http_proxy_sending_connect", "target", addr) + if err := req.Write(conn); err != nil { + _ = conn.Close() + slog.Debug("tls_fingerprint_http_proxy_write_failed", "error", err) + return nil, fmt.Errorf("write CONNECT request: %w", err) + } + + // Step 3: Read CONNECT response + br := bufio.NewReader(conn) + resp, err := http.ReadResponse(br, req) + if err != nil { + _ = conn.Close() + slog.Debug("tls_fingerprint_http_proxy_read_response_failed", "error", err) + return nil, fmt.Errorf("read CONNECT response: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + _ = conn.Close() + slog.Debug("tls_fingerprint_http_proxy_connect_failed_status", "status_code", resp.StatusCode, "status", resp.Status) + return nil, fmt.Errorf("proxy CONNECT failed: %s", resp.Status) + } + slog.Debug("tls_fingerprint_http_proxy_tunnel_established") + + // Step 4: Perform TLS handshake on the tunnel with utls fingerprint + host, _, err := net.SplitHostPort(addr) + if err != nil { + host = addr + } + slog.Debug("tls_fingerprint_http_proxy_starting_handshake", "host", host) + + // Build ClientHello specification (reuse the shared method) + spec := buildClientHelloSpecFromProfile(d.profile) + slog.Debug("tls_fingerprint_http_proxy_clienthello_spec", + "cipher_suites", len(spec.CipherSuites), + "extensions", len(spec.Extensions)) + + if d.profile != nil { + slog.Debug("tls_fingerprint_http_proxy_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE) + } + + // Create uTLS connection on the tunnel + // Note: TLS 1.3 cipher suites are handled automatically by utls when TLS 1.3 is in SupportedVersions + tlsConn := utls.UClient(conn, &utls.Config{ + ServerName: host, + }, utls.HelloCustom) + + if err := tlsConn.ApplyPreset(spec); err != nil { + slog.Debug("tls_fingerprint_http_proxy_apply_preset_failed", "error", err) + _ = conn.Close() + return nil, fmt.Errorf("apply TLS preset: %w", err) + } + + if err := tlsConn.HandshakeContext(ctx); err != nil { + slog.Debug("tls_fingerprint_http_proxy_handshake_failed", "error", err) + _ = conn.Close() + return nil, fmt.Errorf("TLS handshake failed: %w", err) + } + + state := tlsConn.ConnectionState() + slog.Debug("tls_fingerprint_http_proxy_handshake_success", + "version", fmt.Sprintf("0x%04x", state.Version), + "cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite), + "alpn", state.NegotiatedProtocol) + + return tlsConn, nil +} + +// DialTLSContext establishes a TLS connection with the configured fingerprint. +// This method is designed to be used as http.Transport.DialTLSContext. +func (d *Dialer) DialTLSContext(ctx context.Context, network, addr string) (net.Conn, error) { + // Establish TCP connection using base dialer (supports proxy) + slog.Debug("tls_fingerprint_dialing_tcp", "addr", addr) + conn, err := d.baseDialer(ctx, network, addr) + if err != nil { + slog.Debug("tls_fingerprint_tcp_dial_failed", "error", err) + return nil, err + } + slog.Debug("tls_fingerprint_tcp_connected", "addr", addr) + + // Extract hostname for SNI + host, _, err := net.SplitHostPort(addr) + if err != nil { + host = addr + } + slog.Debug("tls_fingerprint_sni_hostname", "host", host) + + // Build ClientHello specification + spec := d.buildClientHelloSpec() + slog.Debug("tls_fingerprint_clienthello_spec", + "cipher_suites", len(spec.CipherSuites), + "extensions", len(spec.Extensions)) + + // Log profile info + if d.profile != nil { + slog.Debug("tls_fingerprint_using_profile", "name", d.profile.Name, "grease", d.profile.EnableGREASE) + } else { + slog.Debug("tls_fingerprint_using_default_profile") + } + + // Create uTLS connection + // Note: TLS 1.3 cipher suites are handled automatically by utls when TLS 1.3 is in SupportedVersions + tlsConn := utls.UClient(conn, &utls.Config{ + ServerName: host, + }, utls.HelloCustom) + + // Apply fingerprint + if err := tlsConn.ApplyPreset(spec); err != nil { + slog.Debug("tls_fingerprint_apply_preset_failed", "error", err) + _ = conn.Close() + return nil, err + } + slog.Debug("tls_fingerprint_preset_applied") + + // Perform TLS handshake + if err := tlsConn.HandshakeContext(ctx); err != nil { + slog.Debug("tls_fingerprint_handshake_failed", + "error", err, + "local_addr", conn.LocalAddr(), + "remote_addr", conn.RemoteAddr()) + _ = conn.Close() + return nil, fmt.Errorf("TLS handshake failed: %w", err) + } + + // Log successful handshake details + state := tlsConn.ConnectionState() + slog.Debug("tls_fingerprint_handshake_success", + "version", fmt.Sprintf("0x%04x", state.Version), + "cipher_suite", fmt.Sprintf("0x%04x", state.CipherSuite), + "alpn", state.NegotiatedProtocol) + + return tlsConn, nil +} + +// buildClientHelloSpec constructs the ClientHello specification based on the profile. +func (d *Dialer) buildClientHelloSpec() *utls.ClientHelloSpec { + return buildClientHelloSpecFromProfile(d.profile) +} + +// toUTLSCurves converts uint16 slice to utls.CurveID slice. +func toUTLSCurves(curves []uint16) []utls.CurveID { + result := make([]utls.CurveID, len(curves)) + for i, c := range curves { + result[i] = utls.CurveID(c) + } + return result +} + +// buildClientHelloSpecFromProfile constructs ClientHelloSpec from a Profile. +// This is a standalone function that can be used by both Dialer and HTTPProxyDialer. +func buildClientHelloSpecFromProfile(profile *Profile) *utls.ClientHelloSpec { + // Get cipher suites + var cipherSuites []uint16 + if profile != nil && len(profile.CipherSuites) > 0 { + cipherSuites = profile.CipherSuites + } else { + cipherSuites = defaultCipherSuites + } + + // Get curves + var curves []utls.CurveID + if profile != nil && len(profile.Curves) > 0 { + curves = toUTLSCurves(profile.Curves) + } else { + curves = defaultCurves + } + + // Get point formats + var pointFormats []uint8 + if profile != nil && len(profile.PointFormats) > 0 { + pointFormats = profile.PointFormats + } else { + pointFormats = defaultPointFormats + } + + // Check if GREASE is enabled + enableGREASE := profile != nil && profile.EnableGREASE + + extensions := make([]utls.TLSExtension, 0, 16) + + if enableGREASE { + extensions = append(extensions, &utls.UtlsGREASEExtension{}) + } + + // SNI extension - MUST be explicitly added for HelloCustom mode + // utls will populate the server name from Config.ServerName + extensions = append(extensions, &utls.SNIExtension{}) + + // Claude CLI extension order (captured from tshark): + // server_name(0), ec_point_formats(11), supported_groups(10), session_ticket(35), + // alpn(16), encrypt_then_mac(22), extended_master_secret(23), + // signature_algorithms(13), supported_versions(43), + // psk_key_exchange_modes(45), key_share(51) + extensions = append(extensions, + &utls.SupportedPointsExtension{SupportedPoints: pointFormats}, + &utls.SupportedCurvesExtension{Curves: curves}, + &utls.SessionTicketExtension{}, + &utls.ALPNExtension{AlpnProtocols: []string{"http/1.1"}}, + &utls.GenericExtension{Id: 22}, + &utls.ExtendedMasterSecretExtension{}, + &utls.SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: defaultSignatureAlgorithms}, + &utls.SupportedVersionsExtension{Versions: []uint16{ + utls.VersionTLS13, + utls.VersionTLS12, + }}, + &utls.PSKKeyExchangeModesExtension{Modes: []uint8{utls.PskModeDHE}}, + &utls.KeyShareExtension{KeyShares: []utls.KeyShare{ + {Group: utls.X25519}, + }}, + ) + + if enableGREASE { + extensions = append(extensions, &utls.UtlsGREASEExtension{}) + } + + return &utls.ClientHelloSpec{ + CipherSuites: cipherSuites, + CompressionMethods: []uint8{0}, // null compression only (standard) + Extensions: extensions, + TLSVersMax: utls.VersionTLS13, + TLSVersMin: utls.VersionTLS10, + } +} diff --git a/backend/internal/pkg/tlsfingerprint/dialer_test.go b/backend/internal/pkg/tlsfingerprint/dialer_test.go new file mode 100644 index 00000000..2aed1287 --- /dev/null +++ b/backend/internal/pkg/tlsfingerprint/dialer_test.go @@ -0,0 +1,307 @@ +// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients. +// +// Integration tests for verifying TLS fingerprint correctness. +// These tests make actual network requests and should be run manually. +// +// Run with: go test -v ./internal/pkg/tlsfingerprint/... +// Run integration tests: go test -v -run TestJA3 ./internal/pkg/tlsfingerprint/... +package tlsfingerprint + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/url" + "strings" + "testing" + "time" +) + +// FingerprintResponse represents the response from tls.peet.ws/api/all. +type FingerprintResponse struct { + IP string `json:"ip"` + TLS TLSInfo `json:"tls"` + HTTP2 any `json:"http2"` +} + +// TLSInfo contains TLS fingerprint details. +type TLSInfo struct { + JA3 string `json:"ja3"` + JA3Hash string `json:"ja3_hash"` + JA4 string `json:"ja4"` + PeetPrint string `json:"peetprint"` + PeetPrintHash string `json:"peetprint_hash"` + ClientRandom string `json:"client_random"` + SessionID string `json:"session_id"` +} + +// TestDialerBasicConnection tests that the dialer can establish TLS connections. +func TestDialerBasicConnection(t *testing.T) { + if testing.Short() { + t.Skip("skipping network test in short mode") + } + + // Create a dialer with default profile + profile := &Profile{ + Name: "Test Profile", + EnableGREASE: false, + } + dialer := NewDialer(profile, nil) + + // Create HTTP client with custom TLS dialer + client := &http.Client{ + Transport: &http.Transport{ + DialTLSContext: dialer.DialTLSContext, + }, + Timeout: 30 * time.Second, + } + + // Make a request to a known HTTPS endpoint + resp, err := client.Get("https://www.google.com") + if err != nil { + t.Fatalf("failed to connect: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected status 200, got %d", resp.StatusCode) + } +} + +// TestJA3Fingerprint verifies the JA3/JA4 fingerprint matches expected value. +// This test uses tls.peet.ws to verify the fingerprint. +// Expected JA3 hash: 1a28e69016765d92e3b381168d68922c (Claude CLI / Node.js 20.x) +// Expected JA4: t13d5911h1_a33745022dd6_1f22a2ca17c4 (d=domain) or t13i5911h1_... (i=IP) +func TestJA3Fingerprint(t *testing.T) { + // Skip if network is unavailable or if running in short mode + if testing.Short() { + t.Skip("skipping integration test in short mode") + } + + profile := &Profile{ + Name: "Claude CLI Test", + EnableGREASE: false, + } + dialer := NewDialer(profile, nil) + + client := &http.Client{ + Transport: &http.Transport{ + DialTLSContext: dialer.DialTLSContext, + }, + Timeout: 30 * time.Second, + } + + // Use tls.peet.ws fingerprint detection API + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", "https://tls.peet.ws/api/all", nil) + if err != nil { + t.Fatalf("failed to create request: %v", err) + } + req.Header.Set("User-Agent", "Claude Code/2.0.0 Node.js/20.0.0") + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("failed to get fingerprint: %v", err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("failed to read response: %v", err) + } + + var fpResp FingerprintResponse + if err := json.Unmarshal(body, &fpResp); err != nil { + t.Logf("Response body: %s", string(body)) + t.Fatalf("failed to parse fingerprint response: %v", err) + } + + // Log all fingerprint information + t.Logf("JA3: %s", fpResp.TLS.JA3) + t.Logf("JA3 Hash: %s", fpResp.TLS.JA3Hash) + t.Logf("JA4: %s", fpResp.TLS.JA4) + t.Logf("PeetPrint: %s", fpResp.TLS.PeetPrint) + t.Logf("PeetPrint Hash: %s", fpResp.TLS.PeetPrintHash) + + // Verify JA3 hash matches expected value + expectedJA3Hash := "1a28e69016765d92e3b381168d68922c" + if fpResp.TLS.JA3Hash == expectedJA3Hash { + t.Logf("✓ JA3 hash matches expected value: %s", expectedJA3Hash) + } else { + t.Errorf("✗ JA3 hash mismatch: got %s, expected %s", fpResp.TLS.JA3Hash, expectedJA3Hash) + } + + // Verify JA4 fingerprint + // JA4 format: t[version][sni][cipher_count][ext_count][alpn]_[cipher_hash]_[ext_hash] + // Expected: t13d5910h1 (d=domain) or t13i5910h1 (i=IP) + // The suffix _a33745022dd6_1f22a2ca17c4 should match + expectedJA4Suffix := "_a33745022dd6_1f22a2ca17c4" + if strings.HasSuffix(fpResp.TLS.JA4, expectedJA4Suffix) { + t.Logf("✓ JA4 suffix matches expected value: %s", expectedJA4Suffix) + } else { + t.Errorf("✗ JA4 suffix mismatch: got %s, expected suffix %s", fpResp.TLS.JA4, expectedJA4Suffix) + } + + // Verify JA4 prefix (t13d5911h1 or t13i5911h1) + // d = domain (SNI present), i = IP (no SNI) + // Since we connect to tls.peet.ws (domain), we expect 'd' + expectedJA4Prefix := "t13d5911h1" + if strings.HasPrefix(fpResp.TLS.JA4, expectedJA4Prefix) { + t.Logf("✓ JA4 prefix matches: %s (t13=TLS1.3, d=domain, 59=ciphers, 11=extensions, h1=HTTP/1.1)", expectedJA4Prefix) + } else { + // Also accept 'i' variant for IP connections + altPrefix := "t13i5911h1" + if strings.HasPrefix(fpResp.TLS.JA4, altPrefix) { + t.Logf("✓ JA4 prefix matches (IP variant): %s", altPrefix) + } else { + t.Errorf("✗ JA4 prefix mismatch: got %s, expected %s or %s", fpResp.TLS.JA4, expectedJA4Prefix, altPrefix) + } + } + + // Verify JA3 contains expected cipher suites (TLS 1.3 ciphers at the beginning) + if strings.Contains(fpResp.TLS.JA3, "4866-4867-4865") { + t.Logf("✓ JA3 contains expected TLS 1.3 cipher suites") + } else { + t.Logf("Warning: JA3 does not contain expected TLS 1.3 cipher suites") + } + + // Verify extension list (should be 11 extensions including SNI) + // Expected: 0-11-10-35-16-22-23-13-43-45-51 + expectedExtensions := "0-11-10-35-16-22-23-13-43-45-51" + if strings.Contains(fpResp.TLS.JA3, expectedExtensions) { + t.Logf("✓ JA3 contains expected extension list: %s", expectedExtensions) + } else { + t.Logf("Warning: JA3 extension list may differ") + } +} + +// TestDialerWithProfile tests that different profiles produce different fingerprints. +func TestDialerWithProfile(t *testing.T) { + // Create two dialers with different profiles + profile1 := &Profile{ + Name: "Profile 1 - No GREASE", + EnableGREASE: false, + } + profile2 := &Profile{ + Name: "Profile 2 - With GREASE", + EnableGREASE: true, + } + + dialer1 := NewDialer(profile1, nil) + dialer2 := NewDialer(profile2, nil) + + // Build specs and compare + // Note: We can't directly compare JA3 without making network requests + // but we can verify the specs are different + spec1 := dialer1.buildClientHelloSpec() + spec2 := dialer2.buildClientHelloSpec() + + // Profile with GREASE should have more extensions + if len(spec2.Extensions) <= len(spec1.Extensions) { + t.Error("expected GREASE profile to have more extensions") + } +} + +// TestHTTPProxyDialerBasic tests HTTP proxy dialer creation. +// Note: This is a unit test - actual proxy testing requires a proxy server. +func TestHTTPProxyDialerBasic(t *testing.T) { + profile := &Profile{ + Name: "Test Profile", + EnableGREASE: false, + } + + // Test that dialer is created without panic + proxyURL := mustParseURL("http://proxy.example.com:8080") + dialer := NewHTTPProxyDialer(profile, proxyURL) + + if dialer == nil { + t.Fatal("expected dialer to be created") + } + if dialer.profile != profile { + t.Error("expected profile to be set") + } + if dialer.proxyURL != proxyURL { + t.Error("expected proxyURL to be set") + } +} + +// TestSOCKS5ProxyDialerBasic tests SOCKS5 proxy dialer creation. +// Note: This is a unit test - actual proxy testing requires a proxy server. +func TestSOCKS5ProxyDialerBasic(t *testing.T) { + profile := &Profile{ + Name: "Test Profile", + EnableGREASE: false, + } + + // Test that dialer is created without panic + proxyURL := mustParseURL("socks5://proxy.example.com:1080") + dialer := NewSOCKS5ProxyDialer(profile, proxyURL) + + if dialer == nil { + t.Fatal("expected dialer to be created") + } + if dialer.profile != profile { + t.Error("expected profile to be set") + } + if dialer.proxyURL != proxyURL { + t.Error("expected proxyURL to be set") + } +} + +// TestBuildClientHelloSpec tests ClientHello spec construction. +func TestBuildClientHelloSpec(t *testing.T) { + // Test with nil profile (should use defaults) + spec := buildClientHelloSpecFromProfile(nil) + + if len(spec.CipherSuites) == 0 { + t.Error("expected cipher suites to be set") + } + if len(spec.Extensions) == 0 { + t.Error("expected extensions to be set") + } + + // Verify default cipher suites are used + if len(spec.CipherSuites) != len(defaultCipherSuites) { + t.Errorf("expected %d cipher suites, got %d", len(defaultCipherSuites), len(spec.CipherSuites)) + } + + // Test with custom profile + customProfile := &Profile{ + Name: "Custom", + EnableGREASE: false, + CipherSuites: []uint16{0x1301, 0x1302}, + } + spec = buildClientHelloSpecFromProfile(customProfile) + + if len(spec.CipherSuites) != 2 { + t.Errorf("expected 2 cipher suites, got %d", len(spec.CipherSuites)) + } +} + +// TestToUTLSCurves tests curve ID conversion. +func TestToUTLSCurves(t *testing.T) { + input := []uint16{0x001d, 0x0017, 0x0018} + result := toUTLSCurves(input) + + if len(result) != len(input) { + t.Errorf("expected %d curves, got %d", len(input), len(result)) + } + + for i, curve := range result { + if uint16(curve) != input[i] { + t.Errorf("curve %d: expected 0x%04x, got 0x%04x", i, input[i], uint16(curve)) + } + } +} + +// Helper function to parse URL without error handling. +func mustParseURL(rawURL string) *url.URL { + u, err := url.Parse(rawURL) + if err != nil { + panic(err) + } + return u +} diff --git a/backend/internal/pkg/tlsfingerprint/registry.go b/backend/internal/pkg/tlsfingerprint/registry.go new file mode 100644 index 00000000..6e9dc539 --- /dev/null +++ b/backend/internal/pkg/tlsfingerprint/registry.go @@ -0,0 +1,171 @@ +// Package tlsfingerprint provides TLS fingerprint simulation for HTTP clients. +package tlsfingerprint + +import ( + "log/slog" + "sort" + "sync" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +// DefaultProfileName is the name of the built-in Claude CLI profile. +const DefaultProfileName = "claude_cli_v2" + +// Registry manages TLS fingerprint profiles. +// It holds a collection of profiles that can be used for TLS fingerprint simulation. +// Profiles are selected based on account ID using modulo operation. +type Registry struct { + mu sync.RWMutex + profiles map[string]*Profile + profileNames []string // Sorted list of profile names for deterministic selection +} + +// NewRegistry creates a new TLS fingerprint profile registry. +// It initializes with the built-in default profile. +func NewRegistry() *Registry { + r := &Registry{ + profiles: make(map[string]*Profile), + profileNames: make([]string, 0), + } + + // Register the built-in default profile + r.registerBuiltinProfile() + + return r +} + +// NewRegistryFromConfig creates a new registry and loads profiles from config. +// If the config has custom profiles defined, they will be merged with the built-in default. +func NewRegistryFromConfig(cfg *config.TLSFingerprintConfig) *Registry { + r := NewRegistry() + + if cfg == nil || !cfg.Enabled { + slog.Debug("tls_registry_disabled", "reason", "disabled or no config") + return r + } + + // Load custom profiles from config + for name, profileCfg := range cfg.Profiles { + profile := &Profile{ + Name: profileCfg.Name, + EnableGREASE: profileCfg.EnableGREASE, + CipherSuites: profileCfg.CipherSuites, + Curves: profileCfg.Curves, + PointFormats: profileCfg.PointFormats, + } + + // If the profile has empty values, they will use defaults in dialer + r.RegisterProfile(name, profile) + slog.Debug("tls_registry_loaded_profile", "key", name, "name", profileCfg.Name) + } + + slog.Debug("tls_registry_initialized", "profile_count", len(r.profileNames), "profiles", r.profileNames) + return r +} + +// registerBuiltinProfile adds the default Claude CLI profile to the registry. +func (r *Registry) registerBuiltinProfile() { + defaultProfile := &Profile{ + Name: "Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)", + EnableGREASE: false, // Node.js does not use GREASE + // Empty slices will cause dialer to use built-in defaults + CipherSuites: nil, + Curves: nil, + PointFormats: nil, + } + r.RegisterProfile(DefaultProfileName, defaultProfile) +} + +// RegisterProfile adds or updates a profile in the registry. +func (r *Registry) RegisterProfile(name string, profile *Profile) { + r.mu.Lock() + defer r.mu.Unlock() + + // Check if this is a new profile + _, exists := r.profiles[name] + r.profiles[name] = profile + + if !exists { + r.profileNames = append(r.profileNames, name) + // Keep names sorted for deterministic selection + sort.Strings(r.profileNames) + } +} + +// GetProfile returns a profile by name. +// Returns nil if the profile does not exist. +func (r *Registry) GetProfile(name string) *Profile { + r.mu.RLock() + defer r.mu.RUnlock() + return r.profiles[name] +} + +// GetDefaultProfile returns the built-in default profile. +func (r *Registry) GetDefaultProfile() *Profile { + return r.GetProfile(DefaultProfileName) +} + +// GetProfileByAccountID returns a profile for the given account ID. +// The profile is selected using: profileNames[accountID % len(profiles)] +// This ensures deterministic profile assignment for each account. +func (r *Registry) GetProfileByAccountID(accountID int64) *Profile { + r.mu.RLock() + defer r.mu.RUnlock() + + if len(r.profileNames) == 0 { + return nil + } + + // Use modulo to select profile index + // Use absolute value to handle negative IDs (though unlikely) + idx := accountID + if idx < 0 { + idx = -idx + } + selectedIndex := int(idx % int64(len(r.profileNames))) + selectedName := r.profileNames[selectedIndex] + + return r.profiles[selectedName] +} + +// ProfileCount returns the number of registered profiles. +func (r *Registry) ProfileCount() int { + r.mu.RLock() + defer r.mu.RUnlock() + return len(r.profiles) +} + +// ProfileNames returns a sorted list of all registered profile names. +func (r *Registry) ProfileNames() []string { + r.mu.RLock() + defer r.mu.RUnlock() + + // Return a copy to prevent modification + names := make([]string, len(r.profileNames)) + copy(names, r.profileNames) + return names +} + +// Global registry instance for convenience +var globalRegistry *Registry +var globalRegistryOnce sync.Once + +// GlobalRegistry returns the global TLS fingerprint registry. +// The registry is lazily initialized with the default profile. +func GlobalRegistry() *Registry { + globalRegistryOnce.Do(func() { + globalRegistry = NewRegistry() + }) + return globalRegistry +} + +// InitGlobalRegistry initializes the global registry with configuration. +// This should be called during application startup. +// It is safe to call multiple times; subsequent calls will update the registry. +func InitGlobalRegistry(cfg *config.TLSFingerprintConfig) *Registry { + globalRegistryOnce.Do(func() { + globalRegistry = NewRegistryFromConfig(cfg) + }) + return globalRegistry +} diff --git a/backend/internal/pkg/tlsfingerprint/registry_test.go b/backend/internal/pkg/tlsfingerprint/registry_test.go new file mode 100644 index 00000000..752ba0cc --- /dev/null +++ b/backend/internal/pkg/tlsfingerprint/registry_test.go @@ -0,0 +1,243 @@ +package tlsfingerprint + +import ( + "testing" + + "github.com/Wei-Shaw/sub2api/internal/config" +) + +func TestNewRegistry(t *testing.T) { + r := NewRegistry() + + // Should have exactly one profile (the default) + if r.ProfileCount() != 1 { + t.Errorf("expected 1 profile, got %d", r.ProfileCount()) + } + + // Should have the default profile + profile := r.GetDefaultProfile() + if profile == nil { + t.Error("expected default profile to exist") + } + + // Default profile name should be in the list + names := r.ProfileNames() + if len(names) != 1 || names[0] != DefaultProfileName { + t.Errorf("expected profile names to be [%s], got %v", DefaultProfileName, names) + } +} + +func TestRegisterProfile(t *testing.T) { + r := NewRegistry() + + // Register a new profile + customProfile := &Profile{ + Name: "Custom Profile", + EnableGREASE: true, + } + r.RegisterProfile("custom", customProfile) + + // Should now have 2 profiles + if r.ProfileCount() != 2 { + t.Errorf("expected 2 profiles, got %d", r.ProfileCount()) + } + + // Should be able to retrieve the custom profile + retrieved := r.GetProfile("custom") + if retrieved == nil { + t.Fatal("expected custom profile to exist") + } + if retrieved.Name != "Custom Profile" { + t.Errorf("expected profile name 'Custom Profile', got '%s'", retrieved.Name) + } + if !retrieved.EnableGREASE { + t.Error("expected EnableGREASE to be true") + } +} + +func TestGetProfile(t *testing.T) { + r := NewRegistry() + + // Get existing profile + profile := r.GetProfile(DefaultProfileName) + if profile == nil { + t.Error("expected default profile to exist") + } + + // Get non-existing profile + nonExistent := r.GetProfile("nonexistent") + if nonExistent != nil { + t.Error("expected nil for non-existent profile") + } +} + +func TestGetProfileByAccountID(t *testing.T) { + r := NewRegistry() + + // With only default profile, all account IDs should return the same profile + for i := int64(0); i < 10; i++ { + profile := r.GetProfileByAccountID(i) + if profile == nil { + t.Errorf("expected profile for account %d, got nil", i) + } + } + + // Add more profiles + r.RegisterProfile("profile_a", &Profile{Name: "Profile A"}) + r.RegisterProfile("profile_b", &Profile{Name: "Profile B"}) + + // Now we have 3 profiles: claude_cli_v2, profile_a, profile_b + // Names are sorted, so order is: claude_cli_v2, profile_a, profile_b + expectedOrder := []string{DefaultProfileName, "profile_a", "profile_b"} + names := r.ProfileNames() + for i, name := range expectedOrder { + if names[i] != name { + t.Errorf("expected name at index %d to be %s, got %s", i, name, names[i]) + } + } + + // Test modulo selection + // Account ID 0 % 3 = 0 -> claude_cli_v2 + // Account ID 1 % 3 = 1 -> profile_a + // Account ID 2 % 3 = 2 -> profile_b + // Account ID 3 % 3 = 0 -> claude_cli_v2 + testCases := []struct { + accountID int64 + expectedName string + }{ + {0, "Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)"}, + {1, "Profile A"}, + {2, "Profile B"}, + {3, "Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)"}, + {4, "Profile A"}, + {5, "Profile B"}, + {100, "Profile A"}, // 100 % 3 = 1 + {-1, "Profile A"}, // |-1| % 3 = 1 + {-3, "Claude CLI 2.x (Node.js 20.x + OpenSSL 3.x)"}, // |-3| % 3 = 0 + } + + for _, tc := range testCases { + profile := r.GetProfileByAccountID(tc.accountID) + if profile == nil { + t.Errorf("expected profile for account %d, got nil", tc.accountID) + continue + } + if profile.Name != tc.expectedName { + t.Errorf("account %d: expected profile name '%s', got '%s'", tc.accountID, tc.expectedName, profile.Name) + } + } +} + +func TestNewRegistryFromConfig(t *testing.T) { + // Test with nil config + r := NewRegistryFromConfig(nil) + if r.ProfileCount() != 1 { + t.Errorf("expected 1 profile with nil config, got %d", r.ProfileCount()) + } + + // Test with disabled config + disabledCfg := &config.TLSFingerprintConfig{ + Enabled: false, + } + r = NewRegistryFromConfig(disabledCfg) + if r.ProfileCount() != 1 { + t.Errorf("expected 1 profile with disabled config, got %d", r.ProfileCount()) + } + + // Test with enabled config and custom profiles + enabledCfg := &config.TLSFingerprintConfig{ + Enabled: true, + Profiles: map[string]config.TLSProfileConfig{ + "custom1": { + Name: "Custom Profile 1", + EnableGREASE: true, + }, + "custom2": { + Name: "Custom Profile 2", + EnableGREASE: false, + }, + }, + } + r = NewRegistryFromConfig(enabledCfg) + + // Should have 3 profiles: default + 2 custom + if r.ProfileCount() != 3 { + t.Errorf("expected 3 profiles, got %d", r.ProfileCount()) + } + + // Check custom profiles exist + custom1 := r.GetProfile("custom1") + if custom1 == nil || custom1.Name != "Custom Profile 1" { + t.Error("expected custom1 profile to exist with correct name") + } + custom2 := r.GetProfile("custom2") + if custom2 == nil || custom2.Name != "Custom Profile 2" { + t.Error("expected custom2 profile to exist with correct name") + } +} + +func TestProfileNames(t *testing.T) { + r := NewRegistry() + + // Add profiles in non-alphabetical order + r.RegisterProfile("zebra", &Profile{Name: "Zebra"}) + r.RegisterProfile("alpha", &Profile{Name: "Alpha"}) + r.RegisterProfile("beta", &Profile{Name: "Beta"}) + + names := r.ProfileNames() + + // Should be sorted alphabetically + expected := []string{"alpha", "beta", DefaultProfileName, "zebra"} + if len(names) != len(expected) { + t.Errorf("expected %d names, got %d", len(expected), len(names)) + } + for i, name := range expected { + if names[i] != name { + t.Errorf("expected name at index %d to be %s, got %s", i, name, names[i]) + } + } + + // Test that returned slice is a copy (modifying it shouldn't affect registry) + names[0] = "modified" + originalNames := r.ProfileNames() + if originalNames[0] == "modified" { + t.Error("modifying returned slice should not affect registry") + } +} + +func TestConcurrentAccess(t *testing.T) { + r := NewRegistry() + + // Run concurrent reads and writes + done := make(chan bool) + + // Writers + for i := 0; i < 10; i++ { + go func(id int) { + for j := 0; j < 100; j++ { + r.RegisterProfile("concurrent"+string(rune('0'+id)), &Profile{Name: "Concurrent"}) + } + done <- true + }(i) + } + + // Readers + for i := 0; i < 10; i++ { + go func(id int) { + for j := 0; j < 100; j++ { + _ = r.ProfileCount() + _ = r.ProfileNames() + _ = r.GetProfileByAccountID(int64(id * j)) + _ = r.GetProfile(DefaultProfileName) + } + done <- true + }(i) + } + + // Wait for all goroutines + for i := 0; i < 20; i++ { + <-done + } + + // Test should pass without data races (run with -race flag) +} diff --git a/backend/internal/repository/account_repo.go b/backend/internal/repository/account_repo.go index f7725820..c2673ad3 100644 --- a/backend/internal/repository/account_repo.go +++ b/backend/internal/repository/account_repo.go @@ -543,6 +543,15 @@ func (r *accountRepository) SetError(ctx context.Context, id int64, errorMsg str return nil } +func (r *accountRepository) ClearError(ctx context.Context, id int64) error { + _, err := r.client.Account.Update(). + Where(dbaccount.IDEQ(id)). + SetStatus(service.StatusActive). + SetErrorMessage(""). + Save(ctx) + return err +} + func (r *accountRepository) AddToGroup(ctx context.Context, accountID, groupID int64, priority int) error { _, err := r.client.AccountGroup.Create(). SetAccountID(accountID). @@ -960,7 +969,16 @@ func (r *accountRepository) UpdateSessionWindow(ctx context.Context, id int64, s builder.SetSessionWindowEnd(*end) } _, err := builder.Save(ctx) - return err + if err != nil { + return err + } + // 触发调度器缓存更新(仅当窗口时间有变化时) + if start != nil || end != nil { + if err := enqueueSchedulerOutbox(ctx, r.sql, service.SchedulerOutboxEventAccountChanged, &id, nil, nil); err != nil { + log.Printf("[SchedulerOutbox] enqueue session window update failed: account=%d err=%v", id, err) + } + } + return nil } func (r *accountRepository) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { diff --git a/backend/internal/repository/api_key_cache.go b/backend/internal/repository/api_key_cache.go index 6d834b40..a1072057 100644 --- a/backend/internal/repository/api_key_cache.go +++ b/backend/internal/repository/api_key_cache.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "log" "time" "github.com/Wei-Shaw/sub2api/internal/service" @@ -12,9 +13,10 @@ import ( ) const ( - apiKeyRateLimitKeyPrefix = "apikey:ratelimit:" - apiKeyRateLimitDuration = 24 * time.Hour - apiKeyAuthCachePrefix = "apikey:auth:" + apiKeyRateLimitKeyPrefix = "apikey:ratelimit:" + apiKeyRateLimitDuration = 24 * time.Hour + apiKeyAuthCachePrefix = "apikey:auth:" + authCacheInvalidateChannel = "auth:cache:invalidate" ) // apiKeyRateLimitKey generates the Redis key for API key creation rate limiting. @@ -91,3 +93,45 @@ func (c *apiKeyCache) SetAuthCache(ctx context.Context, key string, entry *servi func (c *apiKeyCache) DeleteAuthCache(ctx context.Context, key string) error { return c.rdb.Del(ctx, apiKeyAuthCacheKey(key)).Err() } + +// PublishAuthCacheInvalidation publishes a cache invalidation message to all instances +func (c *apiKeyCache) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error { + return c.rdb.Publish(ctx, authCacheInvalidateChannel, cacheKey).Err() +} + +// SubscribeAuthCacheInvalidation subscribes to cache invalidation messages +func (c *apiKeyCache) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error { + pubsub := c.rdb.Subscribe(ctx, authCacheInvalidateChannel) + + // Verify subscription is working + _, err := pubsub.Receive(ctx) + if err != nil { + _ = pubsub.Close() + return fmt.Errorf("subscribe to auth cache invalidation: %w", err) + } + + go func() { + defer func() { + if err := pubsub.Close(); err != nil { + log.Printf("Warning: failed to close auth cache invalidation pubsub: %v", err) + } + }() + + ch := pubsub.Channel() + for { + select { + case <-ctx.Done(): + return + case msg, ok := <-ch: + if !ok { + return + } + if msg != nil { + handler(msg.Payload) + } + } + } + }() + + return nil +} diff --git a/backend/internal/repository/claude_oauth_service.go b/backend/internal/repository/claude_oauth_service.go index 677fce52..1f1db553 100644 --- a/backend/internal/repository/claude_oauth_service.go +++ b/backend/internal/repository/claude_oauth_service.go @@ -182,7 +182,9 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod resp, err := client.R(). SetContext(ctx). + SetHeader("Accept", "application/json, text/plain, */*"). SetHeader("Content-Type", "application/json"). + SetHeader("User-Agent", "axios/1.8.4"). SetBody(reqBody). SetSuccessResult(&tokenResp). Post(s.tokenURL) @@ -205,8 +207,6 @@ func (s *claudeOAuthService) ExchangeCodeForToken(ctx context.Context, code, cod func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*oauth.TokenResponse, error) { client := s.clientFactory(proxyURL) - // 使用 JSON 格式(与 ExchangeCodeForToken 保持一致) - // Anthropic OAuth API 期望 JSON 格式的请求体 reqBody := map[string]any{ "grant_type": "refresh_token", "refresh_token": refreshToken, @@ -217,7 +217,9 @@ func (s *claudeOAuthService) RefreshToken(ctx context.Context, refreshToken, pro resp, err := client.R(). SetContext(ctx). + SetHeader("Accept", "application/json, text/plain, */*"). SetHeader("Content-Type", "application/json"). + SetHeader("User-Agent", "axios/1.8.4"). SetBody(reqBody). SetSuccessResult(&tokenResp). Post(s.tokenURL) diff --git a/backend/internal/repository/claude_oauth_service_test.go b/backend/internal/repository/claude_oauth_service_test.go index a7f76056..7395c6d8 100644 --- a/backend/internal/repository/claude_oauth_service_test.go +++ b/backend/internal/repository/claude_oauth_service_test.go @@ -171,7 +171,7 @@ func (s *ClaudeOAuthServiceSuite) TestGetAuthorizationCode() { s.client.baseURL = "http://in-process" s.client.clientFactory = func(string) *req.Client { return newTestReqClient(rt) } - code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeProfile, "cc", "st", "") + code, err := s.client.GetAuthorizationCode(context.Background(), "sess", "org-1", oauth.ScopeInference, "cc", "st", "") if tt.wantErr { require.Error(s.T(), err) diff --git a/backend/internal/repository/claude_usage_service.go b/backend/internal/repository/claude_usage_service.go index 4c87b2de..1198f472 100644 --- a/backend/internal/repository/claude_usage_service.go +++ b/backend/internal/repository/claude_usage_service.go @@ -14,37 +14,82 @@ import ( const defaultClaudeUsageURL = "https://api.anthropic.com/api/oauth/usage" +// 默认 User-Agent,与用户抓包的请求一致 +const defaultUsageUserAgent = "claude-code/2.1.7" + type claudeUsageService struct { usageURL string allowPrivateHosts bool + httpUpstream service.HTTPUpstream } -func NewClaudeUsageFetcher() service.ClaudeUsageFetcher { - return &claudeUsageService{usageURL: defaultClaudeUsageURL} +// NewClaudeUsageFetcher 创建 Claude 用量获取服务 +// httpUpstream: 可选,如果提供则支持 TLS 指纹伪装 +func NewClaudeUsageFetcher(httpUpstream service.HTTPUpstream) service.ClaudeUsageFetcher { + return &claudeUsageService{ + usageURL: defaultClaudeUsageURL, + httpUpstream: httpUpstream, + } } +// FetchUsage 简单版本,不支持 TLS 指纹(向后兼容) func (s *claudeUsageService) FetchUsage(ctx context.Context, accessToken, proxyURL string) (*service.ClaudeUsageResponse, error) { - client, err := httpclient.GetClient(httpclient.Options{ - ProxyURL: proxyURL, - Timeout: 30 * time.Second, - ValidateResolvedIP: true, - AllowPrivateHosts: s.allowPrivateHosts, + return s.FetchUsageWithOptions(ctx, &service.ClaudeUsageFetchOptions{ + AccessToken: accessToken, + ProxyURL: proxyURL, }) - if err != nil { - client = &http.Client{Timeout: 30 * time.Second} +} + +// FetchUsageWithOptions 完整版本,支持 TLS 指纹和自定义 User-Agent +func (s *claudeUsageService) FetchUsageWithOptions(ctx context.Context, opts *service.ClaudeUsageFetchOptions) (*service.ClaudeUsageResponse, error) { + if opts == nil { + return nil, fmt.Errorf("options is nil") } + // 创建请求 req, err := http.NewRequestWithContext(ctx, "GET", s.usageURL, nil) if err != nil { return nil, fmt.Errorf("create request failed: %w", err) } - req.Header.Set("Authorization", "Bearer "+accessToken) + // 设置请求头(与抓包一致,但不设置 Accept-Encoding,让 Go 自动处理压缩) + req.Header.Set("Accept", "application/json, text/plain, */*") + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+opts.AccessToken) req.Header.Set("anthropic-beta", "oauth-2025-04-20") - resp, err := client.Do(req) - if err != nil { - return nil, fmt.Errorf("request failed: %w", err) + // 设置 User-Agent(优先使用缓存的 Fingerprint,否则使用默认值) + userAgent := defaultUsageUserAgent + if opts.Fingerprint != nil && opts.Fingerprint.UserAgent != "" { + userAgent = opts.Fingerprint.UserAgent + } + req.Header.Set("User-Agent", userAgent) + + var resp *http.Response + + // 如果启用 TLS 指纹且有 HTTPUpstream,使用 DoWithTLS + if opts.EnableTLSFingerprint && s.httpUpstream != nil { + // accountConcurrency 传 0 使用默认连接池配置,usage 请求不需要特殊的并发设置 + resp, err = s.httpUpstream.DoWithTLS(req, opts.ProxyURL, opts.AccountID, 0, true) + if err != nil { + return nil, fmt.Errorf("request with TLS fingerprint failed: %w", err) + } + } else { + // 不启用 TLS 指纹,使用普通 HTTP 客户端 + client, err := httpclient.GetClient(httpclient.Options{ + ProxyURL: opts.ProxyURL, + Timeout: 30 * time.Second, + ValidateResolvedIP: true, + AllowPrivateHosts: s.allowPrivateHosts, + }) + if err != nil { + client = &http.Client{Timeout: 30 * time.Second} + } + + resp, err = client.Do(req) + if err != nil { + return nil, fmt.Errorf("request failed: %w", err) + } } defer func() { _ = resp.Body.Close() }() diff --git a/backend/internal/repository/dashboard_aggregation_repo.go b/backend/internal/repository/dashboard_aggregation_repo.go index 3543e061..59bbd6a3 100644 --- a/backend/internal/repository/dashboard_aggregation_repo.go +++ b/backend/internal/repository/dashboard_aggregation_repo.go @@ -77,6 +77,75 @@ func (r *dashboardAggregationRepository) AggregateRange(ctx context.Context, sta return nil } +func (r *dashboardAggregationRepository) RecomputeRange(ctx context.Context, start, end time.Time) error { + if r == nil || r.sql == nil { + return nil + } + loc := timezone.Location() + startLocal := start.In(loc) + endLocal := end.In(loc) + if !endLocal.After(startLocal) { + return nil + } + + hourStart := startLocal.Truncate(time.Hour) + hourEnd := endLocal.Truncate(time.Hour) + if endLocal.After(hourEnd) { + hourEnd = hourEnd.Add(time.Hour) + } + + dayStart := truncateToDay(startLocal) + dayEnd := truncateToDay(endLocal) + if endLocal.After(dayEnd) { + dayEnd = dayEnd.Add(24 * time.Hour) + } + + // 尽量使用事务保证范围内的一致性(允许在非 *sql.DB 的情况下退化为非事务执行)。 + if db, ok := r.sql.(*sql.DB); ok { + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return err + } + txRepo := newDashboardAggregationRepositoryWithSQL(tx) + if err := txRepo.recomputeRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd); err != nil { + _ = tx.Rollback() + return err + } + return tx.Commit() + } + return r.recomputeRangeInTx(ctx, hourStart, hourEnd, dayStart, dayEnd) +} + +func (r *dashboardAggregationRepository) recomputeRangeInTx(ctx context.Context, hourStart, hourEnd, dayStart, dayEnd time.Time) error { + // 先清空范围内桶,再重建(避免仅增量插入导致活跃用户等指标无法回退)。 + if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly WHERE bucket_start >= $1 AND bucket_start < $2", hourStart, hourEnd); err != nil { + return err + } + if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_hourly_users WHERE bucket_start >= $1 AND bucket_start < $2", hourStart, hourEnd); err != nil { + return err + } + if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily WHERE bucket_date >= $1::date AND bucket_date < $2::date", dayStart, dayEnd); err != nil { + return err + } + if _, err := r.sql.ExecContext(ctx, "DELETE FROM usage_dashboard_daily_users WHERE bucket_date >= $1::date AND bucket_date < $2::date", dayStart, dayEnd); err != nil { + return err + } + + if err := r.insertHourlyActiveUsers(ctx, hourStart, hourEnd); err != nil { + return err + } + if err := r.insertDailyActiveUsers(ctx, hourStart, hourEnd); err != nil { + return err + } + if err := r.upsertHourlyAggregates(ctx, hourStart, hourEnd); err != nil { + return err + } + if err := r.upsertDailyAggregates(ctx, dayStart, dayEnd); err != nil { + return err + } + return nil +} + func (r *dashboardAggregationRepository) GetAggregationWatermark(ctx context.Context) (time.Time, error) { var ts time.Time query := "SELECT last_aggregated_at FROM usage_dashboard_aggregation_watermark WHERE id = 1" diff --git a/backend/internal/repository/ent.go b/backend/internal/repository/ent.go index 8005f114..d7d574e8 100644 --- a/backend/internal/repository/ent.go +++ b/backend/internal/repository/ent.go @@ -65,5 +65,18 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) { // 创建 Ent 客户端,绑定到已配置的数据库驱动。 client := ent.NewClient(ent.Driver(drv)) + + // SIMPLE 模式:启动时补齐各平台默认分组。 + // - anthropic/openai/gemini: 确保存在 -default + // - antigravity: 仅要求存在 >=2 个未软删除分组(用于 claude/gemini 混合调度场景) + if cfg.RunMode == config.RunModeSimple { + seedCtx, seedCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer seedCancel() + if err := ensureSimpleModeDefaultGroups(seedCtx, client); err != nil { + _ = client.Close() + return nil, nil, err + } + } + return client, drv.DB(), nil } diff --git a/backend/internal/repository/http_upstream.go b/backend/internal/repository/http_upstream.go index feb32541..b0f15f19 100644 --- a/backend/internal/repository/http_upstream.go +++ b/backend/internal/repository/http_upstream.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "io" + "log/slog" "net" "net/http" "net/url" @@ -14,6 +15,7 @@ import ( "github.com/Wei-Shaw/sub2api/internal/config" "github.com/Wei-Shaw/sub2api/internal/pkg/proxyutil" + "github.com/Wei-Shaw/sub2api/internal/pkg/tlsfingerprint" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/util/urlvalidator" ) @@ -150,6 +152,172 @@ func (s *httpUpstreamService) Do(req *http.Request, proxyURL string, accountID i return resp, nil } +// DoWithTLS 执行带 TLS 指纹伪装的 HTTP 请求 +// 根据 enableTLSFingerprint 参数决定是否使用 TLS 指纹 +// +// 参数: +// - req: HTTP 请求对象 +// - proxyURL: 代理地址,空字符串表示直连 +// - accountID: 账户 ID,用于账户级隔离和 TLS 指纹模板选择 +// - accountConcurrency: 账户并发限制,用于动态调整连接池大小 +// - enableTLSFingerprint: 是否启用 TLS 指纹伪装 +// +// TLS 指纹说明: +// - 当 enableTLSFingerprint=true 时,使用 utls 库模拟 Claude CLI 的 TLS 指纹 +// - 指纹模板根据 accountID % len(profiles) 自动选择 +// - 支持直连、HTTP/HTTPS 代理、SOCKS5 代理三种场景 +func (s *httpUpstreamService) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) { + // 如果未启用 TLS 指纹,直接使用标准请求路径 + if !enableTLSFingerprint { + return s.Do(req, proxyURL, accountID, accountConcurrency) + } + + // TLS 指纹已启用,记录调试日志 + targetHost := "" + if req != nil && req.URL != nil { + targetHost = req.URL.Host + } + proxyInfo := "direct" + if proxyURL != "" { + proxyInfo = proxyURL + } + slog.Debug("tls_fingerprint_enabled", "account_id", accountID, "target", targetHost, "proxy", proxyInfo) + + if err := s.validateRequestHost(req); err != nil { + return nil, err + } + + // 获取 TLS 指纹 Profile + registry := tlsfingerprint.GlobalRegistry() + profile := registry.GetProfileByAccountID(accountID) + if profile == nil { + // 如果获取不到 profile,回退到普通请求 + slog.Debug("tls_fingerprint_no_profile", "account_id", accountID, "fallback", "standard_request") + return s.Do(req, proxyURL, accountID, accountConcurrency) + } + + slog.Debug("tls_fingerprint_using_profile", "account_id", accountID, "profile", profile.Name, "grease", profile.EnableGREASE) + + // 获取或创建带 TLS 指纹的客户端 + entry, err := s.acquireClientWithTLS(proxyURL, accountID, accountConcurrency, profile) + if err != nil { + slog.Debug("tls_fingerprint_acquire_client_failed", "account_id", accountID, "error", err) + return nil, err + } + + // 执行请求 + resp, err := entry.client.Do(req) + if err != nil { + // 请求失败,立即减少计数 + atomic.AddInt64(&entry.inFlight, -1) + atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano()) + slog.Debug("tls_fingerprint_request_failed", "account_id", accountID, "error", err) + return nil, err + } + + slog.Debug("tls_fingerprint_request_success", "account_id", accountID, "status", resp.StatusCode) + + // 包装响应体,在关闭时自动减少计数并更新时间戳 + resp.Body = wrapTrackedBody(resp.Body, func() { + atomic.AddInt64(&entry.inFlight, -1) + atomic.StoreInt64(&entry.lastUsed, time.Now().UnixNano()) + }) + + return resp, nil +} + +// acquireClientWithTLS 获取或创建带 TLS 指纹的客户端 +func (s *httpUpstreamService) acquireClientWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile) (*upstreamClientEntry, error) { + return s.getClientEntryWithTLS(proxyURL, accountID, accountConcurrency, profile, true, true) +} + +// getClientEntryWithTLS 获取或创建带 TLS 指纹的客户端条目 +// TLS 指纹客户端使用独立的缓存键,与普通客户端隔离 +func (s *httpUpstreamService) getClientEntryWithTLS(proxyURL string, accountID int64, accountConcurrency int, profile *tlsfingerprint.Profile, markInFlight bool, enforceLimit bool) (*upstreamClientEntry, error) { + isolation := s.getIsolationMode() + proxyKey, parsedProxy := normalizeProxyURL(proxyURL) + // TLS 指纹客户端使用独立的缓存键,加 "tls:" 前缀 + cacheKey := "tls:" + buildCacheKey(isolation, proxyKey, accountID) + poolKey := s.buildPoolKey(isolation, accountConcurrency) + ":tls" + + now := time.Now() + nowUnix := now.UnixNano() + + // 读锁快速路径 + s.mu.RLock() + if entry, ok := s.clients[cacheKey]; ok && s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) { + atomic.StoreInt64(&entry.lastUsed, nowUnix) + if markInFlight { + atomic.AddInt64(&entry.inFlight, 1) + } + s.mu.RUnlock() + slog.Debug("tls_fingerprint_reusing_client", "account_id", accountID, "cache_key", cacheKey) + return entry, nil + } + s.mu.RUnlock() + + // 写锁慢路径 + s.mu.Lock() + if entry, ok := s.clients[cacheKey]; ok { + if s.shouldReuseEntry(entry, isolation, proxyKey, poolKey) { + atomic.StoreInt64(&entry.lastUsed, nowUnix) + if markInFlight { + atomic.AddInt64(&entry.inFlight, 1) + } + s.mu.Unlock() + slog.Debug("tls_fingerprint_reusing_client", "account_id", accountID, "cache_key", cacheKey) + return entry, nil + } + slog.Debug("tls_fingerprint_evicting_stale_client", + "account_id", accountID, + "cache_key", cacheKey, + "proxy_changed", entry.proxyKey != proxyKey, + "pool_changed", entry.poolKey != poolKey) + s.removeClientLocked(cacheKey, entry) + } + + // 超出缓存上限时尝试淘汰 + if enforceLimit && s.maxUpstreamClients() > 0 { + s.evictIdleLocked(now) + if len(s.clients) >= s.maxUpstreamClients() { + if !s.evictOldestIdleLocked() { + s.mu.Unlock() + return nil, errUpstreamClientLimitReached + } + } + } + + // 创建带 TLS 指纹的 Transport + slog.Debug("tls_fingerprint_creating_new_client", "account_id", accountID, "cache_key", cacheKey, "proxy", proxyKey) + settings := s.resolvePoolSettings(isolation, accountConcurrency) + transport, err := buildUpstreamTransportWithTLSFingerprint(settings, parsedProxy, profile) + if err != nil { + s.mu.Unlock() + return nil, fmt.Errorf("build TLS fingerprint transport: %w", err) + } + + client := &http.Client{Transport: transport} + if s.shouldValidateResolvedIP() { + client.CheckRedirect = s.redirectChecker + } + + entry := &upstreamClientEntry{ + client: client, + proxyKey: proxyKey, + poolKey: poolKey, + } + atomic.StoreInt64(&entry.lastUsed, nowUnix) + if markInFlight { + atomic.StoreInt64(&entry.inFlight, 1) + } + s.clients[cacheKey] = entry + + s.evictIdleLocked(now) + s.evictOverLimitLocked() + s.mu.Unlock() + return entry, nil +} + func (s *httpUpstreamService) shouldValidateResolvedIP() bool { if s.cfg == nil { return false @@ -618,6 +786,64 @@ func buildUpstreamTransport(settings poolSettings, proxyURL *url.URL) (*http.Tra return transport, nil } +// buildUpstreamTransportWithTLSFingerprint 构建带 TLS 指纹伪装的 Transport +// 使用 utls 库模拟 Claude CLI 的 TLS 指纹 +// +// 参数: +// - settings: 连接池配置 +// - proxyURL: 代理 URL(nil 表示直连) +// - profile: TLS 指纹配置 +// +// 返回: +// - *http.Transport: 配置好的 Transport 实例 +// - error: 配置错误 +// +// 代理类型处理: +// - nil/空: 直连,使用 TLSFingerprintDialer +// - http/https: HTTP 代理,使用 HTTPProxyDialer(CONNECT 隧道 + utls 握手) +// - socks5: SOCKS5 代理,使用 SOCKS5ProxyDialer(SOCKS5 隧道 + utls 握手) +func buildUpstreamTransportWithTLSFingerprint(settings poolSettings, proxyURL *url.URL, profile *tlsfingerprint.Profile) (*http.Transport, error) { + transport := &http.Transport{ + MaxIdleConns: settings.maxIdleConns, + MaxIdleConnsPerHost: settings.maxIdleConnsPerHost, + MaxConnsPerHost: settings.maxConnsPerHost, + IdleConnTimeout: settings.idleConnTimeout, + ResponseHeaderTimeout: settings.responseHeaderTimeout, + // 禁用默认的 TLS,我们使用自定义的 DialTLSContext + ForceAttemptHTTP2: false, + } + + // 根据代理类型选择合适的 TLS 指纹 Dialer + if proxyURL == nil { + // 直连:使用 TLSFingerprintDialer + slog.Debug("tls_fingerprint_transport_direct") + dialer := tlsfingerprint.NewDialer(profile, nil) + transport.DialTLSContext = dialer.DialTLSContext + } else { + scheme := strings.ToLower(proxyURL.Scheme) + switch scheme { + case "socks5", "socks5h": + // SOCKS5 代理:使用 SOCKS5ProxyDialer + slog.Debug("tls_fingerprint_transport_socks5", "proxy", proxyURL.Host) + socks5Dialer := tlsfingerprint.NewSOCKS5ProxyDialer(profile, proxyURL) + transport.DialTLSContext = socks5Dialer.DialTLSContext + case "http", "https": + // HTTP/HTTPS 代理:使用 HTTPProxyDialer(CONNECT 隧道) + slog.Debug("tls_fingerprint_transport_http_connect", "proxy", proxyURL.Host) + httpDialer := tlsfingerprint.NewHTTPProxyDialer(profile, proxyURL) + transport.DialTLSContext = httpDialer.DialTLSContext + default: + // 未知代理类型,回退到普通代理配置(无 TLS 指纹) + slog.Debug("tls_fingerprint_transport_unknown_scheme_fallback", "scheme", scheme) + if err := proxyutil.ConfigureTransportProxy(transport, proxyURL); err != nil { + return nil, err + } + } + } + + return transport, nil +} + // trackedBody 带跟踪功能的响应体包装器 // 在 Close 时执行回调,用于更新请求计数 type trackedBody struct { diff --git a/backend/internal/repository/identity_cache.go b/backend/internal/repository/identity_cache.go index d28477b7..c4986547 100644 --- a/backend/internal/repository/identity_cache.go +++ b/backend/internal/repository/identity_cache.go @@ -11,8 +11,10 @@ import ( ) const ( - fingerprintKeyPrefix = "fingerprint:" - fingerprintTTL = 24 * time.Hour + fingerprintKeyPrefix = "fingerprint:" + fingerprintTTL = 24 * time.Hour + maskedSessionKeyPrefix = "masked_session:" + maskedSessionTTL = 15 * time.Minute ) // fingerprintKey generates the Redis key for account fingerprint cache. @@ -20,6 +22,11 @@ func fingerprintKey(accountID int64) string { return fmt.Sprintf("%s%d", fingerprintKeyPrefix, accountID) } +// maskedSessionKey generates the Redis key for masked session ID cache. +func maskedSessionKey(accountID int64) string { + return fmt.Sprintf("%s%d", maskedSessionKeyPrefix, accountID) +} + type identityCache struct { rdb *redis.Client } @@ -49,3 +56,20 @@ func (c *identityCache) SetFingerprint(ctx context.Context, accountID int64, fp } return c.rdb.Set(ctx, key, val, fingerprintTTL).Err() } + +func (c *identityCache) GetMaskedSessionID(ctx context.Context, accountID int64) (string, error) { + key := maskedSessionKey(accountID) + val, err := c.rdb.Get(ctx, key).Result() + if err != nil { + if err == redis.Nil { + return "", nil + } + return "", err + } + return val, nil +} + +func (c *identityCache) SetMaskedSessionID(ctx context.Context, accountID int64, sessionID string) error { + key := maskedSessionKey(accountID) + return c.rdb.Set(ctx, key, sessionID, maskedSessionTTL).Err() +} diff --git a/backend/internal/repository/ops_repo.go b/backend/internal/repository/ops_repo.go index 613c5bd5..b04154b7 100644 --- a/backend/internal/repository/ops_repo.go +++ b/backend/internal/repository/ops_repo.go @@ -992,7 +992,8 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) { } // View filter: errors vs excluded vs all. - // Excluded = upstream 429/529 and business-limited (quota/concurrency/billing) errors. + // Excluded = business-limited errors (quota/concurrency/billing). + // Upstream 429/529 are included in errors view to match SLA calculation. view := "" if filter != nil { view = strings.ToLower(strings.TrimSpace(filter.View)) @@ -1000,15 +1001,13 @@ func buildOpsErrorLogsWhere(filter *service.OpsErrorLogFilter) (string, []any) { switch view { case "", "errors": clauses = append(clauses, "COALESCE(is_business_limited,false) = false") - clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)") case "excluded": - clauses = append(clauses, "(COALESCE(is_business_limited,false) = true OR COALESCE(upstream_status_code, status_code, 0) IN (429, 529))") + clauses = append(clauses, "COALESCE(is_business_limited,false) = true") case "all": // no-op default: // treat unknown as default 'errors' clauses = append(clauses, "COALESCE(is_business_limited,false) = false") - clauses = append(clauses, "COALESCE(upstream_status_code, status_code, 0) NOT IN (429, 529)") } if len(filter.StatusCodes) > 0 { args = append(args, pq.Array(filter.StatusCodes)) diff --git a/backend/internal/repository/session_limit_cache.go b/backend/internal/repository/session_limit_cache.go index 16f2a69c..3dc89f87 100644 --- a/backend/internal/repository/session_limit_cache.go +++ b/backend/internal/repository/session_limit_cache.go @@ -217,7 +217,7 @@ func (c *sessionLimitCache) GetActiveSessionCount(ctx context.Context, accountID } // GetActiveSessionCountBatch 批量获取多个账号的活跃会话数 -func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) { +func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64, idleTimeouts map[int64]time.Duration) (map[int64]int, error) { if len(accountIDs) == 0 { return make(map[int64]int), nil } @@ -226,11 +226,18 @@ func (c *sessionLimitCache) GetActiveSessionCountBatch(ctx context.Context, acco // 使用 pipeline 批量执行 pipe := c.rdb.Pipeline() - idleTimeoutSeconds := int(c.defaultIdleTimeout.Seconds()) cmds := make(map[int64]*redis.Cmd, len(accountIDs)) for _, accountID := range accountIDs { key := sessionLimitKey(accountID) + // 使用各账号自己的 idleTimeout,如果没有则用默认值 + idleTimeout := c.defaultIdleTimeout + if idleTimeouts != nil { + if t, ok := idleTimeouts[accountID]; ok && t > 0 { + idleTimeout = t + } + } + idleTimeoutSeconds := int(idleTimeout.Seconds()) cmds[accountID] = getActiveSessionCountScript.Run(ctx, pipe, []string{key}, idleTimeoutSeconds) } diff --git a/backend/internal/repository/simple_mode_default_groups.go b/backend/internal/repository/simple_mode_default_groups.go new file mode 100644 index 00000000..56309184 --- /dev/null +++ b/backend/internal/repository/simple_mode_default_groups.go @@ -0,0 +1,82 @@ +package repository + +import ( + "context" + "fmt" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +func ensureSimpleModeDefaultGroups(ctx context.Context, client *dbent.Client) error { + if client == nil { + return fmt.Errorf("nil ent client") + } + + requiredByPlatform := map[string]int{ + service.PlatformAnthropic: 1, + service.PlatformOpenAI: 1, + service.PlatformGemini: 1, + service.PlatformAntigravity: 2, + } + + for platform, minCount := range requiredByPlatform { + count, err := client.Group.Query(). + Where(group.PlatformEQ(platform), group.DeletedAtIsNil()). + Count(ctx) + if err != nil { + return fmt.Errorf("count groups for platform %s: %w", platform, err) + } + + if platform == service.PlatformAntigravity { + if count < minCount { + for i := count; i < minCount; i++ { + name := fmt.Sprintf("%s-default-%d", platform, i+1) + if err := createGroupIfNotExists(ctx, client, name, platform); err != nil { + return err + } + } + } + continue + } + + // Non-antigravity platforms: ensure -default exists. + name := platform + "-default" + if err := createGroupIfNotExists(ctx, client, name, platform); err != nil { + return err + } + } + + return nil +} + +func createGroupIfNotExists(ctx context.Context, client *dbent.Client, name, platform string) error { + exists, err := client.Group.Query(). + Where(group.NameEQ(name), group.DeletedAtIsNil()). + Exist(ctx) + if err != nil { + return fmt.Errorf("check group exists %s: %w", name, err) + } + if exists { + return nil + } + + _, err = client.Group.Create(). + SetName(name). + SetDescription("Auto-created default group"). + SetPlatform(platform). + SetStatus(service.StatusActive). + SetSubscriptionType(service.SubscriptionTypeStandard). + SetRateMultiplier(1.0). + SetIsExclusive(false). + Save(ctx) + if err != nil { + if dbent.IsConstraintError(err) { + // Concurrent server startups may race on creation; treat as success. + return nil + } + return fmt.Errorf("create default group %s: %w", name, err) + } + return nil +} diff --git a/backend/internal/repository/simple_mode_default_groups_integration_test.go b/backend/internal/repository/simple_mode_default_groups_integration_test.go new file mode 100644 index 00000000..3327257b --- /dev/null +++ b/backend/internal/repository/simple_mode_default_groups_integration_test.go @@ -0,0 +1,84 @@ +//go:build integration + +package repository + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestEnsureSimpleModeDefaultGroups_CreatesMissingDefaults(t *testing.T) { + ctx := context.Background() + tx := testEntTx(t) + client := tx.Client() + + seedCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + require.NoError(t, ensureSimpleModeDefaultGroups(seedCtx, client)) + + assertGroupExists := func(name string) { + exists, err := client.Group.Query().Where(group.NameEQ(name), group.DeletedAtIsNil()).Exist(seedCtx) + require.NoError(t, err) + require.True(t, exists, "expected group %s to exist", name) + } + + assertGroupExists(service.PlatformAnthropic + "-default") + assertGroupExists(service.PlatformOpenAI + "-default") + assertGroupExists(service.PlatformGemini + "-default") + assertGroupExists(service.PlatformAntigravity + "-default-1") + assertGroupExists(service.PlatformAntigravity + "-default-2") +} + +func TestEnsureSimpleModeDefaultGroups_IgnoresSoftDeletedGroups(t *testing.T) { + ctx := context.Background() + tx := testEntTx(t) + client := tx.Client() + + seedCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + // Create and then soft-delete an anthropic default group. + g, err := client.Group.Create(). + SetName(service.PlatformAnthropic + "-default"). + SetPlatform(service.PlatformAnthropic). + SetStatus(service.StatusActive). + SetSubscriptionType(service.SubscriptionTypeStandard). + SetRateMultiplier(1.0). + SetIsExclusive(false). + Save(seedCtx) + require.NoError(t, err) + + _, err = client.Group.Delete().Where(group.IDEQ(g.ID)).Exec(seedCtx) + require.NoError(t, err) + + require.NoError(t, ensureSimpleModeDefaultGroups(seedCtx, client)) + + // New active one should exist. + count, err := client.Group.Query().Where(group.NameEQ(service.PlatformAnthropic+"-default"), group.DeletedAtIsNil()).Count(seedCtx) + require.NoError(t, err) + require.Equal(t, 1, count) +} + +func TestEnsureSimpleModeDefaultGroups_AntigravityNeedsTwoGroupsOnlyByCount(t *testing.T) { + ctx := context.Background() + tx := testEntTx(t) + client := tx.Client() + + seedCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + mustCreateGroup(t, client, &service.Group{Name: "ag-custom-1-" + time.Now().Format(time.RFC3339Nano), Platform: service.PlatformAntigravity}) + mustCreateGroup(t, client, &service.Group{Name: "ag-custom-2-" + time.Now().Format(time.RFC3339Nano), Platform: service.PlatformAntigravity}) + + require.NoError(t, ensureSimpleModeDefaultGroups(seedCtx, client)) + + count, err := client.Group.Query().Where(group.PlatformEQ(service.PlatformAntigravity), group.DeletedAtIsNil()).Count(seedCtx) + require.NoError(t, err) + require.GreaterOrEqual(t, count, 2) +} diff --git a/backend/internal/repository/usage_cleanup_repo.go b/backend/internal/repository/usage_cleanup_repo.go new file mode 100644 index 00000000..9c021357 --- /dev/null +++ b/backend/internal/repository/usage_cleanup_repo.go @@ -0,0 +1,551 @@ +package repository + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "strings" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + dbusagecleanuptask "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +type usageCleanupRepository struct { + client *dbent.Client + sql sqlExecutor +} + +func NewUsageCleanupRepository(client *dbent.Client, sqlDB *sql.DB) service.UsageCleanupRepository { + return newUsageCleanupRepositoryWithSQL(client, sqlDB) +} + +func newUsageCleanupRepositoryWithSQL(client *dbent.Client, sqlq sqlExecutor) *usageCleanupRepository { + return &usageCleanupRepository{client: client, sql: sqlq} +} + +func (r *usageCleanupRepository) CreateTask(ctx context.Context, task *service.UsageCleanupTask) error { + if task == nil { + return nil + } + if r.client != nil { + return r.createTaskWithEnt(ctx, task) + } + return r.createTaskWithSQL(ctx, task) +} + +func (r *usageCleanupRepository) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]service.UsageCleanupTask, *pagination.PaginationResult, error) { + if r.client != nil { + return r.listTasksWithEnt(ctx, params) + } + var total int64 + if err := scanSingleRow(ctx, r.sql, "SELECT COUNT(*) FROM usage_cleanup_tasks", nil, &total); err != nil { + return nil, nil, err + } + if total == 0 { + return []service.UsageCleanupTask{}, paginationResultFromTotal(0, params), nil + } + + query := ` + SELECT id, status, filters, created_by, deleted_rows, error_message, + canceled_by, canceled_at, + started_at, finished_at, created_at, updated_at + FROM usage_cleanup_tasks + ORDER BY created_at DESC, id DESC + LIMIT $1 OFFSET $2 + ` + rows, err := r.sql.QueryContext(ctx, query, params.Limit(), params.Offset()) + if err != nil { + return nil, nil, err + } + defer func() { _ = rows.Close() }() + + tasks := make([]service.UsageCleanupTask, 0) + for rows.Next() { + var task service.UsageCleanupTask + var filtersJSON []byte + var errMsg sql.NullString + var canceledBy sql.NullInt64 + var canceledAt sql.NullTime + var startedAt sql.NullTime + var finishedAt sql.NullTime + if err := rows.Scan( + &task.ID, + &task.Status, + &filtersJSON, + &task.CreatedBy, + &task.DeletedRows, + &errMsg, + &canceledBy, + &canceledAt, + &startedAt, + &finishedAt, + &task.CreatedAt, + &task.UpdatedAt, + ); err != nil { + return nil, nil, err + } + if err := json.Unmarshal(filtersJSON, &task.Filters); err != nil { + return nil, nil, fmt.Errorf("parse cleanup filters: %w", err) + } + if errMsg.Valid { + task.ErrorMsg = &errMsg.String + } + if canceledBy.Valid { + v := canceledBy.Int64 + task.CanceledBy = &v + } + if canceledAt.Valid { + task.CanceledAt = &canceledAt.Time + } + if startedAt.Valid { + task.StartedAt = &startedAt.Time + } + if finishedAt.Valid { + task.FinishedAt = &finishedAt.Time + } + tasks = append(tasks, task) + } + if err := rows.Err(); err != nil { + return nil, nil, err + } + return tasks, paginationResultFromTotal(total, params), nil +} + +func (r *usageCleanupRepository) ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*service.UsageCleanupTask, error) { + if staleRunningAfterSeconds <= 0 { + staleRunningAfterSeconds = 1800 + } + query := ` + WITH next AS ( + SELECT id + FROM usage_cleanup_tasks + WHERE status = $1 + OR ( + status = $2 + AND started_at IS NOT NULL + AND started_at < NOW() - ($3 * interval '1 second') + ) + ORDER BY created_at ASC + LIMIT 1 + FOR UPDATE SKIP LOCKED + ) + UPDATE usage_cleanup_tasks AS tasks + SET status = $4, + started_at = NOW(), + finished_at = NULL, + error_message = NULL, + updated_at = NOW() + FROM next + WHERE tasks.id = next.id + RETURNING tasks.id, tasks.status, tasks.filters, tasks.created_by, tasks.deleted_rows, tasks.error_message, + tasks.started_at, tasks.finished_at, tasks.created_at, tasks.updated_at + ` + var task service.UsageCleanupTask + var filtersJSON []byte + var errMsg sql.NullString + var startedAt sql.NullTime + var finishedAt sql.NullTime + if err := scanSingleRow( + ctx, + r.sql, + query, + []any{ + service.UsageCleanupStatusPending, + service.UsageCleanupStatusRunning, + staleRunningAfterSeconds, + service.UsageCleanupStatusRunning, + }, + &task.ID, + &task.Status, + &filtersJSON, + &task.CreatedBy, + &task.DeletedRows, + &errMsg, + &startedAt, + &finishedAt, + &task.CreatedAt, + &task.UpdatedAt, + ); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + if err := json.Unmarshal(filtersJSON, &task.Filters); err != nil { + return nil, fmt.Errorf("parse cleanup filters: %w", err) + } + if errMsg.Valid { + task.ErrorMsg = &errMsg.String + } + if startedAt.Valid { + task.StartedAt = &startedAt.Time + } + if finishedAt.Valid { + task.FinishedAt = &finishedAt.Time + } + return &task, nil +} + +func (r *usageCleanupRepository) GetTaskStatus(ctx context.Context, taskID int64) (string, error) { + if r.client != nil { + return r.getTaskStatusWithEnt(ctx, taskID) + } + var status string + if err := scanSingleRow(ctx, r.sql, "SELECT status FROM usage_cleanup_tasks WHERE id = $1", []any{taskID}, &status); err != nil { + return "", err + } + return status, nil +} + +func (r *usageCleanupRepository) UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error { + if r.client != nil { + return r.updateTaskProgressWithEnt(ctx, taskID, deletedRows) + } + query := ` + UPDATE usage_cleanup_tasks + SET deleted_rows = $1, + updated_at = NOW() + WHERE id = $2 + ` + _, err := r.sql.ExecContext(ctx, query, deletedRows, taskID) + return err +} + +func (r *usageCleanupRepository) CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error) { + if r.client != nil { + return r.cancelTaskWithEnt(ctx, taskID, canceledBy) + } + query := ` + UPDATE usage_cleanup_tasks + SET status = $1, + canceled_by = $3, + canceled_at = NOW(), + finished_at = NOW(), + error_message = NULL, + updated_at = NOW() + WHERE id = $2 + AND status IN ($4, $5) + RETURNING id + ` + var id int64 + err := scanSingleRow(ctx, r.sql, query, []any{ + service.UsageCleanupStatusCanceled, + taskID, + canceledBy, + service.UsageCleanupStatusPending, + service.UsageCleanupStatusRunning, + }, &id) + if errors.Is(err, sql.ErrNoRows) { + return false, nil + } + if err != nil { + return false, err + } + return true, nil +} + +func (r *usageCleanupRepository) MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error { + if r.client != nil { + return r.markTaskSucceededWithEnt(ctx, taskID, deletedRows) + } + query := ` + UPDATE usage_cleanup_tasks + SET status = $1, + deleted_rows = $2, + finished_at = NOW(), + updated_at = NOW() + WHERE id = $3 + ` + _, err := r.sql.ExecContext(ctx, query, service.UsageCleanupStatusSucceeded, deletedRows, taskID) + return err +} + +func (r *usageCleanupRepository) MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error { + if r.client != nil { + return r.markTaskFailedWithEnt(ctx, taskID, deletedRows, errorMsg) + } + query := ` + UPDATE usage_cleanup_tasks + SET status = $1, + deleted_rows = $2, + error_message = $3, + finished_at = NOW(), + updated_at = NOW() + WHERE id = $4 + ` + _, err := r.sql.ExecContext(ctx, query, service.UsageCleanupStatusFailed, deletedRows, errorMsg, taskID) + return err +} + +func (r *usageCleanupRepository) DeleteUsageLogsBatch(ctx context.Context, filters service.UsageCleanupFilters, limit int) (int64, error) { + if filters.StartTime.IsZero() || filters.EndTime.IsZero() { + return 0, fmt.Errorf("cleanup filters missing time range") + } + whereClause, args := buildUsageCleanupWhere(filters) + if whereClause == "" { + return 0, fmt.Errorf("cleanup filters missing time range") + } + args = append(args, limit) + query := fmt.Sprintf(` + WITH target AS ( + SELECT id + FROM usage_logs + WHERE %s + ORDER BY created_at ASC, id ASC + LIMIT $%d + ) + DELETE FROM usage_logs + WHERE id IN (SELECT id FROM target) + RETURNING id + `, whereClause, len(args)) + + rows, err := r.sql.QueryContext(ctx, query, args...) + if err != nil { + return 0, err + } + defer func() { _ = rows.Close() }() + + var deleted int64 + for rows.Next() { + deleted++ + } + if err := rows.Err(); err != nil { + return 0, err + } + return deleted, nil +} + +func buildUsageCleanupWhere(filters service.UsageCleanupFilters) (string, []any) { + conditions := make([]string, 0, 8) + args := make([]any, 0, 8) + idx := 1 + if !filters.StartTime.IsZero() { + conditions = append(conditions, fmt.Sprintf("created_at >= $%d", idx)) + args = append(args, filters.StartTime) + idx++ + } + if !filters.EndTime.IsZero() { + conditions = append(conditions, fmt.Sprintf("created_at <= $%d", idx)) + args = append(args, filters.EndTime) + idx++ + } + if filters.UserID != nil { + conditions = append(conditions, fmt.Sprintf("user_id = $%d", idx)) + args = append(args, *filters.UserID) + idx++ + } + if filters.APIKeyID != nil { + conditions = append(conditions, fmt.Sprintf("api_key_id = $%d", idx)) + args = append(args, *filters.APIKeyID) + idx++ + } + if filters.AccountID != nil { + conditions = append(conditions, fmt.Sprintf("account_id = $%d", idx)) + args = append(args, *filters.AccountID) + idx++ + } + if filters.GroupID != nil { + conditions = append(conditions, fmt.Sprintf("group_id = $%d", idx)) + args = append(args, *filters.GroupID) + idx++ + } + if filters.Model != nil { + model := strings.TrimSpace(*filters.Model) + if model != "" { + conditions = append(conditions, fmt.Sprintf("model = $%d", idx)) + args = append(args, model) + idx++ + } + } + if filters.Stream != nil { + conditions = append(conditions, fmt.Sprintf("stream = $%d", idx)) + args = append(args, *filters.Stream) + idx++ + } + if filters.BillingType != nil { + conditions = append(conditions, fmt.Sprintf("billing_type = $%d", idx)) + args = append(args, *filters.BillingType) + } + return strings.Join(conditions, " AND "), args +} + +func (r *usageCleanupRepository) createTaskWithEnt(ctx context.Context, task *service.UsageCleanupTask) error { + client := clientFromContext(ctx, r.client) + filtersJSON, err := json.Marshal(task.Filters) + if err != nil { + return fmt.Errorf("marshal cleanup filters: %w", err) + } + created, err := client.UsageCleanupTask. + Create(). + SetStatus(task.Status). + SetFilters(json.RawMessage(filtersJSON)). + SetCreatedBy(task.CreatedBy). + SetDeletedRows(task.DeletedRows). + Save(ctx) + if err != nil { + return err + } + task.ID = created.ID + task.CreatedAt = created.CreatedAt + task.UpdatedAt = created.UpdatedAt + return nil +} + +func (r *usageCleanupRepository) createTaskWithSQL(ctx context.Context, task *service.UsageCleanupTask) error { + filtersJSON, err := json.Marshal(task.Filters) + if err != nil { + return fmt.Errorf("marshal cleanup filters: %w", err) + } + query := ` + INSERT INTO usage_cleanup_tasks ( + status, + filters, + created_by, + deleted_rows + ) VALUES ($1, $2, $3, $4) + RETURNING id, created_at, updated_at + ` + if err := scanSingleRow(ctx, r.sql, query, []any{task.Status, filtersJSON, task.CreatedBy, task.DeletedRows}, &task.ID, &task.CreatedAt, &task.UpdatedAt); err != nil { + return err + } + return nil +} + +func (r *usageCleanupRepository) listTasksWithEnt(ctx context.Context, params pagination.PaginationParams) ([]service.UsageCleanupTask, *pagination.PaginationResult, error) { + client := clientFromContext(ctx, r.client) + query := client.UsageCleanupTask.Query() + total, err := query.Clone().Count(ctx) + if err != nil { + return nil, nil, err + } + if total == 0 { + return []service.UsageCleanupTask{}, paginationResultFromTotal(0, params), nil + } + rows, err := query. + Order(dbent.Desc(dbusagecleanuptask.FieldCreatedAt), dbent.Desc(dbusagecleanuptask.FieldID)). + Offset(params.Offset()). + Limit(params.Limit()). + All(ctx) + if err != nil { + return nil, nil, err + } + tasks := make([]service.UsageCleanupTask, 0, len(rows)) + for _, row := range rows { + task, err := usageCleanupTaskFromEnt(row) + if err != nil { + return nil, nil, err + } + tasks = append(tasks, task) + } + return tasks, paginationResultFromTotal(int64(total), params), nil +} + +func (r *usageCleanupRepository) getTaskStatusWithEnt(ctx context.Context, taskID int64) (string, error) { + client := clientFromContext(ctx, r.client) + task, err := client.UsageCleanupTask.Query(). + Where(dbusagecleanuptask.IDEQ(taskID)). + Only(ctx) + if err != nil { + if dbent.IsNotFound(err) { + return "", sql.ErrNoRows + } + return "", err + } + return task.Status, nil +} + +func (r *usageCleanupRepository) updateTaskProgressWithEnt(ctx context.Context, taskID int64, deletedRows int64) error { + client := clientFromContext(ctx, r.client) + now := time.Now() + _, err := client.UsageCleanupTask.Update(). + Where(dbusagecleanuptask.IDEQ(taskID)). + SetDeletedRows(deletedRows). + SetUpdatedAt(now). + Save(ctx) + return err +} + +func (r *usageCleanupRepository) cancelTaskWithEnt(ctx context.Context, taskID int64, canceledBy int64) (bool, error) { + client := clientFromContext(ctx, r.client) + now := time.Now() + affected, err := client.UsageCleanupTask.Update(). + Where( + dbusagecleanuptask.IDEQ(taskID), + dbusagecleanuptask.StatusIn(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning), + ). + SetStatus(service.UsageCleanupStatusCanceled). + SetCanceledBy(canceledBy). + SetCanceledAt(now). + SetFinishedAt(now). + ClearErrorMessage(). + SetUpdatedAt(now). + Save(ctx) + if err != nil { + return false, err + } + return affected > 0, nil +} + +func (r *usageCleanupRepository) markTaskSucceededWithEnt(ctx context.Context, taskID int64, deletedRows int64) error { + client := clientFromContext(ctx, r.client) + now := time.Now() + _, err := client.UsageCleanupTask.Update(). + Where(dbusagecleanuptask.IDEQ(taskID)). + SetStatus(service.UsageCleanupStatusSucceeded). + SetDeletedRows(deletedRows). + SetFinishedAt(now). + SetUpdatedAt(now). + Save(ctx) + return err +} + +func (r *usageCleanupRepository) markTaskFailedWithEnt(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error { + client := clientFromContext(ctx, r.client) + now := time.Now() + _, err := client.UsageCleanupTask.Update(). + Where(dbusagecleanuptask.IDEQ(taskID)). + SetStatus(service.UsageCleanupStatusFailed). + SetDeletedRows(deletedRows). + SetErrorMessage(errorMsg). + SetFinishedAt(now). + SetUpdatedAt(now). + Save(ctx) + return err +} + +func usageCleanupTaskFromEnt(row *dbent.UsageCleanupTask) (service.UsageCleanupTask, error) { + task := service.UsageCleanupTask{ + ID: row.ID, + Status: row.Status, + CreatedBy: row.CreatedBy, + DeletedRows: row.DeletedRows, + CreatedAt: row.CreatedAt, + UpdatedAt: row.UpdatedAt, + } + if len(row.Filters) > 0 { + if err := json.Unmarshal(row.Filters, &task.Filters); err != nil { + return service.UsageCleanupTask{}, fmt.Errorf("parse cleanup filters: %w", err) + } + } + if row.ErrorMessage != nil { + task.ErrorMsg = row.ErrorMessage + } + if row.CanceledBy != nil { + task.CanceledBy = row.CanceledBy + } + if row.CanceledAt != nil { + task.CanceledAt = row.CanceledAt + } + if row.StartedAt != nil { + task.StartedAt = row.StartedAt + } + if row.FinishedAt != nil { + task.FinishedAt = row.FinishedAt + } + return task, nil +} diff --git a/backend/internal/repository/usage_cleanup_repo_ent_test.go b/backend/internal/repository/usage_cleanup_repo_ent_test.go new file mode 100644 index 00000000..6c20b2b9 --- /dev/null +++ b/backend/internal/repository/usage_cleanup_repo_ent_test.go @@ -0,0 +1,251 @@ +package repository + +import ( + "context" + "database/sql" + "encoding/json" + "testing" + "time" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/enttest" + dbusagecleanuptask "github.com/Wei-Shaw/sub2api/ent/usagecleanuptask" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" + + "entgo.io/ent/dialect" + entsql "entgo.io/ent/dialect/sql" + _ "modernc.org/sqlite" +) + +func newUsageCleanupEntRepo(t *testing.T) (*usageCleanupRepository, *dbent.Client) { + t.Helper() + db, err := sql.Open("sqlite", "file:usage_cleanup?mode=memory&cache=shared") + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + _, err = db.Exec("PRAGMA foreign_keys = ON") + require.NoError(t, err) + + drv := entsql.OpenDB(dialect.SQLite, db) + client := enttest.NewClient(t, enttest.WithOptions(dbent.Driver(drv))) + t.Cleanup(func() { _ = client.Close() }) + + repo := &usageCleanupRepository{client: client, sql: db} + return repo, client +} + +func TestUsageCleanupRepositoryEntCreateAndList(t *testing.T) { + repo, _ := newUsageCleanupEntRepo(t) + + start := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + task := &service.UsageCleanupTask{ + Status: service.UsageCleanupStatusPending, + Filters: service.UsageCleanupFilters{StartTime: start, EndTime: end}, + CreatedBy: 9, + } + require.NoError(t, repo.CreateTask(context.Background(), task)) + require.NotZero(t, task.ID) + + task2 := &service.UsageCleanupTask{ + Status: service.UsageCleanupStatusRunning, + Filters: service.UsageCleanupFilters{StartTime: start.Add(-24 * time.Hour), EndTime: end.Add(-24 * time.Hour)}, + CreatedBy: 10, + } + require.NoError(t, repo.CreateTask(context.Background(), task2)) + + tasks, result, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 10}) + require.NoError(t, err) + require.Len(t, tasks, 2) + require.Equal(t, int64(2), result.Total) + require.Greater(t, tasks[0].ID, tasks[1].ID) + require.Equal(t, start, tasks[1].Filters.StartTime) + require.Equal(t, end, tasks[1].Filters.EndTime) +} + +func TestUsageCleanupRepositoryEntListEmpty(t *testing.T) { + repo, _ := newUsageCleanupEntRepo(t) + + tasks, result, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 10}) + require.NoError(t, err) + require.Empty(t, tasks) + require.Equal(t, int64(0), result.Total) +} + +func TestUsageCleanupRepositoryEntGetStatusAndProgress(t *testing.T) { + repo, client := newUsageCleanupEntRepo(t) + + task := &service.UsageCleanupTask{ + Status: service.UsageCleanupStatusPending, + Filters: service.UsageCleanupFilters{StartTime: time.Now().UTC(), EndTime: time.Now().UTC().Add(time.Hour)}, + CreatedBy: 3, + } + require.NoError(t, repo.CreateTask(context.Background(), task)) + + status, err := repo.GetTaskStatus(context.Background(), task.ID) + require.NoError(t, err) + require.Equal(t, service.UsageCleanupStatusPending, status) + + _, err = repo.GetTaskStatus(context.Background(), task.ID+99) + require.ErrorIs(t, err, sql.ErrNoRows) + + require.NoError(t, repo.UpdateTaskProgress(context.Background(), task.ID, 42)) + loaded, err := client.UsageCleanupTask.Get(context.Background(), task.ID) + require.NoError(t, err) + require.Equal(t, int64(42), loaded.DeletedRows) +} + +func TestUsageCleanupRepositoryEntCancelAndFinish(t *testing.T) { + repo, client := newUsageCleanupEntRepo(t) + + task := &service.UsageCleanupTask{ + Status: service.UsageCleanupStatusPending, + Filters: service.UsageCleanupFilters{StartTime: time.Now().UTC(), EndTime: time.Now().UTC().Add(time.Hour)}, + CreatedBy: 5, + } + require.NoError(t, repo.CreateTask(context.Background(), task)) + + ok, err := repo.CancelTask(context.Background(), task.ID, 7) + require.NoError(t, err) + require.True(t, ok) + + loaded, err := client.UsageCleanupTask.Get(context.Background(), task.ID) + require.NoError(t, err) + require.Equal(t, service.UsageCleanupStatusCanceled, loaded.Status) + require.NotNil(t, loaded.CanceledBy) + require.NotNil(t, loaded.CanceledAt) + require.NotNil(t, loaded.FinishedAt) + + loaded.Status = service.UsageCleanupStatusSucceeded + _, err = client.UsageCleanupTask.Update().Where(dbusagecleanuptask.IDEQ(task.ID)).SetStatus(loaded.Status).Save(context.Background()) + require.NoError(t, err) + + ok, err = repo.CancelTask(context.Background(), task.ID, 7) + require.NoError(t, err) + require.False(t, ok) +} + +func TestUsageCleanupRepositoryEntCancelError(t *testing.T) { + repo, client := newUsageCleanupEntRepo(t) + + task := &service.UsageCleanupTask{ + Status: service.UsageCleanupStatusPending, + Filters: service.UsageCleanupFilters{StartTime: time.Now().UTC(), EndTime: time.Now().UTC().Add(time.Hour)}, + CreatedBy: 5, + } + require.NoError(t, repo.CreateTask(context.Background(), task)) + + require.NoError(t, client.Close()) + _, err := repo.CancelTask(context.Background(), task.ID, 7) + require.Error(t, err) +} + +func TestUsageCleanupRepositoryEntMarkResults(t *testing.T) { + repo, client := newUsageCleanupEntRepo(t) + + task := &service.UsageCleanupTask{ + Status: service.UsageCleanupStatusRunning, + Filters: service.UsageCleanupFilters{StartTime: time.Now().UTC(), EndTime: time.Now().UTC().Add(time.Hour)}, + CreatedBy: 12, + } + require.NoError(t, repo.CreateTask(context.Background(), task)) + + require.NoError(t, repo.MarkTaskSucceeded(context.Background(), task.ID, 6)) + loaded, err := client.UsageCleanupTask.Get(context.Background(), task.ID) + require.NoError(t, err) + require.Equal(t, service.UsageCleanupStatusSucceeded, loaded.Status) + require.Equal(t, int64(6), loaded.DeletedRows) + require.NotNil(t, loaded.FinishedAt) + + task2 := &service.UsageCleanupTask{ + Status: service.UsageCleanupStatusRunning, + Filters: service.UsageCleanupFilters{StartTime: time.Now().UTC(), EndTime: time.Now().UTC().Add(time.Hour)}, + CreatedBy: 12, + } + require.NoError(t, repo.CreateTask(context.Background(), task2)) + + require.NoError(t, repo.MarkTaskFailed(context.Background(), task2.ID, 4, "boom")) + loaded2, err := client.UsageCleanupTask.Get(context.Background(), task2.ID) + require.NoError(t, err) + require.Equal(t, service.UsageCleanupStatusFailed, loaded2.Status) + require.Equal(t, "boom", *loaded2.ErrorMessage) +} + +func TestUsageCleanupRepositoryEntInvalidStatus(t *testing.T) { + repo, _ := newUsageCleanupEntRepo(t) + + task := &service.UsageCleanupTask{ + Status: "invalid", + Filters: service.UsageCleanupFilters{StartTime: time.Now().UTC(), EndTime: time.Now().UTC().Add(time.Hour)}, + CreatedBy: 1, + } + require.Error(t, repo.CreateTask(context.Background(), task)) +} + +func TestUsageCleanupRepositoryEntListInvalidFilters(t *testing.T) { + repo, client := newUsageCleanupEntRepo(t) + + now := time.Now().UTC() + driver, ok := client.Driver().(*entsql.Driver) + require.True(t, ok) + _, err := driver.DB().ExecContext( + context.Background(), + `INSERT INTO usage_cleanup_tasks (status, filters, created_by, deleted_rows, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?)`, + service.UsageCleanupStatusPending, + []byte("invalid-json"), + int64(1), + int64(0), + now, + now, + ) + require.NoError(t, err) + + _, _, err = repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 10}) + require.Error(t, err) +} + +func TestUsageCleanupTaskFromEntFull(t *testing.T) { + start := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + errMsg := "failed" + canceledBy := int64(2) + canceledAt := start.Add(time.Minute) + startedAt := start.Add(2 * time.Minute) + finishedAt := start.Add(3 * time.Minute) + filters := service.UsageCleanupFilters{StartTime: start, EndTime: end} + filtersJSON, err := json.Marshal(filters) + require.NoError(t, err) + + task, err := usageCleanupTaskFromEnt(&dbent.UsageCleanupTask{ + ID: 10, + Status: service.UsageCleanupStatusFailed, + Filters: filtersJSON, + CreatedBy: 11, + DeletedRows: 7, + ErrorMessage: &errMsg, + CanceledBy: &canceledBy, + CanceledAt: &canceledAt, + StartedAt: &startedAt, + FinishedAt: &finishedAt, + CreatedAt: start, + UpdatedAt: end, + }) + require.NoError(t, err) + require.Equal(t, int64(10), task.ID) + require.Equal(t, service.UsageCleanupStatusFailed, task.Status) + require.NotNil(t, task.ErrorMsg) + require.NotNil(t, task.CanceledBy) + require.NotNil(t, task.CanceledAt) + require.NotNil(t, task.StartedAt) + require.NotNil(t, task.FinishedAt) +} + +func TestUsageCleanupTaskFromEntInvalidFilters(t *testing.T) { + task, err := usageCleanupTaskFromEnt(&dbent.UsageCleanupTask{ + Filters: json.RawMessage("invalid-json"), + }) + require.Error(t, err) + require.Empty(t, task) +} diff --git a/backend/internal/repository/usage_cleanup_repo_test.go b/backend/internal/repository/usage_cleanup_repo_test.go new file mode 100644 index 00000000..0ca30ec7 --- /dev/null +++ b/backend/internal/repository/usage_cleanup_repo_test.go @@ -0,0 +1,482 @@ +package repository + +import ( + "context" + "database/sql" + "encoding/json" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func newSQLMock(t *testing.T) (*sql.DB, sqlmock.Sqlmock) { + t.Helper() + db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherRegexp)) + require.NoError(t, err) + t.Cleanup(func() { _ = db.Close() }) + return db, mock +} + +func TestNewUsageCleanupRepository(t *testing.T) { + db, _ := newSQLMock(t) + repo := NewUsageCleanupRepository(nil, db) + require.NotNil(t, repo) +} + +func TestUsageCleanupRepositoryCreateTask(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + task := &service.UsageCleanupTask{ + Status: service.UsageCleanupStatusPending, + Filters: service.UsageCleanupFilters{StartTime: start, EndTime: end}, + CreatedBy: 12, + } + now := time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC) + + mock.ExpectQuery("INSERT INTO usage_cleanup_tasks"). + WithArgs(task.Status, sqlmock.AnyArg(), task.CreatedBy, task.DeletedRows). + WillReturnRows(sqlmock.NewRows([]string{"id", "created_at", "updated_at"}).AddRow(int64(1), now, now)) + + err := repo.CreateTask(context.Background(), task) + require.NoError(t, err) + require.Equal(t, int64(1), task.ID) + require.Equal(t, now, task.CreatedAt) + require.Equal(t, now, task.UpdatedAt) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryCreateTaskNil(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + err := repo.CreateTask(context.Background(), nil) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryCreateTaskQueryError(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + task := &service.UsageCleanupTask{ + Status: service.UsageCleanupStatusPending, + Filters: service.UsageCleanupFilters{StartTime: time.Now(), EndTime: time.Now().Add(time.Hour)}, + CreatedBy: 1, + } + + mock.ExpectQuery("INSERT INTO usage_cleanup_tasks"). + WithArgs(task.Status, sqlmock.AnyArg(), task.CreatedBy, task.DeletedRows). + WillReturnError(sql.ErrConnDone) + + err := repo.CreateTask(context.Background(), task) + require.Error(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryListTasksEmpty(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_cleanup_tasks"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(0))) + + tasks, result, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20}) + require.NoError(t, err) + require.Empty(t, tasks) + require.Equal(t, int64(0), result.Total) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryListTasks(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(2 * time.Hour) + filters := service.UsageCleanupFilters{StartTime: start, EndTime: end} + filtersJSON, err := json.Marshal(filters) + require.NoError(t, err) + + createdAt := time.Date(2024, 1, 2, 12, 0, 0, 0, time.UTC) + updatedAt := createdAt.Add(time.Minute) + rows := sqlmock.NewRows([]string{ + "id", "status", "filters", "created_by", "deleted_rows", "error_message", + "canceled_by", "canceled_at", + "started_at", "finished_at", "created_at", "updated_at", + }).AddRow( + int64(1), + service.UsageCleanupStatusSucceeded, + filtersJSON, + int64(2), + int64(9), + "error", + nil, + nil, + start, + end, + createdAt, + updatedAt, + ) + + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_cleanup_tasks"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(1))) + mock.ExpectQuery("SELECT id, status, filters, created_by, deleted_rows, error_message"). + WithArgs(20, 0). + WillReturnRows(rows) + + tasks, result, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20}) + require.NoError(t, err) + require.Len(t, tasks, 1) + require.Equal(t, int64(1), tasks[0].ID) + require.Equal(t, service.UsageCleanupStatusSucceeded, tasks[0].Status) + require.Equal(t, int64(2), tasks[0].CreatedBy) + require.Equal(t, int64(9), tasks[0].DeletedRows) + require.NotNil(t, tasks[0].ErrorMsg) + require.Equal(t, "error", *tasks[0].ErrorMsg) + require.NotNil(t, tasks[0].StartedAt) + require.NotNil(t, tasks[0].FinishedAt) + require.Equal(t, int64(1), result.Total) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryListTasksQueryError(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_cleanup_tasks"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(2))) + mock.ExpectQuery("SELECT id, status, filters, created_by, deleted_rows, error_message"). + WithArgs(20, 0). + WillReturnError(sql.ErrConnDone) + + _, _, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20}) + require.Error(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryListTasksInvalidFilters(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + rows := sqlmock.NewRows([]string{ + "id", "status", "filters", "created_by", "deleted_rows", "error_message", + "canceled_by", "canceled_at", + "started_at", "finished_at", "created_at", "updated_at", + }).AddRow( + int64(1), + service.UsageCleanupStatusSucceeded, + []byte("not-json"), + int64(2), + int64(9), + nil, + nil, + nil, + nil, + nil, + time.Now().UTC(), + time.Now().UTC(), + ) + + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM usage_cleanup_tasks"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(int64(1))) + mock.ExpectQuery("SELECT id, status, filters, created_by, deleted_rows, error_message"). + WithArgs(20, 0). + WillReturnRows(rows) + + _, _, err := repo.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20}) + require.Error(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryClaimNextPendingTaskNone(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + mock.ExpectQuery("UPDATE usage_cleanup_tasks"). + WithArgs(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning, int64(1800), service.UsageCleanupStatusRunning). + WillReturnRows(sqlmock.NewRows([]string{ + "id", "status", "filters", "created_by", "deleted_rows", "error_message", + "started_at", "finished_at", "created_at", "updated_at", + })) + + task, err := repo.ClaimNextPendingTask(context.Background(), 1800) + require.NoError(t, err) + require.Nil(t, task) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryClaimNextPendingTask(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + filters := service.UsageCleanupFilters{StartTime: start, EndTime: end} + filtersJSON, err := json.Marshal(filters) + require.NoError(t, err) + + rows := sqlmock.NewRows([]string{ + "id", "status", "filters", "created_by", "deleted_rows", "error_message", + "started_at", "finished_at", "created_at", "updated_at", + }).AddRow( + int64(4), + service.UsageCleanupStatusRunning, + filtersJSON, + int64(7), + int64(0), + nil, + start, + nil, + start, + start, + ) + + mock.ExpectQuery("UPDATE usage_cleanup_tasks"). + WithArgs(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning, int64(1800), service.UsageCleanupStatusRunning). + WillReturnRows(rows) + + task, err := repo.ClaimNextPendingTask(context.Background(), 1800) + require.NoError(t, err) + require.NotNil(t, task) + require.Equal(t, int64(4), task.ID) + require.Equal(t, service.UsageCleanupStatusRunning, task.Status) + require.Equal(t, int64(7), task.CreatedBy) + require.NotNil(t, task.StartedAt) + require.Nil(t, task.ErrorMsg) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryClaimNextPendingTaskError(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + mock.ExpectQuery("UPDATE usage_cleanup_tasks"). + WithArgs(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning, int64(1800), service.UsageCleanupStatusRunning). + WillReturnError(sql.ErrConnDone) + + _, err := repo.ClaimNextPendingTask(context.Background(), 1800) + require.Error(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryClaimNextPendingTaskInvalidFilters(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + rows := sqlmock.NewRows([]string{ + "id", "status", "filters", "created_by", "deleted_rows", "error_message", + "started_at", "finished_at", "created_at", "updated_at", + }).AddRow( + int64(4), + service.UsageCleanupStatusRunning, + []byte("invalid"), + int64(7), + int64(0), + nil, + nil, + nil, + time.Now().UTC(), + time.Now().UTC(), + ) + + mock.ExpectQuery("UPDATE usage_cleanup_tasks"). + WithArgs(service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning, int64(1800), service.UsageCleanupStatusRunning). + WillReturnRows(rows) + + _, err := repo.ClaimNextPendingTask(context.Background(), 1800) + require.Error(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryMarkTaskSucceeded(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + mock.ExpectExec("UPDATE usage_cleanup_tasks"). + WithArgs(service.UsageCleanupStatusSucceeded, int64(12), int64(9)). + WillReturnResult(sqlmock.NewResult(0, 1)) + + err := repo.MarkTaskSucceeded(context.Background(), 9, 12) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryMarkTaskFailed(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + mock.ExpectExec("UPDATE usage_cleanup_tasks"). + WithArgs(service.UsageCleanupStatusFailed, int64(4), "boom", int64(2)). + WillReturnResult(sqlmock.NewResult(0, 1)) + + err := repo.MarkTaskFailed(context.Background(), 2, 4, "boom") + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryGetTaskStatus(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + mock.ExpectQuery("SELECT status FROM usage_cleanup_tasks"). + WithArgs(int64(9)). + WillReturnRows(sqlmock.NewRows([]string{"status"}).AddRow(service.UsageCleanupStatusPending)) + + status, err := repo.GetTaskStatus(context.Background(), 9) + require.NoError(t, err) + require.Equal(t, service.UsageCleanupStatusPending, status) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryGetTaskStatusQueryError(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + mock.ExpectQuery("SELECT status FROM usage_cleanup_tasks"). + WithArgs(int64(9)). + WillReturnError(sql.ErrConnDone) + + _, err := repo.GetTaskStatus(context.Background(), 9) + require.Error(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryUpdateTaskProgress(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + mock.ExpectExec("UPDATE usage_cleanup_tasks"). + WithArgs(int64(123), int64(8)). + WillReturnResult(sqlmock.NewResult(0, 1)) + + err := repo.UpdateTaskProgress(context.Background(), 8, 123) + require.NoError(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryCancelTask(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + mock.ExpectQuery("UPDATE usage_cleanup_tasks"). + WithArgs(service.UsageCleanupStatusCanceled, int64(6), int64(9), service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(int64(6))) + + ok, err := repo.CancelTask(context.Background(), 6, 9) + require.NoError(t, err) + require.True(t, ok) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryCancelTaskNoRows(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + mock.ExpectQuery("UPDATE usage_cleanup_tasks"). + WithArgs(service.UsageCleanupStatusCanceled, int64(6), int64(9), service.UsageCleanupStatusPending, service.UsageCleanupStatusRunning). + WillReturnRows(sqlmock.NewRows([]string{"id"})) + + ok, err := repo.CancelTask(context.Background(), 6, 9) + require.NoError(t, err) + require.False(t, ok) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryDeleteUsageLogsBatchMissingRange(t *testing.T) { + db, _ := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + _, err := repo.DeleteUsageLogsBatch(context.Background(), service.UsageCleanupFilters{}, 10) + require.Error(t, err) +} + +func TestUsageCleanupRepositoryDeleteUsageLogsBatch(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + userID := int64(3) + model := " gpt-4 " + filters := service.UsageCleanupFilters{ + StartTime: start, + EndTime: end, + UserID: &userID, + Model: &model, + } + + mock.ExpectQuery("DELETE FROM usage_logs"). + WithArgs(start, end, userID, "gpt-4", 2). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(int64(1)).AddRow(int64(2))) + + deleted, err := repo.DeleteUsageLogsBatch(context.Background(), filters, 2) + require.NoError(t, err) + require.Equal(t, int64(2), deleted) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestUsageCleanupRepositoryDeleteUsageLogsBatchQueryError(t *testing.T) { + db, mock := newSQLMock(t) + repo := &usageCleanupRepository{sql: db} + + start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + filters := service.UsageCleanupFilters{StartTime: start, EndTime: end} + + mock.ExpectQuery("DELETE FROM usage_logs"). + WithArgs(start, end, 5). + WillReturnError(sql.ErrConnDone) + + _, err := repo.DeleteUsageLogsBatch(context.Background(), filters, 5) + require.Error(t, err) + require.NoError(t, mock.ExpectationsWereMet()) +} + +func TestBuildUsageCleanupWhere(t *testing.T) { + start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + userID := int64(1) + apiKeyID := int64(2) + accountID := int64(3) + groupID := int64(4) + model := " gpt-4 " + stream := true + billingType := int8(2) + + where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{ + StartTime: start, + EndTime: end, + UserID: &userID, + APIKeyID: &apiKeyID, + AccountID: &accountID, + GroupID: &groupID, + Model: &model, + Stream: &stream, + BillingType: &billingType, + }) + + require.Equal(t, "created_at >= $1 AND created_at <= $2 AND user_id = $3 AND api_key_id = $4 AND account_id = $5 AND group_id = $6 AND model = $7 AND stream = $8 AND billing_type = $9", where) + require.Equal(t, []any{start, end, userID, apiKeyID, accountID, groupID, "gpt-4", stream, billingType}, args) +} + +func TestBuildUsageCleanupWhereModelEmpty(t *testing.T) { + start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + model := " " + + where, args := buildUsageCleanupWhere(service.UsageCleanupFilters{ + StartTime: start, + EndTime: end, + Model: &model, + }) + + require.Equal(t, "created_at >= $1 AND created_at <= $2", where) + require.Equal(t, []any{start, end}, args) +} diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 4a2aaade..963db7ba 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -1411,7 +1411,7 @@ func (r *usageLogRepository) GetBatchAPIKeyUsageStats(ctx context.Context, apiKe } // GetUsageTrendWithFilters returns usage trend data with optional filters -func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) (results []TrendDataPoint, err error) { +func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) (results []TrendDataPoint, err error) { dateFormat := "YYYY-MM-DD" if granularity == "hour" { dateFormat = "YYYY-MM-DD HH24:00" @@ -1456,6 +1456,10 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start query += fmt.Sprintf(" AND stream = $%d", len(args)+1) args = append(args, *stream) } + if billingType != nil { + query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) + args = append(args, int16(*billingType)) + } query += " GROUP BY date ORDER BY date ASC" rows, err := r.sql.QueryContext(ctx, query, args...) @@ -1479,7 +1483,7 @@ func (r *usageLogRepository) GetUsageTrendWithFilters(ctx context.Context, start } // GetModelStatsWithFilters returns model statistics with optional filters -func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) (results []ModelStat, err error) { +func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) (results []ModelStat, err error) { actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost" // 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。 if accountID > 0 && userID == 0 && apiKeyID == 0 { @@ -1520,6 +1524,10 @@ func (r *usageLogRepository) GetModelStatsWithFilters(ctx context.Context, start query += fmt.Sprintf(" AND stream = $%d", len(args)+1) args = append(args, *stream) } + if billingType != nil { + query += fmt.Sprintf(" AND billing_type = $%d", len(args)+1) + args = append(args, int16(*billingType)) + } query += " GROUP BY model ORDER BY total_tokens DESC" rows, err := r.sql.QueryContext(ctx, query, args...) @@ -1825,7 +1833,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID } } - models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil) + models, err := r.GetModelStatsWithFilters(ctx, startTime, endTime, 0, 0, accountID, 0, nil, nil) if err != nil { models = []ModelStat{} } diff --git a/backend/internal/repository/usage_log_repo_integration_test.go b/backend/internal/repository/usage_log_repo_integration_test.go index 7174be18..eb220f22 100644 --- a/backend/internal/repository/usage_log_repo_integration_test.go +++ b/backend/internal/repository/usage_log_repo_integration_test.go @@ -944,17 +944,17 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters() { endTime := base.Add(48 * time.Hour) // Test with user filter - trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil) + trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, 0, 0, 0, "", nil, nil) s.Require().NoError(err, "GetUsageTrendWithFilters user filter") s.Require().Len(trend, 2) // Test with apiKey filter - trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil) + trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", 0, apiKey.ID, 0, 0, "", nil, nil) s.Require().NoError(err, "GetUsageTrendWithFilters apiKey filter") s.Require().Len(trend, 2) // Test with both filters - trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil) + trend, err = s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "day", user.ID, apiKey.ID, 0, 0, "", nil, nil) s.Require().NoError(err, "GetUsageTrendWithFilters both filters") s.Require().Len(trend, 2) } @@ -971,7 +971,7 @@ func (s *UsageLogRepoSuite) TestGetUsageTrendWithFilters_HourlyGranularity() { startTime := base.Add(-1 * time.Hour) endTime := base.Add(3 * time.Hour) - trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil) + trend, err := s.repo.GetUsageTrendWithFilters(s.ctx, startTime, endTime, "hour", user.ID, 0, 0, 0, "", nil, nil) s.Require().NoError(err, "GetUsageTrendWithFilters hourly") s.Require().Len(trend, 2) } @@ -1017,17 +1017,17 @@ func (s *UsageLogRepoSuite) TestGetModelStatsWithFilters() { endTime := base.Add(2 * time.Hour) // Test with user filter - stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil) + stats, err := s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, user.ID, 0, 0, 0, nil, nil) s.Require().NoError(err, "GetModelStatsWithFilters user filter") s.Require().Len(stats, 2) // Test with apiKey filter - stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil) + stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, apiKey.ID, 0, 0, nil, nil) s.Require().NoError(err, "GetModelStatsWithFilters apiKey filter") s.Require().Len(stats, 2) // Test with account filter - stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil) + stats, err = s.repo.GetModelStatsWithFilters(s.ctx, startTime, endTime, 0, 0, account.ID, 0, nil, nil) s.Require().NoError(err, "GetModelStatsWithFilters account filter") s.Require().Len(stats, 2) } diff --git a/backend/internal/repository/wire.go b/backend/internal/repository/wire.go index 77ed37e1..7a8d85f4 100644 --- a/backend/internal/repository/wire.go +++ b/backend/internal/repository/wire.go @@ -57,6 +57,7 @@ var ProviderSet = wire.NewSet( NewRedeemCodeRepository, NewPromoCodeRepository, NewUsageLogRepository, + NewUsageCleanupRepository, NewDashboardAggregationRepository, NewSettingRepository, NewOpsRepository, diff --git a/backend/internal/server/api_contract_test.go b/backend/internal/server/api_contract_test.go index 356b4a4e..4ce58942 100644 --- a/backend/internal/server/api_contract_test.go +++ b/backend/internal/server/api_contract_test.go @@ -51,7 +51,6 @@ func TestAPIContracts(t *testing.T) { "id": 1, "email": "alice@example.com", "username": "alice", - "notes": "hello", "role": "user", "balance": 12.5, "concurrency": 5, @@ -131,6 +130,153 @@ func TestAPIContracts(t *testing.T) { } }`, }, + { + name: "GET /api/v1/groups/available", + setup: func(t *testing.T, deps *contractDeps) { + t.Helper() + // 普通用户可见的分组列表不应包含内部字段(如 model_routing/account_count)。 + deps.groupRepo.SetActive([]service.Group{ + { + ID: 10, + Name: "Group One", + Description: "desc", + Platform: service.PlatformAnthropic, + RateMultiplier: 1.5, + IsExclusive: false, + Status: service.StatusActive, + SubscriptionType: service.SubscriptionTypeStandard, + ModelRoutingEnabled: true, + ModelRouting: map[string][]int64{ + "claude-3-*": []int64{101, 102}, + }, + AccountCount: 2, + CreatedAt: deps.now, + UpdatedAt: deps.now, + }, + }) + deps.userSubRepo.SetActiveByUserID(1, nil) + }, + method: http.MethodGet, + path: "/api/v1/groups/available", + wantStatus: http.StatusOK, + wantJSON: `{ + "code": 0, + "message": "success", + "data": [ + { + "id": 10, + "name": "Group One", + "description": "desc", + "platform": "anthropic", + "rate_multiplier": 1.5, + "is_exclusive": false, + "status": "active", + "subscription_type": "standard", + "daily_limit_usd": null, + "weekly_limit_usd": null, + "monthly_limit_usd": null, + "image_price_1k": null, + "image_price_2k": null, + "image_price_4k": null, + "claude_code_only": false, + "fallback_group_id": null, + "created_at": "2025-01-02T03:04:05Z", + "updated_at": "2025-01-02T03:04:05Z" + } + ] + }`, + }, + { + name: "GET /api/v1/subscriptions", + setup: func(t *testing.T, deps *contractDeps) { + t.Helper() + // 普通用户订阅接口不应包含 assigned_* / notes 等管理员字段。 + deps.userSubRepo.SetByUserID(1, []service.UserSubscription{ + { + ID: 501, + UserID: 1, + GroupID: 10, + StartsAt: deps.now, + ExpiresAt: deps.now.Add(24 * time.Hour), + Status: service.SubscriptionStatusActive, + DailyUsageUSD: 1.23, + WeeklyUsageUSD: 2.34, + MonthlyUsageUSD: 3.45, + AssignedBy: ptr(int64(999)), + AssignedAt: deps.now, + Notes: "admin-note", + CreatedAt: deps.now, + UpdatedAt: deps.now, + }, + }) + }, + method: http.MethodGet, + path: "/api/v1/subscriptions", + wantStatus: http.StatusOK, + wantJSON: `{ + "code": 0, + "message": "success", + "data": [ + { + "id": 501, + "user_id": 1, + "group_id": 10, + "starts_at": "2025-01-02T03:04:05Z", + "expires_at": "2025-01-03T03:04:05Z", + "status": "active", + "daily_window_start": null, + "weekly_window_start": null, + "monthly_window_start": null, + "daily_usage_usd": 1.23, + "weekly_usage_usd": 2.34, + "monthly_usage_usd": 3.45, + "created_at": "2025-01-02T03:04:05Z", + "updated_at": "2025-01-02T03:04:05Z" + } + ] + }`, + }, + { + name: "GET /api/v1/redeem/history", + setup: func(t *testing.T, deps *contractDeps) { + t.Helper() + // 普通用户兑换历史不应包含 notes 等内部字段。 + deps.redeemRepo.SetByUser(1, []service.RedeemCode{ + { + ID: 900, + Code: "CODE-123", + Type: service.RedeemTypeBalance, + Value: 1.25, + Status: service.StatusUsed, + UsedBy: ptr(int64(1)), + UsedAt: ptr(deps.now), + Notes: "internal-note", + CreatedAt: deps.now, + }, + }) + }, + method: http.MethodGet, + path: "/api/v1/redeem/history", + wantStatus: http.StatusOK, + wantJSON: `{ + "code": 0, + "message": "success", + "data": [ + { + "id": 900, + "code": "CODE-123", + "type": "balance", + "value": 1.25, + "status": "used", + "used_by": 1, + "used_at": "2025-01-02T03:04:05Z", + "created_at": "2025-01-02T03:04:05Z", + "group_id": null, + "validity_days": 0 + } + ] + }`, + }, { name: "GET /api/v1/usage/stats", setup: func(t *testing.T, deps *contractDeps) { @@ -190,24 +336,25 @@ func TestAPIContracts(t *testing.T) { t.Helper() deps.usageRepo.SetUserLogs(1, []service.UsageLog{ { - ID: 1, - UserID: 1, - APIKeyID: 100, - AccountID: 200, - RequestID: "req_123", - Model: "claude-3", - InputTokens: 10, - OutputTokens: 20, - CacheCreationTokens: 1, - CacheReadTokens: 2, - TotalCost: 0.5, - ActualCost: 0.5, - RateMultiplier: 1, - BillingType: service.BillingTypeBalance, - Stream: true, - DurationMs: ptr(100), - FirstTokenMs: ptr(50), - CreatedAt: deps.now, + ID: 1, + UserID: 1, + APIKeyID: 100, + AccountID: 200, + AccountRateMultiplier: ptr(0.5), + RequestID: "req_123", + Model: "claude-3", + InputTokens: 10, + OutputTokens: 20, + CacheCreationTokens: 1, + CacheReadTokens: 2, + TotalCost: 0.5, + ActualCost: 0.5, + RateMultiplier: 1, + BillingType: service.BillingTypeBalance, + Stream: true, + DurationMs: ptr(100), + FirstTokenMs: ptr(50), + CreatedAt: deps.now, }, }) }, @@ -238,10 +385,9 @@ func TestAPIContracts(t *testing.T) { "output_cost": 0, "cache_creation_cost": 0, "cache_read_cost": 0, - "total_cost": 0.5, + "total_cost": 0.5, "actual_cost": 0.5, "rate_multiplier": 1, - "account_rate_multiplier": null, "billing_type": 0, "stream": true, "duration_ms": 100, @@ -337,7 +483,8 @@ func TestAPIContracts(t *testing.T) { "fallback_model_openai": "gpt-4o", "enable_identity_patch": true, "identity_patch_prompt": "", - "home_content": "" + "home_content": "", + "hide_ccs_import_button": false } }`, }, @@ -385,8 +532,11 @@ type contractDeps struct { now time.Time router http.Handler apiKeyRepo *stubApiKeyRepo + groupRepo *stubGroupRepo + userSubRepo *stubUserSubscriptionRepo usageRepo *stubUsageLogRepo settingRepo *stubSettingRepo + redeemRepo *stubRedeemCodeRepo } func newContractDeps(t *testing.T) *contractDeps { @@ -414,11 +564,11 @@ func newContractDeps(t *testing.T) *contractDeps { apiKeyRepo := newStubApiKeyRepo(now) apiKeyCache := stubApiKeyCache{} - groupRepo := stubGroupRepo{} - userSubRepo := stubUserSubscriptionRepo{} + groupRepo := &stubGroupRepo{} + userSubRepo := &stubUserSubscriptionRepo{} accountRepo := stubAccountRepo{} proxyRepo := stubProxyRepo{} - redeemRepo := stubRedeemCodeRepo{} + redeemRepo := &stubRedeemCodeRepo{} cfg := &config.Config{ Default: config.DefaultConfig{ @@ -433,6 +583,12 @@ func newContractDeps(t *testing.T) *contractDeps { usageRepo := newStubUsageLogRepo() usageService := service.NewUsageService(usageRepo, userRepo, nil, nil) + subscriptionService := service.NewSubscriptionService(groupRepo, userSubRepo, nil) + subscriptionHandler := handler.NewSubscriptionHandler(subscriptionService) + + redeemService := service.NewRedeemService(redeemRepo, userRepo, subscriptionService, nil, nil, nil, nil) + redeemHandler := handler.NewRedeemHandler(redeemService) + settingRepo := newStubSettingRepo() settingService := service.NewSettingService(settingRepo, cfg) @@ -441,7 +597,7 @@ func newContractDeps(t *testing.T) *contractDeps { apiKeyHandler := handler.NewAPIKeyHandler(apiKeyService) usageHandler := handler.NewUsageHandler(usageService, apiKeyService) adminSettingHandler := adminhandler.NewSettingHandler(settingService, nil, nil, nil) - adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) + adminAccountHandler := adminhandler.NewAccountHandler(adminService, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil, nil) jwtAuth := func(c *gin.Context) { c.Set(string(middleware.ContextKeyUser), middleware.AuthSubject{ @@ -472,12 +628,21 @@ func newContractDeps(t *testing.T) *contractDeps { v1Keys.Use(jwtAuth) v1Keys.GET("/keys", apiKeyHandler.List) v1Keys.POST("/keys", apiKeyHandler.Create) + v1Keys.GET("/groups/available", apiKeyHandler.GetAvailableGroups) v1Usage := v1.Group("") v1Usage.Use(jwtAuth) v1Usage.GET("/usage", usageHandler.List) v1Usage.GET("/usage/stats", usageHandler.Stats) + v1Subs := v1.Group("") + v1Subs.Use(jwtAuth) + v1Subs.GET("/subscriptions", subscriptionHandler.List) + + v1Redeem := v1.Group("") + v1Redeem.Use(jwtAuth) + v1Redeem.GET("/redeem/history", redeemHandler.GetHistory) + v1Admin := v1.Group("/admin") v1Admin.Use(adminAuth) v1Admin.GET("/settings", adminSettingHandler.GetSettings) @@ -487,8 +652,11 @@ func newContractDeps(t *testing.T) *contractDeps { now: now, router: r, apiKeyRepo: apiKeyRepo, + groupRepo: groupRepo, + userSubRepo: userSubRepo, usageRepo: usageRepo, settingRepo: settingRepo, + redeemRepo: redeemRepo, } } @@ -618,7 +786,21 @@ func (stubApiKeyCache) DeleteAuthCache(ctx context.Context, key string) error { return nil } -type stubGroupRepo struct{} +func (stubApiKeyCache) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error { + return nil +} + +func (stubApiKeyCache) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error { + return nil +} + +type stubGroupRepo struct { + active []service.Group +} + +func (r *stubGroupRepo) SetActive(groups []service.Group) { + r.active = append([]service.Group(nil), groups...) +} func (stubGroupRepo) Create(ctx context.Context, group *service.Group) error { return errors.New("not implemented") @@ -652,12 +834,19 @@ func (stubGroupRepo) ListWithFilters(ctx context.Context, params pagination.Pagi return nil, nil, errors.New("not implemented") } -func (stubGroupRepo) ListActive(ctx context.Context) ([]service.Group, error) { - return nil, errors.New("not implemented") +func (r *stubGroupRepo) ListActive(ctx context.Context) ([]service.Group, error) { + return append([]service.Group(nil), r.active...), nil } -func (stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) { - return nil, errors.New("not implemented") +func (r *stubGroupRepo) ListActiveByPlatform(ctx context.Context, platform string) ([]service.Group, error) { + out := make([]service.Group, 0, len(r.active)) + for i := range r.active { + g := r.active[i] + if g.Platform == platform { + out = append(out, g) + } + } + return out, nil } func (stubGroupRepo) ExistsByName(ctx context.Context, name string) (bool, error) { @@ -736,6 +925,10 @@ func (s *stubAccountRepo) SetError(ctx context.Context, id int64, errorMsg strin return errors.New("not implemented") } +func (s *stubAccountRepo) ClearError(ctx context.Context, id int64) error { + return errors.New("not implemented") +} + func (s *stubAccountRepo) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { return errors.New("not implemented") } @@ -871,7 +1064,16 @@ func (stubProxyRepo) ListAccountSummariesByProxyID(ctx context.Context, proxyID return nil, errors.New("not implemented") } -type stubRedeemCodeRepo struct{} +type stubRedeemCodeRepo struct { + byUser map[int64][]service.RedeemCode +} + +func (r *stubRedeemCodeRepo) SetByUser(userID int64, codes []service.RedeemCode) { + if r.byUser == nil { + r.byUser = make(map[int64][]service.RedeemCode) + } + r.byUser[userID] = append([]service.RedeemCode(nil), codes...) +} func (stubRedeemCodeRepo) Create(ctx context.Context, code *service.RedeemCode) error { return errors.New("not implemented") @@ -909,11 +1111,35 @@ func (stubRedeemCodeRepo) ListWithFilters(ctx context.Context, params pagination return nil, nil, errors.New("not implemented") } -func (stubRedeemCodeRepo) ListByUser(ctx context.Context, userID int64, limit int) ([]service.RedeemCode, error) { - return nil, errors.New("not implemented") +func (r *stubRedeemCodeRepo) ListByUser(ctx context.Context, userID int64, limit int) ([]service.RedeemCode, error) { + if r.byUser == nil { + return nil, nil + } + codes := r.byUser[userID] + if limit > 0 && len(codes) > limit { + codes = codes[:limit] + } + return append([]service.RedeemCode(nil), codes...), nil } -type stubUserSubscriptionRepo struct{} +type stubUserSubscriptionRepo struct { + byUser map[int64][]service.UserSubscription + activeByUser map[int64][]service.UserSubscription +} + +func (r *stubUserSubscriptionRepo) SetByUserID(userID int64, subs []service.UserSubscription) { + if r.byUser == nil { + r.byUser = make(map[int64][]service.UserSubscription) + } + r.byUser[userID] = append([]service.UserSubscription(nil), subs...) +} + +func (r *stubUserSubscriptionRepo) SetActiveByUserID(userID int64, subs []service.UserSubscription) { + if r.activeByUser == nil { + r.activeByUser = make(map[int64][]service.UserSubscription) + } + r.activeByUser[userID] = append([]service.UserSubscription(nil), subs...) +} func (stubUserSubscriptionRepo) Create(ctx context.Context, sub *service.UserSubscription) error { return errors.New("not implemented") @@ -933,11 +1159,17 @@ func (stubUserSubscriptionRepo) Update(ctx context.Context, sub *service.UserSub func (stubUserSubscriptionRepo) Delete(ctx context.Context, id int64) error { return errors.New("not implemented") } -func (stubUserSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) { - return nil, errors.New("not implemented") +func (r *stubUserSubscriptionRepo) ListByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) { + if r.byUser == nil { + return nil, nil + } + return append([]service.UserSubscription(nil), r.byUser[userID]...), nil } -func (stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) { - return nil, errors.New("not implemented") +func (r *stubUserSubscriptionRepo) ListActiveByUserID(ctx context.Context, userID int64) ([]service.UserSubscription, error) { + if r.activeByUser == nil { + return nil, nil + } + return append([]service.UserSubscription(nil), r.activeByUser[userID]...), nil } func (stubUserSubscriptionRepo) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]service.UserSubscription, *pagination.PaginationResult, error) { return nil, nil, errors.New("not implemented") @@ -1242,11 +1474,11 @@ func (r *stubUsageLogRepo) GetDashboardStats(ctx context.Context) (*usagestats.D return nil, errors.New("not implemented") } -func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) { +func (r *stubUsageLogRepo) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) { return nil, errors.New("not implemented") } -func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) { +func (r *stubUsageLogRepo) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) { return nil, errors.New("not implemented") } diff --git a/backend/internal/server/routes/admin.go b/backend/internal/server/routes/admin.go index ff05b32a..050e724d 100644 --- a/backend/internal/server/routes/admin.go +++ b/backend/internal/server/routes/admin.go @@ -354,6 +354,9 @@ func registerUsageRoutes(admin *gin.RouterGroup, h *handler.Handlers) { usage.GET("/stats", h.Admin.Usage.Stats) usage.GET("/search-users", h.Admin.Usage.SearchUsers) usage.GET("/search-api-keys", h.Admin.Usage.SearchAPIKeys) + usage.GET("/cleanup-tasks", h.Admin.Usage.ListCleanupTasks) + usage.POST("/cleanup-tasks", h.Admin.Usage.CreateCleanupTask) + usage.POST("/cleanup-tasks/:id/cancel", h.Admin.Usage.CancelCleanupTask) } } diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 4fda300e..e710560f 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -592,6 +592,44 @@ func (a *Account) IsAnthropicOAuthOrSetupToken() bool { return a.Platform == PlatformAnthropic && (a.Type == AccountTypeOAuth || a.Type == AccountTypeSetupToken) } +// IsTLSFingerprintEnabled 检查是否启用 TLS 指纹伪装 +// 仅适用于 Anthropic OAuth/SetupToken 类型账号 +// 启用后将模拟 Claude Code (Node.js) 客户端的 TLS 握手特征 +func (a *Account) IsTLSFingerprintEnabled() bool { + // 仅支持 Anthropic OAuth/SetupToken 账号 + if !a.IsAnthropicOAuthOrSetupToken() { + return false + } + if a.Extra == nil { + return false + } + if v, ok := a.Extra["enable_tls_fingerprint"]; ok { + if enabled, ok := v.(bool); ok { + return enabled + } + } + return false +} + +// IsSessionIDMaskingEnabled 检查是否启用会话ID伪装 +// 仅适用于 Anthropic OAuth/SetupToken 类型账号 +// 启用后将在一段时间内(15分钟)固定 metadata.user_id 中的 session ID, +// 使上游认为请求来自同一个会话 +func (a *Account) IsSessionIDMaskingEnabled() bool { + if !a.IsAnthropicOAuthOrSetupToken() { + return false + } + if a.Extra == nil { + return false + } + if v, ok := a.Extra["session_id_masking_enabled"]; ok { + if enabled, ok := v.(bool); ok { + return enabled + } + } + return false +} + // GetWindowCostLimit 获取 5h 窗口费用阈值(美元) // 返回 0 表示未启用 func (a *Account) GetWindowCostLimit() float64 { @@ -668,6 +706,23 @@ func (a *Account) CheckWindowCostSchedulability(currentWindowCost float64) Windo return WindowCostNotSchedulable } +// GetCurrentWindowStartTime 获取当前有效的窗口开始时间 +// 逻辑: +// 1. 如果窗口未过期(SessionWindowEnd 存在且在当前时间之后),使用记录的 SessionWindowStart +// 2. 否则(窗口过期或未设置),使用新的预测窗口开始时间(从当前整点开始) +func (a *Account) GetCurrentWindowStartTime() time.Time { + now := time.Now() + + // 窗口未过期,使用记录的窗口开始时间 + if a.SessionWindowStart != nil && a.SessionWindowEnd != nil && now.Before(*a.SessionWindowEnd) { + return *a.SessionWindowStart + } + + // 窗口已过期或未设置,预测新的窗口开始时间(从当前整点开始) + // 与 ratelimit_service.go 中 UpdateSessionWindow 的预测逻辑保持一致 + return time.Date(now.Year(), now.Month(), now.Day(), now.Hour(), 0, 0, 0, now.Location()) +} + // parseExtraFloat64 从 extra 字段解析 float64 值 func parseExtraFloat64(value any) float64 { switch v := value.(type) { diff --git a/backend/internal/service/account_service.go b/backend/internal/service/account_service.go index ede5b12f..90365d2f 100644 --- a/backend/internal/service/account_service.go +++ b/backend/internal/service/account_service.go @@ -37,6 +37,7 @@ type AccountRepository interface { UpdateLastUsed(ctx context.Context, id int64) error BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error SetError(ctx context.Context, id int64, errorMsg string) error + ClearError(ctx context.Context, id int64) error SetSchedulable(ctx context.Context, id int64, schedulable bool) error AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error diff --git a/backend/internal/service/account_service_delete_test.go b/backend/internal/service/account_service_delete_test.go index 36af719c..e5eabfc6 100644 --- a/backend/internal/service/account_service_delete_test.go +++ b/backend/internal/service/account_service_delete_test.go @@ -99,6 +99,10 @@ func (s *accountRepoStub) SetError(ctx context.Context, id int64, errorMsg strin panic("unexpected SetError call") } +func (s *accountRepoStub) ClearError(ctx context.Context, id int64) error { + panic("unexpected ClearError call") +} + func (s *accountRepoStub) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { panic("unexpected SetSchedulable call") } diff --git a/backend/internal/service/account_test_service.go b/backend/internal/service/account_test_service.go index 8419c2b4..46376c69 100644 --- a/backend/internal/service/account_test_service.go +++ b/backend/internal/service/account_test_service.go @@ -265,7 +265,7 @@ func (s *AccountTestService) testClaudeAccountConnection(c *gin.Context, account proxyURL = account.Proxy.URL() } - resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if err != nil { return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) } @@ -375,7 +375,7 @@ func (s *AccountTestService) testOpenAIAccountConnection(c *gin.Context, account proxyURL = account.Proxy.URL() } - resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if err != nil { return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) } @@ -446,7 +446,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account proxyURL = account.Proxy.URL() } - resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency) + resp, err := s.httpUpstream.DoWithTLS(req, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if err != nil { return s.sendErrorAndEnd(c, fmt.Sprintf("Request failed: %s", err.Error())) } diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index 6f012385..f3b3e20d 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -32,8 +32,8 @@ type UsageLogRepository interface { // Admin dashboard stats GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) - GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) - GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) + GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) + GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error) @@ -157,9 +157,20 @@ type ClaudeUsageResponse struct { } `json:"seven_day_sonnet"` } +// ClaudeUsageFetchOptions 包含获取 Claude 用量数据所需的所有选项 +type ClaudeUsageFetchOptions struct { + AccessToken string // OAuth access token + ProxyURL string // 代理 URL(可选) + AccountID int64 // 账号 ID(用于 TLS 指纹选择) + EnableTLSFingerprint bool // 是否启用 TLS 指纹伪装 + Fingerprint *Fingerprint // 缓存的指纹信息(User-Agent 等) +} + // ClaudeUsageFetcher fetches usage data from Anthropic OAuth API type ClaudeUsageFetcher interface { FetchUsage(ctx context.Context, accessToken, proxyURL string) (*ClaudeUsageResponse, error) + // FetchUsageWithOptions 使用完整选项获取用量数据,支持 TLS 指纹和自定义 User-Agent + FetchUsageWithOptions(ctx context.Context, opts *ClaudeUsageFetchOptions) (*ClaudeUsageResponse, error) } // AccountUsageService 账号使用量查询服务 @@ -170,6 +181,7 @@ type AccountUsageService struct { geminiQuotaService *GeminiQuotaService antigravityQuotaFetcher *AntigravityQuotaFetcher cache *UsageCache + identityCache IdentityCache } // NewAccountUsageService 创建AccountUsageService实例 @@ -180,6 +192,7 @@ func NewAccountUsageService( geminiQuotaService *GeminiQuotaService, antigravityQuotaFetcher *AntigravityQuotaFetcher, cache *UsageCache, + identityCache IdentityCache, ) *AccountUsageService { return &AccountUsageService{ accountRepo: accountRepo, @@ -188,6 +201,7 @@ func NewAccountUsageService( geminiQuotaService: geminiQuotaService, antigravityQuotaFetcher: antigravityQuotaFetcher, cache: cache, + identityCache: identityCache, } } @@ -272,7 +286,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou } dayStart := geminiDailyWindowStart(now) - stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil) + stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID, 0, nil, nil) if err != nil { return nil, fmt.Errorf("get gemini usage stats failed: %w", err) } @@ -294,7 +308,7 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou // Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m) minuteStart := now.Truncate(time.Minute) minuteResetAt := minuteStart.Add(time.Minute) - minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil) + minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID, 0, nil, nil) if err != nil { return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err) } @@ -369,12 +383,8 @@ func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Accou // 如果没有缓存,从数据库查询 if windowStats == nil { - var startTime time.Time - if account.SessionWindowStart != nil { - startTime = *account.SessionWindowStart - } else { - startTime = time.Now().Add(-5 * time.Hour) - } + // 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况) + startTime := account.GetCurrentWindowStartTime() stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime) if err != nil { @@ -428,6 +438,8 @@ func (s *AccountUsageService) GetAccountUsageStats(ctx context.Context, accountI } // fetchOAuthUsageRaw 从 Anthropic API 获取原始响应(不构建 UsageInfo) +// 如果账号开启了 TLS 指纹,则使用 TLS 指纹伪装 +// 如果有缓存的 Fingerprint,则使用缓存的 User-Agent 等信息 func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *Account) (*ClaudeUsageResponse, error) { accessToken := account.GetCredential("access_token") if accessToken == "" { @@ -439,7 +451,22 @@ func (s *AccountUsageService) fetchOAuthUsageRaw(ctx context.Context, account *A proxyURL = account.Proxy.URL() } - return s.usageFetcher.FetchUsage(ctx, accessToken, proxyURL) + // 构建完整的选项 + opts := &ClaudeUsageFetchOptions{ + AccessToken: accessToken, + ProxyURL: proxyURL, + AccountID: account.ID, + EnableTLSFingerprint: account.IsTLSFingerprintEnabled(), + } + + // 尝试获取缓存的 Fingerprint(包含 User-Agent 等信息) + if s.identityCache != nil { + if fp, err := s.identityCache.GetFingerprint(ctx, account.ID); err == nil && fp != nil { + opts.Fingerprint = fp + } + } + + return s.usageFetcher.FetchUsageWithOptions(ctx, opts) } // parseTime 尝试多种格式解析时间 diff --git a/backend/internal/service/admin_service.go b/backend/internal/service/admin_service.go index c0694e4e..0afa0716 100644 --- a/backend/internal/service/admin_service.go +++ b/backend/internal/service/admin_service.go @@ -42,6 +42,7 @@ type AdminService interface { DeleteAccount(ctx context.Context, id int64) error RefreshAccountCredentials(ctx context.Context, id int64) (*Account, error) ClearAccountError(ctx context.Context, id int64) (*Account, error) + SetAccountError(ctx context.Context, id int64, errorMsg string) error SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) BulkUpdateAccounts(ctx context.Context, input *BulkUpdateAccountsInput) (*BulkUpdateAccountsResult, error) @@ -1101,6 +1102,10 @@ func (s *adminServiceImpl) ClearAccountError(ctx context.Context, id int64) (*Ac return account, nil } +func (s *adminServiceImpl) SetAccountError(ctx context.Context, id int64, errorMsg string) error { + return s.accountRepo.SetError(ctx, id, errorMsg) +} + func (s *adminServiceImpl) SetAccountSchedulable(ctx context.Context, id int64, schedulable bool) (*Account, error) { if err := s.accountRepo.SetSchedulable(ctx, id, schedulable); err != nil { return nil, err diff --git a/backend/internal/service/antigravity_gateway_service.go b/backend/internal/service/antigravity_gateway_service.go index 7f3e97a2..043f338d 100644 --- a/backend/internal/service/antigravity_gateway_service.go +++ b/backend/internal/service/antigravity_gateway_service.go @@ -12,6 +12,7 @@ import ( mathrand "math/rand" "net" "net/http" + "os" "strings" "sync/atomic" "time" @@ -28,6 +29,207 @@ const ( antigravityRetryMaxDelay = 16 * time.Second ) +const antigravityScopeRateLimitEnv = "GATEWAY_ANTIGRAVITY_429_SCOPE_LIMIT" + +// antigravityRetryLoopParams 重试循环的参数 +type antigravityRetryLoopParams struct { + ctx context.Context + prefix string + account *Account + proxyURL string + accessToken string + action string + body []byte + quotaScope AntigravityQuotaScope + c *gin.Context + httpUpstream HTTPUpstream + settingService *SettingService + handleError func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) +} + +// antigravityRetryLoopResult 重试循环的结果 +type antigravityRetryLoopResult struct { + resp *http.Response +} + +// antigravityRetryLoop 执行带 URL fallback 的重试循环 +func antigravityRetryLoop(p antigravityRetryLoopParams) (*antigravityRetryLoopResult, error) { + availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() + if len(availableURLs) == 0 { + availableURLs = antigravity.BaseURLs + } + + var resp *http.Response + var usedBaseURL string + logBody := p.settingService != nil && p.settingService.cfg != nil && p.settingService.cfg.Gateway.LogUpstreamErrorBody + maxBytes := 2048 + if p.settingService != nil && p.settingService.cfg != nil && p.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 { + maxBytes = p.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes + } + getUpstreamDetail := func(body []byte) string { + if !logBody { + return "" + } + return truncateString(string(body), maxBytes) + } + +urlFallbackLoop: + for urlIdx, baseURL := range availableURLs { + usedBaseURL = baseURL + for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { + select { + case <-p.ctx.Done(): + log.Printf("%s status=context_canceled error=%v", p.prefix, p.ctx.Err()) + return nil, p.ctx.Err() + default: + } + + upstreamReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, p.body) + if err != nil { + return nil, err + } + + // Capture upstream request body for ops retry of this attempt. + if p.c != nil && len(p.body) > 0 { + p.c.Set(OpsUpstreamRequestBodyKey, string(p.body)) + } + + resp, err = p.httpUpstream.Do(upstreamReq, p.proxyURL, p.account.ID, p.account.Concurrency) + if err != nil { + safeErr := sanitizeUpstreamErrorMessage(err.Error()) + appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{ + Platform: p.account.Platform, + AccountID: p.account.ID, + AccountName: p.account.Name, + UpstreamStatusCode: 0, + Kind: "request_error", + Message: safeErr, + }) + if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { + log.Printf("%s URL fallback (connection error): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1]) + continue urlFallbackLoop + } + if attempt < antigravityMaxRetries { + log.Printf("%s status=request_failed retry=%d/%d error=%v", p.prefix, attempt, antigravityMaxRetries, err) + if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { + log.Printf("%s status=context_canceled_during_backoff", p.prefix) + return nil, p.ctx.Err() + } + continue + } + log.Printf("%s status=request_failed retries_exhausted error=%v", p.prefix, err) + setOpsUpstreamError(p.c, 0, safeErr, "") + return nil, fmt.Errorf("upstream request failed after retries: %w", err) + } + + // 429 限流处理:区分 URL 级别限流和账户配额限流 + if resp.StatusCode == http.StatusTooManyRequests { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + + // "Resource has been exhausted" 是 URL 级别限流,切换 URL + if isURLLevelRateLimit(respBody) && urlIdx < len(availableURLs)-1 { + log.Printf("%s URL fallback (429): %s -> %s", p.prefix, baseURL, availableURLs[urlIdx+1]) + continue urlFallbackLoop + } + + // 账户/模型配额限流,重试 3 次(指数退避) + if attempt < antigravityMaxRetries { + upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{ + Platform: p.account.Platform, + AccountID: p.account.ID, + AccountName: p.account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "retry", + Message: upstreamMsg, + Detail: getUpstreamDetail(respBody), + }) + log.Printf("%s status=429 retry=%d/%d body=%s", p.prefix, attempt, antigravityMaxRetries, truncateForLog(respBody, 200)) + if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { + log.Printf("%s status=context_canceled_during_backoff", p.prefix) + return nil, p.ctx.Err() + } + continue + } + + // 重试用尽,标记账户限流 + p.handleError(p.ctx, p.prefix, p.account, resp.StatusCode, resp.Header, respBody, p.quotaScope) + log.Printf("%s status=429 rate_limited base_url=%s body=%s", p.prefix, baseURL, truncateForLog(respBody, 200)) + resp = &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), + } + break urlFallbackLoop + } + + // 其他可重试错误 + if resp.StatusCode >= 400 && shouldRetryAntigravityError(resp.StatusCode) { + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + _ = resp.Body.Close() + + if attempt < antigravityMaxRetries { + upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) + upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) + appendOpsUpstreamError(p.c, OpsUpstreamErrorEvent{ + Platform: p.account.Platform, + AccountID: p.account.ID, + AccountName: p.account.Name, + UpstreamStatusCode: resp.StatusCode, + UpstreamRequestID: resp.Header.Get("x-request-id"), + Kind: "retry", + Message: upstreamMsg, + Detail: getUpstreamDetail(respBody), + }) + log.Printf("%s status=%d retry=%d/%d body=%s", p.prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500)) + if !sleepAntigravityBackoffWithContext(p.ctx, attempt) { + log.Printf("%s status=context_canceled_during_backoff", p.prefix) + return nil, p.ctx.Err() + } + continue + } + resp = &http.Response{ + StatusCode: resp.StatusCode, + Header: resp.Header.Clone(), + Body: io.NopCloser(bytes.NewReader(respBody)), + } + break urlFallbackLoop + } + + break urlFallbackLoop + } + } + + if resp != nil && resp.StatusCode < 400 && usedBaseURL != "" { + antigravity.DefaultURLAvailability.MarkSuccess(usedBaseURL) + } + + return &antigravityRetryLoopResult{resp: resp}, nil +} + +// shouldRetryAntigravityError 判断是否应该重试 +func shouldRetryAntigravityError(statusCode int) bool { + switch statusCode { + case 429, 500, 502, 503, 504, 529: + return true + default: + return false + } +} + +// isURLLevelRateLimit 判断是否为 URL 级别的限流(应切换 URL 重试) +// "Resource has been exhausted" 是 URL/节点级别限流,切换 URL 可能成功 +// "exhausted your capacity on this model" 是账户/模型配额限流,切换 URL 无效 +func isURLLevelRateLimit(body []byte) bool { + // 快速检查:包含 "Resource has been exhausted" 且不包含 "capacity on this model" + bodyStr := string(body) + return strings.Contains(bodyStr, "Resource has been exhausted") && + !strings.Contains(bodyStr, "capacity on this model") +} + // isAntigravityConnectionError 判断是否为连接错误(网络超时、DNS 失败、连接拒绝) func isAntigravityConnectionError(err error) bool { if err == nil { @@ -238,7 +440,6 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account if err != nil { lastErr = fmt.Errorf("请求失败: %w", err) if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { - antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) log.Printf("[antigravity-Test] URL fallback: %s -> %s", baseURL, availableURLs[urlIdx+1]) continue } @@ -254,7 +455,6 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account // 检查是否需要 URL 降级 if shouldAntigravityFallbackToNextURL(nil, resp.StatusCode) && urlIdx < len(availableURLs)-1 { - antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) log.Printf("[antigravity-Test] URL fallback (HTTP %d): %s -> %s", resp.StatusCode, baseURL, availableURLs[urlIdx+1]) continue } @@ -266,6 +466,8 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account // 解析流式响应,提取文本 text := extractTextFromSSEResponse(respBody) + // 标记成功的 URL,下次优先使用 + antigravity.DefaultURLAvailability.MarkSuccess(baseURL) return &TestConnectionResult{ Text: text, MappedModel: mappedModel, @@ -276,13 +478,14 @@ func (s *AntigravityGatewayService) TestConnection(ctx context.Context, account } // buildGeminiTestRequest 构建 Gemini 格式测试请求 +// 使用最小 token 消耗:输入 "." + maxOutputTokens: 1 func (s *AntigravityGatewayService) buildGeminiTestRequest(projectID, model string) ([]byte, error) { payload := map[string]any{ "contents": []map[string]any{ { "role": "user", "parts": []map[string]any{ - {"text": "hi"}, + {"text": "."}, }, }, }, @@ -292,22 +495,26 @@ func (s *AntigravityGatewayService) buildGeminiTestRequest(projectID, model stri {"text": antigravity.GetDefaultIdentityPatch()}, }, }, + "generationConfig": map[string]any{ + "maxOutputTokens": 1, + }, } payloadBytes, _ := json.Marshal(payload) return s.wrapV1InternalRequest(projectID, model, payloadBytes) } // buildClaudeTestRequest 构建 Claude 格式测试请求并转换为 Gemini 格式 +// 使用最小 token 消耗:输入 "." + MaxTokens: 1 func (s *AntigravityGatewayService) buildClaudeTestRequest(projectID, mappedModel string) ([]byte, error) { claudeReq := &antigravity.ClaudeRequest{ Model: mappedModel, Messages: []antigravity.ClaudeMessage{ { Role: "user", - Content: json.RawMessage(`"hi"`), + Content: json.RawMessage(`"."`), }, }, - MaxTokens: 1024, + MaxTokens: 1, Stream: false, } return antigravity.TransformClaudeToGemini(claudeReq, projectID, mappedModel) @@ -523,9 +730,6 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, proxyURL = account.Proxy.URL() } - // Sanitize thinking blocks (clean cache_control and flatten history thinking) - sanitizeThinkingBlocks(&claudeReq) - // 获取转换选项 // Antigravity 上游要求必须包含身份提示词,否则会返回 429 transformOpts := s.getClaudeTransformOptions(ctx) @@ -537,150 +741,29 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, return nil, fmt.Errorf("transform request: %w", err) } - // Safety net: ensure no cache_control leaked into Gemini request - geminiBody = cleanCacheControlFromGeminiJSON(geminiBody) - // Antigravity 上游只支持流式请求,统一使用 streamGenerateContent // 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后转换返回 action := "streamGenerateContent" - // URL fallback 循环 - availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() - if len(availableURLs) == 0 { - availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有 - } - - // 重试循环 - var resp *http.Response -urlFallbackLoop: - for urlIdx, baseURL := range availableURLs { - for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { - // 检查 context 是否已取消(客户端断开连接) - select { - case <-ctx.Done(): - log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err()) - return nil, ctx.Err() - default: - } - - upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, action, accessToken, geminiBody) - // Capture upstream request body for ops retry of this attempt. - if c != nil { - c.Set(OpsUpstreamRequestBodyKey, string(geminiBody)) - } - if err != nil { - return nil, err - } - - resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) - if err != nil { - safeErr := sanitizeUpstreamErrorMessage(err.Error()) - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: 0, - Kind: "request_error", - Message: safeErr, - }) - // 检查是否应触发 URL 降级 - if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { - antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) - log.Printf("%s URL fallback (connection error): %s -> %s", prefix, baseURL, availableURLs[urlIdx+1]) - continue urlFallbackLoop - } - if attempt < antigravityMaxRetries { - log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", prefix) - return nil, ctx.Err() - } - continue - } - log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err) - setOpsUpstreamError(c, 0, safeErr, "") - return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") - } - - // 检查是否应触发 URL 降级(仅 429) - if resp.StatusCode == http.StatusTooManyRequests && urlIdx < len(availableURLs)-1 { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody - maxBytes := 2048 - if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 { - maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - } - upstreamDetail := "" - if logBody { - upstreamDetail = truncateString(string(respBody), maxBytes) - } - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - Kind: "retry", - Message: upstreamMsg, - Detail: upstreamDetail, - }) - antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) - log.Printf("%s URL fallback (HTTP 429): %s -> %s body=%s", prefix, baseURL, availableURLs[urlIdx+1], truncateForLog(respBody, 200)) - continue urlFallbackLoop - } - - if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - - if attempt < antigravityMaxRetries { - upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody - maxBytes := 2048 - if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 { - maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - } - upstreamDetail := "" - if logBody { - upstreamDetail = truncateString(string(respBody), maxBytes) - } - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - Kind: "retry", - Message: upstreamMsg, - Detail: upstreamDetail, - }) - log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 500)) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", prefix) - return nil, ctx.Err() - } - continue - } - // 所有重试都失败,标记限流状态 - if resp.StatusCode == 429 { - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope) - } - // 最后一次尝试也失败 - resp = &http.Response{ - StatusCode: resp.StatusCode, - Header: resp.Header.Clone(), - Body: io.NopCloser(bytes.NewReader(respBody)), - } - break urlFallbackLoop - } - - break urlFallbackLoop - } + // 执行带重试的请求 + result, err := antigravityRetryLoop(antigravityRetryLoopParams{ + ctx: ctx, + prefix: prefix, + account: account, + proxyURL: proxyURL, + accessToken: accessToken, + action: action, + body: geminiBody, + quotaScope: quotaScope, + c: c, + httpUpstream: s.httpUpstream, + settingService: s.settingService, + handleError: s.handleUpstreamError, + }) + if err != nil { + return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries") } + resp := result.resp defer func() { _ = resp.Body.Close() }() if resp.StatusCode >= 400 { @@ -739,11 +822,20 @@ urlFallbackLoop: if txErr != nil { continue } - retryReq, buildErr := antigravity.NewAPIRequest(ctx, action, accessToken, retryGeminiBody) - if buildErr != nil { - continue - } - retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency) + retryResult, retryErr := antigravityRetryLoop(antigravityRetryLoopParams{ + ctx: ctx, + prefix: prefix, + account: account, + proxyURL: proxyURL, + accessToken: accessToken, + action: action, + body: retryGeminiBody, + quotaScope: quotaScope, + c: c, + httpUpstream: s.httpUpstream, + settingService: s.settingService, + handleError: s.handleUpstreamError, + }) if retryErr != nil { appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, @@ -757,6 +849,7 @@ urlFallbackLoop: continue } + retryResp := retryResult.resp if retryResp.StatusCode < 400 { _ = resp.Body.Close() resp = retryResp @@ -766,6 +859,13 @@ urlFallbackLoop: retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20)) _ = retryResp.Body.Close() + if retryResp.StatusCode == http.StatusTooManyRequests { + retryBaseURL := "" + if retryResp.Request != nil && retryResp.Request.URL != nil { + retryBaseURL = retryResp.Request.URL.Scheme + "://" + retryResp.Request.URL.Host + } + log.Printf("%s status=429 rate_limited base_url=%s retry_stage=%s body=%s", prefix, retryBaseURL, stage.name, truncateForLog(retryBody, 200)) + } kind := "signature_retry" if strings.TrimSpace(stage.name) != "" { kind = "signature_retry_" + strings.ReplaceAll(stage.name, "+", "_") @@ -920,143 +1020,6 @@ func extractAntigravityErrorMessage(body []byte) string { return "" } -// cleanCacheControlFromGeminiJSON removes cache_control from Gemini JSON (emergency fix) -// This should not be needed if transformation is correct, but serves as a safety net -func cleanCacheControlFromGeminiJSON(body []byte) []byte { - // Try a more robust approach: parse and clean - var data map[string]any - if err := json.Unmarshal(body, &data); err != nil { - log.Printf("[Antigravity] Failed to parse Gemini JSON for cache_control cleaning: %v", err) - return body - } - - cleaned := removeCacheControlFromAny(data) - if !cleaned { - return body - } - - if result, err := json.Marshal(data); err == nil { - log.Printf("[Antigravity] Successfully cleaned cache_control from Gemini JSON") - return result - } - - return body -} - -// removeCacheControlFromAny recursively removes cache_control fields -func removeCacheControlFromAny(v any) bool { - cleaned := false - - switch val := v.(type) { - case map[string]any: - for k, child := range val { - if k == "cache_control" { - delete(val, k) - cleaned = true - } else if removeCacheControlFromAny(child) { - cleaned = true - } - } - case []any: - for _, item := range val { - if removeCacheControlFromAny(item) { - cleaned = true - } - } - } - - return cleaned -} - -// sanitizeThinkingBlocks cleans cache_control and flattens history thinking blocks -// Thinking blocks do NOT support cache_control field (Anthropic API/Vertex AI requirement) -// Additionally, history thinking blocks are flattened to text to avoid upstream validation errors -func sanitizeThinkingBlocks(req *antigravity.ClaudeRequest) { - if req == nil { - return - } - - log.Printf("[Antigravity] sanitizeThinkingBlocks: processing request with %d messages", len(req.Messages)) - - // Clean system blocks - if len(req.System) > 0 { - var systemBlocks []map[string]any - if err := json.Unmarshal(req.System, &systemBlocks); err == nil { - for i := range systemBlocks { - if blockType, _ := systemBlocks[i]["type"].(string); blockType == "thinking" || systemBlocks[i]["thinking"] != nil { - if removeCacheControlFromAny(systemBlocks[i]) { - log.Printf("[Antigravity] Deep cleaned cache_control from thinking block in system[%d]", i) - } - } - } - // Marshal back - if cleaned, err := json.Marshal(systemBlocks); err == nil { - req.System = cleaned - } - } - } - - // Clean message content blocks and flatten history - lastMsgIdx := len(req.Messages) - 1 - for msgIdx := range req.Messages { - raw := req.Messages[msgIdx].Content - if len(raw) == 0 { - continue - } - - // Try to parse as blocks array - var blocks []map[string]any - if err := json.Unmarshal(raw, &blocks); err != nil { - continue - } - - cleaned := false - for blockIdx := range blocks { - blockType, _ := blocks[blockIdx]["type"].(string) - - // Check for thinking blocks (typed or untyped) - if blockType == "thinking" || blocks[blockIdx]["thinking"] != nil { - // 1. Clean cache_control - if removeCacheControlFromAny(blocks[blockIdx]) { - log.Printf("[Antigravity] Deep cleaned cache_control from thinking block in messages[%d].content[%d]", msgIdx, blockIdx) - cleaned = true - } - - // 2. Flatten to text if it's a history message (not the last one) - if msgIdx < lastMsgIdx { - log.Printf("[Antigravity] Flattening history thinking block to text at messages[%d].content[%d]", msgIdx, blockIdx) - - // Extract thinking content - var textContent string - if t, ok := blocks[blockIdx]["thinking"].(string); ok { - textContent = t - } else { - // Fallback for non-string content (marshal it) - if b, err := json.Marshal(blocks[blockIdx]["thinking"]); err == nil { - textContent = string(b) - } - } - - // Convert to text block - blocks[blockIdx]["type"] = "text" - blocks[blockIdx]["text"] = textContent - delete(blocks[blockIdx], "thinking") - delete(blocks[blockIdx], "signature") - delete(blocks[blockIdx], "cache_control") // Ensure it's gone - cleaned = true - } - } - } - - // Marshal back if modified - if cleaned { - if marshaled, err := json.Marshal(blocks); err == nil { - req.Messages[msgIdx].Content = marshaled - } - } - } -} - // stripThinkingFromClaudeRequest converts thinking blocks to text blocks in a Claude Messages request. // This preserves the thinking content while avoiding signature validation errors. // Note: redacted_thinking blocks are removed because they cannot be converted to text. @@ -1352,138 +1315,25 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co // 如果客户端请求非流式,在响应处理阶段会收集完整流式响应后返回 upstreamAction := "streamGenerateContent" - // URL fallback 循环 - availableURLs := antigravity.DefaultURLAvailability.GetAvailableURLs() - if len(availableURLs) == 0 { - availableURLs = antigravity.BaseURLs // 所有 URL 都不可用时,重试所有 - } - - // 重试循环 - var resp *http.Response -urlFallbackLoop: - for urlIdx, baseURL := range availableURLs { - for attempt := 1; attempt <= antigravityMaxRetries; attempt++ { - // 检查 context 是否已取消(客户端断开连接) - select { - case <-ctx.Done(): - log.Printf("%s status=context_canceled error=%v", prefix, ctx.Err()) - return nil, ctx.Err() - default: - } - - upstreamReq, err := antigravity.NewAPIRequestWithURL(ctx, baseURL, upstreamAction, accessToken, wrappedBody) - if err != nil { - return nil, err - } - - resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) - if err != nil { - safeErr := sanitizeUpstreamErrorMessage(err.Error()) - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: 0, - Kind: "request_error", - Message: safeErr, - }) - // 检查是否应触发 URL 降级 - if shouldAntigravityFallbackToNextURL(err, 0) && urlIdx < len(availableURLs)-1 { - antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) - log.Printf("%s URL fallback (connection error): %s -> %s", prefix, baseURL, availableURLs[urlIdx+1]) - continue urlFallbackLoop - } - if attempt < antigravityMaxRetries { - log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", prefix) - return nil, ctx.Err() - } - continue - } - log.Printf("%s status=request_failed retries_exhausted error=%v", prefix, err) - setOpsUpstreamError(c, 0, safeErr, "") - return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries") - } - - // 检查是否应触发 URL 降级(仅 429) - if resp.StatusCode == http.StatusTooManyRequests && urlIdx < len(availableURLs)-1 { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody - maxBytes := 2048 - if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 { - maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - } - upstreamDetail := "" - if logBody { - upstreamDetail = truncateString(string(respBody), maxBytes) - } - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - Kind: "retry", - Message: upstreamMsg, - Detail: upstreamDetail, - }) - antigravity.DefaultURLAvailability.MarkUnavailable(baseURL) - log.Printf("%s URL fallback (HTTP 429): %s -> %s body=%s", prefix, baseURL, availableURLs[urlIdx+1], truncateForLog(respBody, 200)) - continue urlFallbackLoop - } - - if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(resp.StatusCode) { - respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) - _ = resp.Body.Close() - - if attempt < antigravityMaxRetries { - upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody)) - upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) - logBody := s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBody - maxBytes := 2048 - if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes > 0 { - maxBytes = s.settingService.cfg.Gateway.LogUpstreamErrorBodyMaxBytes - } - upstreamDetail := "" - if logBody { - upstreamDetail = truncateString(string(respBody), maxBytes) - } - appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ - Platform: account.Platform, - AccountID: account.ID, - AccountName: account.Name, - UpstreamStatusCode: resp.StatusCode, - UpstreamRequestID: resp.Header.Get("x-request-id"), - Kind: "retry", - Message: upstreamMsg, - Detail: upstreamDetail, - }) - log.Printf("%s status=%d retry=%d/%d", prefix, resp.StatusCode, attempt, antigravityMaxRetries) - if !sleepAntigravityBackoffWithContext(ctx, attempt) { - log.Printf("%s status=context_canceled_during_backoff", prefix) - return nil, ctx.Err() - } - continue - } - // 所有重试都失败,标记限流状态 - if resp.StatusCode == 429 { - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope) - } - resp = &http.Response{ - StatusCode: resp.StatusCode, - Header: resp.Header.Clone(), - Body: io.NopCloser(bytes.NewReader(respBody)), - } - break urlFallbackLoop - } - - break urlFallbackLoop - } + // 执行带重试的请求 + result, err := antigravityRetryLoop(antigravityRetryLoopParams{ + ctx: ctx, + prefix: prefix, + account: account, + proxyURL: proxyURL, + accessToken: accessToken, + action: upstreamAction, + body: wrappedBody, + quotaScope: quotaScope, + c: c, + httpUpstream: s.httpUpstream, + settingService: s.settingService, + handleError: s.handleUpstreamError, + }) + if err != nil { + return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries") } + resp := result.resp defer func() { if resp != nil && resp.Body != nil { _ = resp.Body.Close() @@ -1525,8 +1375,6 @@ urlFallbackLoop: goto handleSuccess } - s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope) - requestID := resp.Header.Get("x-request-id") if requestID != "" { c.Header("x-request-id", requestID) @@ -1537,6 +1385,7 @@ urlFallbackLoop: if unwrapErr != nil || len(unwrappedForOps) == 0 { unwrappedForOps = respBody } + s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope) upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(unwrappedForOps)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) @@ -1581,6 +1430,7 @@ urlFallbackLoop: Message: upstreamMsg, Detail: upstreamDetail, }) + log.Printf("[antigravity-Forward] upstream error status=%d body=%s", resp.StatusCode, truncateForLog(unwrappedForOps, 500)) c.Data(resp.StatusCode, contentType, unwrappedForOps) return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode) } @@ -1637,15 +1487,6 @@ handleSuccess: }, nil } -func (s *AntigravityGatewayService) shouldRetryUpstreamError(statusCode int) bool { - switch statusCode { - case 429, 500, 502, 503, 504, 529: - return true - default: - return false - } -} - func (s *AntigravityGatewayService) shouldFailoverUpstreamError(statusCode int) bool { switch statusCode { case 401, 403, 429, 529: @@ -1679,33 +1520,48 @@ func sleepAntigravityBackoffWithContext(ctx context.Context, attempt int) bool { } } +func antigravityUseScopeRateLimit() bool { + v := strings.ToLower(strings.TrimSpace(os.Getenv(antigravityScopeRateLimitEnv))) + return v == "1" || v == "true" || v == "yes" || v == "on" +} + func (s *AntigravityGatewayService) handleUpstreamError(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) { // 429 使用 Gemini 格式解析(从 body 解析重置时间) if statusCode == 429 { + useScopeLimit := antigravityUseScopeRateLimit() && quotaScope != "" resetAt := ParseGeminiRateLimitResetTime(body) if resetAt == nil { - // 解析失败:Gemini 有重试时间用 5 分钟,Claude 没有用 1 分钟 - defaultDur := 1 * time.Minute - if bytes.Contains(body, []byte("Please retry in")) || bytes.Contains(body, []byte("retryDelay")) { - defaultDur = 5 * time.Minute + // 解析失败:使用配置的 fallback 时间,直接限流整个账户 + fallbackMinutes := 5 + if s.settingService != nil && s.settingService.cfg != nil && s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes > 0 { + fallbackMinutes = s.settingService.cfg.Gateway.AntigravityFallbackCooldownMinutes } + defaultDur := time.Duration(fallbackMinutes) * time.Minute ra := time.Now().Add(defaultDur) - log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur) - if quotaScope == "" { - return - } - if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, ra); err != nil { - log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err) + if useScopeLimit { + log.Printf("%s status=429 rate_limited scope=%s reset_in=%v (fallback)", prefix, quotaScope, defaultDur) + if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, ra); err != nil { + log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err) + } + } else { + log.Printf("%s status=429 rate_limited account=%d reset_in=%v (fallback)", prefix, account.ID, defaultDur) + if err := s.accountRepo.SetRateLimited(ctx, account.ID, ra); err != nil { + log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err) + } } return } resetTime := time.Unix(*resetAt, 0) - log.Printf("%s status=429 rate_limited scope=%s reset_at=%v reset_in=%v", prefix, quotaScope, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second)) - if quotaScope == "" { - return - } - if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, resetTime); err != nil { - log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err) + if useScopeLimit { + log.Printf("%s status=429 rate_limited scope=%s reset_at=%v reset_in=%v", prefix, quotaScope, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second)) + if err := s.accountRepo.SetAntigravityQuotaScopeLimit(ctx, account.ID, quotaScope, resetTime); err != nil { + log.Printf("%s status=429 rate_limit_set_failed scope=%s error=%v", prefix, quotaScope, err) + } + } else { + log.Printf("%s status=429 rate_limited account=%d reset_at=%v reset_in=%v", prefix, account.ID, resetTime.Format("15:04:05"), time.Until(resetTime).Truncate(time.Second)) + if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetTime); err != nil { + log.Printf("%s status=429 rate_limit_set_failed account=%d error=%v", prefix, account.ID, err) + } } return } @@ -1884,7 +1740,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context } // handleGeminiStreamToNonStreaming 读取上游流式响应,合并为非流式响应返回给客户端 -// Gemini 流式响应中每个 chunk 都包含累积的完整文本,只需保留最后一个有效响应 +// Gemini 流式响应是增量的,需要累积所有 chunk 的内容 func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Context, resp *http.Response, startTime time.Time) (*antigravityStreamResult, error) { scanner := bufio.NewScanner(resp.Body) maxLineSize := defaultMaxLineSize @@ -1897,6 +1753,8 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont var firstTokenMs *int var last map[string]any var lastWithParts map[string]any + var collectedImageParts []map[string]any // 收集所有包含图片的 parts + var collectedTextParts []string // 收集所有文本片段 type scanEvent struct { line string @@ -1999,6 +1857,16 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont // 保留最后一个有 parts 的响应 if parts := extractGeminiParts(parsed); len(parts) > 0 { lastWithParts = parsed + // 收集包含图片和文本的 parts + for _, part := range parts { + if inlineData, ok := part["inlineData"].(map[string]any); ok { + collectedImageParts = append(collectedImageParts, part) + _ = inlineData // 避免 unused 警告 + } + if text, ok := part["text"].(string); ok && text != "" { + collectedTextParts = append(collectedTextParts, text) + } + } } case <-intervalCh: @@ -2020,6 +1888,16 @@ returnResponse: log.Printf("[antigravity-Forward] warning: empty stream response, no valid chunks received") } + // 如果收集到了图片 parts,需要合并到最终响应中 + if len(collectedImageParts) > 0 { + finalResponse = mergeImagePartsToResponse(finalResponse, collectedImageParts) + } + + // 如果收集到了文本,需要合并到最终响应中 + if len(collectedTextParts) > 0 { + finalResponse = mergeTextPartsToResponse(finalResponse, collectedTextParts) + } + respBody, err := json.Marshal(finalResponse) if err != nil { return nil, fmt.Errorf("failed to marshal response: %w", err) @@ -2029,6 +1907,115 @@ returnResponse: return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, nil } +// getOrCreateGeminiParts 获取 Gemini 响应的 parts 结构,返回深拷贝和更新回调 +func getOrCreateGeminiParts(response map[string]any) (result map[string]any, existingParts []any, setParts func([]any)) { + // 深拷贝 response + result = make(map[string]any) + for k, v := range response { + result[k] = v + } + + // 获取或创建 candidates + candidates, ok := result["candidates"].([]any) + if !ok || len(candidates) == 0 { + candidates = []any{map[string]any{}} + } + + // 获取第一个 candidate + candidate, ok := candidates[0].(map[string]any) + if !ok { + candidate = make(map[string]any) + candidates[0] = candidate + } + + // 获取或创建 content + content, ok := candidate["content"].(map[string]any) + if !ok { + content = map[string]any{"role": "model"} + candidate["content"] = content + } + + // 获取现有 parts + existingParts, ok = content["parts"].([]any) + if !ok { + existingParts = []any{} + } + + // 返回更新回调 + setParts = func(newParts []any) { + content["parts"] = newParts + result["candidates"] = candidates + } + + return result, existingParts, setParts +} + +// mergeImagePartsToResponse 将收集到的图片 parts 合并到 Gemini 响应中 +func mergeImagePartsToResponse(response map[string]any, imageParts []map[string]any) map[string]any { + if len(imageParts) == 0 { + return response + } + + result, existingParts, setParts := getOrCreateGeminiParts(response) + + // 检查现有 parts 中是否已经有图片 + for _, p := range existingParts { + if pm, ok := p.(map[string]any); ok { + if _, hasInline := pm["inlineData"]; hasInline { + return result // 已有图片,不重复添加 + } + } + } + + // 添加收集到的图片 parts + for _, imgPart := range imageParts { + existingParts = append(existingParts, imgPart) + } + setParts(existingParts) + return result +} + +// mergeTextPartsToResponse 将收集到的文本合并到 Gemini 响应中 +func mergeTextPartsToResponse(response map[string]any, textParts []string) map[string]any { + if len(textParts) == 0 { + return response + } + + mergedText := strings.Join(textParts, "") + result, existingParts, setParts := getOrCreateGeminiParts(response) + + // 查找并更新第一个 text part,或创建新的 + newParts := make([]any, 0, len(existingParts)+1) + textUpdated := false + + for _, p := range existingParts { + pm, ok := p.(map[string]any) + if !ok { + newParts = append(newParts, p) + continue + } + if _, hasText := pm["text"]; hasText && !textUpdated { + // 用累积的文本替换 + newPart := make(map[string]any) + for k, v := range pm { + newPart[k] = v + } + newPart["text"] = mergedText + newParts = append(newParts, newPart) + textUpdated = true + } else { + newParts = append(newParts, pm) + } + } + + if !textUpdated { + newParts = append([]any{map[string]any{"text": mergedText}}, newParts...) + } + + setParts(newParts) + return result +} + func (s *AntigravityGatewayService) writeClaudeError(c *gin.Context, status int, errType, message string) error { c.JSON(status, gin.H{ "type": "error", diff --git a/backend/internal/service/antigravity_model_mapping_test.go b/backend/internal/service/antigravity_model_mapping_test.go index 39000e4f..179a3520 100644 --- a/backend/internal/service/antigravity_model_mapping_test.go +++ b/backend/internal/service/antigravity_model_mapping_test.go @@ -30,7 +30,7 @@ func TestIsAntigravityModelSupported(t *testing.T) { {"可映射 - claude-3-haiku-20240307", "claude-3-haiku-20240307", true}, // Gemini 前缀透传 - {"Gemini前缀 - gemini-1.5-pro", "gemini-1.5-pro", true}, + {"Gemini前缀 - gemini-2.5-pro", "gemini-2.5-pro", true}, {"Gemini前缀 - gemini-unknown-model", "gemini-unknown-model", true}, {"Gemini前缀 - gemini-future-version", "gemini-future-version", true}, @@ -142,10 +142,10 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) { expected: "gemini-2.5-flash", }, { - name: "Gemini透传 - gemini-1.5-pro", - requestedModel: "gemini-1.5-pro", + name: "Gemini透传 - gemini-2.5-pro", + requestedModel: "gemini-2.5-pro", accountMapping: nil, - expected: "gemini-1.5-pro", + expected: "gemini-2.5-pro", }, { name: "Gemini透传 - gemini-future-model", diff --git a/backend/internal/service/antigravity_oauth_service.go b/backend/internal/service/antigravity_oauth_service.go index ecf0a553..52293cd5 100644 --- a/backend/internal/service/antigravity_oauth_service.go +++ b/backend/internal/service/antigravity_oauth_service.go @@ -82,13 +82,14 @@ type AntigravityExchangeCodeInput struct { // AntigravityTokenInfo token 信息 type AntigravityTokenInfo struct { - AccessToken string `json:"access_token"` - RefreshToken string `json:"refresh_token"` - ExpiresIn int64 `json:"expires_in"` - ExpiresAt int64 `json:"expires_at"` - TokenType string `json:"token_type"` - Email string `json:"email,omitempty"` - ProjectID string `json:"project_id,omitempty"` + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int64 `json:"expires_in"` + ExpiresAt int64 `json:"expires_at"` + TokenType string `json:"token_type"` + Email string `json:"email,omitempty"` + ProjectID string `json:"project_id,omitempty"` + ProjectIDMissing bool `json:"-"` // LoadCodeAssist 未返回 project_id } // ExchangeCode 用 authorization code 交换 token @@ -149,12 +150,6 @@ func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *Antig result.ProjectID = loadResp.CloudAICompanionProject } - // 兜底:随机生成 project_id - if result.ProjectID == "" { - result.ProjectID = antigravity.GenerateMockProjectID() - fmt.Printf("[AntigravityOAuth] 使用随机生成的 project_id: %s\n", result.ProjectID) - } - return result, nil } @@ -236,16 +231,24 @@ func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, accou return nil, err } - // 保留原有的 project_id 和 email - existingProjectID := strings.TrimSpace(account.GetCredential("project_id")) - if existingProjectID != "" { - tokenInfo.ProjectID = existingProjectID - } + // 保留原有的 email existingEmail := strings.TrimSpace(account.GetCredential("email")) if existingEmail != "" { tokenInfo.Email = existingEmail } + // 每次刷新都调用 LoadCodeAssist 获取 project_id + client := antigravity.NewClient(proxyURL) + loadResp, _, err := client.LoadCodeAssist(ctx, tokenInfo.AccessToken) + if err != nil || loadResp == nil || loadResp.CloudAICompanionProject == "" { + // LoadCodeAssist 失败或返回空,保留原有 project_id,标记缺失 + existingProjectID := strings.TrimSpace(account.GetCredential("project_id")) + tokenInfo.ProjectID = existingProjectID + tokenInfo.ProjectIDMissing = true + } else { + tokenInfo.ProjectID = loadResp.CloudAICompanionProject + } + return tokenInfo, nil } diff --git a/backend/internal/service/antigravity_quota_fetcher.go b/backend/internal/service/antigravity_quota_fetcher.go index c9024e33..07eb563d 100644 --- a/backend/internal/service/antigravity_quota_fetcher.go +++ b/backend/internal/service/antigravity_quota_fetcher.go @@ -31,11 +31,6 @@ func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Accou accessToken := account.GetCredential("access_token") projectID := account.GetCredential("project_id") - // 如果没有 project_id,生成一个随机的 - if projectID == "" { - projectID = antigravity.GenerateMockProjectID() - } - client := antigravity.NewClient(proxyURL) // 调用 API 获取配额 diff --git a/backend/internal/service/antigravity_rate_limit_test.go b/backend/internal/service/antigravity_rate_limit_test.go new file mode 100644 index 00000000..53ec6fdf --- /dev/null +++ b/backend/internal/service/antigravity_rate_limit_test.go @@ -0,0 +1,190 @@ +//go:build unit + +package service + +import ( + "context" + "fmt" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/antigravity" + "github.com/stretchr/testify/require" +) + +type stubAntigravityUpstream struct { + firstBase string + secondBase string + calls []string +} + +func (s *stubAntigravityUpstream) Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) { + url := req.URL.String() + s.calls = append(s.calls, url) + if strings.HasPrefix(url, s.firstBase) { + return &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader(`{"error":{"message":"Resource has been exhausted"}}`)), + }, nil + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader("ok")), + }, nil +} + +func (s *stubAntigravityUpstream) DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) { + return s.Do(req, proxyURL, accountID, accountConcurrency) +} + +type scopeLimitCall struct { + accountID int64 + scope AntigravityQuotaScope + resetAt time.Time +} + +type rateLimitCall struct { + accountID int64 + resetAt time.Time +} + +type stubAntigravityAccountRepo struct { + AccountRepository + scopeCalls []scopeLimitCall + rateCalls []rateLimitCall +} + +func (s *stubAntigravityAccountRepo) SetAntigravityQuotaScopeLimit(ctx context.Context, id int64, scope AntigravityQuotaScope, resetAt time.Time) error { + s.scopeCalls = append(s.scopeCalls, scopeLimitCall{accountID: id, scope: scope, resetAt: resetAt}) + return nil +} + +func (s *stubAntigravityAccountRepo) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error { + s.rateCalls = append(s.rateCalls, rateLimitCall{accountID: id, resetAt: resetAt}) + return nil +} + +func TestAntigravityRetryLoop_URLFallback_UsesLatestSuccess(t *testing.T) { + oldBaseURLs := append([]string(nil), antigravity.BaseURLs...) + oldAvailability := antigravity.DefaultURLAvailability + defer func() { + antigravity.BaseURLs = oldBaseURLs + antigravity.DefaultURLAvailability = oldAvailability + }() + + base1 := "https://ag-1.test" + base2 := "https://ag-2.test" + antigravity.BaseURLs = []string{base1, base2} + antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute) + + upstream := &stubAntigravityUpstream{firstBase: base1, secondBase: base2} + account := &Account{ + ID: 1, + Name: "acc-1", + Platform: PlatformAntigravity, + Schedulable: true, + Status: StatusActive, + Concurrency: 1, + } + + var handleErrorCalled bool + result, err := antigravityRetryLoop(antigravityRetryLoopParams{ + prefix: "[test]", + ctx: context.Background(), + account: account, + proxyURL: "", + accessToken: "token", + action: "generateContent", + body: []byte(`{"input":"test"}`), + quotaScope: AntigravityQuotaScopeClaude, + httpUpstream: upstream, + handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, quotaScope AntigravityQuotaScope) { + handleErrorCalled = true + }, + }) + + require.NoError(t, err) + require.NotNil(t, result) + require.NotNil(t, result.resp) + defer func() { _ = result.resp.Body.Close() }() + require.Equal(t, http.StatusOK, result.resp.StatusCode) + require.False(t, handleErrorCalled) + require.Len(t, upstream.calls, 2) + require.True(t, strings.HasPrefix(upstream.calls[0], base1)) + require.True(t, strings.HasPrefix(upstream.calls[1], base2)) + + available := antigravity.DefaultURLAvailability.GetAvailableURLs() + require.NotEmpty(t, available) + require.Equal(t, base2, available[0]) +} + +func TestAntigravityHandleUpstreamError_UsesScopeLimitWhenEnabled(t *testing.T) { + t.Setenv(antigravityScopeRateLimitEnv, "true") + repo := &stubAntigravityAccountRepo{} + svc := &AntigravityGatewayService{accountRepo: repo} + account := &Account{ID: 9, Name: "acc-9", Platform: PlatformAntigravity} + + body := buildGeminiRateLimitBody("3s") + svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude) + + require.Len(t, repo.scopeCalls, 1) + require.Empty(t, repo.rateCalls) + call := repo.scopeCalls[0] + require.Equal(t, account.ID, call.accountID) + require.Equal(t, AntigravityQuotaScopeClaude, call.scope) + require.WithinDuration(t, time.Now().Add(3*time.Second), call.resetAt, 2*time.Second) +} + +func TestAntigravityHandleUpstreamError_UsesAccountLimitWhenScopeDisabled(t *testing.T) { + t.Setenv(antigravityScopeRateLimitEnv, "false") + repo := &stubAntigravityAccountRepo{} + svc := &AntigravityGatewayService{accountRepo: repo} + account := &Account{ID: 10, Name: "acc-10", Platform: PlatformAntigravity} + + body := buildGeminiRateLimitBody("2s") + svc.handleUpstreamError(context.Background(), "[test]", account, http.StatusTooManyRequests, http.Header{}, body, AntigravityQuotaScopeClaude) + + require.Len(t, repo.rateCalls, 1) + require.Empty(t, repo.scopeCalls) + call := repo.rateCalls[0] + require.Equal(t, account.ID, call.accountID) + require.WithinDuration(t, time.Now().Add(2*time.Second), call.resetAt, 2*time.Second) +} + +func TestAccountIsSchedulableForModel_AntigravityRateLimits(t *testing.T) { + now := time.Now() + future := now.Add(10 * time.Minute) + + account := &Account{ + ID: 1, + Name: "acc", + Platform: PlatformAntigravity, + Status: StatusActive, + Schedulable: true, + } + + account.RateLimitResetAt = &future + require.False(t, account.IsSchedulableForModel("claude-sonnet-4-5")) + require.False(t, account.IsSchedulableForModel("gemini-3-flash")) + + account.RateLimitResetAt = nil + account.Extra = map[string]any{ + antigravityQuotaScopesKey: map[string]any{ + "claude": map[string]any{ + "rate_limit_reset_at": future.Format(time.RFC3339), + }, + }, + } + + require.False(t, account.IsSchedulableForModel("claude-sonnet-4-5")) + require.True(t, account.IsSchedulableForModel("gemini-3-flash")) +} + +func buildGeminiRateLimitBody(delay string) []byte { + return []byte(fmt.Sprintf(`{"error":{"message":"too many requests","details":[{"metadata":{"quotaResetDelay":%q}}]}}`, delay)) +} diff --git a/backend/internal/service/antigravity_token_refresher.go b/backend/internal/service/antigravity_token_refresher.go index 9dd4463f..a07c86e6 100644 --- a/backend/internal/service/antigravity_token_refresher.go +++ b/backend/internal/service/antigravity_token_refresher.go @@ -61,5 +61,10 @@ func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Accoun } } + // 如果 project_id 获取失败,返回 credentials 但同时返回错误让账户被标记 + if tokenInfo.ProjectIDMissing { + return newCredentials, fmt.Errorf("missing_project_id: 账户缺少project id,可能无法使用Antigravity") + } + return newCredentials, nil } diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index 521f1da5..eb5c7534 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -94,6 +94,20 @@ func (s *APIKeyService) initAuthCache(cfg *config.Config) { s.authCacheL1 = cache } +// StartAuthCacheInvalidationSubscriber starts the Pub/Sub subscriber for L1 cache invalidation. +// This should be called after the service is fully initialized. +func (s *APIKeyService) StartAuthCacheInvalidationSubscriber(ctx context.Context) { + if s.cache == nil || s.authCacheL1 == nil { + return + } + if err := s.cache.SubscribeAuthCacheInvalidation(ctx, func(cacheKey string) { + s.authCacheL1.Del(cacheKey) + }); err != nil { + // Log but don't fail - L1 cache will still work, just without cross-instance invalidation + println("[Service] Warning: failed to start auth cache invalidation subscriber:", err.Error()) + } +} + func (s *APIKeyService) authCacheKey(key string) string { sum := sha256.Sum256([]byte(key)) return hex.EncodeToString(sum[:]) @@ -149,6 +163,8 @@ func (s *APIKeyService) deleteAuthCache(ctx context.Context, cacheKey string) { return } _ = s.cache.DeleteAuthCache(ctx, cacheKey) + // Publish invalidation message to other instances + _ = s.cache.PublishAuthCacheInvalidation(ctx, cacheKey) } func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey string) (*APIKeyAuthCacheEntry, error) { diff --git a/backend/internal/service/api_key_service.go b/backend/internal/service/api_key_service.go index ecc570c7..ef1ff990 100644 --- a/backend/internal/service/api_key_service.go +++ b/backend/internal/service/api_key_service.go @@ -65,6 +65,10 @@ type APIKeyCache interface { GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error DeleteAuthCache(ctx context.Context, key string) error + + // Pub/Sub for L1 cache invalidation across instances + PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error + SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error } // APIKeyAuthCacheInvalidator 提供认证缓存失效能力 diff --git a/backend/internal/service/api_key_service_cache_test.go b/backend/internal/service/api_key_service_cache_test.go index 5f2d69c4..c5e9cd47 100644 --- a/backend/internal/service/api_key_service_cache_test.go +++ b/backend/internal/service/api_key_service_cache_test.go @@ -142,6 +142,14 @@ func (s *authCacheStub) DeleteAuthCache(ctx context.Context, key string) error { return nil } +func (s *authCacheStub) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error { + return nil +} + +func (s *authCacheStub) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error { + return nil +} + func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) { cache := &authCacheStub{} repo := &authRepoStub{ diff --git a/backend/internal/service/api_key_service_delete_test.go b/backend/internal/service/api_key_service_delete_test.go index 32ae884e..092b7fce 100644 --- a/backend/internal/service/api_key_service_delete_test.go +++ b/backend/internal/service/api_key_service_delete_test.go @@ -168,6 +168,14 @@ func (s *apiKeyCacheStub) DeleteAuthCache(ctx context.Context, key string) error return nil } +func (s *apiKeyCacheStub) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error { + return nil +} + +func (s *apiKeyCacheStub) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error { + return nil +} + // TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。 // 预期行为: // - GetKeyAndOwnerID 返回所有者 ID 为 1 diff --git a/backend/internal/service/dashboard_aggregation_service.go b/backend/internal/service/dashboard_aggregation_service.go index da5c0e7d..10c68868 100644 --- a/backend/internal/service/dashboard_aggregation_service.go +++ b/backend/internal/service/dashboard_aggregation_service.go @@ -20,12 +20,16 @@ var ( // ErrDashboardBackfillDisabled 当配置禁用回填时返回。 ErrDashboardBackfillDisabled = errors.New("仪表盘聚合回填已禁用") // ErrDashboardBackfillTooLarge 当回填跨度超过限制时返回。 - ErrDashboardBackfillTooLarge = errors.New("回填时间跨度过大") + ErrDashboardBackfillTooLarge = errors.New("回填时间跨度过大") + errDashboardAggregationRunning = errors.New("聚合作业正在运行") ) // DashboardAggregationRepository 定义仪表盘预聚合仓储接口。 type DashboardAggregationRepository interface { AggregateRange(ctx context.Context, start, end time.Time) error + // RecomputeRange 重新计算指定时间范围内的聚合数据(包含活跃用户等派生表)。 + // 设计目的:当 usage_logs 被批量删除/回滚后,确保聚合表可恢复一致性。 + RecomputeRange(ctx context.Context, start, end time.Time) error GetAggregationWatermark(ctx context.Context) (time.Time, error) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error @@ -112,6 +116,41 @@ func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) erro return nil } +// TriggerRecomputeRange 触发指定范围的重新计算(异步)。 +// 与 TriggerBackfill 不同: +// - 不依赖 backfill_enabled(这是内部一致性修复) +// - 不更新 watermark(避免影响正常增量聚合游标) +func (s *DashboardAggregationService) TriggerRecomputeRange(start, end time.Time) error { + if s == nil || s.repo == nil { + return errors.New("聚合服务未初始化") + } + if !s.cfg.Enabled { + return errors.New("聚合服务已禁用") + } + if !end.After(start) { + return errors.New("重新计算时间范围无效") + } + + go func() { + const maxRetries = 3 + for i := 0; i < maxRetries; i++ { + ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout) + err := s.recomputeRange(ctx, start, end) + cancel() + if err == nil { + return + } + if !errors.Is(err, errDashboardAggregationRunning) { + log.Printf("[DashboardAggregation] 重新计算失败: %v", err) + return + } + time.Sleep(5 * time.Second) + } + log.Printf("[DashboardAggregation] 重新计算放弃: 聚合作业持续占用") + }() + return nil +} + func (s *DashboardAggregationService) recomputeRecentDays() { days := s.cfg.RecomputeDays if days <= 0 { @@ -128,6 +167,24 @@ func (s *DashboardAggregationService) recomputeRecentDays() { } } +func (s *DashboardAggregationService) recomputeRange(ctx context.Context, start, end time.Time) error { + if !atomic.CompareAndSwapInt32(&s.running, 0, 1) { + return errDashboardAggregationRunning + } + defer atomic.StoreInt32(&s.running, 0) + + jobStart := time.Now().UTC() + if err := s.repo.RecomputeRange(ctx, start, end); err != nil { + return err + } + log.Printf("[DashboardAggregation] 重新计算完成 (start=%s end=%s duration=%s)", + start.UTC().Format(time.RFC3339), + end.UTC().Format(time.RFC3339), + time.Since(jobStart).String(), + ) + return nil +} + func (s *DashboardAggregationService) runScheduledAggregation() { if !atomic.CompareAndSwapInt32(&s.running, 0, 1) { return @@ -179,7 +236,7 @@ func (s *DashboardAggregationService) runScheduledAggregation() { func (s *DashboardAggregationService) backfillRange(ctx context.Context, start, end time.Time) error { if !atomic.CompareAndSwapInt32(&s.running, 0, 1) { - return errors.New("聚合作业正在运行") + return errDashboardAggregationRunning } defer atomic.StoreInt32(&s.running, 0) diff --git a/backend/internal/service/dashboard_aggregation_service_test.go b/backend/internal/service/dashboard_aggregation_service_test.go index 2fc22105..a7058985 100644 --- a/backend/internal/service/dashboard_aggregation_service_test.go +++ b/backend/internal/service/dashboard_aggregation_service_test.go @@ -27,6 +27,10 @@ func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, s return s.aggregateErr } +func (s *dashboardAggregationRepoTestStub) RecomputeRange(ctx context.Context, start, end time.Time) error { + return s.AggregateRange(ctx, start, end) +} + func (s *dashboardAggregationRepoTestStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) { return s.watermark, nil } diff --git a/backend/internal/service/dashboard_service.go b/backend/internal/service/dashboard_service.go index a9811919..cd11923e 100644 --- a/backend/internal/service/dashboard_service.go +++ b/backend/internal/service/dashboard_service.go @@ -124,16 +124,16 @@ func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.D return stats, nil } -func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool) ([]usagestats.TrendDataPoint, error) { - trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream) +func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) { + trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, stream, billingType) if err != nil { return nil, fmt.Errorf("get usage trend with filters: %w", err) } return trend, nil } -func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool) ([]usagestats.ModelStat, error) { - stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream) +func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) { + stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, stream, billingType) if err != nil { return nil, fmt.Errorf("get model stats with filters: %w", err) } diff --git a/backend/internal/service/dashboard_service_test.go b/backend/internal/service/dashboard_service_test.go index db3c78c3..59b83e66 100644 --- a/backend/internal/service/dashboard_service_test.go +++ b/backend/internal/service/dashboard_service_test.go @@ -101,6 +101,10 @@ func (s *dashboardAggregationRepoStub) AggregateRange(ctx context.Context, start return nil } +func (s *dashboardAggregationRepoStub) RecomputeRange(ctx context.Context, start, end time.Time) error { + return nil +} + func (s *dashboardAggregationRepoStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) { if s.err != nil { return time.Time{}, s.err diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 49bb86a7..da1b9377 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -93,13 +93,14 @@ const ( SettingKeyLinuxDoConnectRedirectURL = "linuxdo_connect_redirect_url" // OEM设置 - SettingKeySiteName = "site_name" // 网站名称 - SettingKeySiteLogo = "site_logo" // 网站Logo (base64) - SettingKeySiteSubtitle = "site_subtitle" // 网站副标题 - SettingKeyAPIBaseURL = "api_base_url" // API端点地址(用于客户端配置和导入) - SettingKeyContactInfo = "contact_info" // 客服联系方式 - SettingKeyDocURL = "doc_url" // 文档链接 - SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src) + SettingKeySiteName = "site_name" // 网站名称 + SettingKeySiteLogo = "site_logo" // 网站Logo (base64) + SettingKeySiteSubtitle = "site_subtitle" // 网站副标题 + SettingKeyAPIBaseURL = "api_base_url" // API端点地址(用于客户端配置和导入) + SettingKeyContactInfo = "contact_info" // 客服联系方式 + SettingKeyDocURL = "doc_url" // 文档链接 + SettingKeyHomeContent = "home_content" // 首页内容(支持 Markdown/HTML,或 URL 作为 iframe src) + SettingKeyHideCcsImportButton = "hide_ccs_import_button" // 是否隐藏 API Keys 页面的导入 CCS 按钮 // 默认配置 SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量 diff --git a/backend/internal/service/gateway_multiplatform_test.go b/backend/internal/service/gateway_multiplatform_test.go index f543ef1a..4d17d5e1 100644 --- a/backend/internal/service/gateway_multiplatform_test.go +++ b/backend/internal/service/gateway_multiplatform_test.go @@ -105,6 +105,9 @@ func (m *mockAccountRepoForPlatform) BatchUpdateLastUsed(ctx context.Context, up func (m *mockAccountRepoForPlatform) SetError(ctx context.Context, id int64, errorMsg string) error { return nil } +func (m *mockAccountRepoForPlatform) ClearError(ctx context.Context, id int64) error { + return nil +} func (m *mockAccountRepoForPlatform) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { return nil } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 8893dac1..f04397e8 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -11,6 +11,8 @@ import ( "fmt" "io" "log" + "log/slog" + mathrand "math/rand" "net/http" "os" "regexp" @@ -819,11 +821,20 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context } // SelectAccountWithLoadAwareness selects account with load-awareness and wait plan. -// metadataUserID: 原始 metadata.user_id 字段(用于提取会话 UUID 进行会话数量限制) +// metadataUserID: 已废弃参数,会话限制现在统一使用 sessionHash func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, metadataUserID string) (*AccountSelectionResult, error) { + // 调试日志:记录调度入口参数 + excludedIDsList := make([]int64, 0, len(excludedIDs)) + for id := range excludedIDs { + excludedIDsList = append(excludedIDsList, id) + } + slog.Debug("account_scheduling_starting", + "group_id", derefGroupID(groupID), + "model", requestedModel, + "session", shortSessionHash(sessionHash), + "excluded_ids", excludedIDsList) + cfg := s.schedulingConfig() - // 提取会话 UUID(用于会话数量限制) - sessionUUID := extractSessionUUID(metadataUserID) var stickyAccountID int64 if sessionHash != "" && s.cache != nil { @@ -849,41 +860,63 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } if s.concurrencyService == nil || !cfg.LoadBatchEnabled { - account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs) - if err != nil { - return nil, err + // 复制排除列表,用于会话限制拒绝时的重试 + localExcluded := make(map[int64]struct{}) + for k, v := range excludedIDs { + localExcluded[k] = v } - result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency) - if err == nil && result.Acquired { - return &AccountSelectionResult{ - Account: account, - Acquired: true, - ReleaseFunc: result.ReleaseFunc, - }, nil - } - if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { - waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) - if waitingCount < cfg.StickySessionMaxWaiting { + + for { + account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, localExcluded) + if err != nil { + return nil, err + } + + result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency) + if err == nil && result.Acquired { + // 获取槽位后检查会话限制(使用 sessionHash 作为会话标识符) + if !s.checkAndRegisterSession(ctx, account, sessionHash) { + result.ReleaseFunc() // 释放槽位 + localExcluded[account.ID] = struct{}{} // 排除此账号 + continue // 重新选择 + } return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, + Account: account, + Acquired: true, + ReleaseFunc: result.ReleaseFunc, }, nil } + + // 对于等待计划的情况,也需要先检查会话限制 + if !s.checkAndRegisterSession(ctx, account, sessionHash) { + localExcluded[account.ID] = struct{}{} + continue + } + + if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil { + waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID) + if waitingCount < cfg.StickySessionMaxWaiting { + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } + } + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: account.ID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.FallbackWaitTimeout, + MaxWaiting: cfg.FallbackMaxWaiting, + }, + }, nil } - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: account.ID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.FallbackWaitTimeout, - MaxWaiting: cfg.FallbackMaxWaiting, - }, - }, nil } platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID, group) @@ -999,7 +1032,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 - if !s.checkAndRegisterSession(ctx, stickyAccount, sessionUUID) { + if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { result.ReleaseFunc() // 释放槽位 // 继续到负载感知选择 } else { @@ -1017,15 +1050,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID) if waitingCount < cfg.StickySessionMaxWaiting { - return &AccountSelectionResult{ - Account: stickyAccount, - WaitPlan: &AccountWaitPlan{ - AccountID: stickyAccountID, - MaxConcurrency: stickyAccount.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil + // 会话数量限制检查(等待计划也需要占用会话配额) + if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) { + // 会话限制已满,继续到负载感知选择 + } else { + return &AccountSelectionResult{ + Account: stickyAccount, + WaitPlan: &AccountWaitPlan{ + AccountID: stickyAccountID, + MaxConcurrency: stickyAccount.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } } // 粘性账号槽位满且等待队列已满,继续使用负载感知选择 } @@ -1086,7 +1124,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 - if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) { + if !s.checkAndRegisterSession(ctx, item.account, sessionHash) { result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 continue } @@ -1104,20 +1142,26 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } } - // 5. 所有路由账号槽位满,返回等待计划(选择负载最低的) - acc := routingAvailable[0].account - if s.debugModelRoutingEnabled() { - log.Printf("[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), acc.ID) + // 5. 所有路由账号槽位满,尝试返回等待计划(选择负载最低的) + // 遍历找到第一个满足会话限制的账号 + for _, item := range routingAvailable { + if !s.checkAndRegisterSession(ctx, item.account, sessionHash) { + continue // 会话限制已满,尝试下一个 + } + if s.debugModelRoutingEnabled() { + log.Printf("[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID) + } + return &AccountSelectionResult{ + Account: item.account, + WaitPlan: &AccountWaitPlan{ + AccountID: item.account.ID, + MaxConcurrency: item.account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil } - return &AccountSelectionResult{ - Account: acc, - WaitPlan: &AccountWaitPlan{ - AccountID: acc.ID, - MaxConcurrency: acc.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil + // 所有路由账号会话限制都已满,继续到 Layer 2 回退 } // 路由列表中的账号都不可用(负载率 >= 100),继续到 Layer 2 回退 log.Printf("[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection", requestedModel) @@ -1137,7 +1181,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 - if !s.checkAndRegisterSession(ctx, account, sessionUUID) { + if !s.checkAndRegisterSession(ctx, account, sessionHash) { result.ReleaseFunc() // 释放槽位,继续到 Layer 2 } else { _ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL) @@ -1151,15 +1195,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID) if waitingCount < cfg.StickySessionMaxWaiting { - return &AccountSelectionResult{ - Account: account, - WaitPlan: &AccountWaitPlan{ - AccountID: accountID, - MaxConcurrency: account.Concurrency, - Timeout: cfg.StickySessionWaitTimeout, - MaxWaiting: cfg.StickySessionMaxWaiting, - }, - }, nil + // 会话数量限制检查(等待计划也需要占用会话配额) + if !s.checkAndRegisterSession(ctx, account, sessionHash) { + // 会话限制已满,继续到 Layer 2 + } else { + return &AccountSelectionResult{ + Account: account, + WaitPlan: &AccountWaitPlan{ + AccountID: accountID, + MaxConcurrency: account.Concurrency, + Timeout: cfg.StickySessionWaitTimeout, + MaxWaiting: cfg.StickySessionMaxWaiting, + }, + }, nil + } } } } @@ -1208,7 +1257,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads) if err != nil { - if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth, sessionUUID); ok { + if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok { return result, nil } } else { @@ -1258,7 +1307,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 - if !s.checkAndRegisterSession(ctx, item.account, sessionUUID) { + if !s.checkAndRegisterSession(ctx, item.account, sessionHash) { result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 continue } @@ -1276,8 +1325,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } // ============ Layer 3: 兜底排队 ============ - sortAccountsByPriorityAndLastUsed(candidates, preferOAuth) + s.sortCandidatesForFallback(candidates, preferOAuth, cfg.FallbackSelectionMode) for _, acc := range candidates { + // 会话数量限制检查(等待计划也需要占用会话配额) + if !s.checkAndRegisterSession(ctx, acc, sessionHash) { + continue // 会话限制已满,尝试下一个账号 + } return &AccountSelectionResult{ Account: acc, WaitPlan: &AccountWaitPlan{ @@ -1291,7 +1344,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro return nil, errors.New("no available accounts") } -func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool, sessionUUID string) (*AccountSelectionResult, bool) { +func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) { ordered := append([]*Account(nil), candidates...) sortAccountsByPriorityAndLastUsed(ordered, preferOAuth) @@ -1299,7 +1352,7 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency) if err == nil && result.Acquired { // 会话数量限制检查 - if !s.checkAndRegisterSession(ctx, acc, sessionUUID) { + if !s.checkAndRegisterSession(ctx, acc, sessionHash) { result.ReleaseFunc() // 释放槽位,继续尝试下一个账号 continue } @@ -1456,7 +1509,24 @@ func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64, gr func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) { if s.schedulerSnapshot != nil { - return s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) + accounts, useMixed, err := s.schedulerSnapshot.ListSchedulableAccounts(ctx, groupID, platform, hasForcePlatform) + if err == nil { + slog.Debug("account_scheduling_list_snapshot", + "group_id", derefGroupID(groupID), + "platform", platform, + "use_mixed", useMixed, + "count", len(accounts)) + for _, acc := range accounts { + slog.Debug("account_scheduling_account_detail", + "account_id", acc.ID, + "name", acc.Name, + "platform", acc.Platform, + "type", acc.Type, + "status", acc.Status, + "tls_fingerprint", acc.IsTLSFingerprintEnabled()) + } + } + return accounts, useMixed, err } useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform if useMixed { @@ -1469,6 +1539,10 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms) } if err != nil { + slog.Debug("account_scheduling_list_failed", + "group_id", derefGroupID(groupID), + "platform", platform, + "error", err) return nil, useMixed, err } filtered := make([]Account, 0, len(accounts)) @@ -1478,6 +1552,20 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i } filtered = append(filtered, acc) } + slog.Debug("account_scheduling_list_mixed", + "group_id", derefGroupID(groupID), + "platform", platform, + "raw_count", len(accounts), + "filtered_count", len(filtered)) + for _, acc := range filtered { + slog.Debug("account_scheduling_account_detail", + "account_id", acc.ID, + "name", acc.Name, + "platform", acc.Platform, + "type", acc.Type, + "status", acc.Status, + "tls_fingerprint", acc.IsTLSFingerprintEnabled()) + } return filtered, useMixed, nil } @@ -1492,8 +1580,25 @@ func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *i accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform) } if err != nil { + slog.Debug("account_scheduling_list_failed", + "group_id", derefGroupID(groupID), + "platform", platform, + "error", err) return nil, useMixed, err } + slog.Debug("account_scheduling_list_single", + "group_id", derefGroupID(groupID), + "platform", platform, + "count", len(accounts)) + for _, acc := range accounts { + slog.Debug("account_scheduling_account_detail", + "account_id", acc.ID, + "name", acc.Name, + "platform", acc.Platform, + "type", acc.Type, + "status", acc.Status, + "tls_fingerprint", acc.IsTLSFingerprintEnabled()) + } return accounts, useMixed, nil } @@ -1559,12 +1664,8 @@ func (s *GatewayService) isAccountSchedulableForWindowCost(ctx context.Context, // 缓存未命中,从数据库查询 { - var startTime time.Time - if account.SessionWindowStart != nil { - startTime = *account.SessionWindowStart - } else { - startTime = time.Now().Add(-5 * time.Hour) - } + // 使用统一的窗口开始时间计算逻辑(考虑窗口过期情况) + startTime := account.GetCurrentWindowStartTime() stats, err := s.usageLogRepo.GetAccountWindowStats(ctx, account.ID, startTime) if err != nil { @@ -1597,15 +1698,16 @@ checkSchedulability: // checkAndRegisterSession 检查并注册会话,用于会话数量限制 // 仅适用于 Anthropic OAuth/SetupToken 账号 +// sessionID: 会话标识符(使用粘性会话的 hash) // 返回 true 表示允许(在限制内或会话已存在),false 表示拒绝(超出限制且是新会话) -func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *Account, sessionUUID string) bool { +func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *Account, sessionID string) bool { // 只检查 Anthropic OAuth/SetupToken 账号 if !account.IsAnthropicOAuthOrSetupToken() { return true } maxSessions := account.GetMaxSessions() - if maxSessions <= 0 || sessionUUID == "" { + if maxSessions <= 0 || sessionID == "" { return true // 未启用会话限制或无会话ID } @@ -1615,7 +1717,7 @@ func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *A idleTimeout := time.Duration(account.GetSessionIdleTimeoutMinutes()) * time.Minute - allowed, err := s.sessionLimitCache.RegisterSession(ctx, account.ID, sessionUUID, maxSessions, idleTimeout) + allowed, err := s.sessionLimitCache.RegisterSession(ctx, account.ID, sessionID, maxSessions, idleTimeout) if err != nil { // 失败开放:缓存错误时允许通过 return true @@ -1623,18 +1725,6 @@ func (s *GatewayService) checkAndRegisterSession(ctx context.Context, account *A return allowed } -// extractSessionUUID 从 metadata.user_id 中提取会话 UUID -// 格式: user_{64位hex}_account__session_{uuid} -func extractSessionUUID(metadataUserID string) string { - if metadataUserID == "" { - return "" - } - if match := sessionIDRegex.FindStringSubmatch(metadataUserID); len(match) > 1 { - return match[1] - } - return "" -} - func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID int64) (*Account, error) { if s.schedulerSnapshot != nil { return s.schedulerSnapshot.GetAccount(ctx, accountID) @@ -1664,6 +1754,56 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) { }) } +// sortCandidatesForFallback 根据配置选择排序策略 +// mode: "last_used"(按最后使用时间) 或 "random"(随机) +func (s *GatewayService) sortCandidatesForFallback(accounts []*Account, preferOAuth bool, mode string) { + if mode == "random" { + // 先按优先级排序,然后在同优先级内随机打乱 + sortAccountsByPriorityOnly(accounts, preferOAuth) + shuffleWithinPriority(accounts) + } else { + // 默认按最后使用时间排序 + sortAccountsByPriorityAndLastUsed(accounts, preferOAuth) + } +} + +// sortAccountsByPriorityOnly 仅按优先级排序 +func sortAccountsByPriorityOnly(accounts []*Account, preferOAuth bool) { + sort.SliceStable(accounts, func(i, j int) bool { + a, b := accounts[i], accounts[j] + if a.Priority != b.Priority { + return a.Priority < b.Priority + } + if preferOAuth && a.Type != b.Type { + return a.Type == AccountTypeOAuth + } + return false + }) +} + +// shuffleWithinPriority 在同优先级内随机打乱顺序 +func shuffleWithinPriority(accounts []*Account) { + if len(accounts) <= 1 { + return + } + r := mathrand.New(mathrand.NewSource(time.Now().UnixNano())) + start := 0 + for start < len(accounts) { + priority := accounts[start].Priority + end := start + 1 + for end < len(accounts) && accounts[end].Priority == priority { + end++ + } + // 对 [start, end) 范围内的账户随机打乱 + if end-start > 1 { + r.Shuffle(end-start, func(i, j int) { + accounts[start+i], accounts[start+j] = accounts[start+j], accounts[start+i] + }) + } + start = end + } +} + // selectAccountForModelWithPlatform 选择单平台账户(完全隔离) func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) { preferOAuth := platform == PlatformGemini @@ -2524,6 +2664,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A proxyURL = account.Proxy.URL() } + // 调试日志:记录即将转发的账号信息 + log.Printf("[Forward] Using account: ID=%d Name=%s Platform=%s Type=%s TLSFingerprint=%v Proxy=%s", + account.ID, account.Name, account.Platform, account.Type, account.IsTLSFingerprintEnabled(), proxyURL) + // 重试循环 var resp *http.Response retryStart := time.Now() @@ -2537,7 +2681,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A } // 发送请求 - resp, err = s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + resp, err = s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if err != nil { if resp != nil && resp.Body != nil { _ = resp.Body.Close() @@ -2611,7 +2755,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A filteredBody := FilterThinkingBlocksForRetry(body) retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) if buildErr == nil { - retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency) + retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr == nil { if retryResp.StatusCode < 400 { log.Printf("Account %d: signature error retry succeeded (thinking downgraded)", account.ID) @@ -2643,7 +2787,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A filteredBody2 := FilterSignatureSensitiveBlocksForRetry(body) retryReq2, buildErr2 := s.buildUpstreamRequest(ctx, c, account, filteredBody2, token, tokenType, reqModel, reqStream, shouldMimicClaudeCode) if buildErr2 == nil { - retryResp2, retryErr2 := s.httpUpstream.Do(retryReq2, proxyURL, account.ID, account.Concurrency) + retryResp2, retryErr2 := s.httpUpstream.DoWithTLS(retryReq2, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr2 == nil { resp = retryResp2 break @@ -2758,6 +2902,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A _ = resp.Body.Close() resp.Body = io.NopCloser(bytes.NewReader(respBody)) + // 调试日志:打印重试耗尽后的错误响应 + log.Printf("[Forward] Upstream error (retry exhausted, failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", + account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) + s.handleRetryExhaustedSideEffects(ctx, resp, account) appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, @@ -2785,6 +2933,10 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A _ = resp.Body.Close() resp.Body = io.NopCloser(bytes.NewReader(respBody)) + // 调试日志:打印上游错误响应 + log.Printf("[Forward] Upstream error (failover): Account=%d(%s) Status=%d RequestID=%s Body=%s", + account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(respBody), 1000)) + s.handleFailoverSideEffects(ctx, resp, account) appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ Platform: account.Platform, @@ -2914,9 +3066,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex fingerprint = fp // 2. 重写metadata.user_id(需要指纹中的ClientID和账号的account_uuid) + // 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值 accountUUID := account.GetExtraString("account_uuid") if accountUUID != "" && fp.ClientID != "" { - if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 { + if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 { body = newBody } } @@ -3183,6 +3336,10 @@ func extractUpstreamErrorMessage(body []byte) string { func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) { body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20)) + // 调试日志:打印上游错误响应 + log.Printf("[Forward] Upstream error (non-retryable): Account=%d(%s) Status=%d RequestID=%s Body=%s", + account.ID, account.Name, resp.StatusCode, resp.Header.Get("x-request-id"), truncateString(string(body), 1000)) + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) @@ -4171,7 +4328,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, } // 发送请求 - resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + resp, err := s.httpUpstream.DoWithTLS(upstreamReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if err != nil { setOpsUpstreamError(c, 0, sanitizeUpstreamErrorMessage(err.Error()), "") s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed") @@ -4193,7 +4350,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, filteredBody := FilterThinkingBlocksForRetry(body) retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel, shouldMimicClaudeCode) if buildErr == nil { - retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency) + retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled()) if retryErr == nil { resp = retryResp respBody, err = io.ReadAll(resp.Body) @@ -4271,12 +4428,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con } // OAuth 账号:应用统一指纹和重写 userID + // 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值 if account.IsOAuth() && s.identityService != nil { fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header) if err == nil { accountUUID := account.GetExtraString("account_uuid") if accountUUID != "" && fp.ClientID != "" { - if newBody, err := s.identityService.RewriteUserID(body, account.ID, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 { + if newBody, err := s.identityService.RewriteUserIDWithMasking(ctx, body, account, accountUUID, fp.ClientID); err == nil && len(newBody) > 0 { body = newBody } } diff --git a/backend/internal/service/gemini_multiplatform_test.go b/backend/internal/service/gemini_multiplatform_test.go index 03f5d757..262a05d9 100644 --- a/backend/internal/service/gemini_multiplatform_test.go +++ b/backend/internal/service/gemini_multiplatform_test.go @@ -88,6 +88,9 @@ func (m *mockAccountRepoForGemini) BatchUpdateLastUsed(ctx context.Context, upda func (m *mockAccountRepoForGemini) SetError(ctx context.Context, id int64, errorMsg string) error { return nil } +func (m *mockAccountRepoForGemini) ClearError(ctx context.Context, id int64) error { + return nil +} func (m *mockAccountRepoForGemini) SetSchedulable(ctx context.Context, id int64, schedulable bool) error { return nil } @@ -599,7 +602,7 @@ func TestGeminiMessagesCompatService_isModelSupportedByAccount(t *testing.T) { name: "Gemini平台-有映射配置-只支持配置的模型", account: &Account{ Platform: PlatformGemini, - Credentials: map[string]any{"model_mapping": map[string]any{"gemini-1.5-pro": "x"}}, + Credentials: map[string]any{"model_mapping": map[string]any{"gemini-2.5-pro": "x"}}, }, model: "gemini-2.5-flash", expected: false, diff --git a/backend/internal/service/http_upstream_port.go b/backend/internal/service/http_upstream_port.go index 9357f763..0e4cfbec 100644 --- a/backend/internal/service/http_upstream_port.go +++ b/backend/internal/service/http_upstream_port.go @@ -10,6 +10,7 @@ import "net/http" // - 支持可选代理配置 // - 支持账户级连接池隔离 // - 实现类负责连接池管理和复用 +// - 支持可选的 TLS 指纹伪装 type HTTPUpstream interface { // Do 执行 HTTP 请求 // @@ -27,4 +28,28 @@ type HTTPUpstream interface { // - 调用方必须关闭 resp.Body,否则会导致连接泄漏 // - 响应体可能已被包装以跟踪请求生命周期 Do(req *http.Request, proxyURL string, accountID int64, accountConcurrency int) (*http.Response, error) + + // DoWithTLS 执行带 TLS 指纹伪装的 HTTP 请求 + // + // 参数: + // - req: HTTP 请求对象,由调用方构建 + // - proxyURL: 代理服务器地址,空字符串表示直连 + // - accountID: 账户 ID,用于连接池隔离和 TLS 指纹模板选择 + // - accountConcurrency: 账户并发限制,用于动态调整连接池大小 + // - enableTLSFingerprint: 是否启用 TLS 指纹伪装 + // + // 返回: + // - *http.Response: HTTP 响应,调用方必须关闭 Body + // - error: 请求错误(网络错误、超时等) + // + // TLS 指纹说明: + // - 当 enableTLSFingerprint=true 时,使用 utls 库模拟 Claude CLI 的 TLS 指纹 + // - TLS 指纹模板根据 accountID % len(profiles) 自动选择 + // - 支持直连、HTTP/HTTPS 代理、SOCKS5 代理三种场景 + // - 如果 enableTLSFingerprint=false,行为与 Do 方法相同 + // + // 注意: + // - 调用方必须关闭 resp.Body,否则会导致连接泄漏 + // - TLS 指纹客户端与普通客户端使用不同的缓存键,互不影响 + DoWithTLS(req *http.Request, proxyURL string, accountID int64, accountConcurrency int, enableTLSFingerprint bool) (*http.Response, error) } diff --git a/backend/internal/service/identity_service.go b/backend/internal/service/identity_service.go index 4ab1ab96..4e227fea 100644 --- a/backend/internal/service/identity_service.go +++ b/backend/internal/service/identity_service.go @@ -8,9 +8,11 @@ import ( "encoding/json" "fmt" "log" + "log/slog" "net/http" "regexp" "strconv" + "strings" "time" ) @@ -49,6 +51,13 @@ type Fingerprint struct { type IdentityCache interface { GetFingerprint(ctx context.Context, accountID int64) (*Fingerprint, error) SetFingerprint(ctx context.Context, accountID int64, fp *Fingerprint) error + // GetMaskedSessionID 获取固定的会话ID(用于会话ID伪装功能) + // 返回的 sessionID 是一个 UUID 格式的字符串 + // 如果不存在或已过期(15分钟无请求),返回空字符串 + GetMaskedSessionID(ctx context.Context, accountID int64) (string, error) + // SetMaskedSessionID 设置固定的会话ID,TTL 为 15 分钟 + // 每次调用都会刷新 TTL + SetMaskedSessionID(ctx context.Context, accountID int64, sessionID string) error } // IdentityService 管理OAuth账号的请求身份指纹 @@ -203,6 +212,94 @@ func (s *IdentityService) RewriteUserID(body []byte, accountID int64, accountUUI return json.Marshal(reqMap) } +// RewriteUserIDWithMasking 重写body中的metadata.user_id,支持会话ID伪装 +// 如果账号启用了会话ID伪装(session_id_masking_enabled), +// 则在完成常规重写后,将 session 部分替换为固定的伪装ID(15分钟内保持不变) +func (s *IdentityService) RewriteUserIDWithMasking(ctx context.Context, body []byte, account *Account, accountUUID, cachedClientID string) ([]byte, error) { + // 先执行常规的 RewriteUserID 逻辑 + newBody, err := s.RewriteUserID(body, account.ID, accountUUID, cachedClientID) + if err != nil { + return newBody, err + } + + // 检查是否启用会话ID伪装 + if !account.IsSessionIDMaskingEnabled() { + return newBody, nil + } + + // 解析重写后的 body,提取 user_id + var reqMap map[string]any + if err := json.Unmarshal(newBody, &reqMap); err != nil { + return newBody, nil + } + + metadata, ok := reqMap["metadata"].(map[string]any) + if !ok { + return newBody, nil + } + + userID, ok := metadata["user_id"].(string) + if !ok || userID == "" { + return newBody, nil + } + + // 查找 _session_ 的位置,替换其后的内容 + const sessionMarker = "_session_" + idx := strings.LastIndex(userID, sessionMarker) + if idx == -1 { + return newBody, nil + } + + // 获取或生成固定的伪装 session ID + maskedSessionID, err := s.cache.GetMaskedSessionID(ctx, account.ID) + if err != nil { + log.Printf("Warning: failed to get masked session ID for account %d: %v", account.ID, err) + return newBody, nil + } + + if maskedSessionID == "" { + // 首次或已过期,生成新的伪装 session ID + maskedSessionID = generateRandomUUID() + log.Printf("Generated new masked session ID for account %d: %s", account.ID, maskedSessionID) + } + + // 刷新 TTL(每次请求都刷新,保持 15 分钟有效期) + if err := s.cache.SetMaskedSessionID(ctx, account.ID, maskedSessionID); err != nil { + log.Printf("Warning: failed to set masked session ID for account %d: %v", account.ID, err) + } + + // 替换 session 部分:保留 _session_ 之前的内容,替换之后的内容 + newUserID := userID[:idx+len(sessionMarker)] + maskedSessionID + + slog.Debug("session_id_masking_applied", + "account_id", account.ID, + "before", userID, + "after", newUserID, + ) + + metadata["user_id"] = newUserID + reqMap["metadata"] = metadata + + return json.Marshal(reqMap) +} + +// generateRandomUUID 生成随机 UUID v4 格式字符串 +func generateRandomUUID() string { + b := make([]byte, 16) + if _, err := rand.Read(b); err != nil { + // fallback: 使用时间戳生成 + h := sha256.Sum256([]byte(fmt.Sprintf("%d", time.Now().UnixNano()))) + b = h[:16] + } + + // 设置 UUID v4 版本和变体位 + b[6] = (b[6] & 0x0f) | 0x40 + b[8] = (b[8] & 0x3f) | 0x80 + + return fmt.Sprintf("%x-%x-%x-%x-%x", + b[0:4], b[4:6], b[6:8], b[8:10], b[10:16]) +} + // generateClientID 生成64位十六进制客户端ID(32字节随机数) func generateClientID() string { b := make([]byte, 32) diff --git a/backend/internal/service/oauth_service.go b/backend/internal/service/oauth_service.go index 0039cb44..03c3438a 100644 --- a/backend/internal/service/oauth_service.go +++ b/backend/internal/service/oauth_service.go @@ -48,8 +48,7 @@ type GenerateAuthURLResult struct { // GenerateAuthURL generates an OAuth authorization URL with full scope func (s *OAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64) (*GenerateAuthURLResult, error) { - scope := fmt.Sprintf("%s %s", oauth.ScopeProfile, oauth.ScopeInference) - return s.generateAuthURLWithScope(ctx, scope, proxyID) + return s.generateAuthURLWithScope(ctx, oauth.ScopeOAuth, proxyID) } // GenerateSetupTokenURL generates an OAuth authorization URL for setup token (inference only) @@ -176,7 +175,8 @@ func (s *OAuthService) CookieAuth(ctx context.Context, input *CookieAuthInput) ( } // Determine scope and if this is a setup token - scope := fmt.Sprintf("%s %s", oauth.ScopeProfile, oauth.ScopeInference) + // Internal API call uses ScopeAPI (org:create_api_key not supported) + scope := oauth.ScopeAPI isSetupToken := false if input.Scope == "inference" { scope = oauth.ScopeInference diff --git a/backend/internal/service/openai_codex_transform.go b/backend/internal/service/openai_codex_transform.go index 264bdf95..48c72593 100644 --- a/backend/internal/service/openai_codex_transform.go +++ b/backend/internal/service/openai_codex_transform.go @@ -394,19 +394,35 @@ func normalizeCodexTools(reqBody map[string]any) bool { } modified := false - for idx, tool := range tools { + validTools := make([]any, 0, len(tools)) + + for _, tool := range tools { toolMap, ok := tool.(map[string]any) if !ok { + // Keep unknown structure as-is to avoid breaking upstream behavior. + validTools = append(validTools, tool) continue } toolType, _ := toolMap["type"].(string) - if strings.TrimSpace(toolType) != "function" { + toolType = strings.TrimSpace(toolType) + if toolType != "function" { + validTools = append(validTools, toolMap) continue } - function, ok := toolMap["function"].(map[string]any) - if !ok { + // OpenAI Responses-style tools use top-level name/parameters. + if name, ok := toolMap["name"].(string); ok && strings.TrimSpace(name) != "" { + validTools = append(validTools, toolMap) + continue + } + + // ChatCompletions-style tools use {type:"function", function:{...}}. + functionValue, hasFunction := toolMap["function"] + function, ok := functionValue.(map[string]any) + if !hasFunction || functionValue == nil || !ok || function == nil { + // Drop invalid function tools. + modified = true continue } @@ -435,11 +451,11 @@ func normalizeCodexTools(reqBody map[string]any) bool { } } - tools[idx] = toolMap + validTools = append(validTools, toolMap) } if modified { - reqBody["tools"] = tools + reqBody["tools"] = validTools } return modified diff --git a/backend/internal/service/openai_codex_transform_test.go b/backend/internal/service/openai_codex_transform_test.go index 0ff9485a..4cd72ab6 100644 --- a/backend/internal/service/openai_codex_transform_test.go +++ b/backend/internal/service/openai_codex_transform_test.go @@ -129,6 +129,37 @@ func TestFilterCodexInput_RemovesItemReferenceWhenNotPreserved(t *testing.T) { require.False(t, hasID) } +func TestApplyCodexOAuthTransform_NormalizeCodexTools_PreservesResponsesFunctionTools(t *testing.T) { + setupCodexCache(t) + + reqBody := map[string]any{ + "model": "gpt-5.1", + "tools": []any{ + map[string]any{ + "type": "function", + "name": "bash", + "description": "desc", + "parameters": map[string]any{"type": "object"}, + }, + map[string]any{ + "type": "function", + "function": nil, + }, + }, + } + + applyCodexOAuthTransform(reqBody) + + tools, ok := reqBody["tools"].([]any) + require.True(t, ok) + require.Len(t, tools, 1) + + first, ok := tools[0].(map[string]any) + require.True(t, ok) + require.Equal(t, "function", first["type"]) + require.Equal(t, "bash", first["name"]) +} + func TestApplyCodexOAuthTransform_EmptyInput(t *testing.T) { // 空 input 应保持为空且不触发异常。 setupCodexCache(t) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 66ac1601..ff731be5 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -133,12 +133,30 @@ func NewOpenAIGatewayService( } } -// GenerateSessionHash generates session hash from header (OpenAI uses session_id header) -func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string { - sessionID := c.GetHeader("session_id") +// GenerateSessionHash generates a sticky-session hash for OpenAI requests. +// +// Priority: +// 1. Header: session_id +// 2. Header: conversation_id +// 3. Body: prompt_cache_key (opencode) +func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, reqBody map[string]any) string { + if c == nil { + return "" + } + + sessionID := strings.TrimSpace(c.GetHeader("session_id")) + if sessionID == "" { + sessionID = strings.TrimSpace(c.GetHeader("conversation_id")) + } + if sessionID == "" && reqBody != nil { + if v, ok := reqBody["prompt_cache_key"].(string); ok { + sessionID = strings.TrimSpace(v) + } + } if sessionID == "" { return "" } + hash := sha256.Sum256([]byte(sessionID)) return hex.EncodeToString(hash[:]) } diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 57b73245..14dd7699 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -68,6 +68,49 @@ func (c stubConcurrencyCache) GetAccountsLoadBatch(ctx context.Context, accounts return out, nil } +func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/openai/v1/responses", nil) + + svc := &OpenAIGatewayService{} + + // 1) session_id header wins + c.Request.Header.Set("session_id", "sess-123") + c.Request.Header.Set("conversation_id", "conv-456") + h1 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"}) + if h1 == "" { + t.Fatalf("expected non-empty hash") + } + + // 2) conversation_id used when session_id absent + c.Request.Header.Del("session_id") + h2 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"}) + if h2 == "" { + t.Fatalf("expected non-empty hash") + } + if h1 == h2 { + t.Fatalf("expected different hashes for different keys") + } + + // 3) prompt_cache_key used when both headers absent + c.Request.Header.Del("conversation_id") + h3 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"}) + if h3 == "" { + t.Fatalf("expected non-empty hash") + } + if h2 == h3 { + t.Fatalf("expected different hashes for different keys") + } + + // 4) empty when no signals + h4 := svc.GenerateSessionHash(c, map[string]any{}) + if h4 != "" { + t.Fatalf("expected empty hash when no signals") + } +} + func TestOpenAISelectAccountWithLoadAwareness_FiltersUnschedulable(t *testing.T) { now := time.Now() resetAt := now.Add(10 * time.Minute) diff --git a/backend/internal/service/openai_tool_corrector.go b/backend/internal/service/openai_tool_corrector.go index 9c9eab84..f4719275 100644 --- a/backend/internal/service/openai_tool_corrector.go +++ b/backend/internal/service/openai_tool_corrector.go @@ -27,6 +27,11 @@ var codexToolNameMapping = map[string]string{ "executeBash": "bash", "exec_bash": "bash", "execBash": "bash", + + // Some clients output generic fetch names. + "fetch": "webfetch", + "web_fetch": "webfetch", + "webFetch": "webfetch", } // ToolCorrectionStats 记录工具修正的统计信息(导出用于 JSON 序列化) @@ -208,27 +213,67 @@ func (c *CodexToolCorrector) correctToolParameters(toolName string, functionCall // 根据工具名称应用特定的参数修正规则 switch toolName { case "bash": - // 移除 workdir 参数(OpenCode 不支持) - if _, exists := argsMap["workdir"]; exists { - delete(argsMap, "workdir") - corrected = true - log.Printf("[CodexToolCorrector] Removed 'workdir' parameter from bash tool") - } - if _, exists := argsMap["work_dir"]; exists { - delete(argsMap, "work_dir") - corrected = true - log.Printf("[CodexToolCorrector] Removed 'work_dir' parameter from bash tool") + // OpenCode bash 支持 workdir;有些来源会输出 work_dir。 + if _, hasWorkdir := argsMap["workdir"]; !hasWorkdir { + if workDir, exists := argsMap["work_dir"]; exists { + argsMap["workdir"] = workDir + delete(argsMap, "work_dir") + corrected = true + log.Printf("[CodexToolCorrector] Renamed 'work_dir' to 'workdir' in bash tool") + } + } else { + if _, exists := argsMap["work_dir"]; exists { + delete(argsMap, "work_dir") + corrected = true + log.Printf("[CodexToolCorrector] Removed duplicate 'work_dir' parameter from bash tool") + } } case "edit": - // OpenCode edit 使用 old_string/new_string,Codex 可能使用其他名称 - // 这里可以添加参数名称的映射逻辑 - if _, exists := argsMap["file_path"]; !exists { - if path, exists := argsMap["path"]; exists { - argsMap["file_path"] = path + // OpenCode edit 参数为 filePath/oldString/newString(camelCase)。 + if _, exists := argsMap["filePath"]; !exists { + if filePath, exists := argsMap["file_path"]; exists { + argsMap["filePath"] = filePath + delete(argsMap, "file_path") + corrected = true + log.Printf("[CodexToolCorrector] Renamed 'file_path' to 'filePath' in edit tool") + } else if filePath, exists := argsMap["path"]; exists { + argsMap["filePath"] = filePath delete(argsMap, "path") corrected = true - log.Printf("[CodexToolCorrector] Renamed 'path' to 'file_path' in edit tool") + log.Printf("[CodexToolCorrector] Renamed 'path' to 'filePath' in edit tool") + } else if filePath, exists := argsMap["file"]; exists { + argsMap["filePath"] = filePath + delete(argsMap, "file") + corrected = true + log.Printf("[CodexToolCorrector] Renamed 'file' to 'filePath' in edit tool") + } + } + + if _, exists := argsMap["oldString"]; !exists { + if oldString, exists := argsMap["old_string"]; exists { + argsMap["oldString"] = oldString + delete(argsMap, "old_string") + corrected = true + log.Printf("[CodexToolCorrector] Renamed 'old_string' to 'oldString' in edit tool") + } + } + + if _, exists := argsMap["newString"]; !exists { + if newString, exists := argsMap["new_string"]; exists { + argsMap["newString"] = newString + delete(argsMap, "new_string") + corrected = true + log.Printf("[CodexToolCorrector] Renamed 'new_string' to 'newString' in edit tool") + } + } + + if _, exists := argsMap["replaceAll"]; !exists { + if replaceAll, exists := argsMap["replace_all"]; exists { + argsMap["replaceAll"] = replaceAll + delete(argsMap, "replace_all") + corrected = true + log.Printf("[CodexToolCorrector] Renamed 'replace_all' to 'replaceAll' in edit tool") } } } diff --git a/backend/internal/service/openai_tool_corrector_test.go b/backend/internal/service/openai_tool_corrector_test.go index 3e885b4b..ff518ea6 100644 --- a/backend/internal/service/openai_tool_corrector_test.go +++ b/backend/internal/service/openai_tool_corrector_test.go @@ -416,22 +416,23 @@ func TestCorrectToolParameters(t *testing.T) { expected map[string]bool // key: 期待存在的参数, value: true表示应该存在 }{ { - name: "remove workdir from bash tool", + name: "rename work_dir to workdir in bash tool", input: `{ "tool_calls": [{ "function": { "name": "bash", - "arguments": "{\"command\":\"ls\",\"workdir\":\"/tmp\"}" + "arguments": "{\"command\":\"ls\",\"work_dir\":\"/tmp\"}" } }] }`, expected: map[string]bool{ - "command": true, - "workdir": false, + "command": true, + "workdir": true, + "work_dir": false, }, }, { - name: "rename path to file_path in edit tool", + name: "rename snake_case edit params to camelCase", input: `{ "tool_calls": [{ "function": { @@ -441,10 +442,12 @@ func TestCorrectToolParameters(t *testing.T) { }] }`, expected: map[string]bool{ - "file_path": true, + "filePath": true, "path": false, - "old_string": true, - "new_string": true, + "oldString": true, + "old_string": false, + "newString": true, + "new_string": false, }, }, } diff --git a/backend/internal/service/pricing_service.go b/backend/internal/service/pricing_service.go index 392fb65c..0ade72cd 100644 --- a/backend/internal/service/pricing_service.go +++ b/backend/internal/service/pricing_service.go @@ -531,8 +531,8 @@ func (s *PricingService) buildModelLookupCandidates(modelLower string) []string func normalizeModelNameForPricing(model string) string { // Common Gemini/VertexAI forms: // - models/gemini-2.0-flash-exp - // - publishers/google/models/gemini-1.5-pro - // - projects/.../locations/.../publishers/google/models/gemini-1.5-pro + // - publishers/google/models/gemini-2.5-pro + // - projects/.../locations/.../publishers/google/models/gemini-2.5-pro model = strings.TrimSpace(model) model = strings.TrimLeft(model, "/") model = strings.TrimPrefix(model, "models/") diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 47a04cf5..41bd253c 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -73,10 +73,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc return false } - tempMatched := false + // 先尝试临时不可调度规则(401除外) + // 如果匹配成功,直接返回,不执行后续禁用逻辑 if statusCode != 401 { - tempMatched = s.tryTempUnschedulable(ctx, account, statusCode, responseBody) + if s.tryTempUnschedulable(ctx, account, statusCode, responseBody) { + return true + } } + upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(responseBody)) upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg) if upstreamMsg != "" { @@ -84,6 +88,14 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc } switch statusCode { + case 400: + // 只有当错误信息包含 "organization has been disabled" 时才禁用 + if strings.Contains(strings.ToLower(upstreamMsg), "organization has been disabled") { + msg := "Organization disabled (400): " + upstreamMsg + s.handleAuthError(ctx, account, msg) + shouldDisable = true + } + // 其他 400 错误(如参数问题)不处理,不禁用账号 case 401: // 对所有 OAuth 账号在 401 错误时调用缓存失效并强制下次刷新 if account.Type == AccountTypeOAuth { @@ -148,9 +160,6 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc } } - if tempMatched { - return true - } return shouldDisable } @@ -190,7 +199,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, start := geminiDailyWindowStart(now) totals, ok := s.getGeminiUsageTotals(account.ID, start, now) if !ok { - stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil) + stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil) if err != nil { return true, err } @@ -237,7 +246,7 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, if limit > 0 { start := now.Truncate(time.Minute) - stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil) + stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID, 0, nil, nil) if err != nil { return true, err } diff --git a/backend/internal/service/session_limit_cache.go b/backend/internal/service/session_limit_cache.go index f6f0c26a..5482d610 100644 --- a/backend/internal/service/session_limit_cache.go +++ b/backend/internal/service/session_limit_cache.go @@ -38,8 +38,9 @@ type SessionLimitCache interface { GetActiveSessionCount(ctx context.Context, accountID int64) (int, error) // GetActiveSessionCountBatch 批量获取多个账号的活跃会话数 + // idleTimeouts: 每个账号的空闲超时时间配置,key 为 accountID;若为 nil 或某账号不在其中,则使用默认超时 // 返回 map[accountID]count,查询失败的账号不在 map 中 - GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) + GetActiveSessionCountBatch(ctx context.Context, accountIDs []int64, idleTimeouts map[int64]time.Duration) (map[int64]int, error) // IsSessionActive 检查特定会话是否活跃(未过期) IsSessionActive(ctx context.Context, accountID int64, sessionUUID string) (bool, error) diff --git a/backend/internal/service/setting_service.go b/backend/internal/service/setting_service.go index 0a7426f8..5ab73588 100644 --- a/backend/internal/service/setting_service.go +++ b/backend/internal/service/setting_service.go @@ -69,6 +69,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings SettingKeyContactInfo, SettingKeyDocURL, SettingKeyHomeContent, + SettingKeyHideCcsImportButton, SettingKeyLinuxDoConnectEnabled, } @@ -96,6 +97,7 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings ContactInfo: settings[SettingKeyContactInfo], DocURL: settings[SettingKeyDocURL], HomeContent: settings[SettingKeyHomeContent], + HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", LinuxDoOAuthEnabled: linuxDoEnabled, }, nil } @@ -132,6 +134,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any ContactInfo string `json:"contact_info,omitempty"` DocURL string `json:"doc_url,omitempty"` HomeContent string `json:"home_content,omitempty"` + HideCcsImportButton bool `json:"hide_ccs_import_button"` LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"` Version string `json:"version,omitempty"` }{ @@ -146,6 +149,7 @@ func (s *SettingService) GetPublicSettingsForInjection(ctx context.Context) (any ContactInfo: settings.ContactInfo, DocURL: settings.DocURL, HomeContent: settings.HomeContent, + HideCcsImportButton: settings.HideCcsImportButton, LinuxDoOAuthEnabled: settings.LinuxDoOAuthEnabled, Version: s.version, }, nil @@ -193,6 +197,7 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet updates[SettingKeyContactInfo] = settings.ContactInfo updates[SettingKeyDocURL] = settings.DocURL updates[SettingKeyHomeContent] = settings.HomeContent + updates[SettingKeyHideCcsImportButton] = strconv.FormatBool(settings.HideCcsImportButton) // 默认配置 updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency) @@ -339,6 +344,7 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin ContactInfo: settings[SettingKeyContactInfo], DocURL: settings[SettingKeyDocURL], HomeContent: settings[SettingKeyHomeContent], + HideCcsImportButton: settings[SettingKeyHideCcsImportButton] == "true", } // 解析整数类型 diff --git a/backend/internal/service/settings_view.go b/backend/internal/service/settings_view.go index e4ee2826..05494272 100644 --- a/backend/internal/service/settings_view.go +++ b/backend/internal/service/settings_view.go @@ -25,13 +25,14 @@ type SystemSettings struct { LinuxDoConnectClientSecretConfigured bool LinuxDoConnectRedirectURL string - SiteName string - SiteLogo string - SiteSubtitle string - APIBaseURL string - ContactInfo string - DocURL string - HomeContent string + SiteName string + SiteLogo string + SiteSubtitle string + APIBaseURL string + ContactInfo string + DocURL string + HomeContent string + HideCcsImportButton bool DefaultConcurrency int DefaultBalance float64 @@ -66,6 +67,7 @@ type PublicSettings struct { ContactInfo string DocURL string HomeContent string + HideCcsImportButton bool LinuxDoOAuthEnabled bool Version string } diff --git a/backend/internal/service/subscription_service.go b/backend/internal/service/subscription_service.go index d960c86f..c25c58a2 100644 --- a/backend/internal/service/subscription_service.go +++ b/backend/internal/service/subscription_service.go @@ -27,6 +27,7 @@ var ( ErrWeeklyLimitExceeded = infraerrors.TooManyRequests("WEEKLY_LIMIT_EXCEEDED", "weekly usage limit exceeded") ErrMonthlyLimitExceeded = infraerrors.TooManyRequests("MONTHLY_LIMIT_EXCEEDED", "monthly usage limit exceeded") ErrSubscriptionNilInput = infraerrors.BadRequest("SUBSCRIPTION_NIL_INPUT", "subscription input cannot be nil") + ErrAdjustWouldExpire = infraerrors.BadRequest("ADJUST_WOULD_EXPIRE", "adjustment would result in expired subscription (remaining days must be > 0)") ) // SubscriptionService 订阅服务 @@ -308,17 +309,20 @@ func (s *SubscriptionService) RevokeSubscription(ctx context.Context, subscripti return nil } -// ExtendSubscription 延长订阅 +// ExtendSubscription 调整订阅时长(正数延长,负数缩短) func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscriptionID int64, days int) (*UserSubscription, error) { sub, err := s.userSubRepo.GetByID(ctx, subscriptionID) if err != nil { return nil, ErrSubscriptionNotFound } - // 限制延长天数 + // 限制调整天数范围 if days > MaxValidityDays { days = MaxValidityDays } + if days < -MaxValidityDays { + days = -MaxValidityDays + } // 计算新的过期时间 newExpiresAt := sub.ExpiresAt.AddDate(0, 0, days) @@ -326,6 +330,14 @@ func (s *SubscriptionService) ExtendSubscription(ctx context.Context, subscripti newExpiresAt = MaxExpiresAt } + // 如果是缩短(负数),检查新的过期时间必须大于当前时间 + if days < 0 { + now := time.Now() + if !newExpiresAt.After(now) { + return nil, ErrAdjustWouldExpire + } + } + if err := s.userSubRepo.ExtendExpiry(ctx, subscriptionID, newExpiresAt); err != nil { return nil, err } diff --git a/backend/internal/service/token_refresh_service.go b/backend/internal/service/token_refresh_service.go index 26cfd97d..02e7d445 100644 --- a/backend/internal/service/token_refresh_service.go +++ b/backend/internal/service/token_refresh_service.go @@ -166,11 +166,25 @@ func (s *TokenRefreshService) refreshWithRetry(ctx context.Context, account *Acc for attempt := 1; attempt <= s.cfg.MaxRetries; attempt++ { newCredentials, err := refresher.Refresh(ctx, account) - if err == nil { - // 刷新成功,更新账号credentials + + // 如果有新凭证,先更新(即使有错误也要保存 token) + if newCredentials != nil { account.Credentials = newCredentials - if err := s.accountRepo.Update(ctx, account); err != nil { - return fmt.Errorf("failed to save credentials: %w", err) + if saveErr := s.accountRepo.Update(ctx, account); saveErr != nil { + return fmt.Errorf("failed to save credentials: %w", saveErr) + } + } + + if err == nil { + // Antigravity 账户:如果之前是因为缺少 project_id 而标记为 error,现在成功获取到了,清除错误状态 + if account.Platform == PlatformAntigravity && + account.Status == StatusError && + strings.Contains(account.ErrorMessage, "missing_project_id:") { + if clearErr := s.accountRepo.ClearError(ctx, account.ID); clearErr != nil { + log.Printf("[TokenRefresh] Failed to clear error status for account %d: %v", account.ID, clearErr) + } else { + log.Printf("[TokenRefresh] Account %d: cleared missing_project_id error", account.ID) + } } // 对所有 OAuth 账号调用缓存失效(InvalidateToken 内部根据平台判断是否需要处理) if s.cacheInvalidator != nil && account.Type == AccountTypeOAuth { @@ -230,6 +244,7 @@ func isNonRetryableRefreshError(err error) bool { "invalid_client", // 客户端配置错误 "unauthorized_client", // 客户端未授权 "access_denied", // 访问被拒绝 + "missing_project_id", // 缺少 project_id } for _, needle := range nonRetryable { if strings.Contains(msg, needle) { diff --git a/backend/internal/service/usage_cleanup.go b/backend/internal/service/usage_cleanup.go new file mode 100644 index 00000000..7e3ffbb9 --- /dev/null +++ b/backend/internal/service/usage_cleanup.go @@ -0,0 +1,74 @@ +package service + +import ( + "context" + "time" + + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" +) + +const ( + UsageCleanupStatusPending = "pending" + UsageCleanupStatusRunning = "running" + UsageCleanupStatusSucceeded = "succeeded" + UsageCleanupStatusFailed = "failed" + UsageCleanupStatusCanceled = "canceled" +) + +// UsageCleanupFilters 定义清理任务过滤条件 +// 时间范围为必填,其他字段可选 +// JSON 序列化用于存储任务参数 +// +// start_time/end_time 使用 RFC3339 时间格式 +// 以 UTC 或用户时区解析后的时间为准 +// +// 说明: +// - nil 表示未设置该过滤条件 +// - 过滤条件均为精确匹配 +type UsageCleanupFilters struct { + StartTime time.Time `json:"start_time"` + EndTime time.Time `json:"end_time"` + UserID *int64 `json:"user_id,omitempty"` + APIKeyID *int64 `json:"api_key_id,omitempty"` + AccountID *int64 `json:"account_id,omitempty"` + GroupID *int64 `json:"group_id,omitempty"` + Model *string `json:"model,omitempty"` + Stream *bool `json:"stream,omitempty"` + BillingType *int8 `json:"billing_type,omitempty"` +} + +// UsageCleanupTask 表示使用记录清理任务 +// 状态包含 pending/running/succeeded/failed/canceled +type UsageCleanupTask struct { + ID int64 + Status string + Filters UsageCleanupFilters + CreatedBy int64 + DeletedRows int64 + ErrorMsg *string + CanceledBy *int64 + CanceledAt *time.Time + StartedAt *time.Time + FinishedAt *time.Time + CreatedAt time.Time + UpdatedAt time.Time +} + +// UsageCleanupRepository 定义清理任务持久层接口 +type UsageCleanupRepository interface { + CreateTask(ctx context.Context, task *UsageCleanupTask) error + ListTasks(ctx context.Context, params pagination.PaginationParams) ([]UsageCleanupTask, *pagination.PaginationResult, error) + // ClaimNextPendingTask 抢占下一条可执行任务: + // - 优先 pending + // - 若 running 超过 staleRunningAfterSeconds(可能由于进程退出/崩溃/超时),允许重新抢占继续执行 + ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*UsageCleanupTask, error) + // GetTaskStatus 查询任务状态;若不存在返回 sql.ErrNoRows + GetTaskStatus(ctx context.Context, taskID int64) (string, error) + // UpdateTaskProgress 更新任务进度(deleted_rows)用于断点续跑/展示 + UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error + // CancelTask 将任务标记为 canceled(仅允许 pending/running) + CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error) + MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error + MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error + DeleteUsageLogsBatch(ctx context.Context, filters UsageCleanupFilters, limit int) (int64, error) +} diff --git a/backend/internal/service/usage_cleanup_service.go b/backend/internal/service/usage_cleanup_service.go new file mode 100644 index 00000000..37f6d375 --- /dev/null +++ b/backend/internal/service/usage_cleanup_service.go @@ -0,0 +1,404 @@ +package service + +import ( + "context" + "database/sql" + "errors" + "fmt" + "log" + "net/http" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" +) + +const ( + usageCleanupWorkerName = "usage_cleanup_worker" +) + +// UsageCleanupService 负责创建与执行使用记录清理任务 +type UsageCleanupService struct { + repo UsageCleanupRepository + timingWheel *TimingWheelService + dashboard *DashboardAggregationService + cfg *config.Config + + running int32 + startOnce sync.Once + stopOnce sync.Once + + workerCtx context.Context + workerCancel context.CancelFunc +} + +func NewUsageCleanupService(repo UsageCleanupRepository, timingWheel *TimingWheelService, dashboard *DashboardAggregationService, cfg *config.Config) *UsageCleanupService { + workerCtx, workerCancel := context.WithCancel(context.Background()) + return &UsageCleanupService{ + repo: repo, + timingWheel: timingWheel, + dashboard: dashboard, + cfg: cfg, + workerCtx: workerCtx, + workerCancel: workerCancel, + } +} + +func describeUsageCleanupFilters(filters UsageCleanupFilters) string { + var parts []string + parts = append(parts, "start="+filters.StartTime.UTC().Format(time.RFC3339)) + parts = append(parts, "end="+filters.EndTime.UTC().Format(time.RFC3339)) + if filters.UserID != nil { + parts = append(parts, fmt.Sprintf("user_id=%d", *filters.UserID)) + } + if filters.APIKeyID != nil { + parts = append(parts, fmt.Sprintf("api_key_id=%d", *filters.APIKeyID)) + } + if filters.AccountID != nil { + parts = append(parts, fmt.Sprintf("account_id=%d", *filters.AccountID)) + } + if filters.GroupID != nil { + parts = append(parts, fmt.Sprintf("group_id=%d", *filters.GroupID)) + } + if filters.Model != nil { + parts = append(parts, "model="+strings.TrimSpace(*filters.Model)) + } + if filters.Stream != nil { + parts = append(parts, fmt.Sprintf("stream=%t", *filters.Stream)) + } + if filters.BillingType != nil { + parts = append(parts, fmt.Sprintf("billing_type=%d", *filters.BillingType)) + } + return strings.Join(parts, " ") +} + +func (s *UsageCleanupService) Start() { + if s == nil { + return + } + if s.cfg != nil && !s.cfg.UsageCleanup.Enabled { + log.Printf("[UsageCleanup] not started (disabled)") + return + } + if s.repo == nil || s.timingWheel == nil { + log.Printf("[UsageCleanup] not started (missing deps)") + return + } + + interval := s.workerInterval() + s.startOnce.Do(func() { + s.timingWheel.ScheduleRecurring(usageCleanupWorkerName, interval, s.runOnce) + log.Printf("[UsageCleanup] started (interval=%s max_range_days=%d batch_size=%d task_timeout=%s)", interval, s.maxRangeDays(), s.batchSize(), s.taskTimeout()) + }) +} + +func (s *UsageCleanupService) Stop() { + if s == nil { + return + } + s.stopOnce.Do(func() { + if s.workerCancel != nil { + s.workerCancel() + } + if s.timingWheel != nil { + s.timingWheel.Cancel(usageCleanupWorkerName) + } + log.Printf("[UsageCleanup] stopped") + }) +} + +func (s *UsageCleanupService) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]UsageCleanupTask, *pagination.PaginationResult, error) { + if s == nil || s.repo == nil { + return nil, nil, fmt.Errorf("cleanup service not ready") + } + return s.repo.ListTasks(ctx, params) +} + +func (s *UsageCleanupService) CreateTask(ctx context.Context, filters UsageCleanupFilters, createdBy int64) (*UsageCleanupTask, error) { + if s == nil || s.repo == nil { + return nil, fmt.Errorf("cleanup service not ready") + } + if s.cfg != nil && !s.cfg.UsageCleanup.Enabled { + return nil, infraerrors.New(http.StatusServiceUnavailable, "USAGE_CLEANUP_DISABLED", "usage cleanup is disabled") + } + if createdBy <= 0 { + return nil, infraerrors.BadRequest("USAGE_CLEANUP_INVALID_CREATOR", "invalid creator") + } + + log.Printf("[UsageCleanup] create_task requested: operator=%d %s", createdBy, describeUsageCleanupFilters(filters)) + sanitizeUsageCleanupFilters(&filters) + if err := s.validateFilters(filters); err != nil { + log.Printf("[UsageCleanup] create_task rejected: operator=%d err=%v %s", createdBy, err, describeUsageCleanupFilters(filters)) + return nil, err + } + + task := &UsageCleanupTask{ + Status: UsageCleanupStatusPending, + Filters: filters, + CreatedBy: createdBy, + } + if err := s.repo.CreateTask(ctx, task); err != nil { + log.Printf("[UsageCleanup] create_task persist failed: operator=%d err=%v %s", createdBy, err, describeUsageCleanupFilters(filters)) + return nil, fmt.Errorf("create cleanup task: %w", err) + } + log.Printf("[UsageCleanup] create_task persisted: task=%d operator=%d status=%s deleted_rows=%d %s", task.ID, createdBy, task.Status, task.DeletedRows, describeUsageCleanupFilters(filters)) + go s.runOnce() + return task, nil +} + +func (s *UsageCleanupService) runOnce() { + svc := s + if svc == nil { + return + } + if !atomic.CompareAndSwapInt32(&svc.running, 0, 1) { + log.Printf("[UsageCleanup] run_once skipped: already_running=true") + return + } + defer atomic.StoreInt32(&svc.running, 0) + + parent := context.Background() + if svc.workerCtx != nil { + parent = svc.workerCtx + } + ctx, cancel := context.WithTimeout(parent, svc.taskTimeout()) + defer cancel() + + task, err := svc.repo.ClaimNextPendingTask(ctx, int64(svc.taskTimeout().Seconds())) + if err != nil { + log.Printf("[UsageCleanup] claim pending task failed: %v", err) + return + } + if task == nil { + log.Printf("[UsageCleanup] run_once done: no_task=true") + return + } + + log.Printf("[UsageCleanup] task claimed: task=%d status=%s created_by=%d deleted_rows=%d %s", task.ID, task.Status, task.CreatedBy, task.DeletedRows, describeUsageCleanupFilters(task.Filters)) + svc.executeTask(ctx, task) +} + +func (s *UsageCleanupService) executeTask(ctx context.Context, task *UsageCleanupTask) { + if task == nil { + return + } + + batchSize := s.batchSize() + deletedTotal := task.DeletedRows + start := time.Now() + log.Printf("[UsageCleanup] task started: task=%d batch_size=%d deleted_rows=%d %s", task.ID, batchSize, deletedTotal, describeUsageCleanupFilters(task.Filters)) + var batchNum int + + for { + if ctx != nil && ctx.Err() != nil { + log.Printf("[UsageCleanup] task interrupted: task=%d err=%v", task.ID, ctx.Err()) + return + } + canceled, err := s.isTaskCanceled(ctx, task.ID) + if err != nil { + s.markTaskFailed(task.ID, deletedTotal, err) + return + } + if canceled { + log.Printf("[UsageCleanup] task canceled: task=%d deleted_rows=%d duration=%s", task.ID, deletedTotal, time.Since(start)) + return + } + + batchNum++ + deleted, err := s.repo.DeleteUsageLogsBatch(ctx, task.Filters, batchSize) + if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + // 任务被中断(例如服务停止/超时),保持 running 状态,后续通过 stale reclaim 续跑。 + log.Printf("[UsageCleanup] task interrupted: task=%d err=%v", task.ID, err) + return + } + s.markTaskFailed(task.ID, deletedTotal, err) + return + } + deletedTotal += deleted + if deleted > 0 { + updateCtx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + if err := s.repo.UpdateTaskProgress(updateCtx, task.ID, deletedTotal); err != nil { + log.Printf("[UsageCleanup] task progress update failed: task=%d deleted_rows=%d err=%v", task.ID, deletedTotal, err) + } + cancel() + } + if batchNum <= 3 || batchNum%20 == 0 || deleted < int64(batchSize) { + log.Printf("[UsageCleanup] task batch done: task=%d batch=%d deleted=%d deleted_total=%d", task.ID, batchNum, deleted, deletedTotal) + } + if deleted == 0 || deleted < int64(batchSize) { + break + } + } + + updateCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := s.repo.MarkTaskSucceeded(updateCtx, task.ID, deletedTotal); err != nil { + log.Printf("[UsageCleanup] update task succeeded failed: task=%d err=%v", task.ID, err) + } else { + log.Printf("[UsageCleanup] task succeeded: task=%d deleted_rows=%d duration=%s", task.ID, deletedTotal, time.Since(start)) + } + + if s.dashboard != nil { + if err := s.dashboard.TriggerRecomputeRange(task.Filters.StartTime, task.Filters.EndTime); err != nil { + log.Printf("[UsageCleanup] trigger dashboard recompute failed: task=%d err=%v", task.ID, err) + } else { + log.Printf("[UsageCleanup] trigger dashboard recompute: task=%d start=%s end=%s", task.ID, task.Filters.StartTime.UTC().Format(time.RFC3339), task.Filters.EndTime.UTC().Format(time.RFC3339)) + } + } +} + +func (s *UsageCleanupService) markTaskFailed(taskID int64, deletedRows int64, err error) { + msg := strings.TrimSpace(err.Error()) + if len(msg) > 500 { + msg = msg[:500] + } + log.Printf("[UsageCleanup] task failed: task=%d deleted_rows=%d err=%s", taskID, deletedRows, msg) + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if updateErr := s.repo.MarkTaskFailed(ctx, taskID, deletedRows, msg); updateErr != nil { + log.Printf("[UsageCleanup] update task failed failed: task=%d err=%v", taskID, updateErr) + } +} + +func (s *UsageCleanupService) isTaskCanceled(ctx context.Context, taskID int64) (bool, error) { + if s == nil || s.repo == nil { + return false, fmt.Errorf("cleanup service not ready") + } + checkCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + status, err := s.repo.GetTaskStatus(checkCtx, taskID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return false, nil + } + return false, err + } + if status == UsageCleanupStatusCanceled { + log.Printf("[UsageCleanup] task cancel detected: task=%d", taskID) + } + return status == UsageCleanupStatusCanceled, nil +} + +func (s *UsageCleanupService) validateFilters(filters UsageCleanupFilters) error { + if filters.StartTime.IsZero() || filters.EndTime.IsZero() { + return infraerrors.BadRequest("USAGE_CLEANUP_MISSING_RANGE", "start_date and end_date are required") + } + if filters.EndTime.Before(filters.StartTime) { + return infraerrors.BadRequest("USAGE_CLEANUP_INVALID_RANGE", "end_date must be after start_date") + } + maxDays := s.maxRangeDays() + if maxDays > 0 { + delta := filters.EndTime.Sub(filters.StartTime) + if delta > time.Duration(maxDays)*24*time.Hour { + return infraerrors.BadRequest("USAGE_CLEANUP_RANGE_TOO_LARGE", fmt.Sprintf("date range exceeds %d days", maxDays)) + } + } + return nil +} + +func (s *UsageCleanupService) CancelTask(ctx context.Context, taskID int64, canceledBy int64) error { + if s == nil || s.repo == nil { + return fmt.Errorf("cleanup service not ready") + } + if s.cfg != nil && !s.cfg.UsageCleanup.Enabled { + return infraerrors.New(http.StatusServiceUnavailable, "USAGE_CLEANUP_DISABLED", "usage cleanup is disabled") + } + if canceledBy <= 0 { + return infraerrors.BadRequest("USAGE_CLEANUP_INVALID_CANCELLER", "invalid canceller") + } + status, err := s.repo.GetTaskStatus(ctx, taskID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return infraerrors.New(http.StatusNotFound, "USAGE_CLEANUP_TASK_NOT_FOUND", "cleanup task not found") + } + return err + } + log.Printf("[UsageCleanup] cancel_task requested: task=%d operator=%d status=%s", taskID, canceledBy, status) + if status != UsageCleanupStatusPending && status != UsageCleanupStatusRunning { + return infraerrors.New(http.StatusConflict, "USAGE_CLEANUP_CANCEL_CONFLICT", "cleanup task cannot be canceled in current status") + } + ok, err := s.repo.CancelTask(ctx, taskID, canceledBy) + if err != nil { + return err + } + if !ok { + // 状态可能并发改变 + return infraerrors.New(http.StatusConflict, "USAGE_CLEANUP_CANCEL_CONFLICT", "cleanup task cannot be canceled in current status") + } + log.Printf("[UsageCleanup] cancel_task done: task=%d operator=%d", taskID, canceledBy) + return nil +} + +func sanitizeUsageCleanupFilters(filters *UsageCleanupFilters) { + if filters == nil { + return + } + if filters.UserID != nil && *filters.UserID <= 0 { + filters.UserID = nil + } + if filters.APIKeyID != nil && *filters.APIKeyID <= 0 { + filters.APIKeyID = nil + } + if filters.AccountID != nil && *filters.AccountID <= 0 { + filters.AccountID = nil + } + if filters.GroupID != nil && *filters.GroupID <= 0 { + filters.GroupID = nil + } + if filters.Model != nil { + model := strings.TrimSpace(*filters.Model) + if model == "" { + filters.Model = nil + } else { + filters.Model = &model + } + } + if filters.BillingType != nil && *filters.BillingType < 0 { + filters.BillingType = nil + } +} + +func (s *UsageCleanupService) maxRangeDays() int { + if s == nil || s.cfg == nil { + return 31 + } + if s.cfg.UsageCleanup.MaxRangeDays > 0 { + return s.cfg.UsageCleanup.MaxRangeDays + } + return 31 +} + +func (s *UsageCleanupService) batchSize() int { + if s == nil || s.cfg == nil { + return 5000 + } + if s.cfg.UsageCleanup.BatchSize > 0 { + return s.cfg.UsageCleanup.BatchSize + } + return 5000 +} + +func (s *UsageCleanupService) workerInterval() time.Duration { + if s == nil || s.cfg == nil { + return 10 * time.Second + } + if s.cfg.UsageCleanup.WorkerIntervalSeconds > 0 { + return time.Duration(s.cfg.UsageCleanup.WorkerIntervalSeconds) * time.Second + } + return 10 * time.Second +} + +func (s *UsageCleanupService) taskTimeout() time.Duration { + if s == nil || s.cfg == nil { + return 30 * time.Minute + } + if s.cfg.UsageCleanup.TaskTimeoutSeconds > 0 { + return time.Duration(s.cfg.UsageCleanup.TaskTimeoutSeconds) * time.Second + } + return 30 * time.Minute +} diff --git a/backend/internal/service/usage_cleanup_service_test.go b/backend/internal/service/usage_cleanup_service_test.go new file mode 100644 index 00000000..05c423bc --- /dev/null +++ b/backend/internal/service/usage_cleanup_service_test.go @@ -0,0 +1,815 @@ +package service + +import ( + "context" + "database/sql" + "errors" + "net/http" + "strings" + "sync" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors" + "github.com/Wei-Shaw/sub2api/internal/pkg/pagination" + "github.com/stretchr/testify/require" +) + +type cleanupDeleteResponse struct { + deleted int64 + err error +} + +type cleanupDeleteCall struct { + filters UsageCleanupFilters + limit int +} + +type cleanupMarkCall struct { + taskID int64 + deletedRows int64 + errMsg string +} + +type cleanupRepoStub struct { + mu sync.Mutex + created []*UsageCleanupTask + createErr error + listTasks []UsageCleanupTask + listResult *pagination.PaginationResult + listErr error + claimQueue []*UsageCleanupTask + claimErr error + deleteQueue []cleanupDeleteResponse + deleteCalls []cleanupDeleteCall + markSucceeded []cleanupMarkCall + markFailed []cleanupMarkCall + statusByID map[int64]string + statusErr error + progressCalls []cleanupMarkCall + updateErr error + cancelCalls []int64 + cancelErr error + cancelResult *bool + markFailedErr error +} + +type dashboardRepoStub struct { + recomputeErr error +} + +func (s *dashboardRepoStub) AggregateRange(ctx context.Context, start, end time.Time) error { + return nil +} + +func (s *dashboardRepoStub) RecomputeRange(ctx context.Context, start, end time.Time) error { + return s.recomputeErr +} + +func (s *dashboardRepoStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) { + return time.Time{}, nil +} + +func (s *dashboardRepoStub) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error { + return nil +} + +func (s *dashboardRepoStub) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error { + return nil +} + +func (s *dashboardRepoStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error { + return nil +} + +func (s *dashboardRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error { + return nil +} + +func (s *cleanupRepoStub) CreateTask(ctx context.Context, task *UsageCleanupTask) error { + if task == nil { + return nil + } + s.mu.Lock() + defer s.mu.Unlock() + if s.createErr != nil { + return s.createErr + } + if task.ID == 0 { + task.ID = int64(len(s.created) + 1) + } + if task.CreatedAt.IsZero() { + task.CreatedAt = time.Now().UTC() + } + if task.UpdatedAt.IsZero() { + task.UpdatedAt = task.CreatedAt + } + clone := *task + s.created = append(s.created, &clone) + return nil +} + +func (s *cleanupRepoStub) ListTasks(ctx context.Context, params pagination.PaginationParams) ([]UsageCleanupTask, *pagination.PaginationResult, error) { + s.mu.Lock() + defer s.mu.Unlock() + return s.listTasks, s.listResult, s.listErr +} + +func (s *cleanupRepoStub) ClaimNextPendingTask(ctx context.Context, staleRunningAfterSeconds int64) (*UsageCleanupTask, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.claimErr != nil { + return nil, s.claimErr + } + if len(s.claimQueue) == 0 { + return nil, nil + } + task := s.claimQueue[0] + s.claimQueue = s.claimQueue[1:] + if s.statusByID == nil { + s.statusByID = map[int64]string{} + } + s.statusByID[task.ID] = UsageCleanupStatusRunning + return task, nil +} + +func (s *cleanupRepoStub) GetTaskStatus(ctx context.Context, taskID int64) (string, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.statusErr != nil { + return "", s.statusErr + } + if s.statusByID == nil { + return "", sql.ErrNoRows + } + status, ok := s.statusByID[taskID] + if !ok { + return "", sql.ErrNoRows + } + return status, nil +} + +func (s *cleanupRepoStub) UpdateTaskProgress(ctx context.Context, taskID int64, deletedRows int64) error { + s.mu.Lock() + defer s.mu.Unlock() + s.progressCalls = append(s.progressCalls, cleanupMarkCall{taskID: taskID, deletedRows: deletedRows}) + if s.updateErr != nil { + return s.updateErr + } + return nil +} + +func (s *cleanupRepoStub) CancelTask(ctx context.Context, taskID int64, canceledBy int64) (bool, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.cancelCalls = append(s.cancelCalls, taskID) + if s.cancelErr != nil { + return false, s.cancelErr + } + if s.cancelResult != nil { + ok := *s.cancelResult + if ok { + if s.statusByID == nil { + s.statusByID = map[int64]string{} + } + s.statusByID[taskID] = UsageCleanupStatusCanceled + } + return ok, nil + } + if s.statusByID == nil { + s.statusByID = map[int64]string{} + } + status := s.statusByID[taskID] + if status != UsageCleanupStatusPending && status != UsageCleanupStatusRunning { + return false, nil + } + s.statusByID[taskID] = UsageCleanupStatusCanceled + return true, nil +} + +func (s *cleanupRepoStub) MarkTaskSucceeded(ctx context.Context, taskID int64, deletedRows int64) error { + s.mu.Lock() + defer s.mu.Unlock() + s.markSucceeded = append(s.markSucceeded, cleanupMarkCall{taskID: taskID, deletedRows: deletedRows}) + if s.statusByID == nil { + s.statusByID = map[int64]string{} + } + s.statusByID[taskID] = UsageCleanupStatusSucceeded + return nil +} + +func (s *cleanupRepoStub) MarkTaskFailed(ctx context.Context, taskID int64, deletedRows int64, errorMsg string) error { + s.mu.Lock() + defer s.mu.Unlock() + s.markFailed = append(s.markFailed, cleanupMarkCall{taskID: taskID, deletedRows: deletedRows, errMsg: errorMsg}) + if s.statusByID == nil { + s.statusByID = map[int64]string{} + } + s.statusByID[taskID] = UsageCleanupStatusFailed + if s.markFailedErr != nil { + return s.markFailedErr + } + return nil +} + +func (s *cleanupRepoStub) DeleteUsageLogsBatch(ctx context.Context, filters UsageCleanupFilters, limit int) (int64, error) { + s.mu.Lock() + defer s.mu.Unlock() + s.deleteCalls = append(s.deleteCalls, cleanupDeleteCall{filters: filters, limit: limit}) + if len(s.deleteQueue) == 0 { + return 0, nil + } + resp := s.deleteQueue[0] + s.deleteQueue = s.deleteQueue[1:] + return resp.deleted, resp.err +} + +func TestUsageCleanupServiceCreateTaskSanitizeFilters(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 31}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + + start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(24 * time.Hour) + userID := int64(-1) + apiKeyID := int64(10) + model := " gpt-4 " + billingType := int8(-2) + filters := UsageCleanupFilters{ + StartTime: start, + EndTime: end, + UserID: &userID, + APIKeyID: &apiKeyID, + Model: &model, + BillingType: &billingType, + } + + task, err := svc.CreateTask(context.Background(), filters, 9) + require.NoError(t, err) + require.Equal(t, UsageCleanupStatusPending, task.Status) + require.Nil(t, task.Filters.UserID) + require.NotNil(t, task.Filters.APIKeyID) + require.Equal(t, apiKeyID, *task.Filters.APIKeyID) + require.NotNil(t, task.Filters.Model) + require.Equal(t, "gpt-4", *task.Filters.Model) + require.Nil(t, task.Filters.BillingType) + require.Equal(t, int64(9), task.CreatedBy) +} + +func TestUsageCleanupServiceCreateTaskInvalidCreator(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + + filters := UsageCleanupFilters{ + StartTime: time.Now(), + EndTime: time.Now().Add(24 * time.Hour), + } + _, err := svc.CreateTask(context.Background(), filters, 0) + require.Error(t, err) + require.Equal(t, "USAGE_CLEANUP_INVALID_CREATOR", infraerrors.Reason(err)) +} + +func TestUsageCleanupServiceCreateTaskDisabled(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: false}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + + filters := UsageCleanupFilters{ + StartTime: time.Now(), + EndTime: time.Now().Add(24 * time.Hour), + } + _, err := svc.CreateTask(context.Background(), filters, 1) + require.Error(t, err) + require.Equal(t, http.StatusServiceUnavailable, infraerrors.Code(err)) + require.Equal(t, "USAGE_CLEANUP_DISABLED", infraerrors.Reason(err)) +} + +func TestUsageCleanupServiceCreateTaskRangeTooLarge(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, MaxRangeDays: 1}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + + start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(48 * time.Hour) + filters := UsageCleanupFilters{StartTime: start, EndTime: end} + + _, err := svc.CreateTask(context.Background(), filters, 1) + require.Error(t, err) + require.Equal(t, "USAGE_CLEANUP_RANGE_TOO_LARGE", infraerrors.Reason(err)) +} + +func TestUsageCleanupServiceCreateTaskMissingRange(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + + _, err := svc.CreateTask(context.Background(), UsageCleanupFilters{}, 1) + require.Error(t, err) + require.Equal(t, "USAGE_CLEANUP_MISSING_RANGE", infraerrors.Reason(err)) +} + +func TestUsageCleanupServiceCreateTaskRepoError(t *testing.T) { + repo := &cleanupRepoStub{createErr: errors.New("db down")} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + + filters := UsageCleanupFilters{ + StartTime: time.Now(), + EndTime: time.Now().Add(24 * time.Hour), + } + _, err := svc.CreateTask(context.Background(), filters, 1) + require.Error(t, err) + require.Contains(t, err.Error(), "create cleanup task") +} + +func TestUsageCleanupServiceRunOnceSuccess(t *testing.T) { + start := time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) + end := start.Add(2 * time.Hour) + repo := &cleanupRepoStub{ + claimQueue: []*UsageCleanupTask{ + {ID: 5, Filters: UsageCleanupFilters{StartTime: start, EndTime: end}}, + }, + deleteQueue: []cleanupDeleteResponse{ + {deleted: 2}, + {deleted: 2}, + {deleted: 1}, + }, + } + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2, TaskTimeoutSeconds: 30}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + + svc.runOnce() + + repo.mu.Lock() + defer repo.mu.Unlock() + require.Len(t, repo.deleteCalls, 3) + require.Len(t, repo.markSucceeded, 1) + require.Empty(t, repo.markFailed) + require.Equal(t, int64(5), repo.markSucceeded[0].taskID) + require.Equal(t, int64(5), repo.markSucceeded[0].deletedRows) + require.Equal(t, 2, repo.deleteCalls[0].limit) + require.Equal(t, start, repo.deleteCalls[0].filters.StartTime) + require.Equal(t, end, repo.deleteCalls[0].filters.EndTime) +} + +func TestUsageCleanupServiceRunOnceClaimError(t *testing.T) { + repo := &cleanupRepoStub{claimErr: errors.New("claim failed")} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + svc.runOnce() + + repo.mu.Lock() + defer repo.mu.Unlock() + require.Empty(t, repo.markSucceeded) + require.Empty(t, repo.markFailed) +} + +func TestUsageCleanupServiceRunOnceAlreadyRunning(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + svc.running = 1 + svc.runOnce() +} + +func TestUsageCleanupServiceExecuteTaskFailed(t *testing.T) { + longMsg := strings.Repeat("x", 600) + repo := &cleanupRepoStub{ + deleteQueue: []cleanupDeleteResponse{ + {err: errors.New(longMsg)}, + }, + } + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 3}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + task := &UsageCleanupTask{ + ID: 11, + Filters: UsageCleanupFilters{ + StartTime: time.Now(), + EndTime: time.Now().Add(24 * time.Hour), + }, + } + + svc.executeTask(context.Background(), task) + + repo.mu.Lock() + defer repo.mu.Unlock() + require.Len(t, repo.markFailed, 1) + require.Equal(t, int64(11), repo.markFailed[0].taskID) + require.Equal(t, 500, len(repo.markFailed[0].errMsg)) +} + +func TestUsageCleanupServiceExecuteTaskProgressError(t *testing.T) { + repo := &cleanupRepoStub{ + deleteQueue: []cleanupDeleteResponse{ + {deleted: 2}, + {deleted: 0}, + }, + updateErr: errors.New("update failed"), + } + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + task := &UsageCleanupTask{ + ID: 8, + Filters: UsageCleanupFilters{ + StartTime: time.Now().UTC(), + EndTime: time.Now().UTC().Add(time.Hour), + }, + } + + svc.executeTask(context.Background(), task) + + repo.mu.Lock() + defer repo.mu.Unlock() + require.Len(t, repo.markSucceeded, 1) + require.Empty(t, repo.markFailed) + require.Len(t, repo.progressCalls, 1) +} + +func TestUsageCleanupServiceExecuteTaskDeleteCanceled(t *testing.T) { + repo := &cleanupRepoStub{ + deleteQueue: []cleanupDeleteResponse{ + {err: context.Canceled}, + }, + } + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + task := &UsageCleanupTask{ + ID: 12, + Filters: UsageCleanupFilters{ + StartTime: time.Now().UTC(), + EndTime: time.Now().UTC().Add(time.Hour), + }, + } + + svc.executeTask(context.Background(), task) + + repo.mu.Lock() + defer repo.mu.Unlock() + require.Empty(t, repo.markSucceeded) + require.Empty(t, repo.markFailed) +} + +func TestUsageCleanupServiceExecuteTaskContextCanceled(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + task := &UsageCleanupTask{ + ID: 9, + Filters: UsageCleanupFilters{ + StartTime: time.Now().UTC(), + EndTime: time.Now().UTC().Add(time.Hour), + }, + } + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + svc.executeTask(ctx, task) + + repo.mu.Lock() + defer repo.mu.Unlock() + require.Empty(t, repo.markSucceeded) + require.Empty(t, repo.markFailed) + require.Empty(t, repo.deleteCalls) +} + +func TestUsageCleanupServiceExecuteTaskMarkFailedUpdateError(t *testing.T) { + repo := &cleanupRepoStub{ + deleteQueue: []cleanupDeleteResponse{ + {err: errors.New("boom")}, + }, + markFailedErr: errors.New("update failed"), + } + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + task := &UsageCleanupTask{ + ID: 13, + Filters: UsageCleanupFilters{ + StartTime: time.Now().UTC(), + EndTime: time.Now().UTC().Add(time.Hour), + }, + } + + svc.executeTask(context.Background(), task) + + repo.mu.Lock() + defer repo.mu.Unlock() + require.Len(t, repo.markFailed, 1) + require.Equal(t, int64(13), repo.markFailed[0].taskID) +} + +func TestUsageCleanupServiceExecuteTaskDashboardRecomputeError(t *testing.T) { + repo := &cleanupRepoStub{ + deleteQueue: []cleanupDeleteResponse{ + {deleted: 0}, + }, + } + dashboard := NewDashboardAggregationService(&dashboardRepoStub{}, nil, &config.Config{ + DashboardAgg: config.DashboardAggregationConfig{Enabled: false}, + }) + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}} + svc := NewUsageCleanupService(repo, nil, dashboard, cfg) + task := &UsageCleanupTask{ + ID: 14, + Filters: UsageCleanupFilters{ + StartTime: time.Now().UTC(), + EndTime: time.Now().UTC().Add(time.Hour), + }, + } + + svc.executeTask(context.Background(), task) + + repo.mu.Lock() + defer repo.mu.Unlock() + require.Len(t, repo.markSucceeded, 1) +} + +func TestUsageCleanupServiceExecuteTaskDashboardRecomputeSuccess(t *testing.T) { + repo := &cleanupRepoStub{ + deleteQueue: []cleanupDeleteResponse{ + {deleted: 0}, + }, + } + dashboard := NewDashboardAggregationService(&dashboardRepoStub{}, nil, &config.Config{ + DashboardAgg: config.DashboardAggregationConfig{Enabled: true}, + }) + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}} + svc := NewUsageCleanupService(repo, nil, dashboard, cfg) + task := &UsageCleanupTask{ + ID: 15, + Filters: UsageCleanupFilters{ + StartTime: time.Now().UTC(), + EndTime: time.Now().UTC().Add(time.Hour), + }, + } + + svc.executeTask(context.Background(), task) + + repo.mu.Lock() + defer repo.mu.Unlock() + require.Len(t, repo.markSucceeded, 1) +} + +func TestUsageCleanupServiceExecuteTaskCanceled(t *testing.T) { + repo := &cleanupRepoStub{ + statusByID: map[int64]string{ + 3: UsageCleanupStatusCanceled, + }, + } + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, BatchSize: 2}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + task := &UsageCleanupTask{ + ID: 3, + Filters: UsageCleanupFilters{ + StartTime: time.Now().UTC(), + EndTime: time.Now().UTC().Add(time.Hour), + }, + } + + svc.executeTask(context.Background(), task) + + repo.mu.Lock() + defer repo.mu.Unlock() + require.Empty(t, repo.deleteCalls) + require.Empty(t, repo.markSucceeded) + require.Empty(t, repo.markFailed) +} + +func TestUsageCleanupServiceCancelTaskSuccess(t *testing.T) { + repo := &cleanupRepoStub{ + statusByID: map[int64]string{ + 5: UsageCleanupStatusPending, + }, + } + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + + err := svc.CancelTask(context.Background(), 5, 9) + require.NoError(t, err) + + repo.mu.Lock() + defer repo.mu.Unlock() + require.Equal(t, UsageCleanupStatusCanceled, repo.statusByID[5]) + require.Len(t, repo.cancelCalls, 1) +} + +func TestUsageCleanupServiceCancelTaskDisabled(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: false}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + + err := svc.CancelTask(context.Background(), 1, 2) + require.Error(t, err) + require.Equal(t, http.StatusServiceUnavailable, infraerrors.Code(err)) + require.Equal(t, "USAGE_CLEANUP_DISABLED", infraerrors.Reason(err)) +} + +func TestUsageCleanupServiceCancelTaskNotFound(t *testing.T) { + repo := &cleanupRepoStub{} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + + err := svc.CancelTask(context.Background(), 999, 1) + require.Error(t, err) + require.Equal(t, http.StatusNotFound, infraerrors.Code(err)) + require.Equal(t, "USAGE_CLEANUP_TASK_NOT_FOUND", infraerrors.Reason(err)) +} + +func TestUsageCleanupServiceCancelTaskStatusError(t *testing.T) { + repo := &cleanupRepoStub{statusErr: errors.New("status broken")} + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + + err := svc.CancelTask(context.Background(), 7, 1) + require.Error(t, err) + require.Contains(t, err.Error(), "status broken") +} + +func TestUsageCleanupServiceCancelTaskConflict(t *testing.T) { + repo := &cleanupRepoStub{ + statusByID: map[int64]string{ + 7: UsageCleanupStatusSucceeded, + }, + } + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + + err := svc.CancelTask(context.Background(), 7, 1) + require.Error(t, err) + require.Equal(t, http.StatusConflict, infraerrors.Code(err)) + require.Equal(t, "USAGE_CLEANUP_CANCEL_CONFLICT", infraerrors.Reason(err)) +} + +func TestUsageCleanupServiceCancelTaskRepoConflict(t *testing.T) { + shouldCancel := false + repo := &cleanupRepoStub{ + statusByID: map[int64]string{ + 7: UsageCleanupStatusPending, + }, + cancelResult: &shouldCancel, + } + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + + err := svc.CancelTask(context.Background(), 7, 1) + require.Error(t, err) + require.Equal(t, http.StatusConflict, infraerrors.Code(err)) + require.Equal(t, "USAGE_CLEANUP_CANCEL_CONFLICT", infraerrors.Reason(err)) +} + +func TestUsageCleanupServiceCancelTaskRepoError(t *testing.T) { + repo := &cleanupRepoStub{ + statusByID: map[int64]string{ + 7: UsageCleanupStatusPending, + }, + cancelErr: errors.New("cancel failed"), + } + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + + err := svc.CancelTask(context.Background(), 7, 1) + require.Error(t, err) + require.Contains(t, err.Error(), "cancel failed") +} + +func TestUsageCleanupServiceCancelTaskInvalidCanceller(t *testing.T) { + repo := &cleanupRepoStub{ + statusByID: map[int64]string{ + 7: UsageCleanupStatusRunning, + }, + } + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}} + svc := NewUsageCleanupService(repo, nil, nil, cfg) + + err := svc.CancelTask(context.Background(), 7, 0) + require.Error(t, err) + require.Equal(t, "USAGE_CLEANUP_INVALID_CANCELLER", infraerrors.Reason(err)) +} + +func TestUsageCleanupServiceListTasks(t *testing.T) { + repo := &cleanupRepoStub{ + listTasks: []UsageCleanupTask{{ID: 1}, {ID: 2}}, + listResult: &pagination.PaginationResult{ + Total: 2, + Page: 1, + PageSize: 20, + Pages: 1, + }, + } + svc := NewUsageCleanupService(repo, nil, nil, &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}) + + tasks, result, err := svc.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20}) + require.NoError(t, err) + require.Len(t, tasks, 2) + require.Equal(t, int64(2), result.Total) +} + +func TestUsageCleanupServiceListTasksNotReady(t *testing.T) { + var nilSvc *UsageCleanupService + _, _, err := nilSvc.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20}) + require.Error(t, err) + + svc := NewUsageCleanupService(nil, nil, nil, &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}) + _, _, err = svc.ListTasks(context.Background(), pagination.PaginationParams{Page: 1, PageSize: 20}) + require.Error(t, err) +} + +func TestUsageCleanupServiceDefaultsAndLifecycle(t *testing.T) { + var nilSvc *UsageCleanupService + require.Equal(t, 31, nilSvc.maxRangeDays()) + require.Equal(t, 5000, nilSvc.batchSize()) + require.Equal(t, 10*time.Second, nilSvc.workerInterval()) + require.Equal(t, 30*time.Minute, nilSvc.taskTimeout()) + nilSvc.Start() + nilSvc.Stop() + + repo := &cleanupRepoStub{} + cfgDisabled := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: false}} + svcDisabled := NewUsageCleanupService(repo, nil, nil, cfgDisabled) + svcDisabled.Start() + svcDisabled.Stop() + + timingWheel, err := NewTimingWheelService() + require.NoError(t, err) + + cfg := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true, WorkerIntervalSeconds: 5}} + svc := NewUsageCleanupService(repo, timingWheel, nil, cfg) + require.Equal(t, 5*time.Second, svc.workerInterval()) + svc.Start() + svc.Stop() + + cfgFallback := &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}} + svcFallback := NewUsageCleanupService(repo, timingWheel, nil, cfgFallback) + require.Equal(t, 31, svcFallback.maxRangeDays()) + require.Equal(t, 5000, svcFallback.batchSize()) + require.Equal(t, 10*time.Second, svcFallback.workerInterval()) + + svcMissingDeps := NewUsageCleanupService(nil, nil, nil, cfgFallback) + svcMissingDeps.Start() +} + +func TestSanitizeUsageCleanupFiltersModelEmpty(t *testing.T) { + model := " " + apiKeyID := int64(-5) + accountID := int64(-1) + groupID := int64(-2) + filters := UsageCleanupFilters{ + UserID: &apiKeyID, + APIKeyID: &apiKeyID, + AccountID: &accountID, + GroupID: &groupID, + Model: &model, + } + + sanitizeUsageCleanupFilters(&filters) + require.Nil(t, filters.UserID) + require.Nil(t, filters.APIKeyID) + require.Nil(t, filters.AccountID) + require.Nil(t, filters.GroupID) + require.Nil(t, filters.Model) +} + +func TestDescribeUsageCleanupFiltersAllFields(t *testing.T) { + start := time.Date(2024, 2, 1, 10, 0, 0, 0, time.UTC) + end := start.Add(2 * time.Hour) + userID := int64(1) + apiKeyID := int64(2) + accountID := int64(3) + groupID := int64(4) + model := " gpt-4 " + stream := true + billingType := int8(2) + filters := UsageCleanupFilters{ + StartTime: start, + EndTime: end, + UserID: &userID, + APIKeyID: &apiKeyID, + AccountID: &accountID, + GroupID: &groupID, + Model: &model, + Stream: &stream, + BillingType: &billingType, + } + + desc := describeUsageCleanupFilters(filters) + require.Equal(t, "start=2024-02-01T10:00:00Z end=2024-02-01T12:00:00Z user_id=1 api_key_id=2 account_id=3 group_id=4 model=gpt-4 stream=true billing_type=2", desc) +} + +func TestUsageCleanupServiceIsTaskCanceledNotFound(t *testing.T) { + repo := &cleanupRepoStub{} + svc := NewUsageCleanupService(repo, nil, nil, &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}) + + canceled, err := svc.isTaskCanceled(context.Background(), 9) + require.NoError(t, err) + require.False(t, canceled) +} + +func TestUsageCleanupServiceIsTaskCanceledError(t *testing.T) { + repo := &cleanupRepoStub{statusErr: errors.New("status err")} + svc := NewUsageCleanupService(repo, nil, nil, &config.Config{UsageCleanup: config.UsageCleanupConfig{Enabled: true}}) + + _, err := svc.isTaskCanceled(context.Background(), 9) + require.Error(t, err) + require.Contains(t, err.Error(), "status err") +} diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index acc0a5fb..b210286d 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -1,6 +1,7 @@ package service import ( + "context" "database/sql" "time" @@ -57,6 +58,13 @@ func ProvideDashboardAggregationService(repo DashboardAggregationRepository, tim return svc } +// ProvideUsageCleanupService 创建并启动使用记录清理任务服务 +func ProvideUsageCleanupService(repo UsageCleanupRepository, timingWheel *TimingWheelService, dashboardAgg *DashboardAggregationService, cfg *config.Config) *UsageCleanupService { + svc := NewUsageCleanupService(repo, timingWheel, dashboardAgg, cfg) + svc.Start() + return svc +} + // ProvideAccountExpiryService creates and starts AccountExpiryService. func ProvideAccountExpiryService(accountRepo AccountRepository) *AccountExpiryService { svc := NewAccountExpiryService(accountRepo, time.Minute) @@ -189,6 +197,8 @@ func ProvideOpsScheduledReportService( // ProvideAPIKeyAuthCacheInvalidator 提供 API Key 认证缓存失效能力 func ProvideAPIKeyAuthCacheInvalidator(apiKeyService *APIKeyService) APIKeyAuthCacheInvalidator { + // Start Pub/Sub subscriber for L1 cache invalidation across instances + apiKeyService.StartAuthCacheInvalidationSubscriber(context.Background()) return apiKeyService } @@ -248,6 +258,7 @@ var ProviderSet = wire.NewSet( ProvideAccountExpiryService, ProvideTimingWheelService, ProvideDashboardAggregationService, + ProvideUsageCleanupService, ProvideDeferredService, NewAntigravityQuotaFetcher, NewUserAttributeService, diff --git a/backend/migrations/006_add_users_allowed_groups_compat.sql b/backend/migrations/006_add_users_allowed_groups_compat.sql new file mode 100644 index 00000000..262945d4 --- /dev/null +++ b/backend/migrations/006_add_users_allowed_groups_compat.sql @@ -0,0 +1,15 @@ +-- 兼容旧库:若尚未创建 user_allowed_groups,则确保 users.allowed_groups 存在,避免 007 迁移回填失败。 +DO $$ +BEGIN + IF to_regclass('public.user_allowed_groups') IS NULL THEN + IF EXISTS ( + SELECT 1 + FROM information_schema.tables + WHERE table_schema = 'public' + AND table_name = 'users' + ) THEN + ALTER TABLE users + ADD COLUMN IF NOT EXISTS allowed_groups BIGINT[] DEFAULT NULL; + END IF; + END IF; +END $$; diff --git a/backend/migrations/006b_guard_users_allowed_groups.sql b/backend/migrations/006b_guard_users_allowed_groups.sql new file mode 100644 index 00000000..79771bf5 --- /dev/null +++ b/backend/migrations/006b_guard_users_allowed_groups.sql @@ -0,0 +1,27 @@ +-- 兼容缺失 users.allowed_groups 的老库,确保 007 回填可执行。 +DO $$ +BEGIN + IF EXISTS ( + SELECT 1 + FROM information_schema.tables + WHERE table_schema = 'public' + AND table_name = 'users' + ) THEN + IF NOT EXISTS ( + SELECT 1 + FROM information_schema.columns + WHERE table_schema = 'public' + AND table_name = 'users' + AND column_name = 'allowed_groups' + ) THEN + IF NOT EXISTS ( + SELECT 1 + FROM schema_migrations + WHERE filename = '014_drop_legacy_allowed_groups.sql' + ) THEN + ALTER TABLE users + ADD COLUMN IF NOT EXISTS allowed_groups BIGINT[] DEFAULT NULL; + END IF; + END IF; + END IF; +END $$; diff --git a/backend/migrations/042_add_usage_cleanup_tasks.sql b/backend/migrations/042_add_usage_cleanup_tasks.sql new file mode 100644 index 00000000..ce4be91f --- /dev/null +++ b/backend/migrations/042_add_usage_cleanup_tasks.sql @@ -0,0 +1,21 @@ +-- 042_add_usage_cleanup_tasks.sql +-- 使用记录清理任务表 + +CREATE TABLE IF NOT EXISTS usage_cleanup_tasks ( + id BIGSERIAL PRIMARY KEY, + status VARCHAR(20) NOT NULL, + filters JSONB NOT NULL, + created_by BIGINT NOT NULL REFERENCES users(id) ON DELETE RESTRICT, + deleted_rows BIGINT NOT NULL DEFAULT 0, + error_message TEXT, + started_at TIMESTAMPTZ, + finished_at TIMESTAMPTZ, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE INDEX IF NOT EXISTS idx_usage_cleanup_tasks_status_created_at + ON usage_cleanup_tasks(status, created_at DESC); + +CREATE INDEX IF NOT EXISTS idx_usage_cleanup_tasks_created_at + ON usage_cleanup_tasks(created_at DESC); diff --git a/backend/migrations/043_add_usage_cleanup_cancel_audit.sql b/backend/migrations/043_add_usage_cleanup_cancel_audit.sql new file mode 100644 index 00000000..42ca6696 --- /dev/null +++ b/backend/migrations/043_add_usage_cleanup_cancel_audit.sql @@ -0,0 +1,10 @@ +-- 043_add_usage_cleanup_cancel_audit.sql +-- usage_cleanup_tasks 取消任务审计字段 + +ALTER TABLE usage_cleanup_tasks + ADD COLUMN IF NOT EXISTS canceled_by BIGINT REFERENCES users(id) ON DELETE SET NULL, + ADD COLUMN IF NOT EXISTS canceled_at TIMESTAMPTZ; + +CREATE INDEX IF NOT EXISTS idx_usage_cleanup_tasks_canceled_at + ON usage_cleanup_tasks(canceled_at DESC); + diff --git a/config.yaml b/config.yaml index 424ce9eb..5e7513fb 100644 --- a/config.yaml +++ b/config.yaml @@ -251,6 +251,27 @@ dashboard_aggregation: # 日聚合保留天数 daily_days: 730 +# ============================================================================= +# Usage Cleanup Task Configuration +# 使用记录清理任务配置(重启生效) +# ============================================================================= +usage_cleanup: + # Enable cleanup task worker + # 启用清理任务执行器 + enabled: true + # Max date range (days) per task + # 单次任务最大时间跨度(天) + max_range_days: 31 + # Batch delete size + # 单批删除数量 + batch_size: 5000 + # Worker interval (seconds) + # 执行器轮询间隔(秒) + worker_interval_seconds: 10 + # Task execution timeout (seconds) + # 单次任务最大执行时长(秒) + task_timeout_seconds: 1800 + # ============================================================================= # Concurrency Wait Configuration # 并发等待配置 diff --git a/deploy/README.md b/deploy/README.md index f697247d..ed4ea721 100644 --- a/deploy/README.md +++ b/deploy/README.md @@ -401,3 +401,60 @@ sudo systemctl status redis 2. **Database connection failed**: Check PostgreSQL is running and credentials are correct 3. **Redis connection failed**: Check Redis is running and password is correct 4. **Permission denied**: Ensure proper file ownership for binary install + +--- + +## TLS Fingerprint Configuration + +Sub2API supports TLS fingerprint simulation to make requests appear as if they come from the official Claude CLI (Node.js client). + +> **💡 Tip:** Visit **[tls.sub2api.org](https://tls.sub2api.org/)** to get TLS fingerprint information for different devices and browsers. + +### Default Behavior + +- Built-in `claude_cli_v2` profile simulates Node.js 20.x + OpenSSL 3.x +- JA3 Hash: `1a28e69016765d92e3b381168d68922c` +- JA4: `t13d5911h1_a33745022dd6_1f22a2ca17c4` +- Profile selection: `accountID % profileCount` + +### Configuration + +```yaml +gateway: + tls_fingerprint: + enabled: true # Global switch + profiles: + # Simple profile (uses default cipher suites) + profile_1: + name: "Profile 1" + + # Profile with custom cipher suites (use compact array format) + profile_2: + name: "Profile 2" + cipher_suites: [4866, 4867, 4865, 49199, 49195, 49200, 49196] + curves: [29, 23, 24] + point_formats: [0] + + # Another custom profile + profile_3: + name: "Profile 3" + cipher_suites: [4865, 4866, 4867, 49199, 49200] + curves: [29, 23, 24, 25] +``` + +### Profile Fields + +| Field | Type | Description | +|-------|------|-------------| +| `name` | string | Display name (required) | +| `cipher_suites` | []uint16 | Cipher suites in decimal. Empty = default | +| `curves` | []uint16 | Elliptic curves in decimal. Empty = default | +| `point_formats` | []uint8 | EC point formats. Empty = default | + +### Common Values Reference + +**Cipher Suites (TLS 1.3):** `4865` (AES_128_GCM), `4866` (AES_256_GCM), `4867` (CHACHA20) + +**Cipher Suites (TLS 1.2):** `49195`, `49196`, `49199`, `49200` (ECDHE variants) + +**Curves:** `29` (X25519), `23` (P-256), `24` (P-384), `25` (P-521) diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 9e85d1ff..558b8ef0 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -210,6 +210,19 @@ gateway: outbox_backlog_rebuild_rows: 10000 # 全量重建周期(秒),0 表示禁用 full_rebuild_interval_seconds: 300 + # TLS fingerprint simulation / TLS 指纹伪装 + # Default profile "claude_cli_v2" simulates Node.js 20.x + # 默认模板 "claude_cli_v2" 模拟 Node.js 20.x 指纹 + tls_fingerprint: + enabled: true + # profiles: + # profile_1: + # name: "Custom Profile 1" + # profile_2: + # name: "Custom Profile 2" + # cipher_suites: [4866, 4867, 4865, 49199, 49195, 49200, 49196] + # curves: [29, 23, 24] + # point_formats: [0] # ============================================================================= # API Key Auth Cache Configuration @@ -292,6 +305,27 @@ dashboard_aggregation: # 日聚合保留天数 daily_days: 730 +# ============================================================================= +# Usage Cleanup Task Configuration +# 使用记录清理任务配置(重启生效) +# ============================================================================= +usage_cleanup: + # Enable cleanup task worker + # 启用清理任务执行器 + enabled: true + # Max date range (days) per task + # 单次任务最大时间跨度(天) + max_range_days: 31 + # Batch delete size + # 单批删除数量 + batch_size: 5000 + # Worker interval (seconds) + # 执行器轮询间隔(秒) + worker_interval_seconds: 10 + # Task execution timeout (seconds) + # 单次任务最大执行时长(秒) + task_timeout_seconds: 1800 + # ============================================================================= # Concurrency Wait Configuration # 并发等待配置 diff --git a/frontend/src/api/admin/dashboard.ts b/frontend/src/api/admin/dashboard.ts index 9b338788..ae48bec2 100644 --- a/frontend/src/api/admin/dashboard.ts +++ b/frontend/src/api/admin/dashboard.ts @@ -50,6 +50,7 @@ export interface TrendParams { account_id?: number group_id?: number stream?: boolean + billing_type?: number | null } export interface TrendResponse { @@ -78,6 +79,7 @@ export interface ModelStatsParams { account_id?: number group_id?: number stream?: boolean + billing_type?: number | null } export interface ModelStatsResponse { diff --git a/frontend/src/api/admin/groups.ts b/frontend/src/api/admin/groups.ts index 44eebc99..4d2b10ef 100644 --- a/frontend/src/api/admin/groups.ts +++ b/frontend/src/api/admin/groups.ts @@ -5,7 +5,7 @@ import { apiClient } from '../client' import type { - Group, + AdminGroup, GroupPlatform, CreateGroupRequest, UpdateGroupRequest, @@ -31,8 +31,8 @@ export async function list( options?: { signal?: AbortSignal } -): Promise> { - const { data } = await apiClient.get>('/admin/groups', { +): Promise> { + const { data } = await apiClient.get>('/admin/groups', { params: { page, page_size: pageSize, @@ -48,8 +48,8 @@ export async function list( * @param platform - Optional platform filter * @returns List of all active groups */ -export async function getAll(platform?: GroupPlatform): Promise { - const { data } = await apiClient.get('/admin/groups/all', { +export async function getAll(platform?: GroupPlatform): Promise { + const { data } = await apiClient.get('/admin/groups/all', { params: platform ? { platform } : undefined }) return data @@ -60,7 +60,7 @@ export async function getAll(platform?: GroupPlatform): Promise { * @param platform - Platform to filter by * @returns List of groups for the specified platform */ -export async function getByPlatform(platform: GroupPlatform): Promise { +export async function getByPlatform(platform: GroupPlatform): Promise { return getAll(platform) } @@ -69,8 +69,8 @@ export async function getByPlatform(platform: GroupPlatform): Promise { * @param id - Group ID * @returns Group details */ -export async function getById(id: number): Promise { - const { data } = await apiClient.get(`/admin/groups/${id}`) +export async function getById(id: number): Promise { + const { data } = await apiClient.get(`/admin/groups/${id}`) return data } @@ -79,8 +79,8 @@ export async function getById(id: number): Promise { * @param groupData - Group data * @returns Created group */ -export async function create(groupData: CreateGroupRequest): Promise { - const { data } = await apiClient.post('/admin/groups', groupData) +export async function create(groupData: CreateGroupRequest): Promise { + const { data } = await apiClient.post('/admin/groups', groupData) return data } @@ -90,8 +90,8 @@ export async function create(groupData: CreateGroupRequest): Promise { * @param updates - Fields to update * @returns Updated group */ -export async function update(id: number, updates: UpdateGroupRequest): Promise { - const { data } = await apiClient.put(`/admin/groups/${id}`, updates) +export async function update(id: number, updates: UpdateGroupRequest): Promise { + const { data } = await apiClient.put(`/admin/groups/${id}`, updates) return data } @@ -111,7 +111,7 @@ export async function deleteGroup(id: number): Promise<{ message: string }> { * @param status - New status * @returns Updated group */ -export async function toggleStatus(id: number, status: 'active' | 'inactive'): Promise { +export async function toggleStatus(id: number, status: 'active' | 'inactive'): Promise { return update(id, { status }) } diff --git a/frontend/src/api/admin/settings.ts b/frontend/src/api/admin/settings.ts index fc72be8d..c9a09e7d 100644 --- a/frontend/src/api/admin/settings.ts +++ b/frontend/src/api/admin/settings.ts @@ -23,6 +23,7 @@ export interface SystemSettings { contact_info: string doc_url: string home_content: string + hide_ccs_import_button: boolean // SMTP settings smtp_host: string smtp_port: number @@ -72,6 +73,7 @@ export interface UpdateSettingsRequest { contact_info?: string doc_url?: string home_content?: string + hide_ccs_import_button?: boolean smtp_host?: string smtp_port?: number smtp_username?: string diff --git a/frontend/src/api/admin/usage.ts b/frontend/src/api/admin/usage.ts index dd85fc24..94f7b57b 100644 --- a/frontend/src/api/admin/usage.ts +++ b/frontend/src/api/admin/usage.ts @@ -4,7 +4,7 @@ */ import { apiClient } from '../client' -import type { UsageLog, UsageQueryParams, PaginatedResponse } from '@/types' +import type { AdminUsageLog, UsageQueryParams, PaginatedResponse } from '@/types' // ==================== Types ==================== @@ -31,6 +31,46 @@ export interface SimpleApiKey { user_id: number } +export interface UsageCleanupFilters { + start_time: string + end_time: string + user_id?: number + api_key_id?: number + account_id?: number + group_id?: number + model?: string | null + stream?: boolean | null + billing_type?: number | null +} + +export interface UsageCleanupTask { + id: number + status: string + filters: UsageCleanupFilters + created_by: number + deleted_rows: number + error_message?: string | null + canceled_by?: number | null + canceled_at?: string | null + started_at?: string | null + finished_at?: string | null + created_at: string + updated_at: string +} + +export interface CreateUsageCleanupTaskRequest { + start_date: string + end_date: string + user_id?: number + api_key_id?: number + account_id?: number + group_id?: number + model?: string | null + stream?: boolean | null + billing_type?: number | null + timezone?: string +} + export interface AdminUsageQueryParams extends UsageQueryParams { user_id?: number } @@ -45,8 +85,8 @@ export interface AdminUsageQueryParams extends UsageQueryParams { export async function list( params: AdminUsageQueryParams, options?: { signal?: AbortSignal } -): Promise> { - const { data } = await apiClient.get>('/admin/usage', { +): Promise> { + const { data } = await apiClient.get>('/admin/usage', { params, signal: options?.signal }) @@ -108,11 +148,51 @@ export async function searchApiKeys(userId?: number, keyword?: string): Promise< return data } +/** + * List usage cleanup tasks (admin only) + * @param params - Query parameters for pagination + * @returns Paginated list of cleanup tasks + */ +export async function listCleanupTasks( + params: { page?: number; page_size?: number }, + options?: { signal?: AbortSignal } +): Promise> { + const { data } = await apiClient.get>('/admin/usage/cleanup-tasks', { + params, + signal: options?.signal + }) + return data +} + +/** + * Create a usage cleanup task (admin only) + * @param payload - Cleanup task parameters + * @returns Created cleanup task + */ +export async function createCleanupTask(payload: CreateUsageCleanupTaskRequest): Promise { + const { data } = await apiClient.post('/admin/usage/cleanup-tasks', payload) + return data +} + +/** + * Cancel a usage cleanup task (admin only) + * @param taskId - Task ID to cancel + */ +export async function cancelCleanupTask(taskId: number): Promise<{ id: number; status: string }> { + const { data } = await apiClient.post<{ id: number; status: string }>( + `/admin/usage/cleanup-tasks/${taskId}/cancel` + ) + return data +} + export const adminUsageAPI = { list, getStats, searchUsers, - searchApiKeys + searchApiKeys, + listCleanupTasks, + createCleanupTask, + cancelCleanupTask } export default adminUsageAPI diff --git a/frontend/src/api/admin/users.ts b/frontend/src/api/admin/users.ts index 44963cf9..734e3ac7 100644 --- a/frontend/src/api/admin/users.ts +++ b/frontend/src/api/admin/users.ts @@ -4,7 +4,7 @@ */ import { apiClient } from '../client' -import type { User, UpdateUserRequest, PaginatedResponse } from '@/types' +import type { AdminUser, UpdateUserRequest, PaginatedResponse } from '@/types' /** * List all users with pagination @@ -26,7 +26,7 @@ export async function list( options?: { signal?: AbortSignal } -): Promise> { +): Promise> { // Build params with attribute filters in attr[id]=value format const params: Record = { page, @@ -44,8 +44,7 @@ export async function list( } } } - - const { data } = await apiClient.get>('/admin/users', { + const { data } = await apiClient.get>('/admin/users', { params, signal: options?.signal }) @@ -57,8 +56,8 @@ export async function list( * @param id - User ID * @returns User details */ -export async function getById(id: number): Promise { - const { data } = await apiClient.get(`/admin/users/${id}`) +export async function getById(id: number): Promise { + const { data } = await apiClient.get(`/admin/users/${id}`) return data } @@ -73,8 +72,8 @@ export async function create(userData: { balance?: number concurrency?: number allowed_groups?: number[] | null -}): Promise { - const { data } = await apiClient.post('/admin/users', userData) +}): Promise { + const { data } = await apiClient.post('/admin/users', userData) return data } @@ -84,8 +83,8 @@ export async function create(userData: { * @param updates - Fields to update * @returns Updated user */ -export async function update(id: number, updates: UpdateUserRequest): Promise { - const { data } = await apiClient.put(`/admin/users/${id}`, updates) +export async function update(id: number, updates: UpdateUserRequest): Promise { + const { data } = await apiClient.put(`/admin/users/${id}`, updates) return data } @@ -112,8 +111,8 @@ export async function updateBalance( balance: number, operation: 'set' | 'add' | 'subtract' = 'set', notes?: string -): Promise { - const { data } = await apiClient.post(`/admin/users/${id}/balance`, { +): Promise { + const { data } = await apiClient.post(`/admin/users/${id}/balance`, { balance, operation, notes: notes || '' @@ -127,7 +126,7 @@ export async function updateBalance( * @param concurrency - New concurrency limit * @returns Updated user */ -export async function updateConcurrency(id: number, concurrency: number): Promise { +export async function updateConcurrency(id: number, concurrency: number): Promise { return update(id, { concurrency }) } @@ -137,7 +136,7 @@ export async function updateConcurrency(id: number, concurrency: number): Promis * @param status - New status * @returns Updated user */ -export async function toggleStatus(id: number, status: 'active' | 'disabled'): Promise { +export async function toggleStatus(id: number, status: 'active' | 'disabled'): Promise { return update(id, { status }) } diff --git a/frontend/src/components/account/AccountTestModal.vue b/frontend/src/components/account/AccountTestModal.vue index 42f3c1b9..dfa1503e 100644 --- a/frontend/src/components/account/AccountTestModal.vue +++ b/frontend/src/components/account/AccountTestModal.vue @@ -292,8 +292,11 @@ const loadAvailableModels = async () => { if (availableModels.value.length > 0) { if (props.account.platform === 'gemini') { const preferred = + availableModels.value.find((m) => m.id === 'gemini-2.0-flash') || + availableModels.value.find((m) => m.id === 'gemini-2.5-flash') || availableModels.value.find((m) => m.id === 'gemini-2.5-pro') || - availableModels.value.find((m) => m.id === 'gemini-3-pro') + availableModels.value.find((m) => m.id === 'gemini-3-flash-preview') || + availableModels.value.find((m) => m.id === 'gemini-3-pro-preview') selectedModelId.value = preferred?.id || availableModels.value[0].id } else { // Try to select Sonnet as default, otherwise use first model diff --git a/frontend/src/components/account/BulkEditAccountModal.vue b/frontend/src/components/account/BulkEditAccountModal.vue index fb776e96..1f6b487b 100644 --- a/frontend/src/components/account/BulkEditAccountModal.vue +++ b/frontend/src/components/account/BulkEditAccountModal.vue @@ -648,7 +648,7 @@ import { ref, watch, computed } from 'vue' import { useI18n } from 'vue-i18n' import { useAppStore } from '@/stores/app' import { adminAPI } from '@/api/admin' -import type { Proxy, Group } from '@/types' +import type { Proxy, AdminGroup } from '@/types' import BaseDialog from '@/components/common/BaseDialog.vue' import Select from '@/components/common/Select.vue' import ProxySelector from '@/components/common/ProxySelector.vue' @@ -659,7 +659,7 @@ interface Props { show: boolean accountIds: number[] proxies: Proxy[] - groups: Group[] + groups: AdminGroup[] } const props = defineProps() diff --git a/frontend/src/components/account/CreateAccountModal.vue b/frontend/src/components/account/CreateAccountModal.vue index c81de00e..144241ff 100644 --- a/frontend/src/components/account/CreateAccountModal.vue +++ b/frontend/src/components/account/CreateAccountModal.vue @@ -1191,6 +1191,190 @@ + +
+
+

{{ t('admin.accounts.quotaControl.title') }}

+

+ {{ t('admin.accounts.quotaControl.hint') }} +

+
+ + +
+
+
+ +

+ {{ t('admin.accounts.quotaControl.windowCost.hint') }} +

+
+ +
+ +
+
+ +
+ $ + +
+

{{ t('admin.accounts.quotaControl.windowCost.limitHint') }}

+
+
+ +
+ $ + +
+

{{ t('admin.accounts.quotaControl.windowCost.stickyReserveHint') }}

+
+
+
+ + +
+
+
+ +

+ {{ t('admin.accounts.quotaControl.sessionLimit.hint') }} +

+
+ +
+ +
+
+ + +

{{ t('admin.accounts.quotaControl.sessionLimit.maxSessionsHint') }}

+
+
+ +
+ + {{ t('common.minutes') }} +
+

{{ t('admin.accounts.quotaControl.sessionLimit.idleTimeoutHint') }}

+
+
+
+ + +
+
+
+ +

+ {{ t('admin.accounts.quotaControl.tlsFingerprint.hint') }} +

+
+ +
+
+ + +
+
+
+ +

+ {{ t('admin.accounts.quotaControl.sessionIdMasking.hint') }} +

+
+ +
+
+
+
@@ -1214,7 +1398,7 @@
- +

{{ t('admin.accounts.billingRateMultiplierHint') }}

@@ -1632,7 +1816,7 @@ import { import { useOpenAIOAuth } from '@/composables/useOpenAIOAuth' import { useGeminiOAuth } from '@/composables/useGeminiOAuth' import { useAntigravityOAuth } from '@/composables/useAntigravityOAuth' -import type { Proxy, Group, AccountPlatform, AccountType } from '@/types' +import type { Proxy, AdminGroup, AccountPlatform, AccountType } from '@/types' import BaseDialog from '@/components/common/BaseDialog.vue' import Icon from '@/components/icons/Icon.vue' import ProxySelector from '@/components/common/ProxySelector.vue' @@ -1678,7 +1862,7 @@ const apiKeyHint = computed(() => { interface Props { show: boolean proxies: Proxy[] - groups: Group[] + groups: AdminGroup[] } const props = defineProps() @@ -1763,6 +1947,16 @@ const geminiAIStudioOAuthEnabled = ref(false) const showAdvancedOAuth = ref(false) const showGeminiHelpDialog = ref(false) +// Quota control state (Anthropic OAuth/SetupToken only) +const windowCostEnabled = ref(false) +const windowCostLimit = ref(null) +const windowCostStickyReserve = ref(null) +const sessionLimitEnabled = ref(false) +const maxSessions = ref(null) +const sessionIdleTimeout = ref(null) +const tlsFingerprintEnabled = ref(false) +const sessionIdMaskingEnabled = ref(false) + // Gemini tier selection (used as fallback when auto-detection is unavailable/fails) const geminiTierGoogleOne = ref<'google_one_free' | 'google_ai_pro' | 'google_ai_ultra'>('google_one_free') const geminiTierGcp = ref<'gcp_standard' | 'gcp_enterprise'>('gcp_standard') @@ -2140,6 +2334,15 @@ const resetForm = () => { customErrorCodeInput.value = null interceptWarmupRequests.value = false autoPauseOnExpired.value = true + // Reset quota control state + windowCostEnabled.value = false + windowCostLimit.value = null + windowCostStickyReserve.value = null + sessionLimitEnabled.value = false + maxSessions.value = null + sessionIdleTimeout.value = null + tlsFingerprintEnabled.value = false + sessionIdMaskingEnabled.value = false tempUnschedEnabled.value = false tempUnschedRules.value = [] geminiOAuthType.value = 'code_assist' @@ -2407,7 +2610,32 @@ const handleAnthropicExchange = async (authCode: string) => { ...proxyConfig }) - const extra = oauth.buildExtraInfo(tokenInfo) + // Build extra with quota control settings + const baseExtra = oauth.buildExtraInfo(tokenInfo) || {} + const extra: Record = { ...baseExtra } + + // Add window cost limit settings + if (windowCostEnabled.value && windowCostLimit.value != null && windowCostLimit.value > 0) { + extra.window_cost_limit = windowCostLimit.value + extra.window_cost_sticky_reserve = windowCostStickyReserve.value ?? 10 + } + + // Add session limit settings + if (sessionLimitEnabled.value && maxSessions.value != null && maxSessions.value > 0) { + extra.max_sessions = maxSessions.value + extra.session_idle_timeout_minutes = sessionIdleTimeout.value ?? 5 + } + + // Add TLS fingerprint settings + if (tlsFingerprintEnabled.value) { + extra.enable_tls_fingerprint = true + } + + // Add session ID masking settings + if (sessionIdMaskingEnabled.value) { + extra.session_id_masking_enabled = true + } + const credentials = { ...tokenInfo, ...(interceptWarmupRequests.value ? { intercept_warmup_requests: true } : {}) @@ -2475,7 +2703,32 @@ const handleCookieAuth = async (sessionKey: string) => { ...proxyConfig }) - const extra = oauth.buildExtraInfo(tokenInfo) + // Build extra with quota control settings + const baseExtra = oauth.buildExtraInfo(tokenInfo) || {} + const extra: Record = { ...baseExtra } + + // Add window cost limit settings + if (windowCostEnabled.value && windowCostLimit.value != null && windowCostLimit.value > 0) { + extra.window_cost_limit = windowCostLimit.value + extra.window_cost_sticky_reserve = windowCostStickyReserve.value ?? 10 + } + + // Add session limit settings + if (sessionLimitEnabled.value && maxSessions.value != null && maxSessions.value > 0) { + extra.max_sessions = maxSessions.value + extra.session_idle_timeout_minutes = sessionIdleTimeout.value ?? 5 + } + + // Add TLS fingerprint settings + if (tlsFingerprintEnabled.value) { + extra.enable_tls_fingerprint = true + } + + // Add session ID masking settings + if (sessionIdMaskingEnabled.value) { + extra.session_id_masking_enabled = true + } + const accountName = keys.length > 1 ? `${form.name} #${i + 1}` : form.name // Merge interceptWarmupRequests into credentials diff --git a/frontend/src/components/account/EditAccountModal.vue b/frontend/src/components/account/EditAccountModal.vue index d27364f1..0dd855ef 100644 --- a/frontend/src/components/account/EditAccountModal.vue +++ b/frontend/src/components/account/EditAccountModal.vue @@ -566,7 +566,7 @@
- +

{{ t('admin.accounts.billingRateMultiplierHint') }}

@@ -732,6 +732,60 @@ + + +
+
+
+ +

+ {{ t('admin.accounts.quotaControl.tlsFingerprint.hint') }} +

+
+ +
+
+ + +
+
+
+ +

+ {{ t('admin.accounts.quotaControl.sessionIdMasking.hint') }} +

+
+ +
+
@@ -829,7 +883,7 @@ import { useI18n } from 'vue-i18n' import { useAppStore } from '@/stores/app' import { useAuthStore } from '@/stores/auth' import { adminAPI } from '@/api/admin' -import type { Account, Proxy, Group } from '@/types' +import type { Account, Proxy, AdminGroup } from '@/types' import BaseDialog from '@/components/common/BaseDialog.vue' import Select from '@/components/common/Select.vue' import Icon from '@/components/icons/Icon.vue' @@ -847,7 +901,7 @@ interface Props { show: boolean account: Account | null proxies: Proxy[] - groups: Group[] + groups: AdminGroup[] } const props = defineProps() @@ -904,6 +958,8 @@ const windowCostStickyReserve = ref(null) const sessionLimitEnabled = ref(false) const maxSessions = ref(null) const sessionIdleTimeout = ref(null) +const tlsFingerprintEnabled = ref(false) +const sessionIdMaskingEnabled = ref(false) // Computed: current preset mappings based on platform const presetMappings = computed(() => getPresetMappingsByPlatform(props.account?.platform || 'anthropic')) @@ -1237,6 +1293,8 @@ function loadQuotaControlSettings(account: Account) { sessionLimitEnabled.value = false maxSessions.value = null sessionIdleTimeout.value = null + tlsFingerprintEnabled.value = false + sessionIdMaskingEnabled.value = false // Only applies to Anthropic OAuth/SetupToken accounts if (account.platform !== 'anthropic' || (account.type !== 'oauth' && account.type !== 'setup-token')) { @@ -1255,6 +1313,16 @@ function loadQuotaControlSettings(account: Account) { maxSessions.value = account.max_sessions sessionIdleTimeout.value = account.session_idle_timeout_minutes ?? 5 } + + // Load TLS fingerprint setting + if (account.enable_tls_fingerprint === true) { + tlsFingerprintEnabled.value = true + } + + // Load session ID masking setting + if (account.session_id_masking_enabled === true) { + sessionIdMaskingEnabled.value = true + } } function formatTempUnschedKeywords(value: unknown) { @@ -1407,6 +1475,20 @@ const handleSubmit = async () => { delete newExtra.session_idle_timeout_minutes } + // TLS fingerprint setting + if (tlsFingerprintEnabled.value) { + newExtra.enable_tls_fingerprint = true + } else { + delete newExtra.enable_tls_fingerprint + } + + // Session ID masking setting + if (sessionIdMaskingEnabled.value) { + newExtra.session_id_masking_enabled = true + } else { + delete newExtra.session_id_masking_enabled + } + updatePayload.extra = newExtra } diff --git a/frontend/src/components/admin/account/AccountTableActions.vue b/frontend/src/components/admin/account/AccountTableActions.vue index 96fceaa0..8dffd6d1 100644 --- a/frontend/src/components/admin/account/AccountTableActions.vue +++ b/frontend/src/components/admin/account/AccountTableActions.vue @@ -1,8 +1,10 @@