diff --git a/backend/internal/repository/ent.go b/backend/internal/repository/ent.go index 8005f114..d7d574e8 100644 --- a/backend/internal/repository/ent.go +++ b/backend/internal/repository/ent.go @@ -65,5 +65,18 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) { // 创建 Ent 客户端,绑定到已配置的数据库驱动。 client := ent.NewClient(ent.Driver(drv)) + + // SIMPLE 模式:启动时补齐各平台默认分组。 + // - anthropic/openai/gemini: 确保存在 -default + // - antigravity: 仅要求存在 >=2 个未软删除分组(用于 claude/gemini 混合调度场景) + if cfg.RunMode == config.RunModeSimple { + seedCtx, seedCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer seedCancel() + if err := ensureSimpleModeDefaultGroups(seedCtx, client); err != nil { + _ = client.Close() + return nil, nil, err + } + } + return client, drv.DB(), nil } diff --git a/backend/internal/repository/simple_mode_default_groups.go b/backend/internal/repository/simple_mode_default_groups.go new file mode 100644 index 00000000..56309184 --- /dev/null +++ b/backend/internal/repository/simple_mode_default_groups.go @@ -0,0 +1,82 @@ +package repository + +import ( + "context" + "fmt" + + dbent "github.com/Wei-Shaw/sub2api/ent" + "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/internal/service" +) + +func ensureSimpleModeDefaultGroups(ctx context.Context, client *dbent.Client) error { + if client == nil { + return fmt.Errorf("nil ent client") + } + + requiredByPlatform := map[string]int{ + service.PlatformAnthropic: 1, + service.PlatformOpenAI: 1, + service.PlatformGemini: 1, + service.PlatformAntigravity: 2, + } + + for platform, minCount := range requiredByPlatform { + count, err := client.Group.Query(). + Where(group.PlatformEQ(platform), group.DeletedAtIsNil()). + Count(ctx) + if err != nil { + return fmt.Errorf("count groups for platform %s: %w", platform, err) + } + + if platform == service.PlatformAntigravity { + if count < minCount { + for i := count; i < minCount; i++ { + name := fmt.Sprintf("%s-default-%d", platform, i+1) + if err := createGroupIfNotExists(ctx, client, name, platform); err != nil { + return err + } + } + } + continue + } + + // Non-antigravity platforms: ensure -default exists. + name := platform + "-default" + if err := createGroupIfNotExists(ctx, client, name, platform); err != nil { + return err + } + } + + return nil +} + +func createGroupIfNotExists(ctx context.Context, client *dbent.Client, name, platform string) error { + exists, err := client.Group.Query(). + Where(group.NameEQ(name), group.DeletedAtIsNil()). + Exist(ctx) + if err != nil { + return fmt.Errorf("check group exists %s: %w", name, err) + } + if exists { + return nil + } + + _, err = client.Group.Create(). + SetName(name). + SetDescription("Auto-created default group"). + SetPlatform(platform). + SetStatus(service.StatusActive). + SetSubscriptionType(service.SubscriptionTypeStandard). + SetRateMultiplier(1.0). + SetIsExclusive(false). + Save(ctx) + if err != nil { + if dbent.IsConstraintError(err) { + // Concurrent server startups may race on creation; treat as success. + return nil + } + return fmt.Errorf("create default group %s: %w", name, err) + } + return nil +} diff --git a/backend/internal/repository/simple_mode_default_groups_integration_test.go b/backend/internal/repository/simple_mode_default_groups_integration_test.go new file mode 100644 index 00000000..3327257b --- /dev/null +++ b/backend/internal/repository/simple_mode_default_groups_integration_test.go @@ -0,0 +1,84 @@ +//go:build integration + +package repository + +import ( + "context" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/ent/group" + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +func TestEnsureSimpleModeDefaultGroups_CreatesMissingDefaults(t *testing.T) { + ctx := context.Background() + tx := testEntTx(t) + client := tx.Client() + + seedCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + require.NoError(t, ensureSimpleModeDefaultGroups(seedCtx, client)) + + assertGroupExists := func(name string) { + exists, err := client.Group.Query().Where(group.NameEQ(name), group.DeletedAtIsNil()).Exist(seedCtx) + require.NoError(t, err) + require.True(t, exists, "expected group %s to exist", name) + } + + assertGroupExists(service.PlatformAnthropic + "-default") + assertGroupExists(service.PlatformOpenAI + "-default") + assertGroupExists(service.PlatformGemini + "-default") + assertGroupExists(service.PlatformAntigravity + "-default-1") + assertGroupExists(service.PlatformAntigravity + "-default-2") +} + +func TestEnsureSimpleModeDefaultGroups_IgnoresSoftDeletedGroups(t *testing.T) { + ctx := context.Background() + tx := testEntTx(t) + client := tx.Client() + + seedCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + // Create and then soft-delete an anthropic default group. + g, err := client.Group.Create(). + SetName(service.PlatformAnthropic + "-default"). + SetPlatform(service.PlatformAnthropic). + SetStatus(service.StatusActive). + SetSubscriptionType(service.SubscriptionTypeStandard). + SetRateMultiplier(1.0). + SetIsExclusive(false). + Save(seedCtx) + require.NoError(t, err) + + _, err = client.Group.Delete().Where(group.IDEQ(g.ID)).Exec(seedCtx) + require.NoError(t, err) + + require.NoError(t, ensureSimpleModeDefaultGroups(seedCtx, client)) + + // New active one should exist. + count, err := client.Group.Query().Where(group.NameEQ(service.PlatformAnthropic+"-default"), group.DeletedAtIsNil()).Count(seedCtx) + require.NoError(t, err) + require.Equal(t, 1, count) +} + +func TestEnsureSimpleModeDefaultGroups_AntigravityNeedsTwoGroupsOnlyByCount(t *testing.T) { + ctx := context.Background() + tx := testEntTx(t) + client := tx.Client() + + seedCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + mustCreateGroup(t, client, &service.Group{Name: "ag-custom-1-" + time.Now().Format(time.RFC3339Nano), Platform: service.PlatformAntigravity}) + mustCreateGroup(t, client, &service.Group{Name: "ag-custom-2-" + time.Now().Format(time.RFC3339Nano), Platform: service.PlatformAntigravity}) + + require.NoError(t, ensureSimpleModeDefaultGroups(seedCtx, client)) + + count, err := client.Group.Query().Where(group.PlatformEQ(service.PlatformAntigravity), group.DeletedAtIsNil()).Count(seedCtx) + require.NoError(t, err) + require.GreaterOrEqual(t, count, 2) +}