fix: use upstream versions of shared files and remove only Sora code
Restore gateway_service.go, setting_handler.go, routes/admin.go, dto/settings.go, group_repo.go, api_key_repo.go, wire_gen.go to upstream/main versions and surgically remove only Sora references. This preserves upstream-only features (RequireOauthOnly, RequirePrivacySet, GroupResolution, etc.) that were missing when using release branch versions.
This commit is contained in:
@@ -604,20 +604,20 @@ func userEntityToService(u *dbent.User) *service.User {
|
||||
return nil
|
||||
}
|
||||
return &service.User{
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
Username: u.Username,
|
||||
Notes: u.Notes,
|
||||
PasswordHash: u.PasswordHash,
|
||||
Role: u.Role,
|
||||
Balance: u.Balance,
|
||||
Concurrency: u.Concurrency,
|
||||
Status: u.Status,
|
||||
TotpSecretEncrypted: u.TotpSecretEncrypted,
|
||||
TotpEnabled: u.TotpEnabled,
|
||||
TotpEnabledAt: u.TotpEnabledAt,
|
||||
CreatedAt: u.CreatedAt,
|
||||
UpdatedAt: u.UpdatedAt,
|
||||
ID: u.ID,
|
||||
Email: u.Email,
|
||||
Username: u.Username,
|
||||
Notes: u.Notes,
|
||||
PasswordHash: u.PasswordHash,
|
||||
Role: u.Role,
|
||||
Balance: u.Balance,
|
||||
Concurrency: u.Concurrency,
|
||||
Status: u.Status,
|
||||
TotpSecretEncrypted: u.TotpSecretEncrypted,
|
||||
TotpEnabled: u.TotpEnabled,
|
||||
TotpEnabledAt: u.TotpEnabledAt,
|
||||
CreatedAt: u.CreatedAt,
|
||||
UpdatedAt: u.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -651,6 +651,8 @@ func groupEntityToService(g *dbent.Group) *service.Group {
|
||||
SupportedModelScopes: g.SupportedModelScopes,
|
||||
SortOrder: g.SortOrder,
|
||||
AllowMessagesDispatch: g.AllowMessagesDispatch,
|
||||
RequireOAuthOnly: g.RequireOauthOnly,
|
||||
RequirePrivacySet: g.RequirePrivacySet,
|
||||
DefaultMappedModel: g.DefaultMappedModel,
|
||||
CreatedAt: g.CreatedAt,
|
||||
UpdatedAt: g.UpdatedAt,
|
||||
|
||||
@@ -56,6 +56,8 @@ func (r *groupRepository) Create(ctx context.Context, groupIn *service.Group) er
|
||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
|
||||
SetMcpXMLInject(groupIn.MCPXMLInject).
|
||||
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
||||
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
|
||||
SetRequirePrivacySet(groupIn.RequirePrivacySet).
|
||||
SetDefaultMappedModel(groupIn.DefaultMappedModel)
|
||||
|
||||
// 设置模型路由配置
|
||||
@@ -120,6 +122,8 @@ func (r *groupRepository) Update(ctx context.Context, groupIn *service.Group) er
|
||||
SetModelRoutingEnabled(groupIn.ModelRoutingEnabled).
|
||||
SetMcpXMLInject(groupIn.MCPXMLInject).
|
||||
SetAllowMessagesDispatch(groupIn.AllowMessagesDispatch).
|
||||
SetRequireOauthOnly(groupIn.RequireOAuthOnly).
|
||||
SetRequirePrivacySet(groupIn.RequirePrivacySet).
|
||||
SetDefaultMappedModel(groupIn.DefaultMappedModel)
|
||||
|
||||
// 显式处理可空字段:nil 需要 clear,非 nil 需要 set。
|
||||
|
||||
@@ -60,6 +60,7 @@ const (
|
||||
claudeMimicDebugInfoKey = "claude_mimic_debug_info"
|
||||
)
|
||||
|
||||
|
||||
// ForceCacheBillingContextKey 强制缓存计费上下文键
|
||||
// 用于粘性会话切换时,将 input_tokens 转为 cache_read_input_tokens 计费
|
||||
type forceCacheBillingKeyType struct{}
|
||||
@@ -503,6 +504,7 @@ type ForwardResult struct {
|
||||
// 图片生成计费字段(图片生成模型使用)
|
||||
ImageCount int // 生成的图片数量
|
||||
ImageSize string // 图片尺寸 "1K", "2K", "4K"
|
||||
|
||||
}
|
||||
|
||||
// UpstreamFailoverError indicates an upstream error that should trigger account failover.
|
||||
@@ -1330,11 +1332,6 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
ctx = s.withWindowCostPrefetch(ctx, accounts)
|
||||
ctx = s.withRPMPrefetch(ctx, accounts)
|
||||
|
||||
// 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用)
|
||||
accountByID := make(map[int64]*Account, len(accounts))
|
||||
for i := range accounts {
|
||||
accountByID[accounts[i].ID] = &accounts[i]
|
||||
}
|
||||
isExcluded := func(accountID int64) bool {
|
||||
if excludedIDs == nil {
|
||||
return false
|
||||
@@ -1343,6 +1340,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
return excluded
|
||||
}
|
||||
|
||||
// 提前构建 accountByID(供 Layer 1 和 Layer 1.5 使用)
|
||||
accountByID := make(map[int64]*Account, len(accounts))
|
||||
for i := range accounts {
|
||||
accountByID[accounts[i].ID] = &accounts[i]
|
||||
}
|
||||
|
||||
// 获取模型路由配置(仅 anthropic 平台)
|
||||
var routingAccountIDs []int64
|
||||
if group != nil && requestedModel != "" && group.Platform == PlatformAnthropic {
|
||||
@@ -1430,19 +1433,24 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) {
|
||||
// 粘性账号在路由列表中,优先使用
|
||||
if stickyAccount, ok := accountByID[stickyAccountID]; ok {
|
||||
if s.isAccountSchedulableForSelection(stickyAccount) &&
|
||||
var stickyCacheMissReason string
|
||||
|
||||
gatePass := s.isAccountSchedulableForSelection(stickyAccount) &&
|
||||
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
|
||||
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) &&
|
||||
s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) &&
|
||||
s.isAccountSchedulableForQuota(stickyAccount) &&
|
||||
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) &&
|
||||
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true)
|
||||
|
||||
s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查
|
||||
rpmPass := gatePass && s.isAccountSchedulableForRPM(ctx, stickyAccount, true)
|
||||
|
||||
if rpmPass { // 粘性会话窗口费用+RPM 检查
|
||||
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位
|
||||
stickyCacheMissReason = "session_limit"
|
||||
// 继续到负载感知选择
|
||||
} else {
|
||||
if s.debugModelRoutingEnabled() {
|
||||
@@ -1456,27 +1464,49 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
}
|
||||
}
|
||||
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
// 会话数量限制检查(等待计划也需要占用会话配额)
|
||||
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
|
||||
// 会话限制已满,继续到负载感知选择
|
||||
if stickyCacheMissReason == "" {
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
// 会话数量限制检查(等待计划也需要占用会话配额)
|
||||
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
|
||||
stickyCacheMissReason = "session_limit"
|
||||
// 会话限制已满,继续到负载感知选择
|
||||
} else {
|
||||
return &AccountSelectionResult{
|
||||
Account: stickyAccount,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: stickyAccountID,
|
||||
MaxConcurrency: stickyAccount.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
} else {
|
||||
return &AccountSelectionResult{
|
||||
Account: stickyAccount,
|
||||
WaitPlan: &AccountWaitPlan{
|
||||
AccountID: stickyAccountID,
|
||||
MaxConcurrency: stickyAccount.Concurrency,
|
||||
Timeout: cfg.StickySessionWaitTimeout,
|
||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||
},
|
||||
}, nil
|
||||
stickyCacheMissReason = "wait_queue_full"
|
||||
}
|
||||
}
|
||||
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
|
||||
} else if !gatePass {
|
||||
stickyCacheMissReason = "gate_check"
|
||||
} else {
|
||||
stickyCacheMissReason = "rpm_red"
|
||||
}
|
||||
|
||||
// 记录粘性缓存未命中的结构化日志
|
||||
if stickyCacheMissReason != "" {
|
||||
baseRPM := stickyAccount.GetBaseRPM()
|
||||
var currentRPM int
|
||||
if count, ok := rpmFromPrefetchContext(ctx, stickyAccount.ID); ok {
|
||||
currentRPM = count
|
||||
}
|
||||
logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=%s account_id=%d session=%s current_rpm=%d base_rpm=%d",
|
||||
stickyCacheMissReason, stickyAccountID, shortSessionHash(sessionHash), currentRPM, baseRPM)
|
||||
}
|
||||
} else {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=account_cleared account_id=%d session=%s current_rpm=0 base_rpm=0",
|
||||
stickyAccountID, shortSessionHash(sessionHash))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1582,6 +1612,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
account, ok := accountByID[accountID]
|
||||
if ok {
|
||||
// 检查账户是否需要清理粘性会话绑定
|
||||
// Check if the account needs sticky session cleanup
|
||||
clearSticky := shouldClearStickySession(account, requestedModel)
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
@@ -1597,6 +1628,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
// 会话数量限制检查
|
||||
// Session count limit check
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
|
||||
} else {
|
||||
@@ -1611,8 +1643,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
|
||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||
// 会话数量限制检查(等待计划也需要占用会话配额)
|
||||
// Session count limit check (wait plan also requires session quota)
|
||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||
// 会话限制已满,继续到 Layer 2
|
||||
// Session limit full, continue to Layer 2
|
||||
} else {
|
||||
return &AccountSelectionResult{
|
||||
Account: account,
|
||||
@@ -2673,6 +2707,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
preferOAuth := platform == PlatformGemini
|
||||
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
|
||||
|
||||
// require_privacy_set: 获取分组信息
|
||||
var schedGroup *Group
|
||||
if groupID != nil && s.groupRepo != nil {
|
||||
schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID)
|
||||
}
|
||||
|
||||
var accounts []Account
|
||||
accountsLoaded := false
|
||||
|
||||
@@ -2696,7 +2736,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) && !s.isStickyAccountUpstreamRestricted(ctx, groupID, account, requestedModel) {
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && account.Platform == platform && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if s.debugModelRoutingEnabled() {
|
||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] legacy routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), accountID)
|
||||
}
|
||||
@@ -2744,6 +2784,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if !s.isAccountSchedulableForSelection(acc) {
|
||||
continue
|
||||
}
|
||||
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
|
||||
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
|
||||
_ = s.accountRepo.SetError(ctx, acc.ID,
|
||||
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
@@ -2849,6 +2895,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
if !s.isAccountSchedulableForSelection(acc) {
|
||||
continue
|
||||
}
|
||||
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
|
||||
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
|
||||
_ = s.accountRepo.SetError(ctx, acc.ID,
|
||||
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
||||
continue
|
||||
}
|
||||
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
|
||||
continue
|
||||
}
|
||||
@@ -2915,6 +2967,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
preferOAuth := nativePlatform == PlatformGemini
|
||||
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform)
|
||||
|
||||
// require_privacy_set: 获取分组信息
|
||||
var schedGroup *Group
|
||||
if groupID != nil && s.groupRepo != nil {
|
||||
schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID)
|
||||
}
|
||||
|
||||
var accounts []Account
|
||||
accountsLoaded := false
|
||||
|
||||
@@ -2982,6 +3040,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if !s.isAccountSchedulableForSelection(acc) {
|
||||
continue
|
||||
}
|
||||
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
|
||||
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
|
||||
_ = s.accountRepo.SetError(ctx, acc.ID,
|
||||
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
||||
continue
|
||||
}
|
||||
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
|
||||
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
||||
continue
|
||||
@@ -3051,7 +3115,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if clearSticky {
|
||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||
}
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) && !s.isStickyAccountUpstreamRestricted(ctx, groupID, account, requestedModel) {
|
||||
if !clearSticky && s.isAccountInGroup(account, groupID) && (requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, account, requestedModel)) && s.isAccountSchedulableForModelSelection(ctx, account, requestedModel) && s.isAccountSchedulableForQuota(account) && s.isAccountSchedulableForWindowCost(ctx, account, true) && s.isAccountSchedulableForRPM(ctx, account, true) {
|
||||
if account.Platform == nativePlatform || (account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()) {
|
||||
return account, nil
|
||||
}
|
||||
@@ -3075,6 +3139,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
ctx = s.withRPMPrefetch(ctx, accounts)
|
||||
|
||||
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
|
||||
// needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查。
|
||||
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
|
||||
var selected *Account
|
||||
for i := range accounts {
|
||||
@@ -3087,6 +3152,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
if !s.isAccountSchedulableForSelection(acc) {
|
||||
continue
|
||||
}
|
||||
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
|
||||
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
|
||||
_ = s.accountRepo.SetError(ctx, acc.ID,
|
||||
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
||||
continue
|
||||
}
|
||||
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
|
||||
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
||||
continue
|
||||
@@ -3254,8 +3325,7 @@ func (s *GatewayService) diagnoseSelectionFailure(
|
||||
return selectionFailureDiagnosis{Category: "excluded"}
|
||||
}
|
||||
if !s.isAccountSchedulableForSelection(acc) {
|
||||
detail := "generic_unschedulable"
|
||||
return selectionFailureDiagnosis{Category: "unschedulable", Detail: detail}
|
||||
return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"}
|
||||
}
|
||||
if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) {
|
||||
return selectionFailureDiagnosis{
|
||||
@@ -3279,7 +3349,6 @@ func (s *GatewayService) diagnoseSelectionFailure(
|
||||
return selectionFailureDiagnosis{Category: "eligible"}
|
||||
}
|
||||
|
||||
// GetAccessToken 获取账号凭证
|
||||
func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool {
|
||||
if acc == nil {
|
||||
return true
|
||||
@@ -3362,10 +3431,6 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
|
||||
_, ok := ResolveBedrockModelID(account, requestedModel)
|
||||
return ok
|
||||
}
|
||||
// OpenAI 透传模式:仅替换认证,允许所有模型
|
||||
if account.Platform == PlatformOpenAI && account.IsOpenAIPassthroughEnabled() {
|
||||
return true
|
||||
}
|
||||
// OAuth/SetupToken 账号使用 Anthropic 标准映射(短ID → 长ID)
|
||||
if account.Platform == PlatformAnthropic && account.Type != AccountTypeAPIKey {
|
||||
requestedModel = claude.NormalizeModelID(requestedModel)
|
||||
@@ -7083,7 +7148,6 @@ func (s *GatewayService) getUserGroupRateMultiplier(ctx context.Context, userID,
|
||||
// RecordUsageInput 记录使用量的输入参数
|
||||
type RecordUsageInput struct {
|
||||
Result *ForwardResult
|
||||
ParsedRequest *ParsedRequest
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
@@ -7242,9 +7306,6 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage
|
||||
cmd.CacheCreationTokens = usageLog.CacheCreationTokens
|
||||
cmd.CacheReadTokens = usageLog.CacheReadTokens
|
||||
cmd.ImageCount = usageLog.ImageCount
|
||||
if usageLog.MediaType != nil {
|
||||
cmd.MediaType = *usageLog.MediaType
|
||||
}
|
||||
if usageLog.ServiceTier != nil {
|
||||
cmd.ServiceTier = *usageLog.ServiceTier
|
||||
}
|
||||
@@ -7395,11 +7456,11 @@ func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usage
|
||||
|
||||
// recordUsageOpts 内部选项,参数化 RecordUsage 与 RecordUsageWithLongContext 的差异点。
|
||||
type recordUsageOpts struct {
|
||||
// ParsedRequest(可选,仅 Claude 路径传入)
|
||||
// Claude Max 策略所需的 ParsedRequest(可选,仅 Claude 路径传入)
|
||||
ParsedRequest *ParsedRequest
|
||||
|
||||
// EnableClaudePath 启用 Claude 路径特有逻辑:
|
||||
// - MediaType 字段写入使用日志
|
||||
// - Claude Max 缓存计费策略
|
||||
EnableClaudePath bool
|
||||
|
||||
// 长上下文计费(仅 Gemini 路径需要)
|
||||
@@ -7424,7 +7485,6 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
APIKeyService: input.APIKeyService,
|
||||
ChannelUsageFields: input.ChannelUsageFields,
|
||||
}, &recordUsageOpts{
|
||||
ParsedRequest: input.ParsedRequest,
|
||||
EnableClaudePath: true,
|
||||
})
|
||||
}
|
||||
@@ -7490,6 +7550,7 @@ type recordUsageCoreInput struct {
|
||||
|
||||
// recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。
|
||||
// opts 中的字段控制两者之间的差异行为:
|
||||
// - ParsedRequest != nil → 启用 Claude Max 缓存计费策略
|
||||
// - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext
|
||||
func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error {
|
||||
result := input.Result
|
||||
@@ -7748,13 +7809,12 @@ func (s *GatewayService) buildRecordUsageLog(
|
||||
RateMultiplier: multiplier,
|
||||
AccountRateMultiplier: &accountRateMultiplier,
|
||||
BillingType: billingType,
|
||||
BillingMode: resolveBillingMode(opts, result, cost),
|
||||
BillingMode: resolveBillingMode(result, cost),
|
||||
Stream: result.Stream,
|
||||
DurationMs: &durationMs,
|
||||
FirstTokenMs: result.FirstTokenMs,
|
||||
ImageCount: result.ImageCount,
|
||||
ImageSize: optionalTrimmedStringPtr(result.ImageSize),
|
||||
MediaType: resolveMediaType(opts, result),
|
||||
CacheTTLOverridden: cacheTTLOverridden,
|
||||
ChannelID: optionalInt64Ptr(input.ChannelID),
|
||||
ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain),
|
||||
@@ -7778,7 +7838,7 @@ func (s *GatewayService) buildRecordUsageLog(
|
||||
}
|
||||
|
||||
// resolveBillingMode 根据计费结果和请求类型确定计费模式。
|
||||
func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *CostBreakdown) *string {
|
||||
func resolveBillingMode(result *ForwardResult, cost *CostBreakdown) *string {
|
||||
var mode string
|
||||
switch {
|
||||
case cost != nil && cost.BillingMode != "":
|
||||
@@ -7791,10 +7851,6 @@ func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *Cost
|
||||
return &mode
|
||||
}
|
||||
|
||||
func resolveMediaType(opts *recordUsageOpts, result *ForwardResult) *string {
|
||||
return nil
|
||||
}
|
||||
|
||||
func optionalSubscriptionID(subscription *UserSubscription) *int64 {
|
||||
if subscription != nil {
|
||||
return &subscription.ID
|
||||
@@ -7899,19 +7955,6 @@ func (s *GatewayService) needsUpstreamChannelRestrictionCheck(ctx context.Contex
|
||||
return ch.BillingModelSource == BillingModelSourceUpstream
|
||||
}
|
||||
|
||||
// isStickyAccountUpstreamRestricted 检查粘性会话命中的账号是否受 upstream 渠道限制。
|
||||
// 合并 needsUpstreamChannelRestrictionCheck + isUpstreamModelRestrictedByChannel 两步调用,
|
||||
// 供 sticky session 条件链使用,避免内联多个函数调用导致行过长。
|
||||
func (s *GatewayService) isStickyAccountUpstreamRestricted(ctx context.Context, groupID *int64, account *Account, requestedModel string) bool {
|
||||
if groupID == nil {
|
||||
return false
|
||||
}
|
||||
if !s.needsUpstreamChannelRestrictionCheck(ctx, groupID) {
|
||||
return false
|
||||
}
|
||||
return s.isUpstreamModelRestrictedByChannel(ctx, *groupID, account, requestedModel)
|
||||
}
|
||||
|
||||
// ForwardCountTokens 转发 count_tokens 请求到上游 API
|
||||
// 特点:不记录使用量、仅支持非流式响应
|
||||
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {
|
||||
|
||||
Reference in New Issue
Block a user