diff --git a/backend/cmd/server/VERSION b/backend/cmd/server/VERSION index 061bed0b..b2581b19 100644 --- a/backend/cmd/server/VERSION +++ b/backend/cmd/server/VERSION @@ -1 +1 @@ -0.1.110.11 +0.1.110.20 diff --git a/backend/internal/handler/dto/notify_email_entry.go b/backend/internal/handler/dto/notify_email_entry.go index 180f8b25..78641005 100644 --- a/backend/internal/handler/dto/notify_email_entry.go +++ b/backend/internal/handler/dto/notify_email_entry.go @@ -3,7 +3,7 @@ package dto import "github.com/Wei-Shaw/sub2api/internal/service" // NotifyEmailEntry represents a notification email with enable/disable and verification state. -// Email="" is a placeholder for the "primary email" (user's registration email or first admin email). +// All emails are user-managed; maximum 3 entries per user. type NotifyEmailEntry struct { Email string `json:"email"` Disabled bool `json:"disabled"` diff --git a/backend/internal/handler/user_handler.go b/backend/internal/handler/user_handler.go index 9e0a243a..2535ea5e 100644 --- a/backend/internal/handler/user_handler.go +++ b/backend/internal/handler/user_handler.go @@ -217,7 +217,7 @@ func (h *UserHandler) RemoveNotifyEmail(c *gin.Context) { // ToggleNotifyEmailRequest represents the request to toggle a notify email's disabled state type ToggleNotifyEmailRequest struct { - Email string `json:"email"` // empty string for primary email placeholder + Email string `json:"email" binding:"required,email"` Disabled bool `json:"disabled"` } diff --git a/backend/internal/repository/email_cache.go b/backend/internal/repository/email_cache.go index 63552ab0..ed903e0d 100644 --- a/backend/internal/repository/email_cache.go +++ b/backend/internal/repository/email_cache.go @@ -3,6 +3,7 @@ package repository import ( "context" "encoding/json" + "fmt" "time" "github.com/Wei-Shaw/sub2api/internal/service" @@ -10,10 +11,11 @@ import ( ) const ( - verifyCodeKeyPrefix = "verify_code:" - notifyVerifyKeyPrefix = "notify_verify:" - passwordResetKeyPrefix = "password_reset:" - passwordResetSentAtKeyPrefix = "password_reset_sent:" + verifyCodeKeyPrefix = "verify_code:" + notifyVerifyKeyPrefix = "notify_verify:" + passwordResetKeyPrefix = "password_reset:" + passwordResetSentAtKeyPrefix = "password_reset_sent:" + notifyCodeUserRateKeyPrefix = "notify_code_user_rate:" ) // verifyCodeKey generates the Redis key for email verification code. @@ -141,3 +143,31 @@ func (c *emailCache) DeleteNotifyVerifyCode(ctx context.Context, email string) e key := notifyVerifyKey(email) return c.rdb.Del(ctx, key).Err() } + +// User-level rate limiting for notify email verification codes + +func notifyCodeUserRateKey(userID int64) string { + return notifyCodeUserRateKeyPrefix + fmt.Sprintf("%d", userID) +} + +func (c *emailCache) IncrNotifyCodeUserRate(ctx context.Context, userID int64, window time.Duration) (int64, error) { + key := notifyCodeUserRateKey(userID) + count, err := c.rdb.Incr(ctx, key).Result() + if err != nil { + return 0, err + } + // Always set TTL (idempotent) to avoid orphan keys if process crashes between INCR and EXPIRE. + if err := c.rdb.Expire(ctx, key, window).Err(); err != nil { + return count, fmt.Errorf("expire notify code rate key: %w", err) + } + return count, nil +} + +func (c *emailCache) GetNotifyCodeUserRate(ctx context.Context, userID int64) (int64, error) { + key := notifyCodeUserRateKey(userID) + count, err := c.rdb.Get(ctx, key).Int64() + if err != nil { + return 0, err + } + return count, nil +} diff --git a/backend/internal/service/account_stats_pricing_test.go b/backend/internal/service/account_stats_pricing_test.go index bf9da978..23409d5e 100644 --- a/backend/internal/service/account_stats_pricing_test.go +++ b/backend/internal/service/account_stats_pricing_test.go @@ -145,14 +145,14 @@ func TestFindPricingForModel(t *testing.T) { wantNil: true, }, { - name: "longer wildcard prefix wins over shorter", + name: "wildcard matches by config order (first match wins)", list: []ChannelModelPricing{ {ID: 10, Models: []string{"claude-*"}}, {ID: 11, Models: []string{"claude-opus-*"}}, }, platform: "", model: "claude-opus-4", - wantID: 11, // "claude-opus-" (12 chars) > "claude-" (7 chars) + wantID: 10, // config order: "claude-*" is first and matches, so it wins }, { name: "shorter wildcard used when longer does not match", diff --git a/backend/internal/service/api_key_auth_cache.go b/backend/internal/service/api_key_auth_cache.go index 60cb6233..b1660ea7 100644 --- a/backend/internal/service/api_key_auth_cache.go +++ b/backend/internal/service/api_key_auth_cache.go @@ -42,6 +42,7 @@ type APIKeyAuthUserSnapshot struct { BalanceNotifyThresholdType string `json:"balance_notify_threshold_type"` BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"` BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails,omitempty"` + TotalRecharged float64 `json:"total_recharged"` } // APIKeyAuthGroupSnapshot 分组快照 diff --git a/backend/internal/service/api_key_auth_cache_impl.go b/backend/internal/service/api_key_auth_cache_impl.go index 711090c2..25c6331a 100644 --- a/backend/internal/service/api_key_auth_cache_impl.go +++ b/backend/internal/service/api_key_auth_cache_impl.go @@ -13,7 +13,7 @@ import ( "github.com/dgraph-io/ristretto" ) -const apiKeyAuthSnapshotVersion = 4 // v4: added balance notification fields to UserSnapshot +const apiKeyAuthSnapshotVersion = 5 // v5: added TotalRecharged for percentage threshold type apiKeyAuthCacheConfig struct { l1Size int @@ -230,6 +230,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot { BalanceNotifyThresholdType: apiKey.User.BalanceNotifyThresholdType, BalanceNotifyThreshold: apiKey.User.BalanceNotifyThreshold, BalanceNotifyExtraEmails: apiKey.User.BalanceNotifyExtraEmails, + TotalRecharged: apiKey.User.TotalRecharged, }, } if apiKey.Group != nil { @@ -291,6 +292,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho BalanceNotifyThresholdType: snapshot.User.BalanceNotifyThresholdType, BalanceNotifyThreshold: snapshot.User.BalanceNotifyThreshold, BalanceNotifyExtraEmails: snapshot.User.BalanceNotifyExtraEmails, + TotalRecharged: snapshot.User.TotalRecharged, }, } if snapshot.Group != nil { diff --git a/backend/internal/service/balance_notify_service.go b/backend/internal/service/balance_notify_service.go index 1e4d8ff6..14aa6766 100644 --- a/backend/internal/service/balance_notify_service.go +++ b/backend/internal/service/balance_notify_service.go @@ -309,7 +309,7 @@ func (s *BalanceNotifyService) sendBalanceLowEmails(recipients []string, userNam if displayName == "" { displayName = userEmail } - subject := fmt.Sprintf("[%s] 余额不足提醒 / Balance Low Alert", siteName) + subject := fmt.Sprintf("[%s] 余额不足提醒 / Balance Low Alert", sanitizeEmailHeader(siteName)) body := s.buildBalanceLowEmailBody(html.EscapeString(displayName), balance, threshold, html.EscapeString(siteName)) s.sendEmails(recipients, subject, body, "user_email", userEmail, "balance", balance) } @@ -321,11 +321,16 @@ func (s *BalanceNotifyService) sendQuotaAlertEmails(adminEmails []string, accoun dimLabel = dimension } - subject := fmt.Sprintf("[%s] 账号限额告警 / Account Quota Alert - %s", siteName, accountName) + subject := fmt.Sprintf("[%s] 账号限额告警 / Account Quota Alert - %s", sanitizeEmailHeader(siteName), sanitizeEmailHeader(accountName)) body := s.buildQuotaAlertEmailBody(html.EscapeString(accountName), html.EscapeString(dimLabel), used, limit, threshold, html.EscapeString(siteName)) s.sendEmails(adminEmails, subject, body, "account", accountName, "dimension", dimension) } +// sanitizeEmailHeader removes CR/LF characters to prevent SMTP header injection. +func sanitizeEmailHeader(s string) string { + return strings.NewReplacer("\r", "", "\n", "").Replace(s) +} + // balanceLowEmailTemplate is the HTML template for balance low notifications. // Format args: siteName, userName, userName, balance, threshold, threshold. const balanceLowEmailTemplate = ` diff --git a/backend/internal/service/channel.go b/backend/internal/service/channel.go index 3867f2a0..b034fda0 100644 --- a/backend/internal/service/channel.go +++ b/backend/internal/service/channel.go @@ -196,6 +196,9 @@ func (c *Channel) Clone() *Channel { cp.ModelMapping[platform] = inner } } + if c.FeaturesConfig != nil { + cp.FeaturesConfig = deepCopyFeaturesConfig(c.FeaturesConfig) + } if c.AccountStatsPricingRules != nil { cp.AccountStatsPricingRules = make([]AccountStatsPricingRule, len(c.AccountStatsPricingRules)) for i, rule := range c.AccountStatsPricingRules { @@ -219,6 +222,19 @@ func (c *Channel) Clone() *Channel { return &cp } +// deepCopyFeaturesConfig creates a deep copy of FeaturesConfig to prevent cache pollution. +func deepCopyFeaturesConfig(src map[string]any) map[string]any { + dst := make(map[string]any, len(src)) + for k, v := range src { + if inner, ok := v.(map[string]any); ok { + dst[k] = deepCopyFeaturesConfig(inner) + } else { + dst[k] = v + } + } + return dst +} + // ValidateIntervals 校验区间列表的合法性。 // 规则:MinTokens >= 0;MaxTokens 若非 nil 则 > 0 且 > MinTokens; // 所有价格字段 >= 0;区间按 MinTokens 排序后无重叠((min, max] 语义); diff --git a/backend/internal/service/email_service.go b/backend/internal/service/email_service.go index 61090776..3eade83e 100644 --- a/backend/internal/service/email_service.go +++ b/backend/internal/service/email_service.go @@ -7,7 +7,7 @@ import ( "crypto/tls" "encoding/hex" "fmt" - "log" + "log/slog" "math/big" "net/smtp" "net/url" @@ -292,7 +292,7 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 { data.Attempts++ if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil { - log.Printf("[Email] Failed to update verification attempt count: %v", err) + slog.Error("failed to update verification attempt count", "email", email, "error", err) } if data.Attempts >= maxVerifyCodeAttempts { return ErrVerifyCodeMaxAttempts @@ -302,7 +302,7 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error // 验证成功,删除验证码 if err := s.cache.DeleteVerificationCode(ctx, email); err != nil { - log.Printf("[Email] Failed to delete verification code after success: %v", err) + slog.Error("failed to delete verification code after success", "email", email, "error", err) } return nil } @@ -452,7 +452,7 @@ func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteNa func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, email, siteName, resetURL string) error { // Check email cooldown to prevent email bombing if s.cache.IsPasswordResetEmailInCooldown(ctx, email) { - log.Printf("[Email] Password reset email skipped (cooldown): %s", email) + slog.Info("password reset email skipped due to cooldown", "email", email) return nil // Silent success to prevent revealing cooldown to attackers } @@ -463,7 +463,7 @@ func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, e // Set cooldown marker (Redis TTL handles expiration) if err := s.cache.SetPasswordResetEmailCooldown(ctx, email, passwordResetEmailCooldown); err != nil { - log.Printf("[Email] Failed to set password reset cooldown for %s: %v", email, err) + slog.Error("failed to set password reset cooldown", "email", email, "error", err) } return nil @@ -493,7 +493,7 @@ func (s *EmailService) ConsumePasswordResetToken(ctx context.Context, email, tok // Delete after verification (one-time use) if err := s.cache.DeletePasswordResetToken(ctx, email); err != nil { - log.Printf("[Email] Failed to delete password reset token after consumption: %v", err) + slog.Error("failed to delete password reset token after consumption", "email", email, "error", err) } return nil } diff --git a/backend/internal/service/notify_email_entry.go b/backend/internal/service/notify_email_entry.go index c0e739f4..d181200b 100644 --- a/backend/internal/service/notify_email_entry.go +++ b/backend/internal/service/notify_email_entry.go @@ -6,7 +6,7 @@ import ( ) // NotifyEmailEntry represents a notification email with enable/disable and verification state. -// Email="" is a placeholder for the "primary email" (user's registration email or first admin email). +// All emails are user-managed; maximum 3 entries per user. type NotifyEmailEntry struct { Email string `json:"email"` Disabled bool `json:"disabled"` diff --git a/backend/internal/service/user_service.go b/backend/internal/service/user_service.go index bcb21c1d..3baee81d 100644 --- a/backend/internal/service/user_service.go +++ b/backend/internal/service/user_service.go @@ -4,7 +4,7 @@ import ( "context" "crypto/subtle" "fmt" - "log" + "log/slog" "strings" "time" @@ -13,12 +13,19 @@ import ( ) var ( - ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found") - ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect") - ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions") + ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found") + ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect") + ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions") + ErrNotifyCodeUserRateLimit = infraerrors.TooManyRequests("NOTIFY_CODE_USER_RATE_LIMIT", "too many verification codes requested, please try again later") ) -const maxNotifyEmails = 3 // Total limit: primary (email="") + up to 2 extra +const ( + maxNotifyEmails = 3 // Maximum number of notification emails per user + + // User-level rate limiting for notify email verification codes + notifyCodeUserRateLimit = 5 + notifyCodeUserRateWindow = 10 * time.Minute +) // UserListFilters contains all filter options for listing users type UserListFilters struct { @@ -220,7 +227,7 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() if err := s.billingCache.InvalidateUserBalance(cacheCtx, userID); err != nil { - log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err) + slog.Error("invalidate user balance cache failed", "user_id", userID, "error", err) } }() } @@ -270,21 +277,44 @@ func (s *UserService) Delete(ctx context.Context, userID int64) error { // SendNotifyEmailCode sends a verification code to the extra notification email. func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, email string, emailService *EmailService, cache EmailCache) error { - // Check cooldown + if err := checkNotifyCodeRateLimit(ctx, cache, userID, email); err != nil { + return err + } + + code, err := emailService.GenerateVerifyCode() + if err != nil { + return fmt.Errorf("generate code: %w", err) + } + + if err := saveNotifyVerifyCode(ctx, cache, email, code); err != nil { + return err + } + + // Increment user-level counter after successful save + if _, err := cache.IncrNotifyCodeUserRate(ctx, userID, notifyCodeUserRateWindow); err != nil { + slog.Error("failed to increment notify code user rate", "user_id", userID, "error", err) + } + + return s.sendNotifyVerifyEmail(ctx, emailService, email, code) +} + +// checkNotifyCodeRateLimit checks both email cooldown and user-level rate limit. +func checkNotifyCodeRateLimit(ctx context.Context, cache EmailCache, userID int64, email string) error { existing, err := cache.GetNotifyVerifyCode(ctx, email) if err == nil && existing != nil { if time.Since(existing.CreatedAt) < verifyCodeCooldown { return ErrVerifyCodeTooFrequent } } - - // Generate code - code, err := emailService.GenerateVerifyCode() - if err != nil { - return fmt.Errorf("generate code: %w", err) + count, err := cache.GetNotifyCodeUserRate(ctx, userID) + if err == nil && count >= notifyCodeUserRateLimit { + return ErrNotifyCodeUserRateLimit } + return nil +} - // Save to cache +// saveNotifyVerifyCode saves the verification code to cache. +func saveNotifyVerifyCode(ctx context.Context, cache EmailCache, email, code string) error { data := &VerificationCodeData{ Code: code, Attempts: 0, @@ -293,16 +323,17 @@ func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, ema if err := cache.SetNotifyVerifyCode(ctx, email, data, verifyCodeTTL); err != nil { return fmt.Errorf("save verify code: %w", err) } + return nil +} - // Get site name +// sendNotifyVerifyEmail builds and sends the verification email. +func (s *UserService) sendNotifyVerifyEmail(ctx context.Context, emailService *EmailService, email, code string) error { siteName := "Sub2API" if s.settingRepo != nil { if name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName); err == nil && name != "" { siteName = name } } - - // Build and send email subject := fmt.Sprintf("[%s] 通知邮箱验证码 / Notification Email Verification", siteName) body := buildNotifyVerifyEmailBody(code, siteName) return emailService.SendEmail(ctx, email, subject, body) @@ -310,7 +341,15 @@ func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, ema // VerifyAndAddNotifyEmail verifies the code and adds the email to user's extra emails. func (s *UserService) VerifyAndAddNotifyEmail(ctx context.Context, userID int64, email, code string, cache EmailCache) error { - // Verify code + if err := verifyNotifyCode(ctx, cache, email, code); err != nil { + return err + } + _ = cache.DeleteNotifyVerifyCode(ctx, email) + return s.addOrVerifyNotifyEmail(ctx, userID, email) +} + +// verifyNotifyCode validates the verification code against the cached data. +func verifyNotifyCode(ctx context.Context, cache EmailCache, email, code string) error { data, err := cache.GetNotifyVerifyCode(ctx, email) if err != nil || data == nil { return ErrInvalidVerifyCode @@ -326,17 +365,18 @@ func (s *UserService) VerifyAndAddNotifyEmail(ctx context.Context, userID int64, } return ErrInvalidVerifyCode } + return nil +} - // Delete code after verification - _ = cache.DeleteNotifyVerifyCode(ctx, email) - - // Add to user's extra emails +// addOrVerifyNotifyEmail adds the email to user's extra notification emails or marks it as verified. +// Note: concurrent calls for the same user could race on the read-modify-write of +// BalanceNotifyExtraEmails. The window is small (requires two verify flows completing +// simultaneously), and the worst case is a duplicate entry which is harmless. +func (s *UserService) addOrVerifyNotifyEmail(ctx context.Context, userID int64, email string) error { user, err := s.userRepo.GetByID(ctx, userID) if err != nil { return err } - - // Check if already exists — if unverified, mark as verified for i, e := range user.BalanceNotifyExtraEmails { if strings.EqualFold(e.Email, email) { if !e.Verified { @@ -346,12 +386,9 @@ func (s *UserService) VerifyAndAddNotifyEmail(ctx context.Context, userID int64, return nil // Already verified } } - - // Check limit if len(user.BalanceNotifyExtraEmails) >= maxNotifyEmails { return infraerrors.BadRequest("TOO_MANY_NOTIFY_EMAILS", fmt.Sprintf("maximum %d notification emails allowed", maxNotifyEmails)) } - user.BalanceNotifyExtraEmails = append(user.BalanceNotifyExtraEmails, NotifyEmailEntry{ Email: email, Disabled: false, @@ -399,10 +436,9 @@ func (s *UserService) ToggleNotifyEmail(ctx context.Context, userID int64, email return s.userRepo.Update(ctx, user) } -// buildNotifyVerifyEmailBody builds the HTML email body for notify email verification. -func buildNotifyVerifyEmailBody(code, siteName string) string { - return fmt.Sprintf(` - +// notifyVerifyEmailTemplate is the HTML template for notify email verification. +// Format args: siteName, code. +const notifyVerifyEmailTemplate = ` @@ -439,6 +475,9 @@ func buildNotifyVerifyEmailBody(code, siteName string) string { - -`, siteName, code) +` + +// buildNotifyVerifyEmailBody builds the HTML email body for notify email verification. +func buildNotifyVerifyEmailBody(code, siteName string) string { + return fmt.Sprintf(notifyVerifyEmailTemplate, siteName, code) } diff --git a/frontend/src/views/admin/ChannelsView.vue b/frontend/src/views/admin/ChannelsView.vue index b714ca30..2ca1141d 100644 --- a/frontend/src/views/admin/ChannelsView.vue +++ b/frontend/src/views/admin/ChannelsView.vue @@ -421,7 +421,7 @@ @@ -524,7 +524,7 @@
-
@@ -538,7 +538,7 @@ :entry="entry" :platform="section.platform" @update="rule.pricing.splice(pIdx, 1, $event)" - @remove="removeRulePricingEntry(ruleIndex, pIdx)" + @remove="removeRulePricingEntry(sIdx, ruleIndex, pIdx)" />
@@ -625,6 +625,14 @@ async function loadWebSearchGlobalState() { } } +// ── Form-level pricing rule type (per-platform) ── +interface FormPricingRule { + name: string + group_ids: number[] + account_ids: number[] + pricing: PricingFormEntry[] +} + // ── Platform Section type ── interface PlatformSection { platform: GroupPlatform @@ -634,6 +642,7 @@ interface PlatformSection { model_mapping: Record model_pricing: PricingFormEntry[] web_search_emulation: boolean + account_stats_pricing_rules: FormPricingRule[] } // ── Table columns ── @@ -703,12 +712,6 @@ const form = reactive({ billing_model_source: 'channel_mapped' as string, platforms: [] as PlatformSection[], apply_pricing_to_account_stats: false, - account_stats_pricing_rules: [] as Array<{ - name: string - group_ids: number[] - account_ids: number[] - pricing: PricingFormEntry[] - }> }) let abortController: AbortController | null = null @@ -754,6 +757,7 @@ function addPlatformSection(platform: GroupPlatform) { model_mapping: {}, model_pricing: [], web_search_emulation: false, + account_stats_pricing_rules: [], }) } @@ -867,8 +871,8 @@ function renameMappingKey(sectionIdx: number, oldKey: string, newKey: string) { } // ── Account Stats Pricing helpers ── -function addAccountStatsRule() { - form.account_stats_pricing_rules.push({ +function addAccountStatsRule(sectionIdx: number) { + form.platforms[sectionIdx].account_stats_pricing_rules.push({ name: '', group_ids: [], account_ids: [], @@ -876,8 +880,8 @@ function addAccountStatsRule() { }) } -function addRulePricingEntry(ruleIndex: number) { - form.account_stats_pricing_rules[ruleIndex].pricing.push({ +function addRulePricingEntry(sectionIdx: number, ruleIndex: number) { + form.platforms[sectionIdx].account_stats_pricing_rules[ruleIndex].pricing.push({ models: [], billing_mode: 'token', input_price: null, @@ -890,15 +894,15 @@ function addRulePricingEntry(ruleIndex: number) { }) } -function removeAccountStatsRule(ruleIndex: number) { - form.account_stats_pricing_rules.splice(ruleIndex, 1) +function removeAccountStatsRule(sectionIdx: number, ruleIndex: number) { + form.platforms[sectionIdx].account_stats_pricing_rules.splice(ruleIndex, 1) // Clear all search state since indices shift after removal ruleAccountSearchRunner.clearAll() clearAllRuleAccountSearchState() } -function removeRulePricingEntry(ruleIndex: number, pricingIndex: number) { - form.account_stats_pricing_rules[ruleIndex].pricing.splice(pricingIndex, 1) +function removeRulePricingEntry(sectionIdx: number, ruleIndex: number, pricingIndex: number) { + form.platforms[sectionIdx].account_stats_pricing_rules[ruleIndex].pricing.splice(pricingIndex, 1) } function getGroupNameById(groupId: number): string { @@ -980,38 +984,33 @@ function clearAllRuleAccountSearchState() { showRuleAccountDropdown.value = {} } -function inferRulePlatform(groupIds: number[]): string { - const platforms = new Set() - for (const gid of groupIds) { - const group = allGroups.value.find(g => g.id === gid) - if (group) platforms.add(group.platform) - } - return platforms.size === 1 ? [...platforms][0] : '' -} - function accountStatsRulesToAPI(): AccountStatsPricingRule[] { - return form.account_stats_pricing_rules.map(rule => { - const platform = inferRulePlatform(rule.group_ids) - return { - name: rule.name, - group_ids: rule.group_ids, - account_ids: rule.account_ids, - pricing: rule.pricing - .filter(p => p.models.length > 0) - .map(p => ({ - platform, - models: p.models, - billing_mode: p.billing_mode, - input_price: mTokToPerToken(p.input_price), - output_price: mTokToPerToken(p.output_price), - cache_write_price: mTokToPerToken(p.cache_write_price), - cache_read_price: mTokToPerToken(p.cache_read_price), - image_output_price: mTokToPerToken(p.image_output_price), - per_request_price: p.per_request_price != null && p.per_request_price !== '' ? Number(p.per_request_price) : null, - intervals: formIntervalsToAPI(p.intervals || []) - })) + const rules: AccountStatsPricingRule[] = [] + for (const section of form.platforms) { + if (!section.enabled) continue + for (const rule of section.account_stats_pricing_rules) { + rules.push({ + name: rule.name, + group_ids: rule.group_ids, + account_ids: rule.account_ids, + pricing: rule.pricing + .filter(p => p.models.length > 0) + .map(p => ({ + platform: section.platform, + models: p.models, + billing_mode: p.billing_mode, + input_price: mTokToPerToken(p.input_price), + output_price: mTokToPerToken(p.output_price), + cache_write_price: mTokToPerToken(p.cache_write_price), + cache_read_price: mTokToPerToken(p.cache_read_price), + image_output_price: mTokToPerToken(p.image_output_price), + per_request_price: p.per_request_price != null && p.per_request_price !== '' ? Number(p.per_request_price) : null, + intervals: formIntervalsToAPI(p.intervals || []) + })) + }) } - }) + } + return rules } // ── Form ↔ API conversion ── @@ -1120,6 +1119,7 @@ function apiToForm(channel: Channel): PlatformSection[] { model_mapping: { ...mapping }, model_pricing: pricing, web_search_emulation: webSearchEnabled, + account_stats_pricing_rules: [], }) } @@ -1213,7 +1213,6 @@ function resetForm() { form.billing_model_source = 'channel_mapped' form.platforms = [] form.apply_pricing_to_account_stats = false - form.account_stats_pricing_rules = [] activeTab.value = 'basic' ruleAccountSearchRunner.clearAll() clearAllRuleAccountSearchState() @@ -1235,28 +1234,91 @@ async function openEditDialog(channel: Channel) { form.restrict_models = channel.restrict_models || false form.billing_model_source = channel.billing_model_source || 'channel_mapped' form.apply_pricing_to_account_stats = channel.apply_pricing_to_account_stats || false - form.account_stats_pricing_rules = (channel.account_stats_pricing_rules || []).map(rule => ({ - name: rule.name || '', - group_ids: [...(rule.group_ids || [])], - account_ids: [...(rule.account_ids || [])], - pricing: (rule.pricing || []).map(p => ({ - models: [...(p.models || [])], - billing_mode: p.billing_mode, - input_price: perTokenToMTok(p.input_price), - output_price: perTokenToMTok(p.output_price), - cache_write_price: perTokenToMTok(p.cache_write_price), - cache_read_price: perTokenToMTok(p.cache_read_price), - image_output_price: perTokenToMTok(p.image_output_price), - per_request_price: p.per_request_price, - intervals: apiIntervalsToForm(p.intervals || []) - } as PricingFormEntry)) - })) // Must load groups first so apiToForm can map groupID → platform await Promise.all([loadGroups(), loadAllChannelsForConflict()]) form.platforms = apiToForm(channel) + + // Distribute channel-level rules into per-platform sections + distributeRulesToPlatforms(channel.account_stats_pricing_rules || []) + + // Populate ruleAccountNameCache for existing rule accounts + await populateRuleAccountNameCache() + showDialog.value = true } +/** Distribute flat channel-level rules into the matching platform section based on group_ids */ +function distributeRulesToPlatforms(apiRules: AccountStatsPricingRule[]) { + // Build groupID → platform lookup + const groupPlatformMap = new Map() + for (const g of allGroups.value) { + groupPlatformMap.set(g.id, g.platform) + } + + for (const apiRule of apiRules) { + // Infer platform from group_ids + const platforms = new Set() + for (const gid of apiRule.group_ids || []) { + const p = groupPlatformMap.get(gid) + if (p) platforms.add(p) + } + // If pricing has a platform field, use that as fallback + if (platforms.size === 0 && apiRule.pricing?.length > 0) { + const p = apiRule.pricing[0].platform as GroupPlatform | undefined + if (p) platforms.add(p) + } + const targetPlatform = platforms.size >= 1 ? [...platforms][0] : null + if (!targetPlatform) continue + + const section = form.platforms.find(s => s.platform === targetPlatform) + if (!section) continue + + const formRule: FormPricingRule = { + name: apiRule.name || '', + group_ids: [...(apiRule.group_ids || [])], + account_ids: [...(apiRule.account_ids || [])], + pricing: (apiRule.pricing || []).map(p => ({ + models: [...(p.models || [])], + billing_mode: p.billing_mode, + input_price: perTokenToMTok(p.input_price), + output_price: perTokenToMTok(p.output_price), + cache_write_price: perTokenToMTok(p.cache_write_price), + cache_read_price: perTokenToMTok(p.cache_read_price), + image_output_price: perTokenToMTok(p.image_output_price), + per_request_price: p.per_request_price, + intervals: apiIntervalsToForm(p.intervals || []) + } as PricingFormEntry)) + } + section.account_stats_pricing_rules.push(formRule) + } +} + +/** Populate ruleAccountNameCache by fetching account details for all account_ids in rules */ +async function populateRuleAccountNameCache() { + const allAccountIds = new Set() + for (const section of form.platforms) { + for (const rule of section.account_stats_pricing_rules) { + for (const id of rule.account_ids) { + allAccountIds.add(id) + } + } + } + if (allAccountIds.size === 0) return + + // Fetch account details in parallel (batch of individual getById calls) + const ids = [...allAccountIds] + const results = await Promise.allSettled( + ids.map(id => adminAPI.accounts.getById(id)) + ) + for (let i = 0; i < ids.length; i++) { + const result = results[i] + if (result.status === 'fulfilled') { + ruleAccountNameCache.value[ids[i]] = result.value.name + } + // If rejected, the cache won't have the name, so it'll show "#ID" which is acceptable + } +} + function closeDialog() { showDialog.value = false editingChannel.value = null