Merge pull request #317 from IanShaw027/fix/gemini-issue

fix(gemini,group): 更新 Gemini 模型配置并补齐 SIMPLE 默认分组
This commit is contained in:
Wesley Liddick
2026-01-18 14:30:42 +08:00
committed by GitHub
12 changed files with 226 additions and 37 deletions

View File

@@ -16,14 +16,11 @@ type ModelsListResponse struct {
func DefaultModels() []Model {
methods := []string{"generateContent", "streamGenerateContent"}
return []Model{
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.0-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-1.5-pro", SupportedGenerationMethods: methods},
{Name: "models/gemini-1.5-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-1.5-flash-8b", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.5-flash", SupportedGenerationMethods: methods},
{Name: "models/gemini-2.5-pro", SupportedGenerationMethods: methods},
{Name: "models/gemini-3-flash-preview", SupportedGenerationMethods: methods},
{Name: "models/gemini-3-pro-preview", SupportedGenerationMethods: methods},
}
}

View File

@@ -12,10 +12,10 @@ type Model struct {
// DefaultModels is the curated Gemini model list used by the admin UI "test account" flow.
var DefaultModels = []Model{
{ID: "gemini-2.0-flash", Type: "model", DisplayName: "Gemini 2.0 Flash", CreatedAt: ""},
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
{ID: "gemini-2.5-flash", Type: "model", DisplayName: "Gemini 2.5 Flash", CreatedAt: ""},
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
{ID: "gemini-2.5-pro", Type: "model", DisplayName: "Gemini 2.5 Pro", CreatedAt: ""},
{ID: "gemini-3-flash-preview", Type: "model", DisplayName: "Gemini 3 Flash Preview", CreatedAt: ""},
{ID: "gemini-3-pro-preview", Type: "model", DisplayName: "Gemini 3 Pro Preview", CreatedAt: ""},
}
// DefaultTestModel is the default model to preselect in test flows.

View File

@@ -65,5 +65,18 @@ func InitEnt(cfg *config.Config) (*ent.Client, *sql.DB, error) {
// 创建 Ent 客户端,绑定到已配置的数据库驱动。
client := ent.NewClient(ent.Driver(drv))
// SIMPLE 模式:启动时补齐各平台默认分组。
// - anthropic/openai/gemini: 确保存在 <platform>-default
// - antigravity: 仅要求存在 >=2 个未软删除分组(用于 claude/gemini 混合调度场景)
if cfg.RunMode == config.RunModeSimple {
seedCtx, seedCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer seedCancel()
if err := ensureSimpleModeDefaultGroups(seedCtx, client); err != nil {
_ = client.Close()
return nil, nil, err
}
}
return client, drv.DB(), nil
}

View File

@@ -0,0 +1,82 @@
package repository
import (
"context"
"fmt"
dbent "github.com/Wei-Shaw/sub2api/ent"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/internal/service"
)
func ensureSimpleModeDefaultGroups(ctx context.Context, client *dbent.Client) error {
if client == nil {
return fmt.Errorf("nil ent client")
}
requiredByPlatform := map[string]int{
service.PlatformAnthropic: 1,
service.PlatformOpenAI: 1,
service.PlatformGemini: 1,
service.PlatformAntigravity: 2,
}
for platform, minCount := range requiredByPlatform {
count, err := client.Group.Query().
Where(group.PlatformEQ(platform), group.DeletedAtIsNil()).
Count(ctx)
if err != nil {
return fmt.Errorf("count groups for platform %s: %w", platform, err)
}
if platform == service.PlatformAntigravity {
if count < minCount {
for i := count; i < minCount; i++ {
name := fmt.Sprintf("%s-default-%d", platform, i+1)
if err := createGroupIfNotExists(ctx, client, name, platform); err != nil {
return err
}
}
}
continue
}
// Non-antigravity platforms: ensure <platform>-default exists.
name := platform + "-default"
if err := createGroupIfNotExists(ctx, client, name, platform); err != nil {
return err
}
}
return nil
}
func createGroupIfNotExists(ctx context.Context, client *dbent.Client, name, platform string) error {
exists, err := client.Group.Query().
Where(group.NameEQ(name), group.DeletedAtIsNil()).
Exist(ctx)
if err != nil {
return fmt.Errorf("check group exists %s: %w", name, err)
}
if exists {
return nil
}
_, err = client.Group.Create().
SetName(name).
SetDescription("Auto-created default group").
SetPlatform(platform).
SetStatus(service.StatusActive).
SetSubscriptionType(service.SubscriptionTypeStandard).
SetRateMultiplier(1.0).
SetIsExclusive(false).
Save(ctx)
if err != nil {
if dbent.IsConstraintError(err) {
// Concurrent server startups may race on creation; treat as success.
return nil
}
return fmt.Errorf("create default group %s: %w", name, err)
}
return nil
}

View File

@@ -0,0 +1,84 @@
//go:build integration
package repository
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/ent/group"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
func TestEnsureSimpleModeDefaultGroups_CreatesMissingDefaults(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
client := tx.Client()
seedCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
require.NoError(t, ensureSimpleModeDefaultGroups(seedCtx, client))
assertGroupExists := func(name string) {
exists, err := client.Group.Query().Where(group.NameEQ(name), group.DeletedAtIsNil()).Exist(seedCtx)
require.NoError(t, err)
require.True(t, exists, "expected group %s to exist", name)
}
assertGroupExists(service.PlatformAnthropic + "-default")
assertGroupExists(service.PlatformOpenAI + "-default")
assertGroupExists(service.PlatformGemini + "-default")
assertGroupExists(service.PlatformAntigravity + "-default-1")
assertGroupExists(service.PlatformAntigravity + "-default-2")
}
func TestEnsureSimpleModeDefaultGroups_IgnoresSoftDeletedGroups(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
client := tx.Client()
seedCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
// Create and then soft-delete an anthropic default group.
g, err := client.Group.Create().
SetName(service.PlatformAnthropic + "-default").
SetPlatform(service.PlatformAnthropic).
SetStatus(service.StatusActive).
SetSubscriptionType(service.SubscriptionTypeStandard).
SetRateMultiplier(1.0).
SetIsExclusive(false).
Save(seedCtx)
require.NoError(t, err)
_, err = client.Group.Delete().Where(group.IDEQ(g.ID)).Exec(seedCtx)
require.NoError(t, err)
require.NoError(t, ensureSimpleModeDefaultGroups(seedCtx, client))
// New active one should exist.
count, err := client.Group.Query().Where(group.NameEQ(service.PlatformAnthropic+"-default"), group.DeletedAtIsNil()).Count(seedCtx)
require.NoError(t, err)
require.Equal(t, 1, count)
}
func TestEnsureSimpleModeDefaultGroups_AntigravityNeedsTwoGroupsOnlyByCount(t *testing.T) {
ctx := context.Background()
tx := testEntTx(t)
client := tx.Client()
seedCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
mustCreateGroup(t, client, &service.Group{Name: "ag-custom-1-" + time.Now().Format(time.RFC3339Nano), Platform: service.PlatformAntigravity})
mustCreateGroup(t, client, &service.Group{Name: "ag-custom-2-" + time.Now().Format(time.RFC3339Nano), Platform: service.PlatformAntigravity})
require.NoError(t, ensureSimpleModeDefaultGroups(seedCtx, client))
count, err := client.Group.Query().Where(group.PlatformEQ(service.PlatformAntigravity), group.DeletedAtIsNil()).Count(seedCtx)
require.NoError(t, err)
require.GreaterOrEqual(t, count, 2)
}

View File

@@ -30,7 +30,7 @@ func TestIsAntigravityModelSupported(t *testing.T) {
{"可映射 - claude-3-haiku-20240307", "claude-3-haiku-20240307", true},
// Gemini 前缀透传
{"Gemini前缀 - gemini-1.5-pro", "gemini-1.5-pro", true},
{"Gemini前缀 - gemini-2.5-pro", "gemini-2.5-pro", true},
{"Gemini前缀 - gemini-unknown-model", "gemini-unknown-model", true},
{"Gemini前缀 - gemini-future-version", "gemini-future-version", true},
@@ -142,10 +142,10 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
expected: "gemini-2.5-flash",
},
{
name: "Gemini透传 - gemini-1.5-pro",
requestedModel: "gemini-1.5-pro",
name: "Gemini透传 - gemini-2.5-pro",
requestedModel: "gemini-2.5-pro",
accountMapping: nil,
expected: "gemini-1.5-pro",
expected: "gemini-2.5-pro",
},
{
name: "Gemini透传 - gemini-future-model",

View File

@@ -599,7 +599,7 @@ func TestGeminiMessagesCompatService_isModelSupportedByAccount(t *testing.T) {
name: "Gemini平台-有映射配置-只支持配置的模型",
account: &Account{
Platform: PlatformGemini,
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-1.5-pro": "x"}},
Credentials: map[string]any{"model_mapping": map[string]any{"gemini-2.5-pro": "x"}},
},
model: "gemini-2.5-flash",
expected: false,

View File

@@ -531,8 +531,8 @@ func (s *PricingService) buildModelLookupCandidates(modelLower string) []string
func normalizeModelNameForPricing(model string) string {
// Common Gemini/VertexAI forms:
// - models/gemini-2.0-flash-exp
// - publishers/google/models/gemini-1.5-pro
// - projects/.../locations/.../publishers/google/models/gemini-1.5-pro
// - publishers/google/models/gemini-2.5-pro
// - projects/.../locations/.../publishers/google/models/gemini-2.5-pro
model = strings.TrimSpace(model)
model = strings.TrimLeft(model, "/")
model = strings.TrimPrefix(model, "models/")