feat(group): 添加分组级别模型路由配置功能
支持为分组配置模型路由规则,可以指定特定模型模式优先使用的账号列表。 - 新增 model_routing 字段存储路由配置(JSONB格式,支持通配符匹配) - 新增 model_routing_enabled 字段控制是否启用路由 - 更新后端 handler/service/repository 支持路由配置的增删改查 - 更新前端 GroupsView 添加路由配置界面 - 添加数据库迁移脚本 040/041
This commit is contained in:
@@ -40,6 +40,9 @@ type CreateGroupRequest struct {
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||
}
|
||||
|
||||
// UpdateGroupRequest represents update group request
|
||||
@@ -60,6 +63,9 @@ type UpdateGroupRequest struct {
|
||||
ImagePrice4K *float64 `json:"image_price_4k"`
|
||||
ClaudeCodeOnly *bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||
ModelRoutingEnabled *bool `json:"model_routing_enabled"`
|
||||
}
|
||||
|
||||
// List handles listing all groups with pagination
|
||||
@@ -149,20 +155,22 @@ func (h *GroupHandler) Create(c *gin.Context) {
|
||||
}
|
||||
|
||||
group, err := h.adminService.CreateGroup(c.Request.Context(), &service.CreateGroupInput{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Platform: req.Platform,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||
FallbackGroupID: req.FallbackGroupID,
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Platform: req.Platform,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||
FallbackGroupID: req.FallbackGroupID,
|
||||
ModelRouting: req.ModelRouting,
|
||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
@@ -188,21 +196,23 @@ func (h *GroupHandler) Update(c *gin.Context) {
|
||||
}
|
||||
|
||||
group, err := h.adminService.UpdateGroup(c.Request.Context(), groupID, &service.UpdateGroupInput{
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Platform: req.Platform,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
Status: req.Status,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||
FallbackGroupID: req.FallbackGroupID,
|
||||
Name: req.Name,
|
||||
Description: req.Description,
|
||||
Platform: req.Platform,
|
||||
RateMultiplier: req.RateMultiplier,
|
||||
IsExclusive: req.IsExclusive,
|
||||
Status: req.Status,
|
||||
SubscriptionType: req.SubscriptionType,
|
||||
DailyLimitUSD: req.DailyLimitUSD,
|
||||
WeeklyLimitUSD: req.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: req.MonthlyLimitUSD,
|
||||
ImagePrice1K: req.ImagePrice1K,
|
||||
ImagePrice2K: req.ImagePrice2K,
|
||||
ImagePrice4K: req.ImagePrice4K,
|
||||
ClaudeCodeOnly: req.ClaudeCodeOnly,
|
||||
FallbackGroupID: req.FallbackGroupID,
|
||||
ModelRouting: req.ModelRouting,
|
||||
ModelRoutingEnabled: req.ModelRoutingEnabled,
|
||||
})
|
||||
if err != nil {
|
||||
response.ErrorFrom(c, err)
|
||||
|
||||
@@ -87,9 +87,11 @@ func GroupFromServiceShallow(g *service.Group) *Group {
|
||||
ImagePrice1K: g.ImagePrice1K,
|
||||
ImagePrice2K: g.ImagePrice2K,
|
||||
ImagePrice4K: g.ImagePrice4K,
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
CreatedAt: g.CreatedAt,
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
AccountCount: g.AccountCount,
|
||||
}
|
||||
|
||||
@@ -58,6 +58,10 @@ type Group struct {
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id"`
|
||||
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64 `json:"model_routing"`
|
||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
|
||||
|
||||
@@ -136,6 +136,8 @@ func (r *apiKeyRepository) GetByKeyForAuth(ctx context.Context, key string) (*se
|
||||
group.FieldImagePrice4k,
|
||||
group.FieldClaudeCodeOnly,
|
||||
group.FieldFallbackGroupID,
|
||||
group.FieldModelRoutingEnabled,
|
||||
group.FieldModelRouting,
|
||||
)
|
||||
}).
|
||||
Only(ctx)
|
||||
@@ -422,6 +424,8 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
DefaultValidityDays: g.DefaultValidityDays,
|
||||
ClaudeCodeOnly: g.ClaudeCodeOnly,
|
||||
FallbackGroupID: g.FallbackGroupID,
|
||||
ModelRouting: g.ModelRouting,
|
||||
ModelRoutingEnabled: g.ModelRoutingEnabled,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
}
|
||||
|
||||
@@ -49,7 +49,13 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
||||
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
||||
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
||||
SetNillableFallbackGroupID(groupIn.FallbackGroupID)
|
||||
SetNillableFallbackGroupID(groupIn.FallbackGroupID).
|
||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled)
|
||||
|
||||
// 设置模型路由配置
|
||||
if groupIn.ModelRouting != nil {
|
||||
builder = builder.SetModelRouting(groupIn.ModelRouting)
|
||||
}
|
||||
|
||||
created, err := builder.Save(ctx)
|
||||
if err == nil {
|
||||
@@ -101,7 +107,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
||||
SetNillableImagePrice2k(groupIn.ImagePrice2K).
|
||||
SetNillableImagePrice4k(groupIn.ImagePrice4K).
|
||||
SetDefaultValidityDays(groupIn.DefaultValidityDays).
|
||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly)
|
||||
SetClaudeCodeOnly(groupIn.ClaudeCodeOnly).
|
||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled)
|
||||
|
||||
// 处理 FallbackGroupID:nil 时清除,否则设置
|
||||
if groupIn.FallbackGroupID != nil {
|
||||
@@ -110,6 +117,13 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
||||
builder = builder.ClearFallbackGroupID()
|
||||
}
|
||||
|
||||
// 处理 ModelRouting:nil 时清除,否则设置
|
||||
if groupIn.ModelRouting != nil {
|
||||
builder = builder.SetModelRouting(groupIn.ModelRouting)
|
||||
} else {
|
||||
builder = builder.ClearModelRouting()
|
||||
}
|
||||
|
||||
updated, err := builder.Save(ctx)
|
||||
if err != nil {
|
||||
return translatePersistenceError(err, service.ErrGroupNotFound, service.ErrGroupExists)
|
||||
|
||||
@@ -106,6 +106,9 @@ type CreateGroupInput struct {
|
||||
ImagePrice4K *float64
|
||||
ClaudeCodeOnly bool // 仅允许 Claude Code 客户端
|
||||
FallbackGroupID *int64 // 降级分组 ID
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64
|
||||
ModelRoutingEnabled bool // 是否启用模型路由
|
||||
}
|
||||
|
||||
type UpdateGroupInput struct {
|
||||
@@ -125,6 +128,9 @@ type UpdateGroupInput struct {
|
||||
ImagePrice4K *float64
|
||||
ClaudeCodeOnly *bool // 仅允许 Claude Code 客户端
|
||||
FallbackGroupID *int64 // 降级分组 ID
|
||||
// 模型路由配置(仅 anthropic 平台使用)
|
||||
ModelRouting map[string][]int64
|
||||
ModelRoutingEnabled *bool // 是否启用模型路由
|
||||
}
|
||||
|
||||
type CreateAccountInput struct {
|
||||
@@ -581,6 +587,7 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
|
||||
ImagePrice4K: imagePrice4K,
|
||||
ClaudeCodeOnly: input.ClaudeCodeOnly,
|
||||
FallbackGroupID: input.FallbackGroupID,
|
||||
ModelRouting: input.ModelRouting,
|
||||
}
|
||||
if err := s.groupRepo.Create(ctx, group); err != nil {
|
||||
return nil, err
|
||||
@@ -709,6 +716,14 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
|
||||
}
|
||||
}
|
||||
|
||||
// 模型路由配置
|
||||
if input.ModelRouting != nil {
|
||||
group.ModelRouting = input.ModelRouting
|
||||
}
|
||||
if input.ModelRoutingEnabled != nil {
|
||||
group.ModelRoutingEnabled = *input.ModelRoutingEnabled
|
||||
}
|
||||
|
||||
if err := s.groupRepo.Update(ctx, group); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -37,6 +37,11 @@ type APIKeyAuthGroupSnapshot struct {
|
||||
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||
|
||||
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
|
||||
// Only anthropic groups use these fields; others may leave them empty.
|
||||
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
|
||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||
}
|
||||
|
||||
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
|
||||
|
||||
@@ -221,6 +221,8 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
||||
ImagePrice4K: apiKey.Group.ImagePrice4K,
|
||||
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
|
||||
FallbackGroupID: apiKey.Group.FallbackGroupID,
|
||||
ModelRouting: apiKey.Group.ModelRouting,
|
||||
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
|
||||
}
|
||||
}
|
||||
return snapshot
|
||||
@@ -263,6 +265,8 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
||||
ImagePrice4K: snapshot.Group.ImagePrice4K,
|
||||
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
|
||||
FallbackGroupID: snapshot.Group.FallbackGroupID,
|
||||
ModelRouting: snapshot.Group.ModelRouting,
|
||||
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
|
||||
}
|
||||
}
|
||||
return apiKey
|
||||
|
||||
@@ -178,6 +178,10 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
|
||||
Status: StatusActive,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
RateMultiplier: 1,
|
||||
ModelRoutingEnabled: true,
|
||||
ModelRouting: map[string][]int64{
|
||||
"claude-opus-*": {1, 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -190,6 +194,8 @@ func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
|
||||
require.Equal(t, int64(1), apiKey.ID)
|
||||
require.Equal(t, int64(2), apiKey.User.ID)
|
||||
require.Equal(t, groupID, apiKey.Group.ID)
|
||||
require.True(t, apiKey.Group.ModelRoutingEnabled)
|
||||
require.Equal(t, map[string][]int64{"claude-opus-*": {1, 2}}, apiKey.Group.ModelRouting)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
|
||||
|
||||
@@ -1053,6 +1053,60 @@ func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
|
||||
require.Equal(t, int64(1), result.Account.ID, "应选择优先级最高的账号")
|
||||
})
|
||||
|
||||
t.Run("模型路由-无ConcurrencyService也生效", func(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
sessionHash := "sticky"
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, AccountGroups: []AccountGroup{{GroupID: groupID}}},
|
||||
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5, AccountGroups: []AccountGroup{{GroupID: groupID}}},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
}
|
||||
for i := range repo.accounts {
|
||||
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
|
||||
}
|
||||
|
||||
cache := &mockGatewayCacheForPlatform{
|
||||
sessionBindings: map[string]int64{sessionHash: 1},
|
||||
}
|
||||
|
||||
groupRepo := &mockGroupRepoForGateway{
|
||||
groups: map[int64]*Group{
|
||||
groupID: {
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
Hydrated: true,
|
||||
ModelRoutingEnabled: true,
|
||||
ModelRouting: map[string][]int64{
|
||||
"claude-a": {1},
|
||||
"claude-b": {2},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
cfg := testConfig()
|
||||
cfg.Gateway.Scheduling.LoadBatchEnabled = true
|
||||
|
||||
svc := &GatewayService{
|
||||
accountRepo: repo,
|
||||
groupRepo: groupRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
concurrencyService: nil, // legacy path
|
||||
}
|
||||
|
||||
result, err := svc.SelectAccountWithLoadAwareness(ctx, &groupID, sessionHash, "claude-b", nil)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.NotNil(t, result.Account)
|
||||
require.Equal(t, int64(2), result.Account.ID, "切换到 claude-b 时应按模型路由切换账号")
|
||||
require.Equal(t, int64(2), cache.sessionBindings[sessionHash], "粘性绑定应更新为路由选择的账号")
|
||||
})
|
||||
|
||||
t.Run("无ConcurrencyService-降级到传统选择", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
@@ -1341,6 +1395,7 @@ func TestGatewayService_GroupResolution_IgnoresInvalidContextGroup(t *testing.T)
|
||||
ID: groupID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
Hydrated: true,
|
||||
}
|
||||
groupRepo := &mockGroupRepoForGateway{
|
||||
groups: map[int64]*Group{groupID: group},
|
||||
@@ -1398,6 +1453,7 @@ func TestGatewayService_GroupResolution_FallbackUsesLiteOnce(t *testing.T) {
|
||||
ID: fallbackID,
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
Hydrated: true,
|
||||
}
|
||||
ctx = context.WithValue(ctx, ctxkey.Group, group)
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
@@ -38,6 +39,21 @@ const (
|
||||
maxCacheControlBlocks = 4 // Anthropic API 允许的最大 cache_control 块数量
|
||||
)
|
||||
|
||||
func (s *GatewayService) debugModelRoutingEnabled() bool {
|
||||
v := strings.ToLower(strings.TrimSpace(os.Getenv("SUB2API_DEBUG_MODEL_ROUTING")))
|
||||
return v == "1" || v == "true" || v == "yes" || v == "on"
|
||||
}
|
||||
|
||||
func shortSessionHash(sessionHash string) string {
|
||||
if sessionHash == "" {
|
||||
return ""
|
||||
}
|
||||
if len(sessionHash) <= 8 {
|
||||
return sessionHash
|
||||
}
|
||||
return sessionHash[:8]
|
||||
}
|
||||
|
||||
// sseDataRe matches SSE data lines with optional whitespace after colon.
|
||||
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
||||
var (
|
||||
@@ -407,6 +423,15 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
ctx = s.withGroupContext(ctx, group)
|
||||
|
||||
if s.debugModelRoutingEnabled() && requestedModel != "" {
|
||||
groupPlatform := ""
|
||||
if group != nil {
|
||||
groupPlatform = group.Platform
|
||||
}
|
||||
log.Printf("[ModelRoutingDebug] select entry: group_id=%v group_platform=%s model=%s session=%s sticky_account=%d load_batch=%v concurrency=%v",
|
||||
derefGroupID(groupID), groupPlatform, requestedModel, shortSessionHash(sessionHash), stickyAccountID, cfg.LoadBatchEnabled, s.concurrencyService != nil)
|
||||
}
|
||||
|
||||
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
|
||||
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
|
||||
if err != nil {
|
||||
@@ -450,6 +475,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
return nil, err
|
||||
}
|
||||
preferOAuth := platform == PlatformGemini
|
||||
if s.debugModelRoutingEnabled() && platform == PlatformAnthropic && requestedModel != "" {
|
||||
log.Printf("[ModelRoutingDebug] load-aware enabled: group_id=%v model=%s session=%s platform=%s", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), platform)
|
||||
}
|
||||
|
||||
accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||
if err != nil {
|
||||
@@ -467,15 +495,206 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
return excluded
|
||||
}
|
||||
|
||||
// ============ Layer 1: 粘性会话优先 ============
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
// 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用)
|
||||
accountByID := make(map[int64]*Account, len(accounts))
|
||||
for i := range accounts {
|
||||
accountByID[accounts[i].ID] = &accounts[i]
|
||||
}
|
||||
|
||||
// 获取模型路由配置(仅 anthropic 平台)
|
||||
var routingAccountIDs []int64
|
||||
if group != nil && requestedModel != "" && group.Platform == PlatformAnthropic {
|
||||
routingAccountIDs = group.GetRoutingAccountIDs(requestedModel)
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] context group routing: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v session=%s sticky_account=%d",
|
||||
group.ID, requestedModel, group.ModelRoutingEnabled, len(group.ModelRouting), routingAccountIDs, shortSessionHash(sessionHash), stickyAccountID)
|
||||
if len(routingAccountIDs) == 0 && group.ModelRoutingEnabled && len(group.ModelRouting) > 0 {
|
||||
keys := make([]string, 0, len(group.ModelRouting))
|
||||
for k := range group.ModelRouting {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
const maxKeys = 20
|
||||
if len(keys) > maxKeys {
|
||||
keys = keys[:maxKeys]
|
||||
}
|
||||
log.Printf("[ModelRoutingDebug] context group routing miss: group_id=%d model=%s patterns(sample)=%v", group.ID, requestedModel, keys)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Layer 1: 模型路由优先选择(优先级高于粘性会话) ============
|
||||
if len(routingAccountIDs) > 0 && s.concurrencyService != nil {
|
||||
// 1. 过滤出路由列表中可调度的账号
|
||||
var routingCandidates []*Account
|
||||
var filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping int
|
||||
for _, routingAccountID := range routingAccountIDs {
|
||||
if isExcluded(routingAccountID) {
|
||||
filteredExcluded++
|
||||
continue
|
||||
}
|
||||
account, ok := accountByID[routingAccountID]
|
||||
if !ok || !account.IsSchedulable() {
|
||||
if !ok {
|
||||
filteredMissing++
|
||||
} else {
|
||||
filteredUnsched++
|
||||
}
|
||||
continue
|
||||
}
|
||||
if !s.isAccountAllowedForPlatform(account, platform, useMixed) {
|
||||
filteredPlatform++
|
||||
continue
|
||||
}
|
||||
if !account.IsSchedulableForModel(requestedModel) {
|
||||
filteredModelScope++
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !s.isModelSupportedByAccount(account, requestedModel) {
|
||||
filteredModelMapping++
|
||||
continue
|
||||
}
|
||||
routingCandidates = append(routingCandidates, account)
|
||||
}
|
||||
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] routed candidates: group_id=%v model=%s routed=%d candidates=%d filtered(excluded=%d missing=%d unsched=%d platform=%d model_scope=%d model_mapping=%d)",
|
||||
derefGroupID(groupID), requestedModel, len(routingAccountIDs), len(routingCandidates),
|
||||
filteredExcluded, filteredMissing, filteredUnsched, filteredPlatform, filteredModelScope, filteredModelMapping)
|
||||
}
|
||||
|
||||
if len(routingCandidates) > 0 {
|
||||
// 1.5. 在路由账号范围内检查粘性会话
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
stickyAccountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
if err == nil && stickyAccountID > 0 && containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) {
|
||||
// 粘性账号在路由列表中,优先使用
|
||||
if stickyAccount, ok := accountByID[stickyAccountID]; ok {
|
||||
if stickyAccount.IsSchedulable() &&
|
||||
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
|
||||
stickyAccount.IsSchedulableForModel(requestedModel) &&
|
||||
(requestedModel == "" || s.isModelSupportedByAccount(stickyAccount, requestedModel)) {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: stickyAccount,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
return &AccountSelectionResult{
|
||||
Account: stickyAccount,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: stickyAccountID,
|
||||
MaxConcurrency: stickyAccount.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 批量获取负载信息
|
||||
routingLoads := make([]AccountWithConcurrency, 0, len(routingCandidates))
|
||||
for _, acc := range routingCandidates {
|
||||
routingLoads = append(routingLoads, AccountWithConcurrency{
|
||||
ID: acc.ID,
|
||||
MaxConcurrency: acc.Concurrency,
|
||||
})
|
||||
}
|
||||
routingLoadMap, _ := s.concurrencyService.GetAccountsLoadBatch(ctx, routingLoads)
|
||||
|
||||
// 3. 按负载感知排序
|
||||
type accountWithLoad struct {
|
||||
account *Account
|
||||
loadInfo *AccountLoadInfo
|
||||
}
|
||||
var routingAvailable []accountWithLoad
|
||||
for _, acc := range routingCandidates {
|
||||
loadInfo := routingLoadMap[acc.ID]
|
||||
if loadInfo == nil {
|
||||
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
|
||||
}
|
||||
if loadInfo.LoadRate < 100 {
|
||||
routingAvailable = append(routingAvailable, accountWithLoad{account: acc, loadInfo: loadInfo})
|
||||
}
|
||||
}
|
||||
|
||||
if len(routingAvailable) > 0 {
|
||||
// 排序:优先级 > 负载率 > 最后使用时间
|
||||
sort.SliceStable(routingAvailable, func(i, j int) bool {
|
||||
a, b := routingAvailable[i], routingAvailable[j]
|
||||
if a.account.Priority != b.account.Priority {
|
||||
return a.account.Priority < b.account.Priority
|
||||
}
|
||||
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
|
||||
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
|
||||
}
|
||||
switch {
|
||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
|
||||
return true
|
||||
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
|
||||
return false
|
||||
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
|
||||
return false
|
||||
default:
|
||||
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
|
||||
}
|
||||
})
|
||||
|
||||
// 4. 尝试获取槽位
|
||||
for _, item := range routingAvailable {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, item.account.ID, stickySessionTTL)
|
||||
}
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: item.account,
|
||||
Acquired: true,
|
||||
ReleaseFunc: result.ReleaseFunc,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 5. 所有路由账号槽位满,返回等待计划(选择负载最低的)
|
||||
acc := routingAvailable[0].account
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), acc.ID)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
Account: acc,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: acc.ID,
|
||||
MaxConcurrency: acc.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
// 路由列表中的账号都不可用(负载率 >= 100),继续到 Layer 2 回退
|
||||
log.Printf("[ModelRouting] All routed accounts unavailable for model=%s, falling back to normal selection", requestedModel)
|
||||
}
|
||||
}
|
||||
|
||||
// ============ Layer 1.5: 粘性会话(仅在无模型路由配置时生效) ============
|
||||
if len(routingAccountIDs) == 0 && sessionHash != "" && s.cache != nil {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||
// 粘性命中仅在当前可调度候选集中生效。
|
||||
accountByID := make(map[int64]*Account, len(accounts))
|
||||
for i := range accounts {
|
||||
accountByID[accounts[i].ID] = &accounts[i]
|
||||
}
|
||||
account, ok := accountByID[accountID]
|
||||
if ok && s.isAccountInGroup(account, groupID) &&
|
||||
s.isAccountAllowedForPlatform(account, platform, useMixed) &&
|
||||
@@ -687,6 +906,32 @@ func (s *GatewayService) resolveGroupByID(ctx context.Context, groupID int64) (*
|
||||
return group, nil
|
||||
}
|
||||
|
||||
func (s *GatewayService) routingAccountIDsForRequest(ctx context.Context, groupID *int64, requestedModel string, platform string) []int64 {
|
||||
if groupID == nil || requestedModel == "" || platform != PlatformAnthropic {
|
||||
return nil
|
||||
}
|
||||
group, err := s.resolveGroupByID(ctx, *groupID)
|
||||
if err != nil || group == nil {
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] resolve group failed: group_id=%v model=%s platform=%s err=%v", derefGroupID(groupID), requestedModel, platform, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// Preserve existing behavior: model routing only applies to anthropic groups.
|
||||
if group.Platform != PlatformAnthropic {
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] skip: non-anthropic group platform: group_id=%d group_platform=%s model=%s", group.ID, group.Platform, requestedModel)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
ids := group.GetRoutingAccountIDs(requestedModel)
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] routing lookup: group_id=%d model=%s enabled=%v rules=%d matched_ids=%v",
|
||||
group.ID, requestedModel, group.ModelRoutingEnabled, len(group.ModelRouting), ids)
|
||||
}
|
||||
return ids
|
||||
}
|
||||
|
||||
func (s *GatewayService) resolveGatewayGroup(ctx context.Context, groupID *int64) (*Group, *int64, error) {
|
||||
if groupID == nil {
|
||||
return nil, nil, nil
|
||||
@@ -868,6 +1113,116 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
|
||||
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
|
||||
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
|
||||
preferOAuth := platform == PlatformGemini
|
||||
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
|
||||
|
||||
var accounts []Account
|
||||
accountsLoaded := false
|
||||
|
||||
// ============ Model Routing (legacy path): apply before sticky session ============
|
||||
// When load-awareness is disabled (e.g. concurrency service not configured), we still honor model routing
|
||||
// so switching model can switch upstream account within the same sticky session.
|
||||
if len(routingAccountIDs) > 0 {
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] legacy routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v",
|
||||
derefGroupID(groupID), requestedModel, platform, shortSessionHash(sessionHash), routingAccountIDs)
|
||||
}
|
||||
// 1) Sticky session only applies if the bound account is within the routing set.
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
if err == nil && accountID > 0 && containsInt64(routingAccountIDs, accountID) {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
// 检查账号分组归属和平台匹配(确保粘性会话不会跨分组或跨平台)
|
||||
if err == nil && s.isAccountInGroup(account, groupID) && account.Platform == platform && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2) Select an account from the routed candidates.
|
||||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||
if hasForcePlatform && forcePlatform == "" {
|
||||
hasForcePlatform = false
|
||||
}
|
||||
var err error
|
||||
accounts, _, err = s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
accountsLoaded = true
|
||||
|
||||
routingSet := make(map[int64]struct{}, len(routingAccountIDs))
|
||||
for _, id := range routingAccountIDs {
|
||||
if id > 0 {
|
||||
routingSet[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
var selected *Account
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if _, ok := routingSet[acc.ID]; !ok {
|
||||
continue
|
||||
}
|
||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||
continue
|
||||
}
|
||||
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
|
||||
// avoid selecting accounts that were recently rate-limited/overloaded.
|
||||
if !acc.IsSchedulable() {
|
||||
continue
|
||||
}
|
||||
if !acc.IsSchedulableForModel(requestedModel) {
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if selected == nil {
|
||||
selected = acc
|
||||
continue
|
||||
}
|
||||
if acc.Priority < selected.Priority {
|
||||
selected = acc
|
||||
} else if acc.Priority == selected.Priority {
|
||||
switch {
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
|
||||
selected = acc
|
||||
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
|
||||
// keep selected (never used is preferred)
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
||||
if preferOAuth && acc.Type != selected.Type && acc.Type == AccountTypeOAuth {
|
||||
selected = acc
|
||||
}
|
||||
default:
|
||||
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
|
||||
selected = acc
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if selected != nil {
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil {
|
||||
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
||||
}
|
||||
}
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] legacy routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID)
|
||||
}
|
||||
return selected, nil
|
||||
}
|
||||
log.Printf("[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel)
|
||||
}
|
||||
|
||||
// 1. 查询粘性会话
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
@@ -886,13 +1241,16 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
}
|
||||
|
||||
// 2. 获取可调度账号列表(单平台)
|
||||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||
if hasForcePlatform && forcePlatform == "" {
|
||||
hasForcePlatform = false
|
||||
}
|
||||
accounts, _, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
if !accountsLoaded {
|
||||
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
|
||||
if hasForcePlatform && forcePlatform == "" {
|
||||
hasForcePlatform = false
|
||||
}
|
||||
var err error
|
||||
accounts, _, err = s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 按优先级+最久未用选择(考虑模型支持)
|
||||
@@ -958,6 +1316,115 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户
|
||||
func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) {
|
||||
preferOAuth := nativePlatform == PlatformGemini
|
||||
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform)
|
||||
|
||||
var accounts []Account
|
||||
accountsLoaded := false
|
||||
|
||||
// ============ Model Routing (legacy path): apply before sticky session ============
|
||||
if len(routingAccountIDs) > 0 {
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] legacy mixed routed begin: group_id=%v model=%s platform=%s session=%s routed_ids=%v",
|
||||
derefGroupID(groupID), requestedModel, nativePlatform, shortSessionHash(sessionHash), routingAccountIDs)
|
||||
}
|
||||
// 1) Sticky session only applies if the bound account is within the routing set.
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
if err == nil && accountID > 0 && containsInt64(routingAccountIDs, accountID) {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
account, err := s.getSchedulableAccount(ctx, accountID)
|
||||
// 检查账号分组归属和有效性:原生平台直接匹配,antigravity 需要启用混合调度
|
||||
if err == nil && s.isAccountInGroup(account, groupID) && account.IsSchedulableForModel(requestedModel) && (requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
if err := s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL); err != nil {
|
||||
log.Printf("refresh session ttl failed: session=%s err=%v", sessionHash, err)
|
||||
}
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] legacy mixed routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2) Select an account from the routed candidates.
|
||||
var err error
|
||||
accounts, _, err = s.listSchedulableAccounts(ctx, groupID, nativePlatform, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
accountsLoaded = true
|
||||
|
||||
routingSet := make(map[int64]struct{}, len(routingAccountIDs))
|
||||
for _, id := range routingAccountIDs {
|
||||
if id > 0 {
|
||||
routingSet[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
var selected *Account
|
||||
for i := range accounts {
|
||||
acc := &accounts[i]
|
||||
if _, ok := routingSet[acc.ID]; !ok {
|
||||
continue
|
||||
}
|
||||
if _, excluded := excludedIDs[acc.ID]; excluded {
|
||||
continue
|
||||
}
|
||||
// Scheduler snapshots can be temporarily stale; re-check schedulability here to
|
||||
// avoid selecting accounts that were recently rate-limited/overloaded.
|
||||
if !acc.IsSchedulable() {
|
||||
continue
|
||||
}
|
||||
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
|
||||
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
||||
continue
|
||||
}
|
||||
if !acc.IsSchedulableForModel(requestedModel) {
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
if selected == nil {
|
||||
selected = acc
|
||||
continue
|
||||
}
|
||||
if acc.Priority < selected.Priority {
|
||||
selected = acc
|
||||
} else if acc.Priority == selected.Priority {
|
||||
switch {
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt != nil:
|
||||
selected = acc
|
||||
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
|
||||
// keep selected (never used is preferred)
|
||||
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
|
||||
if preferOAuth && acc.Platform == PlatformGemini && selected.Platform == PlatformGemini && acc.Type != selected.Type && acc.Type == AccountTypeOAuth {
|
||||
selected = acc
|
||||
}
|
||||
default:
|
||||
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
|
||||
selected = acc
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if selected != nil {
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
if err := s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.ID, stickySessionTTL); err != nil {
|
||||
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
||||
}
|
||||
}
|
||||
if s.debugModelRoutingEnabled() {
|
||||
log.Printf("[ModelRoutingDebug] legacy mixed routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), selected.ID)
|
||||
}
|
||||
return selected, nil
|
||||
}
|
||||
log.Printf("[ModelRouting] No routed accounts available for model=%s, falling back to normal selection", requestedModel)
|
||||
}
|
||||
|
||||
// 1. 查询粘性会话
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
@@ -979,9 +1446,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
}
|
||||
|
||||
// 2. 获取可调度账号列表
|
||||
accounts, _, err := s.listSchedulableAccounts(ctx, groupID, nativePlatform, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
if !accountsLoaded {
|
||||
var err error
|
||||
accounts, _, err = s.listSchedulableAccounts(ctx, groupID, nativePlatform, false)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query accounts failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package service
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Group struct {
|
||||
ID int64
|
||||
@@ -27,6 +30,12 @@ type Group struct {
|
||||
ClaudeCodeOnly bool
|
||||
FallbackGroupID *int64
|
||||
|
||||
// 模型路由配置
|
||||
// key: 模型匹配模式(支持 * 通配符,如 "claude-opus-*")
|
||||
// value: 优先账号 ID 列表
|
||||
ModelRouting map[string][]int64
|
||||
ModelRoutingEnabled bool
|
||||
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
@@ -90,3 +99,41 @@ func IsGroupContextValid(group *Group) bool {
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// GetRoutingAccountIDs 根据请求模型获取路由账号 ID 列表
|
||||
// 返回匹配的优先账号 ID 列表,如果没有匹配规则则返回 nil
|
||||
func (g *Group) GetRoutingAccountIDs(requestedModel string) []int64 {
|
||||
if !g.ModelRoutingEnabled || len(g.ModelRouting) == 0 || requestedModel == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
// 1. 精确匹配优先
|
||||
if accountIDs, ok := g.ModelRouting[requestedModel]; ok && len(accountIDs) > 0 {
|
||||
return accountIDs
|
||||
}
|
||||
|
||||
// 2. 通配符匹配(前缀匹配)
|
||||
for pattern, accountIDs := range g.ModelRouting {
|
||||
if matchModelPattern(pattern, requestedModel) && len(accountIDs) > 0 {
|
||||
return accountIDs
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// matchModelPattern 检查模型是否匹配模式
|
||||
// 支持 * 通配符,如 "claude-opus-*" 匹配 "claude-opus-4-20250514"
|
||||
func matchModelPattern(pattern, model string) bool {
|
||||
if pattern == model {
|
||||
return true
|
||||
}
|
||||
|
||||
// 处理 * 通配符(仅支持末尾通配符)
|
||||
if strings.HasSuffix(pattern, "*") {
|
||||
prefix := strings.TrimSuffix(pattern, "*")
|
||||
return strings.HasPrefix(model, prefix)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user