fix: use upstream versions of shared files and remove only Sora code
Restore gateway_service.go, setting_handler.go, routes/admin.go, dto/settings.go, group_repo.go, api_key_repo.go, wire_gen.go to upstream/main versions and surgically remove only Sora references. This preserves upstream-only features (RequireOauthOnly, RequirePrivacySet, GroupResolution, etc.) that were missing when using release branch versions.
This commit is contained in:
@@ -604,20 +604,20 @@ func userEntityToService(u *dbent.User) *service.User {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return &service.User{
|
return &service.User{
|
||||||
ID: u.ID,
|
ID: u.ID,
|
||||||
Email: u.Email,
|
Email: u.Email,
|
||||||
Username: u.Username,
|
Username: u.Username,
|
||||||
Notes: u.Notes,
|
Notes: u.Notes,
|
||||||
PasswordHash: u.PasswordHash,
|
PasswordHash: u.PasswordHash,
|
||||||
Role: u.Role,
|
Role: u.Role,
|
||||||
Balance: u.Balance,
|
Balance: u.Balance,
|
||||||
Concurrency: u.Concurrency,
|
Concurrency: u.Concurrency,
|
||||||
Status: u.Status,
|
Status: u.Status,
|
||||||
TotpSecretEncrypted: u.TotpSecretEncrypted,
|
TotpSecretEncrypted: u.TotpSecretEncrypted,
|
||||||
TotpEnabled: u.TotpEnabled,
|
TotpEnabled: u.TotpEnabled,
|
||||||
TotpEnabledAt: u.TotpEnabledAt,
|
TotpEnabledAt: u.TotpEnabledAt,
|
||||||
CreatedAt: u.CreatedAt,
|
CreatedAt: u.CreatedAt,
|
||||||
UpdatedAt: u.UpdatedAt,
|
UpdatedAt: u.UpdatedAt,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -651,6 +651,8 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
|||||||
SupportedModelScopes: g.SupportedModelScopes,
|
SupportedModelScopes: g.SupportedModelScopes,
|
||||||
SortOrder: g.SortOrder,
|
SortOrder: g.SortOrder,
|
||||||
AllowMessagesDispatch: g.AllowMessagesDispatch,
|
AllowMessagesDispatch: g.AllowMessagesDispatch,
|
||||||
|
RequireOAuthOnly: g.RequireOauthOnly,
|
||||||
|
RequirePrivacySet: g.RequirePrivacySet,
|
||||||
DefaultMappedModel: g.DefaultMappedModel,
|
DefaultMappedModel: g.DefaultMappedModel,
|
||||||
CreatedAt: g.CreatedAt,
|
CreatedAt: g.CreatedAt,
|
||||||
UpdatedAt: g.UpdatedAt,
|
UpdatedAt: g.UpdatedAt,
|
||||||
|
|||||||
@@ -56,6 +56,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
|||||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
|
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
|
||||||
SetMcpXMLInject(groupIn.MCPXMLInject).
|
SetMcpXMLInject(groupIn.MCPXMLInject).
|
||||||
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
||||||
|
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
|
||||||
|
SetRequirePrivacySet(groupIn.RequirePrivacySet).
|
||||||
SetDefaultMappedModel(groupIn.DefaultMappedModel)
|
SetDefaultMappedModel(groupIn.DefaultMappedModel)
|
||||||
|
|
||||||
// 设置模型路由配置
|
// 设置模型路由配置
|
||||||
@@ -120,6 +122,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
|||||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
|
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
|
||||||
SetMcpXMLInject(groupIn.MCPXMLInject).
|
SetMcpXMLInject(groupIn.MCPXMLInject).
|
||||||
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
||||||
|
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
|
||||||
|
SetRequirePrivacySet(groupIn.RequirePrivacySet).
|
||||||
SetDefaultMappedModel(groupIn.DefaultMappedModel)
|
SetDefaultMappedModel(groupIn.DefaultMappedModel)
|
||||||
|
|
||||||
// 显式处理可空字段:nil 需要 clear,非 nil 需要 set。
|
// 显式处理可空字段:nil 需要 clear,非 nil 需要 set。
|
||||||
|
|||||||
@@ -60,6 +60,7 @@ const (
|
|||||||
claudeMimicDebugInfoKey = "claude_mimic_debug_info"
|
claudeMimicDebugInfoKey = "claude_mimic_debug_info"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
// ForceCacheBillingContextKey 强制缓存计费上下文键
|
// ForceCacheBillingContextKey 强制缓存计费上下文键
|
||||||
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
|
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
|
||||||
type forceCacheBillingKeyType struct{}
|
type forceCacheBillingKeyType struct{}
|
||||||
@@ -503,6 +504,7 @@ type ForwardResult struct {
|
|||||||
// 图片生成计费字段(图片生成模型使用)
|
// 图片生成计费字段(图片生成模型使用)
|
||||||
ImageCount int // 生成的图片数量
|
ImageCount int // 生成的图片数量
|
||||||
ImageSize string // 图片尺寸 "1K", "2K", "4K"
|
ImageSize string // 图片尺寸 "1K", "2K", "4K"
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
|
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
|
||||||
@@ -1330,11 +1332,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
ctx = s.withWindowCostPrefetch(ctx, accounts)
|
ctx = s.withWindowCostPrefetch(ctx, accounts)
|
||||||
ctx = s.withRPMPrefetch(ctx, accounts)
|
ctx = s.withRPMPrefetch(ctx, accounts)
|
||||||
|
|
||||||
// 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用)
|
|
||||||
accountByID := make(map[int64]*Account, len(accounts))
|
|
||||||
for i := range accounts {
|
|
||||||
accountByID[accounts[i].ID] = &accounts[i]
|
|
||||||
}
|
|
||||||
isExcluded := func(accountID int64) bool {
|
isExcluded := func(accountID int64) bool {
|
||||||
if excludedIDs == nil {
|
if excludedIDs == nil {
|
||||||
return false
|
return false
|
||||||
@@ -1343,6 +1340,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
return excluded
|
return excluded
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用)
|
||||||
|
accountByID := make(map[int64]*Account, len(accounts))
|
||||||
|
for i := range accounts {
|
||||||
|
accountByID[accounts[i].ID] = &accounts[i]
|
||||||
|
}
|
||||||
|
|
||||||
// 获取模型路由配置(仅 anthropic 平台)
|
// 获取模型路由配置(仅 anthropic 平台)
|
||||||
var routingAccountIDs []int64
|
var routingAccountIDs []int64
|
||||||
if group != nil && requestedModel != "" && group.Platform == PlatformAnthropic {
|
if group != nil && requestedModel != "" && group.Platform == PlatformAnthropic {
|
||||||
@@ -1430,19 +1433,24 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) {
|
if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) {
|
||||||
// 粘性账号在路由列表中,优先使用
|
// 粘性账号在路由列表中,优先使用
|
||||||
if stickyAccount, ok := accountByID[stickyAccountID]; ok {
|
if stickyAccount, ok := accountByID[stickyAccountID]; ok {
|
||||||
if s.isAccountSchedulableForSelection(stickyAccount) &&
|
var stickyCacheMissReason string
|
||||||
|
|
||||||
|
gatePass := s.isAccountSchedulableForSelection(stickyAccount) &&
|
||||||
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
|
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
|
||||||
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) &&
|
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) &&
|
||||||
s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) &&
|
s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) &&
|
||||||
s.isAccountSchedulableForQuota(stickyAccount) &&
|
s.isAccountSchedulableForQuota(stickyAccount) &&
|
||||||
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) &&
|
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true)
|
||||||
|
|
||||||
s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查
|
rpmPass := gatePass && s.isAccountSchedulableForRPM(ctx, stickyAccount, true)
|
||||||
|
|
||||||
|
if rpmPass { // 粘性会话窗口费用+RPM 检查
|
||||||
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
|
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
|
||||||
if err == nil && result.Acquired {
|
if err == nil && result.Acquired {
|
||||||
// 会话数量限制检查
|
// 会话数量限制检查
|
||||||
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
|
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
|
||||||
result.ReleaseFunc() // 释放槽位
|
result.ReleaseFunc() // 释放槽位
|
||||||
|
stickyCacheMissReason = "session_limit"
|
||||||
// 继续到负载感知选择
|
// 继续到负载感知选择
|
||||||
} else {
|
} else {
|
||||||
if s.debugModelRoutingEnabled() {
|
if s.debugModelRoutingEnabled() {
|
||||||
@@ -1456,27 +1464,49 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
|
if stickyCacheMissReason == "" {
|
||||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
|
||||||
// 会话数量限制检查(等待计划也需要占用会话配额)
|
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||||
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
|
// 会话数量限制检查(等待计划也需要占用会话配额)
|
||||||
// 会话限制已满,继续到负载感知选择
|
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
|
||||||
|
stickyCacheMissReason = "session_limit"
|
||||||
|
// 会话限制已满,继续到负载感知选择
|
||||||
|
} else {
|
||||||
|
return &AccountSelectionResult{
|
||||||
|
Account: stickyAccount,
|
||||||
|
WaitPlan: &AccountWaitPlan{
|
||||||
|
AccountID: stickyAccountID,
|
||||||
|
MaxConcurrency: stickyAccount.Concurrency,
|
||||||
|
Timeout: cfg.StickySessionWaitTimeout,
|
||||||
|
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
return &AccountSelectionResult{
|
stickyCacheMissReason = "wait_queue_full"
|
||||||
Account: stickyAccount,
|
|
||||||
WaitPlan: &AccountWaitPlan{
|
|
||||||
AccountID: stickyAccountID,
|
|
||||||
MaxConcurrency: stickyAccount.Concurrency,
|
|
||||||
Timeout: cfg.StickySessionWaitTimeout,
|
|
||||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
|
||||||
},
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
|
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
|
||||||
|
} else if !gatePass {
|
||||||
|
stickyCacheMissReason = "gate_check"
|
||||||
|
} else {
|
||||||
|
stickyCacheMissReason = "rpm_red"
|
||||||
|
}
|
||||||
|
|
||||||
|
// 记录粘性缓存未命中的结构化日志
|
||||||
|
if stickyCacheMissReason != "" {
|
||||||
|
baseRPM := stickyAccount.GetBaseRPM()
|
||||||
|
var currentRPM int
|
||||||
|
if count, ok := rpmFromPrefetchContext(ctx, stickyAccount.ID); ok {
|
||||||
|
currentRPM = count
|
||||||
|
}
|
||||||
|
logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=%s account_id=%d session=%s current_rpm=%d base_rpm=%d",
|
||||||
|
stickyCacheMissReason, stickyAccountID, shortSessionHash(sessionHash), currentRPM, baseRPM)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||||
|
logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=account_cleared account_id=%d session=%s current_rpm=0 base_rpm=0",
|
||||||
|
stickyAccountID, shortSessionHash(sessionHash))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1582,6 +1612,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
account, ok := accountByID[accountID]
|
account, ok := accountByID[accountID]
|
||||||
if ok {
|
if ok {
|
||||||
// 检查账户是否需要清理粘性会话绑定
|
// 检查账户是否需要清理粘性会话绑定
|
||||||
|
// Check if the account needs sticky session cleanup
|
||||||
clearSticky := shouldClearStickySession(account, requestedModel)
|
clearSticky := shouldClearStickySession(account, requestedModel)
|
||||||
if clearSticky {
|
if clearSticky {
|
||||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||||
@@ -1597,6 +1628,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||||
if err == nil && result.Acquired {
|
if err == nil && result.Acquired {
|
||||||
// 会话数量限制检查
|
// 会话数量限制检查
|
||||||
|
// Session count limit check
|
||||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||||
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
|
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
|
||||||
} else {
|
} else {
|
||||||
@@ -1611,8 +1643,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||||
// 会话数量限制检查(等待计划也需要占用会话配额)
|
// 会话数量限制检查(等待计划也需要占用会话配额)
|
||||||
|
// Session count limit check (wait plan also requires session quota)
|
||||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||||
// 会话限制已满,继续到 Layer 2
|
// 会话限制已满,继续到 Layer 2
|
||||||
|
// Session limit full, continue to Layer 2
|
||||||
} else {
|
} else {
|
||||||
return &AccountSelectionResult{
|
return &AccountSelectionResult{
|
||||||
Account: account,
|
Account: account,
|
||||||
@@ -2673,6 +2707,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
|||||||
preferOAuth := platform == PlatformGemini
|
preferOAuth := platform == PlatformGemini
|
||||||
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
|
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
|
||||||
|
|
||||||
|
// require_privacy_set: 获取分组信息
|
||||||
|
var schedGroup *Group
|
||||||
|
if groupID != nil && s.groupRepo != nil {
|
||||||
|
schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID)
|
||||||
|
}
|
||||||
|
|
||||||
var accounts []Account
|
var accounts []Account
|
||||||
accountsLoaded := false
|
accountsLoaded := false
|
||||||
|
|
||||||
@@ -2696,7 +2736,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
|||||||
if clearSticky {
|
if clearSticky {
|
||||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||||
}
|
}
|
||||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) && !s.isStickyAccountUpstreamRestricted(ctx, groupID, account, requestedModel) {
|
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||||
if s.debugModelRoutingEnabled() {
|
if s.debugModelRoutingEnabled() {
|
||||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||||
}
|
}
|
||||||
@@ -2744,6 +2784,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
|||||||
if !s.isAccountSchedulableForSelection(acc) {
|
if !s.isAccountSchedulableForSelection(acc) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
|
||||||
|
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
|
||||||
|
_ = s.accountRepo.SetError(ctx, acc.ID,
|
||||||
|
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
||||||
|
continue
|
||||||
|
}
|
||||||
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
|
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -2849,6 +2895,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
|||||||
if !s.isAccountSchedulableForSelection(acc) {
|
if !s.isAccountSchedulableForSelection(acc) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
|
||||||
|
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
|
||||||
|
_ = s.accountRepo.SetError(ctx, acc.ID,
|
||||||
|
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
||||||
|
continue
|
||||||
|
}
|
||||||
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
|
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -2915,6 +2967,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
|||||||
preferOAuth := nativePlatform == PlatformGemini
|
preferOAuth := nativePlatform == PlatformGemini
|
||||||
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform)
|
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform)
|
||||||
|
|
||||||
|
// require_privacy_set: 获取分组信息
|
||||||
|
var schedGroup *Group
|
||||||
|
if groupID != nil && s.groupRepo != nil {
|
||||||
|
schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID)
|
||||||
|
}
|
||||||
|
|
||||||
var accounts []Account
|
var accounts []Account
|
||||||
accountsLoaded := false
|
accountsLoaded := false
|
||||||
|
|
||||||
@@ -2982,6 +3040,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
|||||||
if !s.isAccountSchedulableForSelection(acc) {
|
if !s.isAccountSchedulableForSelection(acc) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
|
||||||
|
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
|
||||||
|
_ = s.accountRepo.SetError(ctx, acc.ID,
|
||||||
|
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
||||||
|
continue
|
||||||
|
}
|
||||||
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
|
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
|
||||||
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
||||||
continue
|
continue
|
||||||
@@ -3051,7 +3115,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
|||||||
if clearSticky {
|
if clearSticky {
|
||||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||||
}
|
}
|
||||||
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) && !s.isStickyAccountUpstreamRestricted(ctx, groupID, account, requestedModel) {
|
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||||
return account, nil
|
return account, nil
|
||||||
}
|
}
|
||||||
@@ -3075,6 +3139,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
|||||||
ctx = s.withRPMPrefetch(ctx, accounts)
|
ctx = s.withRPMPrefetch(ctx, accounts)
|
||||||
|
|
||||||
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
|
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
|
||||||
|
// needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查。
|
||||||
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
|
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
|
||||||
var selected *Account
|
var selected *Account
|
||||||
for i := range accounts {
|
for i := range accounts {
|
||||||
@@ -3087,6 +3152,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
|||||||
if !s.isAccountSchedulableForSelection(acc) {
|
if !s.isAccountSchedulableForSelection(acc) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
|
||||||
|
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
|
||||||
|
_ = s.accountRepo.SetError(ctx, acc.ID,
|
||||||
|
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
||||||
|
continue
|
||||||
|
}
|
||||||
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
|
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
|
||||||
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
||||||
continue
|
continue
|
||||||
@@ -3254,8 +3325,7 @@ func (s *GatewayService) diagnoseSelectionFailure(
|
|||||||
return selectionFailureDiagnosis{Category: "excluded"}
|
return selectionFailureDiagnosis{Category: "excluded"}
|
||||||
}
|
}
|
||||||
if !s.isAccountSchedulableForSelection(acc) {
|
if !s.isAccountSchedulableForSelection(acc) {
|
||||||
detail := "generic_unschedulable"
|
return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"}
|
||||||
return selectionFailureDiagnosis{Category: "unschedulable", Detail: detail}
|
|
||||||
}
|
}
|
||||||
if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) {
|
if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) {
|
||||||
return selectionFailureDiagnosis{
|
return selectionFailureDiagnosis{
|
||||||
@@ -3279,7 +3349,6 @@ func (s *GatewayService) diagnoseSelectionFailure(
|
|||||||
return selectionFailureDiagnosis{Category: "eligible"}
|
return selectionFailureDiagnosis{Category: "eligible"}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccessToken 获取账号凭证
|
|
||||||
func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool {
|
func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool {
|
||||||
if acc == nil {
|
if acc == nil {
|
||||||
return true
|
return true
|
||||||
@@ -3362,10 +3431,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
|
|||||||
_, ok := ResolveBedrockModelID(account, requestedModel)
|
_, ok := ResolveBedrockModelID(account, requestedModel)
|
||||||
return ok
|
return ok
|
||||||
}
|
}
|
||||||
// OpenAI 透传模式:仅替换认证,允许所有模型
|
|
||||||
if account.Platform == PlatformOpenAI && account.IsOpenAIPassthroughEnabled() {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
|
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
|
||||||
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||||
requestedModel = claude.NormalizeModelID(requestedModel)
|
requestedModel = claude.NormalizeModelID(requestedModel)
|
||||||
@@ -7083,7 +7148,6 @@ func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID,
|
|||||||
// RecordUsageInput 记录使用量的输入参数
|
// RecordUsageInput 记录使用量的输入参数
|
||||||
type RecordUsageInput struct {
|
type RecordUsageInput struct {
|
||||||
Result *ForwardResult
|
Result *ForwardResult
|
||||||
ParsedRequest *ParsedRequest
|
|
||||||
APIKey *APIKey
|
APIKey *APIKey
|
||||||
User *User
|
User *User
|
||||||
Account *Account
|
Account *Account
|
||||||
@@ -7242,9 +7306,6 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage
|
|||||||
cmd.CacheCreationTokens = usageLog.CacheCreationTokens
|
cmd.CacheCreationTokens = usageLog.CacheCreationTokens
|
||||||
cmd.CacheReadTokens = usageLog.CacheReadTokens
|
cmd.CacheReadTokens = usageLog.CacheReadTokens
|
||||||
cmd.ImageCount = usageLog.ImageCount
|
cmd.ImageCount = usageLog.ImageCount
|
||||||
if usageLog.MediaType != nil {
|
|
||||||
cmd.MediaType = *usageLog.MediaType
|
|
||||||
}
|
|
||||||
if usageLog.ServiceTier != nil {
|
if usageLog.ServiceTier != nil {
|
||||||
cmd.ServiceTier = *usageLog.ServiceTier
|
cmd.ServiceTier = *usageLog.ServiceTier
|
||||||
}
|
}
|
||||||
@@ -7395,11 +7456,11 @@ func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usage
|
|||||||
|
|
||||||
// recordUsageOpts 内部选项,参数化 RecordUsage 与 RecordUsageWithLongContext 的差异点。
|
// recordUsageOpts 内部选项,参数化 RecordUsage 与 RecordUsageWithLongContext 的差异点。
|
||||||
type recordUsageOpts struct {
|
type recordUsageOpts struct {
|
||||||
// ParsedRequest(可选,仅 Claude 路径传入)
|
// Claude Max 策略所需的 ParsedRequest(可选,仅 Claude 路径传入)
|
||||||
ParsedRequest *ParsedRequest
|
ParsedRequest *ParsedRequest
|
||||||
|
|
||||||
// EnableClaudePath 启用 Claude 路径特有逻辑:
|
// EnableClaudePath 启用 Claude 路径特有逻辑:
|
||||||
// - MediaType 字段写入使用日志
|
// - Claude Max 缓存计费策略
|
||||||
EnableClaudePath bool
|
EnableClaudePath bool
|
||||||
|
|
||||||
// 长上下文计费(仅 Gemini 路径需要)
|
// 长上下文计费(仅 Gemini 路径需要)
|
||||||
@@ -7424,7 +7485,6 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
APIKeyService: input.APIKeyService,
|
APIKeyService: input.APIKeyService,
|
||||||
ChannelUsageFields: input.ChannelUsageFields,
|
ChannelUsageFields: input.ChannelUsageFields,
|
||||||
}, &recordUsageOpts{
|
}, &recordUsageOpts{
|
||||||
ParsedRequest: input.ParsedRequest,
|
|
||||||
EnableClaudePath: true,
|
EnableClaudePath: true,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -7490,6 +7550,7 @@ type recordUsageCoreInput struct {
|
|||||||
|
|
||||||
// recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。
|
// recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。
|
||||||
// opts 中的字段控制两者之间的差异行为:
|
// opts 中的字段控制两者之间的差异行为:
|
||||||
|
// - ParsedRequest != nil → 启用 Claude Max 缓存计费策略
|
||||||
// - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext
|
// - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext
|
||||||
func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error {
|
func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error {
|
||||||
result := input.Result
|
result := input.Result
|
||||||
@@ -7748,13 +7809,12 @@ func (s *GatewayService) buildRecordUsageLog(
|
|||||||
RateMultiplier: multiplier,
|
RateMultiplier: multiplier,
|
||||||
AccountRateMultiplier: &accountRateMultiplier,
|
AccountRateMultiplier: &accountRateMultiplier,
|
||||||
BillingType: billingType,
|
BillingType: billingType,
|
||||||
BillingMode: resolveBillingMode(opts, result, cost),
|
BillingMode: resolveBillingMode(result, cost),
|
||||||
Stream: result.Stream,
|
Stream: result.Stream,
|
||||||
DurationMs: &durationMs,
|
DurationMs: &durationMs,
|
||||||
FirstTokenMs: result.FirstTokenMs,
|
FirstTokenMs: result.FirstTokenMs,
|
||||||
ImageCount: result.ImageCount,
|
ImageCount: result.ImageCount,
|
||||||
ImageSize: optionalTrimmedStringPtr(result.ImageSize),
|
ImageSize: optionalTrimmedStringPtr(result.ImageSize),
|
||||||
MediaType: resolveMediaType(opts, result),
|
|
||||||
CacheTTLOverridden: cacheTTLOverridden,
|
CacheTTLOverridden: cacheTTLOverridden,
|
||||||
ChannelID: optionalInt64Ptr(input.ChannelID),
|
ChannelID: optionalInt64Ptr(input.ChannelID),
|
||||||
ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain),
|
ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain),
|
||||||
@@ -7778,7 +7838,7 @@ func (s *GatewayService) buildRecordUsageLog(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// resolveBillingMode 根据计费结果和请求类型确定计费模式。
|
// resolveBillingMode 根据计费结果和请求类型确定计费模式。
|
||||||
func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *CostBreakdown) *string {
|
func resolveBillingMode(result *ForwardResult, cost *CostBreakdown) *string {
|
||||||
var mode string
|
var mode string
|
||||||
switch {
|
switch {
|
||||||
case cost != nil && cost.BillingMode != "":
|
case cost != nil && cost.BillingMode != "":
|
||||||
@@ -7791,10 +7851,6 @@ func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *Cost
|
|||||||
return &mode
|
return &mode
|
||||||
}
|
}
|
||||||
|
|
||||||
func resolveMediaType(opts *recordUsageOpts, result *ForwardResult) *string {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func optionalSubscriptionID(subscription *UserSubscription) *int64 {
|
func optionalSubscriptionID(subscription *UserSubscription) *int64 {
|
||||||
if subscription != nil {
|
if subscription != nil {
|
||||||
return &subscription.ID
|
return &subscription.ID
|
||||||
@@ -7899,19 +7955,6 @@ func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Contex
|
|||||||
return ch.BillingModelSource == BillingModelSourceUpstream
|
return ch.BillingModelSource == BillingModelSourceUpstream
|
||||||
}
|
}
|
||||||
|
|
||||||
// isStickyAccountUpstreamRestricted 检查粘性会话命中的账号是否受 upstream 渠道限制。
|
|
||||||
// 合并 needsUpstreamChannelRestrictionCheck + isUpstreamModelRestrictedByChannel 两步调用,
|
|
||||||
// 供 sticky session 条件链使用,避免内联多个函数调用导致行过长。
|
|
||||||
func (s *GatewayService) isStickyAccountUpstreamRestricted(ctx context.Context, groupID *int64, account *Account, requestedModel string) bool {
|
|
||||||
if groupID == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if !s.needsUpstreamChannelRestrictionCheck(ctx, groupID) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ForwardCountTokens 转发 count_tokens 请求到上游 API
|
// ForwardCountTokens 转发 count_tokens 请求到上游 API
|
||||||
// 特点:不记录使用量、仅支持非流式响应
|
// 特点:不记录使用量、仅支持非流式响应
|
||||||
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {
|
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {
|
||||||
|
|||||||
Reference in New Issue
Block a user