fix: round-2 audit fixes — security, code quality, and UI improvements
Security (HIGH): - Normalize all Redis cache keys to lowercase (verifyCode, passwordReset) - Fix verify code TTL renewal on failed attempts: use remaining TTL via ExpiresAt field instead of resetting to full 15-minute window - Add 3 missing fields to diffSettings audit log (promo_code, invitation_code, custom_endpoints) Code quality (MEDIUM): - Extract filterVerifiedEmails shared helper (balance_notify_service.go) - Add Pricing array non-empty validation for channel pricing rules - Add platform token semantics comment in gateway_service.go - Complete validatePlanPatch test coverage (+10 test cases) - Replace string types with QuotaThresholdType/QuotaResetMode across frontend - Remove duplicate getPlatformTextColor/getRateBadgeClass in ChannelsView - Return EMAIL_NOT_FOUND error on RemoveNotifyEmail miss UI improvements: - Reorder cost tooltip: user billing above separator, account billing below - Add NaN guard to accountBilled function - Move timezone selector inline into reset-mode row (no longer standalone)
This commit is contained in:
@@ -357,6 +357,11 @@ func (h *ChannelHandler) Create(c *gin.Context) {
|
|||||||
fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1)))
|
fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1)))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if len(r.Pricing) == 0 {
|
||||||
|
response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_PRICING",
|
||||||
|
fmt.Sprintf("pricing rule #%d must have at least one pricing entry", i+1)))
|
||||||
|
return
|
||||||
|
}
|
||||||
rule := accountStatsPricingRuleRequestToService(r)
|
rule := accountStatsPricingRuleRequestToService(r)
|
||||||
rule.SortOrder = i
|
rule.SortOrder = i
|
||||||
statsRules = append(statsRules, rule)
|
statsRules = append(statsRules, rule)
|
||||||
@@ -420,6 +425,11 @@ func (h *ChannelHandler) Update(c *gin.Context) {
|
|||||||
fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1)))
|
fmt.Sprintf("pricing rule #%d must have at least one group or account", i+1)))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if len(r.Pricing) == 0 {
|
||||||
|
response.ErrorFrom(c, infraerrors.BadRequest("PRICING_RULE_EMPTY_PRICING",
|
||||||
|
fmt.Sprintf("pricing rule #%d must have at least one pricing entry", i+1)))
|
||||||
|
return
|
||||||
|
}
|
||||||
rule := accountStatsPricingRuleRequestToService(r)
|
rule := accountStatsPricingRuleRequestToService(r)
|
||||||
rule.SortOrder = i
|
rule.SortOrder = i
|
||||||
statsRules = append(statsRules, rule)
|
statsRules = append(statsRules, rule)
|
||||||
|
|||||||
@@ -1138,6 +1138,12 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
|||||||
if !equalStringSlice(before.RegistrationEmailSuffixWhitelist, after.RegistrationEmailSuffixWhitelist) {
|
if !equalStringSlice(before.RegistrationEmailSuffixWhitelist, after.RegistrationEmailSuffixWhitelist) {
|
||||||
changed = append(changed, "registration_email_suffix_whitelist")
|
changed = append(changed, "registration_email_suffix_whitelist")
|
||||||
}
|
}
|
||||||
|
if before.PromoCodeEnabled != after.PromoCodeEnabled {
|
||||||
|
changed = append(changed, "promo_code_enabled")
|
||||||
|
}
|
||||||
|
if before.InvitationCodeEnabled != after.InvitationCodeEnabled {
|
||||||
|
changed = append(changed, "invitation_code_enabled")
|
||||||
|
}
|
||||||
if before.PasswordResetEnabled != after.PasswordResetEnabled {
|
if before.PasswordResetEnabled != after.PasswordResetEnabled {
|
||||||
changed = append(changed, "password_reset_enabled")
|
changed = append(changed, "password_reset_enabled")
|
||||||
}
|
}
|
||||||
@@ -1348,6 +1354,9 @@ func diffSettings(before *service.SystemSettings, after *service.SystemSettings,
|
|||||||
if before.CustomMenuItems != after.CustomMenuItems {
|
if before.CustomMenuItems != after.CustomMenuItems {
|
||||||
changed = append(changed, "custom_menu_items")
|
changed = append(changed, "custom_menu_items")
|
||||||
}
|
}
|
||||||
|
if before.CustomEndpoints != after.CustomEndpoints {
|
||||||
|
changed = append(changed, "custom_endpoints")
|
||||||
|
}
|
||||||
if before.EnableFingerprintUnification != after.EnableFingerprintUnification {
|
if before.EnableFingerprintUnification != after.EnableFingerprintUnification {
|
||||||
changed = append(changed, "enable_fingerprint_unification")
|
changed = append(changed, "enable_fingerprint_unification")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,8 +20,9 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// verifyCodeKey generates the Redis key for email verification code.
|
// verifyCodeKey generates the Redis key for email verification code.
|
||||||
|
// Email is lowercased for case-insensitive consistency.
|
||||||
func verifyCodeKey(email string) string {
|
func verifyCodeKey(email string) string {
|
||||||
return verifyCodeKeyPrefix + email
|
return verifyCodeKeyPrefix + strings.ToLower(email)
|
||||||
}
|
}
|
||||||
|
|
||||||
// notifyVerifyKey generates the Redis key for notify email verification code.
|
// notifyVerifyKey generates the Redis key for notify email verification code.
|
||||||
@@ -33,12 +34,12 @@ func notifyVerifyKey(email string) string {
|
|||||||
|
|
||||||
// passwordResetKey generates the Redis key for password reset token.
|
// passwordResetKey generates the Redis key for password reset token.
|
||||||
func passwordResetKey(email string) string {
|
func passwordResetKey(email string) string {
|
||||||
return passwordResetKeyPrefix + email
|
return passwordResetKeyPrefix + strings.ToLower(email)
|
||||||
}
|
}
|
||||||
|
|
||||||
// passwordResetSentAtKey generates the Redis key for password reset email sent timestamp.
|
// passwordResetSentAtKey generates the Redis key for password reset email sent timestamp.
|
||||||
func passwordResetSentAtKey(email string) string {
|
func passwordResetSentAtKey(email string) string {
|
||||||
return passwordResetSentAtKeyPrefix + email
|
return passwordResetSentAtKeyPrefix + strings.ToLower(email)
|
||||||
}
|
}
|
||||||
|
|
||||||
type emailCache struct {
|
type emailCache struct {
|
||||||
|
|||||||
@@ -283,6 +283,20 @@ func (s *BalanceNotifyService) getAccountQuotaNotifyEmails(ctx context.Context)
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return filterVerifiedEmails(entries)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getSiteName reads site name from settings with fallback.
|
||||||
|
func (s *BalanceNotifyService) getSiteName(ctx context.Context) string {
|
||||||
|
name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName)
|
||||||
|
if err != nil || name == "" {
|
||||||
|
return defaultSiteName
|
||||||
|
}
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterVerifiedEmails returns deduplicated, non-disabled, verified emails.
|
||||||
|
func filterVerifiedEmails(entries []NotifyEmailEntry) []string {
|
||||||
var recipients []string
|
var recipients []string
|
||||||
seen := make(map[string]bool)
|
seen := make(map[string]bool)
|
||||||
for _, entry := range entries {
|
for _, entry := range entries {
|
||||||
@@ -303,38 +317,10 @@ func (s *BalanceNotifyService) getAccountQuotaNotifyEmails(ctx context.Context)
|
|||||||
return recipients
|
return recipients
|
||||||
}
|
}
|
||||||
|
|
||||||
// getSiteName reads site name from settings with fallback.
|
|
||||||
func (s *BalanceNotifyService) getSiteName(ctx context.Context) string {
|
|
||||||
name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName)
|
|
||||||
if err != nil || name == "" {
|
|
||||||
return defaultSiteName
|
|
||||||
}
|
|
||||||
return name
|
|
||||||
}
|
|
||||||
|
|
||||||
// collectBalanceNotifyRecipients returns verified, non-disabled email recipients.
|
// collectBalanceNotifyRecipients returns verified, non-disabled email recipients.
|
||||||
// Only emails with verified=true and disabled=false are included.
|
// Only emails with verified=true and disabled=false are included.
|
||||||
func (s *BalanceNotifyService) collectBalanceNotifyRecipients(user *User) []string {
|
func (s *BalanceNotifyService) collectBalanceNotifyRecipients(user *User) []string {
|
||||||
var recipients []string
|
return filterVerifiedEmails(user.BalanceNotifyExtraEmails)
|
||||||
seen := make(map[string]bool)
|
|
||||||
|
|
||||||
for _, entry := range user.BalanceNotifyExtraEmails {
|
|
||||||
if entry.Disabled || !entry.Verified {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
email := strings.TrimSpace(entry.Email)
|
|
||||||
if email == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
lower := strings.ToLower(email)
|
|
||||||
if seen[lower] {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
seen[lower] = true
|
|
||||||
recipients = append(recipients, email)
|
|
||||||
}
|
|
||||||
|
|
||||||
return recipients
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendEmails sends an email to all recipients with shared timeout and error logging.
|
// sendEmails sends an email to all recipients with shared timeout and error logging.
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ type VerificationCodeData struct {
|
|||||||
Code string
|
Code string
|
||||||
Attempts int
|
Attempts int
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
|
ExpiresAt time.Time // absolute expiry; used to preserve remaining TTL when updating attempts
|
||||||
}
|
}
|
||||||
|
|
||||||
// PasswordResetTokenData represents password reset token data
|
// PasswordResetTokenData represents password reset token data
|
||||||
@@ -263,6 +264,7 @@ func (s *EmailService) SendVerifyCode(ctx context.Context, email, siteName strin
|
|||||||
Code: code,
|
Code: code,
|
||||||
Attempts: 0,
|
Attempts: 0,
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
|
ExpiresAt: time.Now().Add(verifyCodeTTL),
|
||||||
}
|
}
|
||||||
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
|
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
|
||||||
return fmt.Errorf("save verify code: %w", err)
|
return fmt.Errorf("save verify code: %w", err)
|
||||||
@@ -295,7 +297,11 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
|
|||||||
// 验证码不匹配 (constant-time comparison to prevent timing attacks)
|
// 验证码不匹配 (constant-time comparison to prevent timing attacks)
|
||||||
if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 {
|
if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 {
|
||||||
data.Attempts++
|
data.Attempts++
|
||||||
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
|
remaining := time.Until(data.ExpiresAt)
|
||||||
|
if remaining <= 0 {
|
||||||
|
return ErrInvalidVerifyCode
|
||||||
|
}
|
||||||
|
if err := s.cache.SetVerificationCode(ctx, email, data, remaining); err != nil {
|
||||||
slog.Error("failed to update verification attempt count", "email", email, "error", err)
|
slog.Error("failed to update verification attempt count", "email", email, "error", err)
|
||||||
}
|
}
|
||||||
if data.Attempts >= maxVerifyCodeAttempts {
|
if data.Attempts >= maxVerifyCodeAttempts {
|
||||||
|
|||||||
@@ -1194,12 +1194,20 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
|
|||||||
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
// anthropic/gemini 分组支持混合调度(包含启用了 mixed_scheduling 的 antigravity 账户)
|
||||||
// 注意:强制平台模式不走混合调度
|
// 注意:强制平台模式不走混合调度
|
||||||
if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform {
|
if (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform {
|
||||||
return s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
account, err := s.selectAccountWithMixedScheduling(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s.hydrateSelectedAccount(ctx, account)
|
||||||
}
|
}
|
||||||
|
|
||||||
// antigravity 分组、强制平台模式或无分组使用单平台选择
|
// antigravity 分组、强制平台模式或无分组使用单平台选择
|
||||||
// 注意:强制平台模式也必须遵守分组限制,不再回退到全平台查询
|
// 注意:强制平台模式也必须遵守分组限制,不再回退到全平台查询
|
||||||
return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
account, err := s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return s.hydrateSelectedAccount(ctx, account)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
|
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
|
||||||
@@ -1275,11 +1283,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
localExcluded[account.ID] = struct{}{} // 排除此账号
|
localExcluded[account.ID] = struct{}{} // 排除此账号
|
||||||
continue // 重新选择
|
continue // 重新选择
|
||||||
}
|
}
|
||||||
return &AccountSelectionResult{
|
return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
|
||||||
Account: account,
|
|
||||||
Acquired: true,
|
|
||||||
ReleaseFunc: result.ReleaseFunc,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 对于等待计划的情况,也需要先检查会话限制
|
// 对于等待计划的情况,也需要先检查会话限制
|
||||||
@@ -1291,26 +1295,20 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
|
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
|
||||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
|
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
|
||||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||||
return &AccountSelectionResult{
|
return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
|
||||||
Account: account,
|
|
||||||
WaitPlan: &AccountWaitPlan{
|
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
MaxConcurrency: account.Concurrency,
|
MaxConcurrency: account.Concurrency,
|
||||||
Timeout: cfg.StickySessionWaitTimeout,
|
Timeout: cfg.StickySessionWaitTimeout,
|
||||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||||
},
|
})
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return &AccountSelectionResult{
|
return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
|
||||||
Account: account,
|
|
||||||
WaitPlan: &AccountWaitPlan{
|
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
MaxConcurrency: account.Concurrency,
|
MaxConcurrency: account.Concurrency,
|
||||||
Timeout: cfg.FallbackWaitTimeout,
|
Timeout: cfg.FallbackWaitTimeout,
|
||||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||||
},
|
})
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1433,36 +1431,39 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) {
|
if containsInt64(routingAccountIDs, stickyAccountID) && !isExcluded(stickyAccountID) {
|
||||||
// 粘性账号在路由列表中,优先使用
|
// 粘性账号在路由列表中,优先使用
|
||||||
if stickyAccount, ok := accountByID[stickyAccountID]; ok {
|
if stickyAccount, ok := accountByID[stickyAccountID]; ok {
|
||||||
if s.isAccountSchedulableForSelection(stickyAccount) &&
|
var stickyCacheMissReason string
|
||||||
|
|
||||||
|
gatePass := s.isAccountSchedulableForSelection(stickyAccount) &&
|
||||||
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
|
s.isAccountAllowedForPlatform(stickyAccount, platform, useMixed) &&
|
||||||
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) &&
|
(requestedModel == "" || s.isModelSupportedByAccountWithContext(ctx, stickyAccount, requestedModel)) &&
|
||||||
s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) &&
|
s.isAccountSchedulableForModelSelection(ctx, stickyAccount, requestedModel) &&
|
||||||
s.isAccountSchedulableForQuota(stickyAccount) &&
|
s.isAccountSchedulableForQuota(stickyAccount) &&
|
||||||
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true) &&
|
s.isAccountSchedulableForWindowCost(ctx, stickyAccount, true)
|
||||||
|
|
||||||
s.isAccountSchedulableForRPM(ctx, stickyAccount, true) { // 粘性会话窗口费用+RPM 检查
|
rpmPass := gatePass && s.isAccountSchedulableForRPM(ctx, stickyAccount, true)
|
||||||
|
|
||||||
|
if rpmPass { // 粘性会话窗口费用+RPM 检查
|
||||||
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
|
result, err := s.tryAcquireAccountSlot(ctx, stickyAccountID, stickyAccount.Concurrency)
|
||||||
if err == nil && result.Acquired {
|
if err == nil && result.Acquired {
|
||||||
// 会话数量限制检查
|
// 会话数量限制检查
|
||||||
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
|
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
|
||||||
result.ReleaseFunc() // 释放槽位
|
result.ReleaseFunc() // 释放槽位
|
||||||
|
stickyCacheMissReason = "session_limit"
|
||||||
// 继续到负载感知选择
|
// 继续到负载感知选择
|
||||||
} else {
|
} else {
|
||||||
if s.debugModelRoutingEnabled() {
|
if s.debugModelRoutingEnabled() {
|
||||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID)
|
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed sticky hit: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), stickyAccountID)
|
||||||
}
|
}
|
||||||
return &AccountSelectionResult{
|
return s.newSelectionResult(ctx, stickyAccount, true, result.ReleaseFunc, nil)
|
||||||
Account: stickyAccount,
|
|
||||||
Acquired: true,
|
|
||||||
ReleaseFunc: result.ReleaseFunc,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if stickyCacheMissReason == "" {
|
||||||
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
|
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, stickyAccountID)
|
||||||
if waitingCount < cfg.StickySessionMaxWaiting {
|
if waitingCount < cfg.StickySessionMaxWaiting {
|
||||||
// 会话数量限制检查(等待计划也需要占用会话配额)
|
// 会话数量限制检查(等待计划也需要占用会话配额)
|
||||||
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
|
if !s.checkAndRegisterSession(ctx, stickyAccount, sessionHash) {
|
||||||
|
stickyCacheMissReason = "session_limit"
|
||||||
// 会话限制已满,继续到负载感知选择
|
// 会话限制已满,继续到负载感知选择
|
||||||
} else {
|
} else {
|
||||||
return &AccountSelectionResult{
|
return &AccountSelectionResult{
|
||||||
@@ -1475,11 +1476,31 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
},
|
},
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
stickyCacheMissReason = "wait_queue_full"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
|
// 粘性账号槽位满且等待队列已满,继续使用负载感知选择
|
||||||
|
} else if !gatePass {
|
||||||
|
stickyCacheMissReason = "gate_check"
|
||||||
|
} else {
|
||||||
|
stickyCacheMissReason = "rpm_red"
|
||||||
|
}
|
||||||
|
|
||||||
|
// 记录粘性缓存未命中的结构化日志
|
||||||
|
if stickyCacheMissReason != "" {
|
||||||
|
baseRPM := stickyAccount.GetBaseRPM()
|
||||||
|
var currentRPM int
|
||||||
|
if count, ok := rpmFromPrefetchContext(ctx, stickyAccount.ID); ok {
|
||||||
|
currentRPM = count
|
||||||
|
}
|
||||||
|
logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=%s account_id=%d session=%s current_rpm=%d base_rpm=%d",
|
||||||
|
stickyCacheMissReason, stickyAccountID, shortSessionHash(sessionHash), currentRPM, baseRPM)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
_ = s.cache.DeleteSessionAccountID(ctx, derefGroupID(groupID), sessionHash)
|
||||||
|
logger.LegacyPrintf("service.gateway", "[StickyCacheMiss] reason=account_cleared account_id=%d session=%s current_rpm=0 base_rpm=0",
|
||||||
|
stickyAccountID, shortSessionHash(sessionHash))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1544,11 +1565,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
if s.debugModelRoutingEnabled() {
|
if s.debugModelRoutingEnabled() {
|
||||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
|
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed select: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
|
||||||
}
|
}
|
||||||
return &AccountSelectionResult{
|
return s.newSelectionResult(ctx, item.account, true, result.ReleaseFunc, nil)
|
||||||
Account: item.account,
|
|
||||||
Acquired: true,
|
|
||||||
ReleaseFunc: result.ReleaseFunc,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1561,15 +1578,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
if s.debugModelRoutingEnabled() {
|
if s.debugModelRoutingEnabled() {
|
||||||
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
|
logger.LegacyPrintf("service.gateway", "[ModelRoutingDebug] routed wait: group_id=%v model=%s session=%s account=%d", derefGroupID(groupID), requestedModel, shortSessionHash(sessionHash), item.account.ID)
|
||||||
}
|
}
|
||||||
return &AccountSelectionResult{
|
return s.newSelectionResult(ctx, item.account, false, nil, &AccountWaitPlan{
|
||||||
Account: item.account,
|
|
||||||
WaitPlan: &AccountWaitPlan{
|
|
||||||
AccountID: item.account.ID,
|
AccountID: item.account.ID,
|
||||||
MaxConcurrency: item.account.Concurrency,
|
MaxConcurrency: item.account.Concurrency,
|
||||||
Timeout: cfg.StickySessionWaitTimeout,
|
Timeout: cfg.StickySessionWaitTimeout,
|
||||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||||
},
|
})
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
// 所有路由账号会话限制都已满,继续到 Layer 2 回退
|
// 所有路由账号会话限制都已满,继续到 Layer 2 回退
|
||||||
}
|
}
|
||||||
@@ -1603,11 +1617,10 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||||
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
|
result.ReleaseFunc() // 释放槽位,继续到 Layer 2
|
||||||
} else {
|
} else {
|
||||||
return &AccountSelectionResult{
|
if s.cache != nil {
|
||||||
Account: account,
|
_ = s.cache.RefreshSessionTTL(ctx, derefGroupID(groupID), sessionHash, stickySessionTTL)
|
||||||
Acquired: true,
|
}
|
||||||
ReleaseFunc: result.ReleaseFunc,
|
return s.newSelectionResult(ctx, account, true, result.ReleaseFunc, nil)
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1617,15 +1630,12 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
if !s.checkAndRegisterSession(ctx, account, sessionHash) {
|
||||||
// 会话限制已满,继续到 Layer 2
|
// 会话限制已满,继续到 Layer 2
|
||||||
} else {
|
} else {
|
||||||
return &AccountSelectionResult{
|
return s.newSelectionResult(ctx, account, false, nil, &AccountWaitPlan{
|
||||||
Account: account,
|
|
||||||
WaitPlan: &AccountWaitPlan{
|
|
||||||
AccountID: accountID,
|
AccountID: accountID,
|
||||||
MaxConcurrency: account.Concurrency,
|
MaxConcurrency: account.Concurrency,
|
||||||
Timeout: cfg.StickySessionWaitTimeout,
|
Timeout: cfg.StickySessionWaitTimeout,
|
||||||
MaxWaiting: cfg.StickySessionMaxWaiting,
|
MaxWaiting: cfg.StickySessionMaxWaiting,
|
||||||
},
|
})
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1684,7 +1694,9 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
|
|
||||||
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
|
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); ok {
|
if result, ok, legacyErr := s.tryAcquireByLegacyOrder(ctx, candidates, groupID, sessionHash, preferOAuth); legacyErr != nil {
|
||||||
|
return nil, legacyErr
|
||||||
|
} else if ok {
|
||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -1723,11 +1735,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
if sessionHash != "" && s.cache != nil {
|
if sessionHash != "" && s.cache != nil {
|
||||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
|
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, selected.account.ID, stickySessionTTL)
|
||||||
}
|
}
|
||||||
return &AccountSelectionResult{
|
return s.newSelectionResult(ctx, selected.account, true, result.ReleaseFunc, nil)
|
||||||
Account: selected.account,
|
|
||||||
Acquired: true,
|
|
||||||
ReleaseFunc: result.ReleaseFunc,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1750,20 +1758,17 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
if !s.checkAndRegisterSession(ctx, acc, sessionHash) {
|
if !s.checkAndRegisterSession(ctx, acc, sessionHash) {
|
||||||
continue // 会话限制已满,尝试下一个账号
|
continue // 会话限制已满,尝试下一个账号
|
||||||
}
|
}
|
||||||
return &AccountSelectionResult{
|
return s.newSelectionResult(ctx, acc, false, nil, &AccountWaitPlan{
|
||||||
Account: acc,
|
|
||||||
WaitPlan: &AccountWaitPlan{
|
|
||||||
AccountID: acc.ID,
|
AccountID: acc.ID,
|
||||||
MaxConcurrency: acc.Concurrency,
|
MaxConcurrency: acc.Concurrency,
|
||||||
Timeout: cfg.FallbackWaitTimeout,
|
Timeout: cfg.FallbackWaitTimeout,
|
||||||
MaxWaiting: cfg.FallbackMaxWaiting,
|
MaxWaiting: cfg.FallbackMaxWaiting,
|
||||||
},
|
})
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
return nil, ErrNoAvailableAccounts
|
return nil, ErrNoAvailableAccounts
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
|
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, groupID *int64, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool, error) {
|
||||||
ordered := append([]*Account(nil), candidates...)
|
ordered := append([]*Account(nil), candidates...)
|
||||||
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
|
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
|
||||||
|
|
||||||
@@ -1778,15 +1783,15 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates
|
|||||||
if sessionHash != "" && s.cache != nil {
|
if sessionHash != "" && s.cache != nil {
|
||||||
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL)
|
_ = s.cache.SetSessionAccountID(ctx, derefGroupID(groupID), sessionHash, acc.ID, stickySessionTTL)
|
||||||
}
|
}
|
||||||
return &AccountSelectionResult{
|
selection, err := s.newSelectionResult(ctx, acc, true, result.ReleaseFunc, nil)
|
||||||
Account: acc,
|
if err != nil {
|
||||||
Acquired: true,
|
return nil, false, err
|
||||||
ReleaseFunc: result.ReleaseFunc,
|
}
|
||||||
}, true
|
return selection, true, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, false
|
return nil, false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
|
func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
|
||||||
@@ -2401,6 +2406,33 @@ func (s *GatewayService) getSchedulableAccount(ctx context.Context, accountID in
|
|||||||
return s.accountRepo.GetByID(ctx, accountID)
|
return s.accountRepo.GetByID(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) hydrateSelectedAccount(ctx context.Context, account *Account) (*Account, error) {
|
||||||
|
if account == nil || s.schedulerSnapshot == nil {
|
||||||
|
return account, nil
|
||||||
|
}
|
||||||
|
hydrated, err := s.schedulerSnapshot.GetAccount(ctx, account.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if hydrated == nil {
|
||||||
|
return nil, fmt.Errorf("selected gateway account %d not found during hydration", account.ID)
|
||||||
|
}
|
||||||
|
return hydrated, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GatewayService) newSelectionResult(ctx context.Context, account *Account, acquired bool, release func(), waitPlan *AccountWaitPlan) (*AccountSelectionResult, error) {
|
||||||
|
hydrated, err := s.hydrateSelectedAccount(ctx, account)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &AccountSelectionResult{
|
||||||
|
Account: hydrated,
|
||||||
|
Acquired: acquired,
|
||||||
|
ReleaseFunc: release,
|
||||||
|
WaitPlan: waitPlan,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// filterByMinPriority 过滤出优先级最小的账号集合
|
// filterByMinPriority 过滤出优先级最小的账号集合
|
||||||
func filterByMinPriority(accounts []accountWithLoad) []accountWithLoad {
|
func filterByMinPriority(accounts []accountWithLoad) []accountWithLoad {
|
||||||
if len(accounts) == 0 {
|
if len(accounts) == 0 {
|
||||||
@@ -2676,6 +2708,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
|||||||
preferOAuth := platform == PlatformGemini
|
preferOAuth := platform == PlatformGemini
|
||||||
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
|
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, platform)
|
||||||
|
|
||||||
|
// require_privacy_set: 获取分组信息
|
||||||
|
var schedGroup *Group
|
||||||
|
if groupID != nil && s.groupRepo != nil {
|
||||||
|
schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID)
|
||||||
|
}
|
||||||
|
|
||||||
var accounts []Account
|
var accounts []Account
|
||||||
accountsLoaded := false
|
accountsLoaded := false
|
||||||
|
|
||||||
@@ -2747,6 +2785,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
|||||||
if !s.isAccountSchedulableForSelection(acc) {
|
if !s.isAccountSchedulableForSelection(acc) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
|
||||||
|
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
|
||||||
|
_ = s.accountRepo.SetError(ctx, acc.ID,
|
||||||
|
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
||||||
|
continue
|
||||||
|
}
|
||||||
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
|
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -2852,6 +2896,12 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
|||||||
if !s.isAccountSchedulableForSelection(acc) {
|
if !s.isAccountSchedulableForSelection(acc) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
|
||||||
|
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
|
||||||
|
_ = s.accountRepo.SetError(ctx, acc.ID,
|
||||||
|
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
||||||
|
continue
|
||||||
|
}
|
||||||
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
|
if requestedModel != "" && !s.isModelSupportedByAccountWithContext(ctx, acc, requestedModel) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -2918,6 +2968,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
|||||||
preferOAuth := nativePlatform == PlatformGemini
|
preferOAuth := nativePlatform == PlatformGemini
|
||||||
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform)
|
routingAccountIDs := s.routingAccountIDsForRequest(ctx, groupID, requestedModel, nativePlatform)
|
||||||
|
|
||||||
|
// require_privacy_set: 获取分组信息
|
||||||
|
var schedGroup *Group
|
||||||
|
if groupID != nil && s.groupRepo != nil {
|
||||||
|
schedGroup, _ = s.groupRepo.GetByID(ctx, *groupID)
|
||||||
|
}
|
||||||
|
|
||||||
var accounts []Account
|
var accounts []Account
|
||||||
accountsLoaded := false
|
accountsLoaded := false
|
||||||
|
|
||||||
@@ -2985,6 +3041,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
|||||||
if !s.isAccountSchedulableForSelection(acc) {
|
if !s.isAccountSchedulableForSelection(acc) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
|
||||||
|
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
|
||||||
|
_ = s.accountRepo.SetError(ctx, acc.ID,
|
||||||
|
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
||||||
|
continue
|
||||||
|
}
|
||||||
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
|
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
|
||||||
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
||||||
continue
|
continue
|
||||||
@@ -3078,6 +3140,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
|||||||
ctx = s.withRPMPrefetch(ctx, accounts)
|
ctx = s.withRPMPrefetch(ctx, accounts)
|
||||||
|
|
||||||
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
|
// 3. 按优先级+最久未用选择(考虑模型支持和混合调度)
|
||||||
|
// needsUpstreamCheck 仅在主选择循环中使用;粘性会话命中时跳过此检查。
|
||||||
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
|
needsUpstreamCheck := s.needsUpstreamChannelRestrictionCheck(ctx, groupID)
|
||||||
var selected *Account
|
var selected *Account
|
||||||
for i := range accounts {
|
for i := range accounts {
|
||||||
@@ -3090,6 +3153,12 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
|||||||
if !s.isAccountSchedulableForSelection(acc) {
|
if !s.isAccountSchedulableForSelection(acc) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
// require_privacy_set: 跳过 privacy 未设置的账号并标记异常
|
||||||
|
if schedGroup != nil && schedGroup.RequirePrivacySet && !acc.IsPrivacySet() {
|
||||||
|
_ = s.accountRepo.SetError(ctx, acc.ID,
|
||||||
|
fmt.Sprintf("Privacy not set, required by group [%s]", schedGroup.Name))
|
||||||
|
continue
|
||||||
|
}
|
||||||
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
|
// 过滤:原生平台直接通过,antigravity 需要启用混合调度
|
||||||
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
|
||||||
continue
|
continue
|
||||||
@@ -3257,8 +3326,7 @@ func (s *GatewayService) diagnoseSelectionFailure(
|
|||||||
return selectionFailureDiagnosis{Category: "excluded"}
|
return selectionFailureDiagnosis{Category: "excluded"}
|
||||||
}
|
}
|
||||||
if !s.isAccountSchedulableForSelection(acc) {
|
if !s.isAccountSchedulableForSelection(acc) {
|
||||||
detail := "generic_unschedulable"
|
return selectionFailureDiagnosis{Category: "unschedulable", Detail: "generic_unschedulable"}
|
||||||
return selectionFailureDiagnosis{Category: "unschedulable", Detail: detail}
|
|
||||||
}
|
}
|
||||||
if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) {
|
if isPlatformFilteredForSelection(acc, platform, allowMixedScheduling) {
|
||||||
return selectionFailureDiagnosis{
|
return selectionFailureDiagnosis{
|
||||||
@@ -3282,7 +3350,6 @@ func (s *GatewayService) diagnoseSelectionFailure(
|
|||||||
return selectionFailureDiagnosis{Category: "eligible"}
|
return selectionFailureDiagnosis{Category: "eligible"}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccessToken 获取账号凭证
|
|
||||||
func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool {
|
func isPlatformFilteredForSelection(acc *Account, platform string, allowMixedScheduling bool) bool {
|
||||||
if acc == nil {
|
if acc == nil {
|
||||||
return true
|
return true
|
||||||
@@ -3653,6 +3720,86 @@ func injectClaudeCodePrompt(body []byte, system any) []byte {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// rewriteSystemForNonClaudeCode 将非 Claude Code 客户端的 system prompt 迁移至 messages,
|
||||||
|
// system 字段仅保留 Claude Code 标识提示词。
|
||||||
|
// Anthropic 基于 system 参数内容检测第三方应用,仅前置追加 Claude Code 提示词
|
||||||
|
// 无法通过检测,因为后续内容仍为非 Claude Code 格式。
|
||||||
|
// 策略:将原始 system prompt 提取并注入为 user/assistant 消息对,system 仅保留 Claude Code 标识。
|
||||||
|
func rewriteSystemForNonClaudeCode(body []byte, system any) []byte {
|
||||||
|
system = normalizeSystemParam(system)
|
||||||
|
|
||||||
|
// 1. 提取原始 system prompt 文本
|
||||||
|
var originalSystemText string
|
||||||
|
switch v := system.(type) {
|
||||||
|
case string:
|
||||||
|
originalSystemText = strings.TrimSpace(v)
|
||||||
|
case []any:
|
||||||
|
var parts []string
|
||||||
|
for _, item := range v {
|
||||||
|
if m, ok := item.(map[string]any); ok {
|
||||||
|
if text, ok := m["text"].(string); ok && strings.TrimSpace(text) != "" {
|
||||||
|
parts = append(parts, text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
originalSystemText = strings.Join(parts, "\n\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 将 system 替换为 Claude Code 标准提示词(array 格式,与真实 Claude Code 一致)
|
||||||
|
// 真实 Claude Code 始终以 [{type: "text", text: "...", cache_control: {type: "ephemeral"}}] 发送 system。
|
||||||
|
// 使用 string 格式会被 Anthropic 检测为第三方应用。
|
||||||
|
claudeCodeSystemBlock := []map[string]any{
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": claudeCodeSystemPrompt,
|
||||||
|
"cache_control": map[string]string{"type": "ephemeral"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
out, ok := setJSONValueBytes(body, "system", claudeCodeSystemBlock)
|
||||||
|
if !ok {
|
||||||
|
logger.LegacyPrintf("service.gateway", "Warning: failed to set Claude Code system prompt")
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 将原始 system prompt 作为 user/assistant 消息对注入到 messages 开头
|
||||||
|
// 模型仍通过 messages 接收完整指令,保留客户端功能
|
||||||
|
ccPromptTrimmed := strings.TrimSpace(claudeCodeSystemPrompt)
|
||||||
|
if originalSystemText != "" && originalSystemText != ccPromptTrimmed && !hasClaudeCodePrefix(originalSystemText) {
|
||||||
|
instrMsg, err1 := json.Marshal(map[string]any{
|
||||||
|
"role": "user",
|
||||||
|
"content": []map[string]any{
|
||||||
|
{"type": "text", "text": "[System Instructions]\n" + originalSystemText},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
ackMsg, err2 := json.Marshal(map[string]any{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": []map[string]any{
|
||||||
|
{"type": "text", "text": "Understood. I will follow these instructions."},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err1 != nil || err2 != nil {
|
||||||
|
logger.LegacyPrintf("service.gateway", "Warning: failed to marshal system-to-messages injection")
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
// 重建 messages 数组:[instruction, ack, ...originalMessages]
|
||||||
|
items := [][]byte{instrMsg, ackMsg}
|
||||||
|
messagesResult := gjson.GetBytes(out, "messages")
|
||||||
|
if messagesResult.IsArray() {
|
||||||
|
messagesResult.ForEach(func(_, msg gjson.Result) bool {
|
||||||
|
items = append(items, []byte(msg.Raw))
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if next, setOk := setJSONRawBytes(out, "messages", buildJSONArrayRaw(items)); setOk {
|
||||||
|
out = next
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
type cacheControlPath struct {
|
type cacheControlPath struct {
|
||||||
path string
|
path string
|
||||||
log string
|
log string
|
||||||
@@ -3819,7 +3966,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
// Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest.
|
// Beta policy: evaluate once; block check + cache filter set for buildUpstreamRequest.
|
||||||
// Always overwrite the cache to prevent stale values from a previous retry with a different account.
|
// Always overwrite the cache to prevent stale values from a previous retry with a different account.
|
||||||
if account.Platform == PlatformAnthropic && c != nil {
|
if account.Platform == PlatformAnthropic && c != nil {
|
||||||
policy := s.evaluateBetaPolicy(ctx, c.GetHeader("anthropic-beta"), account)
|
policy := s.evaluateBetaPolicy(ctx, c.GetHeader("anthropic-beta"), account, parsed.Model)
|
||||||
if policy.blockErr != nil {
|
if policy.blockErr != nil {
|
||||||
return nil, policy.blockErr
|
return nil, policy.blockErr
|
||||||
}
|
}
|
||||||
@@ -3849,19 +3996,24 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
|||||||
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
|
shouldMimicClaudeCode := account.IsOAuth() && !isClaudeCode
|
||||||
|
|
||||||
if shouldMimicClaudeCode {
|
if shouldMimicClaudeCode {
|
||||||
// 智能注入 Claude Code 系统提示词(仅 OAuth/SetupToken 账号需要)
|
// 非 Claude Code 客户端:将 system 替换为 Claude Code 标识,原始 system 迁移至 messages
|
||||||
// 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
|
// 条件:1) OAuth/SetupToken 账号 2) 不是 Claude Code 客户端 3) 不是 Haiku 模型 4) system 中还没有 Claude Code 提示词
|
||||||
|
systemRewritten := false
|
||||||
if !strings.Contains(strings.ToLower(reqModel), "haiku") &&
|
if !strings.Contains(strings.ToLower(reqModel), "haiku") &&
|
||||||
!systemIncludesClaudeCodePrompt(parsed.System) {
|
!systemIncludesClaudeCodePrompt(parsed.System) {
|
||||||
body = injectClaudeCodePrompt(body, parsed.System)
|
body = rewriteSystemForNonClaudeCode(body, parsed.System)
|
||||||
|
systemRewritten = true
|
||||||
}
|
}
|
||||||
|
|
||||||
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: true}
|
// system 被重写时保留 CC prompt 的 cache_control: ephemeral(匹配真实 Claude Code 行为);
|
||||||
|
// 未重写时(haiku / 已含 CC 前缀)剥离客户端 cache_control,与原有行为一致。
|
||||||
|
// 两种情况下 enforceCacheControlLimit 都会兜底处理上限。
|
||||||
|
normalizeOpts := claudeOAuthNormalizeOptions{stripSystemCacheControl: !systemRewritten}
|
||||||
if s.identityService != nil {
|
if s.identityService != nil {
|
||||||
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
|
fp, err := s.identityService.GetOrCreateFingerprint(ctx, account.ID, c.Request.Header)
|
||||||
if err == nil && fp != nil {
|
if err == nil && fp != nil {
|
||||||
// metadata 透传开启时跳过 metadata 注入
|
// metadata 透传开启时跳过 metadata 注入
|
||||||
_, mimicMPT := s.settingService.GetGatewayForwardingSettings(ctx)
|
_, mimicMPT, _ := s.settingService.GetGatewayForwardingSettings(ctx)
|
||||||
if !mimicMPT {
|
if !mimicMPT {
|
||||||
if metadataUserID := s.buildOAuthMetadataUserID(parsed, account, fp); metadataUserID != "" {
|
if metadataUserID := s.buildOAuthMetadataUserID(parsed, account, fp); metadataUserID != "" {
|
||||||
normalizeOpts.injectMetadata = true
|
normalizeOpts.injectMetadata = true
|
||||||
@@ -5407,9 +5559,9 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
|||||||
|
|
||||||
// OAuth账号:应用统一指纹和metadata重写(受设置开关控制)
|
// OAuth账号:应用统一指纹和metadata重写(受设置开关控制)
|
||||||
var fingerprint *Fingerprint
|
var fingerprint *Fingerprint
|
||||||
enableFP, enableMPT := true, false
|
enableFP, enableMPT, enableCCH := true, false, false
|
||||||
if s.settingService != nil {
|
if s.settingService != nil {
|
||||||
enableFP, enableMPT = s.settingService.GetGatewayForwardingSettings(ctx)
|
enableFP, enableMPT, enableCCH = s.settingService.GetGatewayForwardingSettings(ctx)
|
||||||
}
|
}
|
||||||
if account.IsOAuth() && s.identityService != nil {
|
if account.IsOAuth() && s.identityService != nil {
|
||||||
// 1. 获取或创建指纹(包含随机生成的ClientID)
|
// 1. 获取或创建指纹(包含随机生成的ClientID)
|
||||||
@@ -5436,6 +5588,15 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 同步 billing header cc_version 与实际发送的 User-Agent 版本
|
||||||
|
if fingerprint != nil {
|
||||||
|
body = syncBillingHeaderVersion(body, fingerprint.UserAgent)
|
||||||
|
}
|
||||||
|
// CCH 签名:将 cch=00000 占位符替换为 xxHash64 签名(需在所有 body 修改之后)
|
||||||
|
if enableCCH {
|
||||||
|
body = signBillingHeaderCCH(body)
|
||||||
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
|
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -5476,9 +5637,8 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Build effective drop set: merge static defaults with dynamic beta policy filter rules
|
// Build effective drop set: merge static defaults with dynamic beta policy filter rules
|
||||||
policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account)
|
policyFilterSet := s.getBetaPolicyFilterSet(ctx, c, account, modelID)
|
||||||
effectiveDropSet := mergeDropSets(policyFilterSet)
|
effectiveDropSet := mergeDropSets(policyFilterSet)
|
||||||
effectiveDropWithClaudeCodeSet := mergeDropSets(policyFilterSet, claude.BetaClaudeCode)
|
|
||||||
|
|
||||||
// 处理 anthropic-beta header(OAuth 账号需要包含 oauth beta)
|
// 处理 anthropic-beta header(OAuth 账号需要包含 oauth beta)
|
||||||
if tokenType == "oauth" {
|
if tokenType == "oauth" {
|
||||||
@@ -5489,11 +5649,16 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
|||||||
applyClaudeCodeMimicHeaders(req, reqStream)
|
applyClaudeCodeMimicHeaders(req, reqStream)
|
||||||
|
|
||||||
incomingBeta := getHeaderRaw(req.Header, "anthropic-beta")
|
incomingBeta := getHeaderRaw(req.Header, "anthropic-beta")
|
||||||
// Match real Claude CLI traffic (per mitmproxy reports):
|
// Claude Code OAuth credentials are scoped to Claude Code.
|
||||||
// messages requests typically use only oauth + interleaved-thinking.
|
// Non-haiku models MUST include claude-code beta for Anthropic to recognize
|
||||||
// Also drop claude-code beta if a downstream client added it.
|
// this as a legitimate Claude Code request; without it, the request is
|
||||||
|
// rejected as third-party ("out of extra usage").
|
||||||
|
// Haiku models are exempt from third-party detection and don't need it.
|
||||||
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
|
requiredBetas := []string{claude.BetaOAuth, claude.BetaInterleavedThinking}
|
||||||
setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropWithClaudeCodeSet))
|
if !strings.Contains(strings.ToLower(modelID), "haiku") {
|
||||||
|
requiredBetas = []string{claude.BetaClaudeCode, claude.BetaOAuth, claude.BetaInterleavedThinking}
|
||||||
|
}
|
||||||
|
setHeaderRaw(req.Header, "anthropic-beta", mergeAnthropicBetaDropping(requiredBetas, incomingBeta, effectiveDropSet))
|
||||||
} else {
|
} else {
|
||||||
// Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta
|
// Claude Code 客户端:尽量透传原始 header,仅补齐 oauth beta
|
||||||
clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta")
|
clientBetaHeader := getHeaderRaw(req.Header, "anthropic-beta")
|
||||||
@@ -5716,7 +5881,7 @@ type betaPolicyResult struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// evaluateBetaPolicy loads settings once and evaluates all rules against the given request.
|
// evaluateBetaPolicy loads settings once and evaluates all rules against the given request.
|
||||||
func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader string, account *Account) betaPolicyResult {
|
func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader string, account *Account, model string) betaPolicyResult {
|
||||||
if s.settingService == nil {
|
if s.settingService == nil {
|
||||||
return betaPolicyResult{}
|
return betaPolicyResult{}
|
||||||
}
|
}
|
||||||
@@ -5731,10 +5896,11 @@ func (s *GatewayService) evaluateBetaPolicy(ctx context.Context, betaHeader stri
|
|||||||
if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
|
if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
switch rule.Action {
|
effectiveAction, effectiveErrMsg := resolveRuleAction(rule, model)
|
||||||
|
switch effectiveAction {
|
||||||
case BetaPolicyActionBlock:
|
case BetaPolicyActionBlock:
|
||||||
if result.blockErr == nil && betaHeader != "" && containsBetaToken(betaHeader, rule.BetaToken) {
|
if result.blockErr == nil && betaHeader != "" && containsBetaToken(betaHeader, rule.BetaToken) {
|
||||||
msg := rule.ErrorMessage
|
msg := effectiveErrMsg
|
||||||
if msg == "" {
|
if msg == "" {
|
||||||
msg = "beta feature " + rule.BetaToken + " is not allowed"
|
msg = "beta feature " + rule.BetaToken + " is not allowed"
|
||||||
}
|
}
|
||||||
@@ -5776,7 +5942,7 @@ const betaPolicyFilterSetKey = "betaPolicyFilterSet"
|
|||||||
// In the /v1/messages path, Forward() evaluates the policy first and caches the result;
|
// In the /v1/messages path, Forward() evaluates the policy first and caches the result;
|
||||||
// buildUpstreamRequest reuses it (zero extra DB calls). In the count_tokens path, this
|
// buildUpstreamRequest reuses it (zero extra DB calls). In the count_tokens path, this
|
||||||
// evaluates on demand (one DB call).
|
// evaluates on demand (one DB call).
|
||||||
func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Context, account *Account) map[string]struct{} {
|
func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Context, account *Account, model string) map[string]struct{} {
|
||||||
if c != nil {
|
if c != nil {
|
||||||
if v, ok := c.Get(betaPolicyFilterSetKey); ok {
|
if v, ok := c.Get(betaPolicyFilterSetKey); ok {
|
||||||
if fs, ok := v.(map[string]struct{}); ok {
|
if fs, ok := v.(map[string]struct{}); ok {
|
||||||
@@ -5784,7 +5950,7 @@ func (s *GatewayService) getBetaPolicyFilterSet(ctx context.Context, c *gin.Cont
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return s.evaluateBetaPolicy(ctx, "", account).filterSet
|
return s.evaluateBetaPolicy(ctx, "", account, model).filterSet
|
||||||
}
|
}
|
||||||
|
|
||||||
// betaPolicyScopeMatches checks whether a rule's scope matches the current account type.
|
// betaPolicyScopeMatches checks whether a rule's scope matches the current account type.
|
||||||
@@ -5803,6 +5969,33 @@ func betaPolicyScopeMatches(scope string, isOAuth bool, isBedrock bool) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// matchModelWhitelist checks if a model matches any pattern in the whitelist.
|
||||||
|
// Reuses matchModelPattern from group.go which supports exact and wildcard prefix matching.
|
||||||
|
func matchModelWhitelist(model string, whitelist []string) bool {
|
||||||
|
for _, pattern := range whitelist {
|
||||||
|
if matchModelPattern(pattern, model) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveRuleAction determines the effective action and error message for a rule given the request model.
|
||||||
|
// When ModelWhitelist is empty, the rule's primary Action/ErrorMessage applies unconditionally.
|
||||||
|
// When non-empty, Action applies to matching models; FallbackAction/FallbackErrorMessage applies to others.
|
||||||
|
func resolveRuleAction(rule BetaPolicyRule, model string) (action, errorMessage string) {
|
||||||
|
if len(rule.ModelWhitelist) == 0 {
|
||||||
|
return rule.Action, rule.ErrorMessage
|
||||||
|
}
|
||||||
|
if matchModelWhitelist(model, rule.ModelWhitelist) {
|
||||||
|
return rule.Action, rule.ErrorMessage
|
||||||
|
}
|
||||||
|
if rule.FallbackAction != "" {
|
||||||
|
return rule.FallbackAction, rule.FallbackErrorMessage
|
||||||
|
}
|
||||||
|
return BetaPolicyActionPass, "" // default fallback: pass (fail-open)
|
||||||
|
}
|
||||||
|
|
||||||
// droppedBetaSet returns claude.DroppedBetas as a set, with optional extra tokens.
|
// droppedBetaSet returns claude.DroppedBetas as a set, with optional extra tokens.
|
||||||
func droppedBetaSet(extra ...string) map[string]struct{} {
|
func droppedBetaSet(extra ...string) map[string]struct{} {
|
||||||
m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(extra))
|
m := make(map[string]struct{}, len(defaultDroppedBetasSet)+len(extra))
|
||||||
@@ -5849,7 +6042,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest(
|
|||||||
modelID string,
|
modelID string,
|
||||||
) ([]string, error) {
|
) ([]string, error) {
|
||||||
// 1. 对原始 header 中的 beta token 做 block 检查(快速失败)
|
// 1. 对原始 header 中的 beta token 做 block 检查(快速失败)
|
||||||
policy := s.evaluateBetaPolicy(ctx, betaHeader, account)
|
policy := s.evaluateBetaPolicy(ctx, betaHeader, account, modelID)
|
||||||
if policy.blockErr != nil {
|
if policy.blockErr != nil {
|
||||||
return nil, policy.blockErr
|
return nil, policy.blockErr
|
||||||
}
|
}
|
||||||
@@ -5861,7 +6054,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest(
|
|||||||
// 例如:管理员 block 了 interleaved-thinking,客户端不在 header 中带该 token,
|
// 例如:管理员 block 了 interleaved-thinking,客户端不在 header 中带该 token,
|
||||||
// 但请求体中包含 thinking 字段 → autoInjectBedrockBetaTokens 会自动补齐 →
|
// 但请求体中包含 thinking 字段 → autoInjectBedrockBetaTokens 会自动补齐 →
|
||||||
// 如果不做此检查,block 规则会被绕过。
|
// 如果不做此检查,block 规则会被绕过。
|
||||||
if blockErr := s.checkBetaPolicyBlockForTokens(ctx, betaTokens, account); blockErr != nil {
|
if blockErr := s.checkBetaPolicyBlockForTokens(ctx, betaTokens, account, modelID); blockErr != nil {
|
||||||
return nil, blockErr
|
return nil, blockErr
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -5870,7 +6063,7 @@ func (s *GatewayService) resolveBedrockBetaTokensForRequest(
|
|||||||
|
|
||||||
// checkBetaPolicyBlockForTokens 检查 token 列表中是否有被管理员 block 规则命中的 token。
|
// checkBetaPolicyBlockForTokens 检查 token 列表中是否有被管理员 block 规则命中的 token。
|
||||||
// 用于补充 evaluateBetaPolicy 对 header 的检查,覆盖 body 自动注入的 token。
|
// 用于补充 evaluateBetaPolicy 对 header 的检查,覆盖 body 自动注入的 token。
|
||||||
func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, tokens []string, account *Account) *BetaBlockedError {
|
func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, tokens []string, account *Account, model string) *BetaBlockedError {
|
||||||
if s.settingService == nil || len(tokens) == 0 {
|
if s.settingService == nil || len(tokens) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -5882,14 +6075,15 @@ func (s *GatewayService) checkBetaPolicyBlockForTokens(ctx context.Context, toke
|
|||||||
isBedrock := account.IsBedrock()
|
isBedrock := account.IsBedrock()
|
||||||
tokenSet := buildBetaTokenSet(tokens)
|
tokenSet := buildBetaTokenSet(tokens)
|
||||||
for _, rule := range settings.Rules {
|
for _, rule := range settings.Rules {
|
||||||
if rule.Action != BetaPolicyActionBlock {
|
effectiveAction, effectiveErrMsg := resolveRuleAction(rule, model)
|
||||||
|
if effectiveAction != BetaPolicyActionBlock {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
|
if !betaPolicyScopeMatches(rule.Scope, isOAuth, isBedrock) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if _, present := tokenSet[rule.BetaToken]; present {
|
if _, present := tokenSet[rule.BetaToken]; present {
|
||||||
msg := rule.ErrorMessage
|
msg := effectiveErrMsg
|
||||||
if msg == "" {
|
if msg == "" {
|
||||||
msg = "beta feature " + rule.BetaToken + " is not allowed"
|
msg = "beta feature " + rule.BetaToken + " is not allowed"
|
||||||
}
|
}
|
||||||
@@ -7146,49 +7340,41 @@ func (p *postUsageBillingParams) shouldUpdateAccountQuota() bool {
|
|||||||
return p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit()
|
return p.Cost.TotalCost > 0 && p.Account.IsAPIKeyOrBedrock() && p.Account.HasAnyQuotaLimit()
|
||||||
}
|
}
|
||||||
|
|
||||||
// postUsageBilling 统一处理使用量记录后的扣费逻辑:
|
// postUsageBilling is the legacy fallback billing path used when the unified
|
||||||
// - 订阅/余额扣费
|
// billing repo is unavailable (nil). Production uses applyUsageBilling → repo.Apply
|
||||||
// - API Key 配额更新
|
// for atomic billing. This path only runs in tests or degraded mode.
|
||||||
// - API Key 限速用量更新
|
|
||||||
// - 账号配额用量更新(账号口径:TotalCost × 账号计费倍率)
|
|
||||||
func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) {
|
func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *billingDeps) {
|
||||||
billingCtx, cancel := detachedBillingContext(ctx)
|
billingCtx, cancel := detachedBillingContext(ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
cost := p.Cost
|
cost := p.Cost
|
||||||
|
|
||||||
// 1. 订阅 / 余额扣费
|
|
||||||
if p.IsSubscriptionBill {
|
if p.IsSubscriptionBill {
|
||||||
if cost.TotalCost > 0 {
|
if cost.TotalCost > 0 {
|
||||||
if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.TotalCost); err != nil {
|
if err := deps.userSubRepo.IncrementUsage(billingCtx, p.Subscription.ID, cost.TotalCost); err != nil {
|
||||||
slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err)
|
slog.Error("increment subscription usage failed", "subscription_id", p.Subscription.ID, "error", err)
|
||||||
}
|
}
|
||||||
deps.billingCacheService.QueueUpdateSubscriptionUsage(p.User.ID, *p.APIKey.GroupID, cost.TotalCost)
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if cost.ActualCost > 0 {
|
if cost.ActualCost > 0 {
|
||||||
if err := deps.userRepo.DeductBalance(billingCtx, p.User.ID, cost.ActualCost); err != nil {
|
if err := deps.userRepo.DeductBalance(billingCtx, p.User.ID, cost.ActualCost); err != nil {
|
||||||
slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err)
|
slog.Error("deduct balance failed", "user_id", p.User.ID, "error", err)
|
||||||
}
|
}
|
||||||
deps.billingCacheService.QueueDeductBalance(p.User.ID, cost.ActualCost)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. API Key 配额
|
|
||||||
if p.shouldDeductAPIKeyQuota() {
|
if p.shouldDeductAPIKeyQuota() {
|
||||||
if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
|
if err := p.APIKeyService.UpdateQuotaUsed(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
|
||||||
slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err)
|
slog.Error("update api key quota failed", "api_key_id", p.APIKey.ID, "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. API Key 限速用量
|
|
||||||
if p.shouldUpdateRateLimits() {
|
if p.shouldUpdateRateLimits() {
|
||||||
if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
|
if err := p.APIKeyService.UpdateRateLimitUsage(billingCtx, p.APIKey.ID, cost.ActualCost); err != nil {
|
||||||
slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err)
|
slog.Error("update api key rate limit usage failed", "api_key_id", p.APIKey.ID, "error", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. 账号配额用量(账号口径:TotalCost × 账号计费倍率)
|
|
||||||
if p.shouldUpdateAccountQuota() {
|
if p.shouldUpdateAccountQuota() {
|
||||||
accountCost := cost.TotalCost * p.AccountRateMultiplier
|
accountCost := cost.TotalCost * p.AccountRateMultiplier
|
||||||
if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil {
|
if err := deps.accountRepo.IncrementQuotaUsed(billingCtx, p.Account.ID, accountCost); err != nil {
|
||||||
@@ -7196,7 +7382,10 @@ func postUsageBilling(ctx context.Context, p *postUsageBillingParams, deps *bill
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
finalizePostUsageBilling(p, deps)
|
// NOTE: finalizePostUsageBilling is NOT called here to avoid double-queuing
|
||||||
|
// cache updates. The legacy path does DB writes directly; the finalize path
|
||||||
|
// does cache queue + notifications. Notifications are dispatched separately
|
||||||
|
// by the caller after recording the usage log.
|
||||||
}
|
}
|
||||||
|
|
||||||
func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string) string {
|
func resolveUsageBillingRequestID(ctx context.Context, upstreamRequestID string) string {
|
||||||
@@ -7250,9 +7439,6 @@ func buildUsageBillingCommand(requestID string, usageLog *UsageLog, p *postUsage
|
|||||||
cmd.CacheCreationTokens = usageLog.CacheCreationTokens
|
cmd.CacheCreationTokens = usageLog.CacheCreationTokens
|
||||||
cmd.CacheReadTokens = usageLog.CacheReadTokens
|
cmd.CacheReadTokens = usageLog.CacheReadTokens
|
||||||
cmd.ImageCount = usageLog.ImageCount
|
cmd.ImageCount = usageLog.ImageCount
|
||||||
if usageLog.MediaType != nil {
|
|
||||||
cmd.MediaType = *usageLog.MediaType
|
|
||||||
}
|
|
||||||
if usageLog.ServiceTier != nil {
|
if usageLog.ServiceTier != nil {
|
||||||
cmd.ServiceTier = *usageLog.ServiceTier
|
cmd.ServiceTier = *usageLog.ServiceTier
|
||||||
}
|
}
|
||||||
@@ -7315,11 +7501,11 @@ func applyUsageBilling(ctx context.Context, requestID string, usageLog *UsageLog
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
finalizePostUsageBilling(p, deps)
|
finalizePostUsageBilling(p, deps, result)
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) {
|
func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) {
|
||||||
if p == nil || p.Cost == nil || deps == nil {
|
if p == nil || p.Cost == nil || deps == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -7338,22 +7524,82 @@ func finalizePostUsageBilling(p *postUsageBillingParams, deps *billingDeps) {
|
|||||||
|
|
||||||
deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID)
|
deps.deferredService.ScheduleLastUsedUpdate(p.Account.ID)
|
||||||
|
|
||||||
// Balance low notification — use real-time balance from billing cache (not stale snapshot)
|
// Notification checks run async — all parameters are already captured,
|
||||||
if !p.IsSubscriptionBill && p.Cost.ActualCost > 0 && p.User != nil && deps.balanceNotifyService != nil {
|
// no dependency on the request context or upstream connection.
|
||||||
oldBalance := p.User.Balance // fallback to snapshot
|
go notifyBalanceLow(p, deps, result)
|
||||||
if deps.billingCacheService != nil {
|
go notifyAccountQuota(p, deps, result)
|
||||||
if realBalance, err := deps.billingCacheService.GetUserBalance(context.Background(), p.User.ID); err == nil {
|
|
||||||
oldBalance = realBalance + p.Cost.ActualCost // DB already deducted, reconstruct pre-deduction balance
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// notifyBalanceLow sends balance low notification after deduction.
|
||||||
|
// When result.NewBalance is available (from DB transaction RETURNING), it is used directly
|
||||||
|
// to reconstruct oldBalance, avoiding stale Redis reads and concurrent-deduction races.
|
||||||
|
func notifyBalanceLow(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
slog.Error("panic in notifyBalanceLow", "recover", r)
|
||||||
}
|
}
|
||||||
|
}()
|
||||||
|
if p.IsSubscriptionBill || p.Cost.ActualCost <= 0 || p.User == nil || deps.balanceNotifyService == nil {
|
||||||
|
slog.Debug("notifyBalanceLow: skipped",
|
||||||
|
"is_subscription", p.IsSubscriptionBill,
|
||||||
|
"actual_cost", p.Cost.ActualCost,
|
||||||
|
"user_nil", p.User == nil,
|
||||||
|
"service_nil", deps.balanceNotifyService == nil,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
oldBalance := resolveOldBalance(p, result)
|
||||||
|
slog.Debug("notifyBalanceLow: calling CheckBalanceAfterDeduction",
|
||||||
|
"user_id", p.User.ID,
|
||||||
|
"old_balance", oldBalance,
|
||||||
|
"cost", p.Cost.ActualCost,
|
||||||
|
"notify_enabled", p.User.BalanceNotifyEnabled,
|
||||||
|
"threshold", p.User.BalanceNotifyThreshold,
|
||||||
|
"result_has_new_balance", result != nil && result.NewBalance != nil,
|
||||||
|
)
|
||||||
deps.balanceNotifyService.CheckBalanceAfterDeduction(context.Background(), p.User, oldBalance, p.Cost.ActualCost)
|
deps.balanceNotifyService.CheckBalanceAfterDeduction(context.Background(), p.User, oldBalance, p.Cost.ActualCost)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Account quota notification (use same cost formula as postUsageBilling)
|
// resolveOldBalance returns the pre-deduction balance.
|
||||||
if p.Cost.TotalCost > 0 && p.Account != nil && p.Account.IsAPIKeyOrBedrock() && deps.balanceNotifyService != nil {
|
// Prefers the DB transaction result (newBalance + cost) over snapshot.
|
||||||
accountCost := p.Cost.TotalCost * p.AccountRateMultiplier
|
func resolveOldBalance(p *postUsageBillingParams, result *UsageBillingApplyResult) float64 {
|
||||||
deps.balanceNotifyService.CheckAccountQuotaAfterIncrement(context.Background(), p.Account, accountCost)
|
if result != nil && result.NewBalance != nil {
|
||||||
|
return *result.NewBalance + p.Cost.ActualCost
|
||||||
}
|
}
|
||||||
|
// Legacy fallback: snapshot balance from request context
|
||||||
|
return p.User.Balance
|
||||||
|
}
|
||||||
|
|
||||||
|
// notifyAccountQuota sends account quota threshold notification after increment.
|
||||||
|
// When result.QuotaState is available (from DB transaction RETURNING), it is passed directly
|
||||||
|
// to avoid a separate DB read that may see stale or concurrently-modified data.
|
||||||
|
func notifyAccountQuota(p *postUsageBillingParams, deps *billingDeps, result *UsageBillingApplyResult) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
slog.Error("panic in notifyAccountQuota", "recover", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
if p.Cost.TotalCost <= 0 || p.Account == nil || !p.Account.IsAPIKeyOrBedrock() || deps.balanceNotifyService == nil {
|
||||||
|
slog.Debug("notifyAccountQuota: skipped",
|
||||||
|
"total_cost", p.Cost.TotalCost,
|
||||||
|
"account_nil", p.Account == nil,
|
||||||
|
"is_apikey_or_bedrock", p.Account != nil && p.Account.IsAPIKeyOrBedrock(),
|
||||||
|
"service_nil", deps.balanceNotifyService == nil,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
accountCost := p.Cost.TotalCost * p.AccountRateMultiplier
|
||||||
|
var quotaState *AccountQuotaState
|
||||||
|
if result != nil {
|
||||||
|
quotaState = result.QuotaState
|
||||||
|
}
|
||||||
|
slog.Debug("notifyAccountQuota: calling CheckAccountQuotaAfterIncrement",
|
||||||
|
"account_id", p.Account.ID,
|
||||||
|
"account_cost", accountCost,
|
||||||
|
"has_quota_state", quotaState != nil,
|
||||||
|
)
|
||||||
|
deps.balanceNotifyService.CheckAccountQuotaAfterIncrement(context.Background(), p.Account, accountCost, quotaState)
|
||||||
}
|
}
|
||||||
|
|
||||||
func detachedBillingContext(ctx context.Context) (context.Context, context.CancelFunc) {
|
func detachedBillingContext(ctx context.Context) (context.Context, context.CancelFunc) {
|
||||||
@@ -7422,11 +7668,11 @@ func writeUsageLogBestEffort(ctx context.Context, repo UsageLogRepository, usage
|
|||||||
|
|
||||||
// recordUsageOpts 内部选项,参数化 RecordUsage 与 RecordUsageWithLongContext 的差异点。
|
// recordUsageOpts 内部选项,参数化 RecordUsage 与 RecordUsageWithLongContext 的差异点。
|
||||||
type recordUsageOpts struct {
|
type recordUsageOpts struct {
|
||||||
// ParsedRequest(可选,仅 Claude 路径传入)
|
// Claude Max 策略所需的 ParsedRequest(可选,仅 Claude 路径传入)
|
||||||
ParsedRequest *ParsedRequest
|
ParsedRequest *ParsedRequest
|
||||||
|
|
||||||
// EnableClaudePath 启用 Claude 路径特有逻辑:
|
// EnableClaudePath 启用 Claude 路径特有逻辑:
|
||||||
// - MediaType 字段写入使用日志
|
// - Claude Max 缓存计费策略
|
||||||
EnableClaudePath bool
|
EnableClaudePath bool
|
||||||
|
|
||||||
// 长上下文计费(仅 Gemini 路径需要)
|
// 长上下文计费(仅 Gemini 路径需要)
|
||||||
@@ -7451,7 +7697,6 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
APIKeyService: input.APIKeyService,
|
APIKeyService: input.APIKeyService,
|
||||||
ChannelUsageFields: input.ChannelUsageFields,
|
ChannelUsageFields: input.ChannelUsageFields,
|
||||||
}, &recordUsageOpts{
|
}, &recordUsageOpts{
|
||||||
ParsedRequest: input.ParsedRequest,
|
|
||||||
EnableClaudePath: true,
|
EnableClaudePath: true,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -7517,6 +7762,7 @@ type recordUsageCoreInput struct {
|
|||||||
|
|
||||||
// recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。
|
// recordUsageCore 是 RecordUsage 和 RecordUsageWithLongContext 的统一实现。
|
||||||
// opts 中的字段控制两者之间的差异行为:
|
// opts 中的字段控制两者之间的差异行为:
|
||||||
|
// - ParsedRequest != nil → 启用 Claude Max 缓存计费策略
|
||||||
// - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext
|
// - LongContextThreshold > 0 → Token 计费回退走 CalculateCostWithLongContext
|
||||||
func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error {
|
func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsageCoreInput, opts *recordUsageOpts) error {
|
||||||
result := input.Result
|
result := input.Result
|
||||||
@@ -7583,13 +7829,10 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
|
|||||||
|
|
||||||
// 计算账号统计定价费用(使用最终上游模型匹配自定义规则)
|
// 计算账号统计定价费用(使用最终上游模型匹配自定义规则)
|
||||||
if apiKey.GroupID != nil {
|
if apiKey.GroupID != nil {
|
||||||
upstreamModel := result.UpstreamModel
|
applyAccountStatsCost(ctx, usageLog, s.channelService, s.billingService,
|
||||||
if upstreamModel == "" {
|
account.ID, *apiKey.GroupID, result.UpstreamModel, result.Model,
|
||||||
upstreamModel = result.Model
|
// Anthropic's input_tokens excludes cache_read and cache_creation (billed separately);
|
||||||
}
|
// OpenAI gateway uses actualInputTokens which also excludes cache_read for the same reason.
|
||||||
usageLog.AccountStatsCost = resolveAccountStatsCost(
|
|
||||||
ctx, s.channelService, s.billingService,
|
|
||||||
account.ID, *apiKey.GroupID, upstreamModel,
|
|
||||||
UsageTokens{
|
UsageTokens{
|
||||||
InputTokens: result.Usage.InputTokens,
|
InputTokens: result.Usage.InputTokens,
|
||||||
OutputTokens: result.Usage.OutputTokens,
|
OutputTokens: result.Usage.OutputTokens,
|
||||||
@@ -7597,7 +7840,6 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage
|
|||||||
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
CacheReadTokens: result.Usage.CacheReadInputTokens,
|
||||||
ImageOutputTokens: result.Usage.ImageOutputTokens,
|
ImageOutputTokens: result.Usage.ImageOutputTokens,
|
||||||
},
|
},
|
||||||
1, // requestCount
|
|
||||||
cost.TotalCost,
|
cost.TotalCost,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
@@ -7796,13 +8038,12 @@ func (s *GatewayService) buildRecordUsageLog(
|
|||||||
RateMultiplier: multiplier,
|
RateMultiplier: multiplier,
|
||||||
AccountRateMultiplier: &accountRateMultiplier,
|
AccountRateMultiplier: &accountRateMultiplier,
|
||||||
BillingType: billingType,
|
BillingType: billingType,
|
||||||
BillingMode: resolveBillingMode(opts, result, cost),
|
BillingMode: resolveBillingMode(result, cost),
|
||||||
Stream: result.Stream,
|
Stream: result.Stream,
|
||||||
DurationMs: &durationMs,
|
DurationMs: &durationMs,
|
||||||
FirstTokenMs: result.FirstTokenMs,
|
FirstTokenMs: result.FirstTokenMs,
|
||||||
ImageCount: result.ImageCount,
|
ImageCount: result.ImageCount,
|
||||||
ImageSize: optionalTrimmedStringPtr(result.ImageSize),
|
ImageSize: optionalTrimmedStringPtr(result.ImageSize),
|
||||||
MediaType: resolveMediaType(opts, result),
|
|
||||||
CacheTTLOverridden: cacheTTLOverridden,
|
CacheTTLOverridden: cacheTTLOverridden,
|
||||||
ChannelID: optionalInt64Ptr(input.ChannelID),
|
ChannelID: optionalInt64Ptr(input.ChannelID),
|
||||||
ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain),
|
ModelMappingChain: optionalTrimmedStringPtr(input.ModelMappingChain),
|
||||||
@@ -7826,7 +8067,7 @@ func (s *GatewayService) buildRecordUsageLog(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// resolveBillingMode 根据计费结果和请求类型确定计费模式。
|
// resolveBillingMode 根据计费结果和请求类型确定计费模式。
|
||||||
func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *CostBreakdown) *string {
|
func resolveBillingMode(result *ForwardResult, cost *CostBreakdown) *string {
|
||||||
var mode string
|
var mode string
|
||||||
switch {
|
switch {
|
||||||
case cost != nil && cost.BillingMode != "":
|
case cost != nil && cost.BillingMode != "":
|
||||||
@@ -7839,10 +8080,6 @@ func resolveBillingMode(opts *recordUsageOpts, result *ForwardResult, cost *Cost
|
|||||||
return &mode
|
return &mode
|
||||||
}
|
}
|
||||||
|
|
||||||
func resolveMediaType(opts *recordUsageOpts, result *ForwardResult) *string {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func optionalSubscriptionID(subscription *UserSubscription) *int64 {
|
func optionalSubscriptionID(subscription *UserSubscription) *int64 {
|
||||||
if subscription != nil {
|
if subscription != nil {
|
||||||
return &subscription.ID
|
return &subscription.ID
|
||||||
@@ -8349,9 +8586,9 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
|||||||
|
|
||||||
// OAuth 账号:应用统一指纹和重写 userID(受设置开关控制)
|
// OAuth 账号:应用统一指纹和重写 userID(受设置开关控制)
|
||||||
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
|
// 如果启用了会话ID伪装,会在重写后替换 session 部分为固定值
|
||||||
ctEnableFP, ctEnableMPT := true, false
|
ctEnableFP, ctEnableMPT, ctEnableCCH := true, false, false
|
||||||
if s.settingService != nil {
|
if s.settingService != nil {
|
||||||
ctEnableFP, ctEnableMPT = s.settingService.GetGatewayForwardingSettings(ctx)
|
ctEnableFP, ctEnableMPT, ctEnableCCH = s.settingService.GetGatewayForwardingSettings(ctx)
|
||||||
}
|
}
|
||||||
var ctFingerprint *Fingerprint
|
var ctFingerprint *Fingerprint
|
||||||
if account.IsOAuth() && s.identityService != nil {
|
if account.IsOAuth() && s.identityService != nil {
|
||||||
@@ -8369,6 +8606,14 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 同步 billing header cc_version 与实际发送的 User-Agent 版本
|
||||||
|
if ctFingerprint != nil && ctEnableFP {
|
||||||
|
body = syncBillingHeaderVersion(body, ctFingerprint.UserAgent)
|
||||||
|
}
|
||||||
|
if ctEnableCCH {
|
||||||
|
body = signBillingHeaderCCH(body)
|
||||||
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
|
req, err := http.NewRequestWithContext(ctx, "POST", targetURL, bytes.NewReader(body))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -8409,7 +8654,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Build effective drop set for count_tokens: merge static defaults with dynamic beta policy filter rules
|
// Build effective drop set for count_tokens: merge static defaults with dynamic beta policy filter rules
|
||||||
ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account))
|
ctEffectiveDropSet := mergeDropSets(s.getBetaPolicyFilterSet(ctx, c, account, modelID))
|
||||||
|
|
||||||
// OAuth 账号:处理 anthropic-beta header
|
// OAuth 账号:处理 anthropic-beta header
|
||||||
if tokenType == "oauth" {
|
if tokenType == "oauth" {
|
||||||
|
|||||||
@@ -128,3 +128,66 @@ func TestValidatePlanPatch_NilOriginalPrice(t *testing.T) {
|
|||||||
err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: nil})
|
err := validatePlanPatch(UpdatePlanRequest{OriginalPrice: nil})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- validatePlanPatch: other fields ---
|
||||||
|
|
||||||
|
func ptrStr(s string) *string { return &s }
|
||||||
|
func ptrInt(i int) *int { return &i }
|
||||||
|
func ptrInt64(i int64) *int64 { return &i }
|
||||||
|
func ptrFloat(f float64) *float64 { return &f }
|
||||||
|
|
||||||
|
func TestValidatePlanPatch_EmptyName(t *testing.T) {
|
||||||
|
err := validatePlanPatch(UpdatePlanRequest{Name: ptrStr("")})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "plan name")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidatePlanPatch_ValidName(t *testing.T) {
|
||||||
|
err := validatePlanPatch(UpdatePlanRequest{Name: ptrStr("Basic")})
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidatePlanPatch_ZeroGroupID(t *testing.T) {
|
||||||
|
err := validatePlanPatch(UpdatePlanRequest{GroupID: ptrInt64(0)})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "group")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidatePlanPatch_NegativePrice(t *testing.T) {
|
||||||
|
err := validatePlanPatch(UpdatePlanRequest{Price: ptrFloat(-1)})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "price")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidatePlanPatch_ZeroPrice(t *testing.T) {
|
||||||
|
err := validatePlanPatch(UpdatePlanRequest{Price: ptrFloat(0)})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "price")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidatePlanPatch_ValidPrice(t *testing.T) {
|
||||||
|
err := validatePlanPatch(UpdatePlanRequest{Price: ptrFloat(9.99)})
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidatePlanPatch_ZeroValidityDays(t *testing.T) {
|
||||||
|
err := validatePlanPatch(UpdatePlanRequest{ValidityDays: ptrInt(0)})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "validity days")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidatePlanPatch_EmptyValidityUnit(t *testing.T) {
|
||||||
|
err := validatePlanPatch(UpdatePlanRequest{ValidityUnit: ptrStr("")})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "validity unit")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidatePlanPatch_ValidValidityUnit(t *testing.T) {
|
||||||
|
err := validatePlanPatch(UpdatePlanRequest{ValidityUnit: ptrStr("days")})
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidatePlanPatch_AllNil(t *testing.T) {
|
||||||
|
err := validatePlanPatch(UpdatePlanRequest{})
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|||||||
@@ -330,6 +330,7 @@ func saveNotifyVerifyCode(ctx context.Context, cache EmailCache, email, code str
|
|||||||
Code: code,
|
Code: code,
|
||||||
Attempts: 0,
|
Attempts: 0,
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
|
ExpiresAt: time.Now().Add(verifyCodeTTL),
|
||||||
}
|
}
|
||||||
if err := cache.SetNotifyVerifyCode(ctx, email, data, verifyCodeTTL); err != nil {
|
if err := cache.SetNotifyVerifyCode(ctx, email, data, verifyCodeTTL); err != nil {
|
||||||
return fmt.Errorf("save verify code: %w", err)
|
return fmt.Errorf("save verify code: %w", err)
|
||||||
@@ -370,7 +371,11 @@ func verifyNotifyCode(ctx context.Context, cache EmailCache, email, code string)
|
|||||||
}
|
}
|
||||||
if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 {
|
if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 {
|
||||||
data.Attempts++
|
data.Attempts++
|
||||||
if err := cache.SetNotifyVerifyCode(ctx, email, data, verifyCodeTTL); err != nil {
|
remaining := time.Until(data.ExpiresAt)
|
||||||
|
if remaining <= 0 {
|
||||||
|
return ErrInvalidVerifyCode
|
||||||
|
}
|
||||||
|
if err := cache.SetNotifyVerifyCode(ctx, email, data, remaining); err != nil {
|
||||||
slog.Error("failed to update notify verify code attempts", "email", email, "error", err)
|
slog.Error("failed to update notify verify code attempts", "email", email, "error", err)
|
||||||
}
|
}
|
||||||
if data.Attempts >= maxVerifyCodeAttempts {
|
if data.Attempts >= maxVerifyCodeAttempts {
|
||||||
@@ -418,11 +423,17 @@ func (s *UserService) RemoveNotifyEmail(ctx context.Context, userID int64, email
|
|||||||
}
|
}
|
||||||
|
|
||||||
filtered := make([]NotifyEmailEntry, 0, len(user.BalanceNotifyExtraEmails))
|
filtered := make([]NotifyEmailEntry, 0, len(user.BalanceNotifyExtraEmails))
|
||||||
|
found := false
|
||||||
for _, e := range user.BalanceNotifyExtraEmails {
|
for _, e := range user.BalanceNotifyExtraEmails {
|
||||||
if !strings.EqualFold(e.Email, email) {
|
if strings.EqualFold(e.Email, email) {
|
||||||
|
found = true
|
||||||
|
} else {
|
||||||
filtered = append(filtered, e)
|
filtered = append(filtered, e)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if !found {
|
||||||
|
return infraerrors.BadRequest("EMAIL_NOT_FOUND", "notification email not found")
|
||||||
|
}
|
||||||
user.BalanceNotifyExtraEmails = filtered
|
user.BalanceNotifyExtraEmails = filtered
|
||||||
return s.userRepo.Update(ctx, user)
|
return s.userRepo.Update(ctx, user)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { useI18n } from 'vue-i18n'
|
import { useI18n } from 'vue-i18n'
|
||||||
import QuotaNotifyToggle from './QuotaNotifyToggle.vue'
|
import QuotaNotifyToggle from './QuotaNotifyToggle.vue'
|
||||||
|
import type { QuotaThresholdType, QuotaResetMode } from '@/constants/account'
|
||||||
|
|
||||||
const { t } = useI18n()
|
const { t } = useI18n()
|
||||||
|
|
||||||
@@ -11,9 +12,9 @@ const props = defineProps<{
|
|||||||
quotaNotifyGlobalEnabled: boolean
|
quotaNotifyGlobalEnabled: boolean
|
||||||
notifyEnabled: boolean | null
|
notifyEnabled: boolean | null
|
||||||
notifyThreshold: number | null
|
notifyThreshold: number | null
|
||||||
notifyThresholdType: string | null
|
notifyThresholdType: QuotaThresholdType | null
|
||||||
// Reset mode (only for daily/weekly, null for total)
|
// Reset mode (only for daily/weekly, null for total)
|
||||||
resetMode: 'rolling' | 'fixed' | null
|
resetMode: QuotaResetMode | null
|
||||||
resetHour: number | null
|
resetHour: number | null
|
||||||
resetDay: number | null // weekly only
|
resetDay: number | null // weekly only
|
||||||
resetTimezone: string | null
|
resetTimezone: string | null
|
||||||
@@ -22,14 +23,15 @@ const props = defineProps<{
|
|||||||
// Shared options passed from parent
|
// Shared options passed from parent
|
||||||
hourOptions: number[]
|
hourOptions: number[]
|
||||||
dayOptions: { value: number; key: string }[]
|
dayOptions: { value: number; key: string }[]
|
||||||
|
timezoneOptions?: string[]
|
||||||
}>()
|
}>()
|
||||||
|
|
||||||
const emit = defineEmits<{
|
const emit = defineEmits<{
|
||||||
'update:limit': [value: number | null]
|
'update:limit': [value: number | null]
|
||||||
'update:notifyEnabled': [value: boolean | null]
|
'update:notifyEnabled': [value: boolean | null]
|
||||||
'update:notifyThreshold': [value: number | null]
|
'update:notifyThreshold': [value: number | null]
|
||||||
'update:notifyThresholdType': [value: string | null]
|
'update:notifyThresholdType': [value: QuotaThresholdType | null]
|
||||||
'update:resetMode': [value: 'rolling' | 'fixed' | null]
|
'update:resetMode': [value: QuotaResetMode | null]
|
||||||
'update:resetHour': [value: number | null]
|
'update:resetHour': [value: number | null]
|
||||||
'update:resetDay': [value: number | null]
|
'update:resetDay': [value: number | null]
|
||||||
'update:resetTimezone': [value: string | null]
|
'update:resetTimezone': [value: string | null]
|
||||||
@@ -43,7 +45,7 @@ const onLimitInput = (e: Event) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const onModeChange = (e: Event) => {
|
const onModeChange = (e: Event) => {
|
||||||
const val = (e.target as HTMLSelectElement).value as 'rolling' | 'fixed'
|
const val = (e.target as HTMLSelectElement).value as QuotaResetMode
|
||||||
emit('update:resetMode', val)
|
emit('update:resetMode', val)
|
||||||
if (val === 'fixed') {
|
if (val === 'fixed') {
|
||||||
if (props.resetHour == null) emit('update:resetHour', 0)
|
if (props.resetHour == null) emit('update:resetHour', 0)
|
||||||
@@ -51,6 +53,17 @@ const onModeChange = (e: Event) => {
|
|||||||
if (!props.resetTimezone) emit('update:resetTimezone', 'UTC')
|
if (!props.resetTimezone) emit('update:resetTimezone', 'UTC')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function getTimezoneOffsetLabel(tz: string): string {
|
||||||
|
try {
|
||||||
|
const dtf = new Intl.DateTimeFormat('en-US', { timeZone: tz, timeZoneName: 'shortOffset' })
|
||||||
|
const parts = dtf.formatToParts(new Date())
|
||||||
|
const tzPart = parts.find(p => p.type === 'timeZoneName')
|
||||||
|
return tzPart ? (tzPart.value === 'GMT' ? 'GMT+0' : tzPart.value) : ''
|
||||||
|
} catch {
|
||||||
|
return ''
|
||||||
|
}
|
||||||
|
}
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<template>
|
<template>
|
||||||
@@ -95,6 +108,11 @@ const onModeChange = (e: Event) => {
|
|||||||
<select :value="resetHour ?? 0" @change="emit('update:resetHour', Number(($event.target as HTMLSelectElement).value))" class="input py-1 text-xs w-24">
|
<select :value="resetHour ?? 0" @change="emit('update:resetHour', Number(($event.target as HTMLSelectElement).value))" class="input py-1 text-xs w-24">
|
||||||
<option v-for="h in hourOptions" :key="h" :value="h">{{ String(h).padStart(2, '0') }}:00</option>
|
<option v-for="h in hourOptions" :key="h" :value="h">{{ String(h).padStart(2, '0') }}:00</option>
|
||||||
</select>
|
</select>
|
||||||
|
<template v-if="timezoneOptions && timezoneOptions.length > 0">
|
||||||
|
<select :value="resetTimezone || 'UTC'" @change="emit('update:resetTimezone', ($event.target as HTMLSelectElement).value)" class="input py-1 text-xs w-auto">
|
||||||
|
<option v-for="tz in timezoneOptions" :key="tz" :value="tz">{{ tz }} ({{ getTimezoneOffsetLabel(tz) }})</option>
|
||||||
|
</select>
|
||||||
|
</template>
|
||||||
</template>
|
</template>
|
||||||
<span class="text-[11px] text-gray-500 dark:text-gray-400">
|
<span class="text-[11px] text-gray-500 dark:text-gray-400">
|
||||||
<template v-if="resetMode === 'fixed'">{{ hintFixed }}</template>
|
<template v-if="resetMode === 'fixed'">{{ hintFixed }}</template>
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
import { ref, watch, computed } from 'vue'
|
import { ref, watch, computed } from 'vue'
|
||||||
import { useI18n } from 'vue-i18n'
|
import { useI18n } from 'vue-i18n'
|
||||||
import QuotaDimensionRow from './QuotaDimensionRow.vue'
|
import QuotaDimensionRow from './QuotaDimensionRow.vue'
|
||||||
|
import type { QuotaThresholdType, QuotaResetMode } from '@/constants/account'
|
||||||
|
|
||||||
const { t } = useI18n()
|
const { t } = useI18n()
|
||||||
|
|
||||||
@@ -9,22 +10,22 @@ const props = withDefaults(defineProps<{
|
|||||||
totalLimit: number | null
|
totalLimit: number | null
|
||||||
dailyLimit: number | null
|
dailyLimit: number | null
|
||||||
weeklyLimit: number | null
|
weeklyLimit: number | null
|
||||||
dailyResetMode: 'rolling' | 'fixed' | null
|
dailyResetMode: QuotaResetMode | null
|
||||||
dailyResetHour: number | null
|
dailyResetHour: number | null
|
||||||
weeklyResetMode: 'rolling' | 'fixed' | null
|
weeklyResetMode: QuotaResetMode | null
|
||||||
weeklyResetDay: number | null
|
weeklyResetDay: number | null
|
||||||
weeklyResetHour: number | null
|
weeklyResetHour: number | null
|
||||||
resetTimezone: string | null
|
resetTimezone: string | null
|
||||||
quotaNotifyGlobalEnabled?: boolean
|
quotaNotifyGlobalEnabled?: boolean
|
||||||
quotaNotifyDailyEnabled?: boolean | null
|
quotaNotifyDailyEnabled?: boolean | null
|
||||||
quotaNotifyDailyThreshold?: number | null
|
quotaNotifyDailyThreshold?: number | null
|
||||||
quotaNotifyDailyThresholdType?: string | null
|
quotaNotifyDailyThresholdType?: QuotaThresholdType | null
|
||||||
quotaNotifyWeeklyEnabled?: boolean | null
|
quotaNotifyWeeklyEnabled?: boolean | null
|
||||||
quotaNotifyWeeklyThreshold?: number | null
|
quotaNotifyWeeklyThreshold?: number | null
|
||||||
quotaNotifyWeeklyThresholdType?: string | null
|
quotaNotifyWeeklyThresholdType?: QuotaThresholdType | null
|
||||||
quotaNotifyTotalEnabled?: boolean | null
|
quotaNotifyTotalEnabled?: boolean | null
|
||||||
quotaNotifyTotalThreshold?: number | null
|
quotaNotifyTotalThreshold?: number | null
|
||||||
quotaNotifyTotalThresholdType?: string | null
|
quotaNotifyTotalThresholdType?: QuotaThresholdType | null
|
||||||
}>(), {
|
}>(), {
|
||||||
quotaNotifyGlobalEnabled: false,
|
quotaNotifyGlobalEnabled: false,
|
||||||
quotaNotifyDailyEnabled: null,
|
quotaNotifyDailyEnabled: null,
|
||||||
@@ -42,21 +43,21 @@ const emit = defineEmits<{
|
|||||||
'update:totalLimit': [value: number | null]
|
'update:totalLimit': [value: number | null]
|
||||||
'update:dailyLimit': [value: number | null]
|
'update:dailyLimit': [value: number | null]
|
||||||
'update:weeklyLimit': [value: number | null]
|
'update:weeklyLimit': [value: number | null]
|
||||||
'update:dailyResetMode': [value: 'rolling' | 'fixed' | null]
|
'update:dailyResetMode': [value: QuotaResetMode | null]
|
||||||
'update:dailyResetHour': [value: number | null]
|
'update:dailyResetHour': [value: number | null]
|
||||||
'update:weeklyResetMode': [value: 'rolling' | 'fixed' | null]
|
'update:weeklyResetMode': [value: QuotaResetMode | null]
|
||||||
'update:weeklyResetDay': [value: number | null]
|
'update:weeklyResetDay': [value: number | null]
|
||||||
'update:weeklyResetHour': [value: number | null]
|
'update:weeklyResetHour': [value: number | null]
|
||||||
'update:resetTimezone': [value: string | null]
|
'update:resetTimezone': [value: string | null]
|
||||||
'update:quotaNotifyDailyEnabled': [value: boolean | null]
|
'update:quotaNotifyDailyEnabled': [value: boolean | null]
|
||||||
'update:quotaNotifyDailyThreshold': [value: number | null]
|
'update:quotaNotifyDailyThreshold': [value: number | null]
|
||||||
'update:quotaNotifyDailyThresholdType': [value: string | null]
|
'update:quotaNotifyDailyThresholdType': [value: QuotaThresholdType | null]
|
||||||
'update:quotaNotifyWeeklyEnabled': [value: boolean | null]
|
'update:quotaNotifyWeeklyEnabled': [value: boolean | null]
|
||||||
'update:quotaNotifyWeeklyThreshold': [value: number | null]
|
'update:quotaNotifyWeeklyThreshold': [value: number | null]
|
||||||
'update:quotaNotifyWeeklyThresholdType': [value: string | null]
|
'update:quotaNotifyWeeklyThresholdType': [value: QuotaThresholdType | null]
|
||||||
'update:quotaNotifyTotalEnabled': [value: boolean | null]
|
'update:quotaNotifyTotalEnabled': [value: boolean | null]
|
||||||
'update:quotaNotifyTotalThreshold': [value: number | null]
|
'update:quotaNotifyTotalThreshold': [value: number | null]
|
||||||
'update:quotaNotifyTotalThresholdType': [value: string | null]
|
'update:quotaNotifyTotalThresholdType': [value: QuotaThresholdType | null]
|
||||||
}>()
|
}>()
|
||||||
|
|
||||||
const enabled = computed(() =>
|
const enabled = computed(() =>
|
||||||
@@ -89,11 +90,6 @@ watch(localEnabled, (val) => {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// Whether any fixed mode is active (to show timezone selector)
|
|
||||||
const hasFixedMode = computed(() =>
|
|
||||||
props.dailyResetMode === 'fixed' || props.weeklyResetMode === 'fixed'
|
|
||||||
)
|
|
||||||
|
|
||||||
// Common timezone options
|
// Common timezone options
|
||||||
const timezoneOptions = [
|
const timezoneOptions = [
|
||||||
'UTC', 'Asia/Shanghai', 'Asia/Tokyo', 'Asia/Seoul', 'Asia/Singapore', 'Asia/Kolkata',
|
'UTC', 'Asia/Shanghai', 'Asia/Tokyo', 'Asia/Seoul', 'Asia/Singapore', 'Asia/Kolkata',
|
||||||
@@ -102,18 +98,6 @@ const timezoneOptions = [
|
|||||||
'America/Sao_Paulo', 'Australia/Sydney', 'Pacific/Auckland',
|
'America/Sao_Paulo', 'Australia/Sydney', 'Pacific/Auckland',
|
||||||
]
|
]
|
||||||
|
|
||||||
// Compute GMT offset label (e.g. "GMT+8", "GMT-5") for a given IANA timezone.
|
|
||||||
function getTimezoneOffsetLabel(tz: string): string {
|
|
||||||
try {
|
|
||||||
const dtf = new Intl.DateTimeFormat('en-US', { timeZone: tz, timeZoneName: 'shortOffset' })
|
|
||||||
const parts = dtf.formatToParts(new Date())
|
|
||||||
const tzPart = parts.find(p => p.type === 'timeZoneName')
|
|
||||||
return tzPart ? (tzPart.value === 'GMT' ? 'GMT+0' : tzPart.value) : ''
|
|
||||||
} catch {
|
|
||||||
return ''
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Hours for dropdown (0-23)
|
// Hours for dropdown (0-23)
|
||||||
const hourOptions = Array.from({ length: 24 }, (_, i) => i)
|
const hourOptions = Array.from({ length: 24 }, (_, i) => i)
|
||||||
|
|
||||||
@@ -197,6 +181,7 @@ const dailyFixedHint = computed(() =>
|
|||||||
:hint-fixed="dailyFixedHint"
|
:hint-fixed="dailyFixedHint"
|
||||||
:hour-options="hourOptions"
|
:hour-options="hourOptions"
|
||||||
:day-options="dayOptions"
|
:day-options="dayOptions"
|
||||||
|
:timezone-options="timezoneOptions"
|
||||||
@update:limit="emit('update:dailyLimit', $event)"
|
@update:limit="emit('update:dailyLimit', $event)"
|
||||||
@update:notify-enabled="emit('update:quotaNotifyDailyEnabled', $event)"
|
@update:notify-enabled="emit('update:quotaNotifyDailyEnabled', $event)"
|
||||||
@update:notify-threshold="emit('update:quotaNotifyDailyThreshold', $event)"
|
@update:notify-threshold="emit('update:quotaNotifyDailyThreshold', $event)"
|
||||||
@@ -223,6 +208,7 @@ const dailyFixedHint = computed(() =>
|
|||||||
:hint-fixed="weeklyFixedHint"
|
:hint-fixed="weeklyFixedHint"
|
||||||
:hour-options="hourOptions"
|
:hour-options="hourOptions"
|
||||||
:day-options="dayOptions"
|
:day-options="dayOptions"
|
||||||
|
:timezone-options="timezoneOptions"
|
||||||
@update:limit="emit('update:weeklyLimit', $event)"
|
@update:limit="emit('update:weeklyLimit', $event)"
|
||||||
@update:notify-enabled="emit('update:quotaNotifyWeeklyEnabled', $event)"
|
@update:notify-enabled="emit('update:quotaNotifyWeeklyEnabled', $event)"
|
||||||
@update:notify-threshold="emit('update:quotaNotifyWeeklyThreshold', $event)"
|
@update:notify-threshold="emit('update:quotaNotifyWeeklyThreshold', $event)"
|
||||||
@@ -233,14 +219,6 @@ const dailyFixedHint = computed(() =>
|
|||||||
@update:reset-timezone="emit('update:resetTimezone', $event)"
|
@update:reset-timezone="emit('update:resetTimezone', $event)"
|
||||||
/>
|
/>
|
||||||
|
|
||||||
<!-- Timezone selector (shared by daily/weekly when fixed mode is active) -->
|
|
||||||
<div v-if="hasFixedMode">
|
|
||||||
<label class="input-label">{{ t('admin.accounts.quotaResetTimezone') }}</label>
|
|
||||||
<select :value="resetTimezone || 'UTC'" @change="emit('update:resetTimezone', ($event.target as HTMLSelectElement).value)" class="input text-sm">
|
|
||||||
<option v-for="tz in timezoneOptions" :key="tz" :value="tz">{{ tz }} ({{ getTimezoneOffsetLabel(tz) }})</option>
|
|
||||||
</select>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<!-- Total quota -->
|
<!-- Total quota -->
|
||||||
<QuotaDimensionRow
|
<QuotaDimensionRow
|
||||||
dim="total"
|
dim="total"
|
||||||
|
|||||||
@@ -1,16 +1,16 @@
|
|||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { QUOTA_THRESHOLD_TYPE_FIXED, QUOTA_THRESHOLD_TYPE_PERCENTAGE } from '@/constants/account'
|
import { QUOTA_THRESHOLD_TYPE_FIXED, QUOTA_THRESHOLD_TYPE_PERCENTAGE, type QuotaThresholdType } from '@/constants/account'
|
||||||
|
|
||||||
defineProps<{
|
defineProps<{
|
||||||
enabled: boolean | null
|
enabled: boolean | null
|
||||||
threshold: number | null
|
threshold: number | null
|
||||||
thresholdType: string | null // "fixed" (default) or "percentage"
|
thresholdType: QuotaThresholdType | null
|
||||||
}>()
|
}>()
|
||||||
|
|
||||||
const emit = defineEmits<{
|
const emit = defineEmits<{
|
||||||
'update:enabled': [value: boolean | null]
|
'update:enabled': [value: boolean | null]
|
||||||
'update:threshold': [value: number | null]
|
'update:threshold': [value: number | null]
|
||||||
'update:thresholdType': [value: string | null]
|
'update:thresholdType': [value: QuotaThresholdType | null]
|
||||||
}>()
|
}>()
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
@@ -43,7 +43,7 @@ const emit = defineEmits<{
|
|||||||
/>
|
/>
|
||||||
<select
|
<select
|
||||||
:value="thresholdType || QUOTA_THRESHOLD_TYPE_FIXED"
|
:value="thresholdType || QUOTA_THRESHOLD_TYPE_FIXED"
|
||||||
@change="emit('update:thresholdType', ($event.target as HTMLSelectElement).value)"
|
@change="emit('update:thresholdType', ($event.target as HTMLSelectElement).value as QuotaThresholdType)"
|
||||||
class="input py-1 text-xs w-[4.5rem] flex-shrink-0 text-center"
|
class="input py-1 text-xs w-[4.5rem] flex-shrink-0 text-center"
|
||||||
>
|
>
|
||||||
<option :value="QUOTA_THRESHOLD_TYPE_FIXED">$</option>
|
<option :value="QUOTA_THRESHOLD_TYPE_FIXED">$</option>
|
||||||
|
|||||||
@@ -313,10 +313,6 @@
|
|||||||
<span class="text-gray-400">{{ t('usage.rate') }}</span>
|
<span class="text-gray-400">{{ t('usage.rate') }}</span>
|
||||||
<span class="font-semibold text-blue-400">{{ formatMultiplier(tooltipData?.rate_multiplier || 1) }}x</span>
|
<span class="font-semibold text-blue-400">{{ formatMultiplier(tooltipData?.rate_multiplier || 1) }}x</span>
|
||||||
</div>
|
</div>
|
||||||
<div class="flex items-center justify-between gap-6">
|
|
||||||
<span class="text-gray-400">{{ t('usage.accountMultiplier') }}</span>
|
|
||||||
<span class="font-semibold text-blue-400">{{ formatMultiplier(tooltipData?.account_rate_multiplier ?? 1) }}x</span>
|
|
||||||
</div>
|
|
||||||
<div class="flex items-center justify-between gap-6">
|
<div class="flex items-center justify-between gap-6">
|
||||||
<span class="text-gray-400">{{ t('usage.original') }}</span>
|
<span class="text-gray-400">{{ t('usage.original') }}</span>
|
||||||
<span class="font-medium text-white">${{ tooltipData?.total_cost?.toFixed(6) || '0.000000' }}</span>
|
<span class="font-medium text-white">${{ tooltipData?.total_cost?.toFixed(6) || '0.000000' }}</span>
|
||||||
@@ -325,7 +321,12 @@
|
|||||||
<span class="text-gray-400">{{ t('usage.userBilled') }}</span>
|
<span class="text-gray-400">{{ t('usage.userBilled') }}</span>
|
||||||
<span class="font-semibold text-green-400">${{ tooltipData?.actual_cost?.toFixed(6) || '0.000000' }}</span>
|
<span class="font-semibold text-green-400">${{ tooltipData?.actual_cost?.toFixed(6) || '0.000000' }}</span>
|
||||||
</div>
|
</div>
|
||||||
|
<!-- Account billing (separated from user billing) -->
|
||||||
<div class="flex items-center justify-between gap-6 border-t border-gray-700 pt-1.5">
|
<div class="flex items-center justify-between gap-6 border-t border-gray-700 pt-1.5">
|
||||||
|
<span class="text-gray-400">{{ t('usage.accountMultiplier') }}</span>
|
||||||
|
<span class="font-semibold text-blue-400">{{ formatMultiplier(tooltipData?.account_rate_multiplier ?? 1) }}x</span>
|
||||||
|
</div>
|
||||||
|
<div class="flex items-center justify-between gap-6">
|
||||||
<span class="text-gray-400">{{ t('usage.accountBilled') }}</span>
|
<span class="text-gray-400">{{ t('usage.accountBilled') }}</span>
|
||||||
<span class="font-semibold text-green-400">
|
<span class="font-semibold text-green-400">
|
||||||
${{ accountBilled({
|
${{ accountBilled({
|
||||||
@@ -355,7 +356,8 @@ import { getBillingModeLabel, getBillingModeBadgeClass, BILLING_MODE_TOKEN, BILL
|
|||||||
/** Compute the account-billed cost for display: (account_stats_cost ?? total_cost) * rate_multiplier */
|
/** Compute the account-billed cost for display: (account_stats_cost ?? total_cost) * rate_multiplier */
|
||||||
function accountBilled(row: { total_cost?: number | null; account_stats_cost?: number | null; account_rate_multiplier?: number | null }): number {
|
function accountBilled(row: { total_cost?: number | null; account_stats_cost?: number | null; account_rate_multiplier?: number | null }): number {
|
||||||
const base = row.account_stats_cost != null ? row.account_stats_cost : (row.total_cost ?? 0)
|
const base = row.account_stats_cost != null ? row.account_stats_cost : (row.total_cost ?? 0)
|
||||||
return base * (row.account_rate_multiplier ?? 1)
|
const result = base * (row.account_rate_multiplier ?? 1)
|
||||||
|
return Number.isNaN(result) ? 0 : result
|
||||||
}
|
}
|
||||||
|
|
||||||
import DataTable from '@/components/common/DataTable.vue'
|
import DataTable from '@/components/common/DataTable.vue'
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import { reactive, ref } from 'vue'
|
import { reactive, ref } from 'vue'
|
||||||
import { adminAPI } from '@/api/admin'
|
import { adminAPI } from '@/api/admin'
|
||||||
import { QUOTA_THRESHOLD_TYPE_FIXED } from '@/constants/account'
|
import { QUOTA_THRESHOLD_TYPE_FIXED, type QuotaThresholdType } from '@/constants/account'
|
||||||
|
|
||||||
export const QUOTA_NOTIFY_DIMS = ['daily', 'weekly', 'total'] as const
|
export const QUOTA_NOTIFY_DIMS = ['daily', 'weekly', 'total'] as const
|
||||||
export type QuotaNotifyDim = (typeof QUOTA_NOTIFY_DIMS)[number]
|
export type QuotaNotifyDim = (typeof QUOTA_NOTIFY_DIMS)[number]
|
||||||
@@ -8,7 +8,7 @@ export type QuotaNotifyDim = (typeof QUOTA_NOTIFY_DIMS)[number]
|
|||||||
interface DimState {
|
interface DimState {
|
||||||
enabled: boolean | null
|
enabled: boolean | null
|
||||||
threshold: number | null
|
threshold: number | null
|
||||||
thresholdType: string | null
|
thresholdType: QuotaThresholdType | null
|
||||||
}
|
}
|
||||||
|
|
||||||
export function useQuotaNotifyState() {
|
export function useQuotaNotifyState() {
|
||||||
@@ -34,7 +34,7 @@ export function useQuotaNotifyState() {
|
|||||||
for (const d of QUOTA_NOTIFY_DIMS) {
|
for (const d of QUOTA_NOTIFY_DIMS) {
|
||||||
state[d].enabled = (extra?.[`quota_notify_${d}_enabled`] as boolean) ?? null
|
state[d].enabled = (extra?.[`quota_notify_${d}_enabled`] as boolean) ?? null
|
||||||
state[d].threshold = (extra?.[`quota_notify_${d}_threshold`] as number) ?? null
|
state[d].threshold = (extra?.[`quota_notify_${d}_threshold`] as number) ?? null
|
||||||
state[d].thresholdType = (extra?.[`quota_notify_${d}_threshold_type`] as string) ?? null
|
state[d].thresholdType = (extra?.[`quota_notify_${d}_threshold_type`] as QuotaThresholdType) ?? null
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,3 +8,8 @@ export type WebSearchMode = typeof WEB_SEARCH_MODE_DEFAULT | typeof WEB_SEARCH_M
|
|||||||
export const QUOTA_THRESHOLD_TYPE_FIXED = 'fixed' as const
|
export const QUOTA_THRESHOLD_TYPE_FIXED = 'fixed' as const
|
||||||
export const QUOTA_THRESHOLD_TYPE_PERCENTAGE = 'percentage' as const
|
export const QUOTA_THRESHOLD_TYPE_PERCENTAGE = 'percentage' as const
|
||||||
export type QuotaThresholdType = typeof QUOTA_THRESHOLD_TYPE_FIXED | typeof QUOTA_THRESHOLD_TYPE_PERCENTAGE
|
export type QuotaThresholdType = typeof QUOTA_THRESHOLD_TYPE_FIXED | typeof QUOTA_THRESHOLD_TYPE_PERCENTAGE
|
||||||
|
|
||||||
|
/** Quota reset mode values */
|
||||||
|
export const QUOTA_RESET_MODE_ROLLING = 'rolling' as const
|
||||||
|
export const QUOTA_RESET_MODE_FIXED = 'fixed' as const
|
||||||
|
export type QuotaResetMode = typeof QUOTA_RESET_MODE_ROLLING | typeof QUOTA_RESET_MODE_FIXED
|
||||||
|
|||||||
@@ -166,8 +166,8 @@
|
|||||||
class="channel-tab group"
|
class="channel-tab group"
|
||||||
:class="activeTab === section.platform ? 'channel-tab-active' : 'channel-tab-inactive'"
|
:class="activeTab === section.platform ? 'channel-tab-active' : 'channel-tab-inactive'"
|
||||||
>
|
>
|
||||||
<PlatformIcon :platform="section.platform" size="xs" :class="getPlatformTextColor(section.platform)" />
|
<PlatformIcon :platform="section.platform" size="xs" :class="platformTextClass(section.platform)" />
|
||||||
<span :class="getPlatformTextColor(section.platform)">{{ t('admin.groups.platforms.' + section.platform, section.platform) }}</span>
|
<span :class="platformTextClass(section.platform)">{{ t('admin.groups.platforms.' + section.platform, section.platform) }}</span>
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
@@ -246,8 +246,8 @@
|
|||||||
class="h-3.5 w-3.5 rounded border-gray-300 text-primary-600 focus:ring-primary-500"
|
class="h-3.5 w-3.5 rounded border-gray-300 text-primary-600 focus:ring-primary-500"
|
||||||
@change="togglePlatform(p)"
|
@change="togglePlatform(p)"
|
||||||
/>
|
/>
|
||||||
<PlatformIcon :platform="p" size="xs" :class="getPlatformTextColor(p)" />
|
<PlatformIcon :platform="p" size="xs" :class="platformTextClass(p)" />
|
||||||
<span :class="getPlatformTextColor(p)">{{ t('admin.groups.platforms.' + p, p) }}</span>
|
<span :class="platformTextClass(p)">{{ t('admin.groups.platforms.' + p, p) }}</span>
|
||||||
</label>
|
</label>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -310,9 +310,9 @@
|
|||||||
class="h-3 w-3 rounded border-gray-300 text-primary-600 focus:ring-primary-500"
|
class="h-3 w-3 rounded border-gray-300 text-primary-600 focus:ring-primary-500"
|
||||||
@change="toggleGroupInSection(sIdx, group.id)"
|
@change="toggleGroupInSection(sIdx, group.id)"
|
||||||
/>
|
/>
|
||||||
<span :class="['font-medium', getPlatformTextColor(group.platform)]">{{ group.name }}</span>
|
<span :class="['font-medium', platformTextClass(group.platform)]">{{ group.name }}</span>
|
||||||
<span
|
<span
|
||||||
:class="['rounded-full px-1 py-0 text-[10px]', getRateBadgeClass(group.platform)]"
|
:class="['rounded-full px-1 py-0 text-[10px]', platformBadgeLightClass(group.platform)]"
|
||||||
>{{ group.rate_multiplier }}x</span>
|
>{{ group.rate_multiplier }}x</span>
|
||||||
<span class="text-[10px] text-gray-400">{{ group.account_count || 0 }}</span>
|
<span class="text-[10px] text-gray-400">{{ group.account_count || 0 }}</span>
|
||||||
<span
|
<span
|
||||||
@@ -363,7 +363,7 @@
|
|||||||
:value="srcModel"
|
:value="srcModel"
|
||||||
type="text"
|
type="text"
|
||||||
class="input flex-1 text-xs"
|
class="input flex-1 text-xs"
|
||||||
:class="getPlatformTextColor(section.platform)"
|
:class="platformTextClass(section.platform)"
|
||||||
:placeholder="t('admin.channels.form.mappingSource', 'Source model')"
|
:placeholder="t('admin.channels.form.mappingSource', 'Source model')"
|
||||||
@change="renameMappingKey(sIdx, srcModel, ($event.target as HTMLInputElement).value)"
|
@change="renameMappingKey(sIdx, srcModel, ($event.target as HTMLInputElement).value)"
|
||||||
/>
|
/>
|
||||||
@@ -372,7 +372,7 @@
|
|||||||
:value="section.model_mapping[srcModel]"
|
:value="section.model_mapping[srcModel]"
|
||||||
type="text"
|
type="text"
|
||||||
class="input flex-1 text-xs"
|
class="input flex-1 text-xs"
|
||||||
:class="getPlatformTextColor(section.platform)"
|
:class="platformTextClass(section.platform)"
|
||||||
:placeholder="t('admin.channels.form.mappingTarget', 'Target model')"
|
:placeholder="t('admin.channels.form.mappingTarget', 'Target model')"
|
||||||
@input="section.model_mapping[srcModel] = ($event.target as HTMLInputElement).value"
|
@input="section.model_mapping[srcModel] = ($event.target as HTMLInputElement).value"
|
||||||
/>
|
/>
|
||||||
@@ -464,7 +464,7 @@
|
|||||||
: 'border-gray-200 hover:bg-gray-50 dark:border-dark-600 dark:hover:bg-dark-700'"
|
: 'border-gray-200 hover:bg-gray-50 dark:border-dark-600 dark:hover:bg-dark-700'"
|
||||||
>
|
>
|
||||||
<input type="checkbox" :checked="rule.group_ids.includes(gid)" class="h-3 w-3 rounded border-gray-300 text-primary-600 focus:ring-primary-500" @change="rule.group_ids.includes(gid) ? rule.group_ids.splice(rule.group_ids.indexOf(gid), 1) : rule.group_ids.push(gid)" />
|
<input type="checkbox" :checked="rule.group_ids.includes(gid)" class="h-3 w-3 rounded border-gray-300 text-primary-600 focus:ring-primary-500" @change="rule.group_ids.includes(gid) ? rule.group_ids.splice(rule.group_ids.indexOf(gid), 1) : rule.group_ids.push(gid)" />
|
||||||
<span>{{ getGroupNameById(gid) }}</span>
|
<span :class="['font-medium', platformTextClass(section.platform)]">{{ getGroupNameById(gid) }}</span>
|
||||||
</label>
|
</label>
|
||||||
</div>
|
</div>
|
||||||
<p v-if="section.group_ids.length === 0" class="mt-1 text-xs text-gray-400">
|
<p v-if="section.group_ids.length === 0" class="mt-1 text-xs text-gray-400">
|
||||||
@@ -481,7 +481,7 @@
|
|||||||
:key="accountId"
|
:key="accountId"
|
||||||
class="inline-flex items-center gap-1 rounded-md border border-primary-300 bg-primary-50 px-2 py-0.5 text-xs dark:border-primary-700 dark:bg-primary-900/20"
|
class="inline-flex items-center gap-1 rounded-md border border-primary-300 bg-primary-50 px-2 py-0.5 text-xs dark:border-primary-700 dark:bg-primary-900/20"
|
||||||
>
|
>
|
||||||
<span>{{ getRuleAccountLabel(accountId) }}</span>
|
<span :class="['font-medium', platformTextClass(section.platform)]">{{ getRuleAccountLabel(accountId) }}</span>
|
||||||
<button type="button" @click="removeRuleAccount(rule, accountId)" class="text-gray-400 hover:text-red-500">
|
<button type="button" @click="removeRuleAccount(rule, accountId)" class="text-gray-400 hover:text-red-500">
|
||||||
<Icon name="x" size="xs" />
|
<Icon name="x" size="xs" />
|
||||||
</button>
|
</button>
|
||||||
@@ -595,7 +595,7 @@ import type { PricingFormEntry } from '@/components/admin/channel/types'
|
|||||||
import { mTokToPerToken, perTokenToMTok, apiIntervalsToForm, formIntervalsToAPI, findModelConflict, validateIntervals } from '@/components/admin/channel/types'
|
import { mTokToPerToken, perTokenToMTok, apiIntervalsToForm, formIntervalsToAPI, findModelConflict, validateIntervals } from '@/components/admin/channel/types'
|
||||||
import type { AdminGroup, GroupPlatform } from '@/types'
|
import type { AdminGroup, GroupPlatform } from '@/types'
|
||||||
import type { Column } from '@/components/common/types'
|
import type { Column } from '@/components/common/types'
|
||||||
import { platformTextClass } from '@/utils/platformColors'
|
import { platformTextClass, platformBadgeLightClass } from '@/utils/platformColors'
|
||||||
import AppLayout from '@/components/layout/AppLayout.vue'
|
import AppLayout from '@/components/layout/AppLayout.vue'
|
||||||
import TablePageLayout from '@/components/layout/TablePageLayout.vue'
|
import TablePageLayout from '@/components/layout/TablePageLayout.vue'
|
||||||
import DataTable from '@/components/common/DataTable.vue'
|
import DataTable from '@/components/common/DataTable.vue'
|
||||||
@@ -720,26 +720,6 @@ let abortController: AbortController | null = null
|
|||||||
// ── Platform config ──
|
// ── Platform config ──
|
||||||
const platformOrder: GroupPlatform[] = ['anthropic', 'openai', 'gemini', 'antigravity']
|
const platformOrder: GroupPlatform[] = ['anthropic', 'openai', 'gemini', 'antigravity']
|
||||||
|
|
||||||
function getPlatformTextColor(platform: string): string {
|
|
||||||
switch (platform) {
|
|
||||||
case 'anthropic': return 'text-orange-600 dark:text-orange-400'
|
|
||||||
case 'openai': return 'text-emerald-600 dark:text-emerald-400'
|
|
||||||
case 'gemini': return 'text-blue-600 dark:text-blue-400'
|
|
||||||
case 'antigravity': return 'text-purple-600 dark:text-purple-400'
|
|
||||||
default: return 'text-gray-600 dark:text-gray-400'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
function getRateBadgeClass(platform: string): string {
|
|
||||||
switch (platform) {
|
|
||||||
case 'anthropic': return 'bg-orange-100 text-orange-700 dark:bg-orange-900/30 dark:text-orange-400'
|
|
||||||
case 'openai': return 'bg-emerald-100 text-emerald-700 dark:bg-emerald-900/30 dark:text-emerald-400'
|
|
||||||
case 'gemini': return 'bg-blue-100 text-blue-700 dark:bg-blue-900/30 dark:text-blue-400'
|
|
||||||
case 'antigravity': return 'bg-purple-100 text-purple-700 dark:bg-purple-900/30 dark:text-purple-400'
|
|
||||||
default: return 'bg-gray-100 text-gray-700 dark:bg-gray-900/30 dark:text-gray-400'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ── Helpers ──
|
// ── Helpers ──
|
||||||
function formatDate(value: string): string {
|
function formatDate(value: string): string {
|
||||||
if (!value) return '-'
|
if (!value) return '-'
|
||||||
|
|||||||
Reference in New Issue
Block a user