fix: audit fixes for websearch, notifications, and channel pricing
P0: fix wildcard matching test assertion (config order, not longest prefix) P0: add TotalRecharged to auth cache snapshot (v5) for percentage threshold P1: move pricing rules into per-platform sections in ChannelsView P1: populate account name cache when editing existing channel rules P1: sanitize email subject headers to prevent SMTP injection P1: make Redis INCR+EXPIRE idempotent for rate limiting P1: deep copy FeaturesConfig in Channel.Clone() P2: clean up stale email="" placeholder comments P2: replace log.Printf with slog in email_service.go
This commit is contained in:
@@ -1 +1 @@
|
|||||||
0.1.110.11
|
0.1.110.20
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ package dto
|
|||||||
import "github.com/Wei-Shaw/sub2api/internal/service"
|
import "github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
// NotifyEmailEntry represents a notification email with enable/disable and verification state.
|
// NotifyEmailEntry represents a notification email with enable/disable and verification state.
|
||||||
// Email="" is a placeholder for the "primary email" (user's registration email or first admin email).
|
// All emails are user-managed; maximum 3 entries per user.
|
||||||
type NotifyEmailEntry struct {
|
type NotifyEmailEntry struct {
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
Disabled bool `json:"disabled"`
|
Disabled bool `json:"disabled"`
|
||||||
|
|||||||
@@ -217,7 +217,7 @@ func (h *UserHandler) RemoveNotifyEmail(c *gin.Context) {
|
|||||||
|
|
||||||
// ToggleNotifyEmailRequest represents the request to toggle a notify email's disabled state
|
// ToggleNotifyEmailRequest represents the request to toggle a notify email's disabled state
|
||||||
type ToggleNotifyEmailRequest struct {
|
type ToggleNotifyEmailRequest struct {
|
||||||
Email string `json:"email"` // empty string for primary email placeholder
|
Email string `json:"email" binding:"required,email"`
|
||||||
Disabled bool `json:"disabled"`
|
Disabled bool `json:"disabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package repository
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
@@ -10,10 +11,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
verifyCodeKeyPrefix = "verify_code:"
|
verifyCodeKeyPrefix = "verify_code:"
|
||||||
notifyVerifyKeyPrefix = "notify_verify:"
|
notifyVerifyKeyPrefix = "notify_verify:"
|
||||||
passwordResetKeyPrefix = "password_reset:"
|
passwordResetKeyPrefix = "password_reset:"
|
||||||
passwordResetSentAtKeyPrefix = "password_reset_sent:"
|
passwordResetSentAtKeyPrefix = "password_reset_sent:"
|
||||||
|
notifyCodeUserRateKeyPrefix = "notify_code_user_rate:"
|
||||||
)
|
)
|
||||||
|
|
||||||
// verifyCodeKey generates the Redis key for email verification code.
|
// verifyCodeKey generates the Redis key for email verification code.
|
||||||
@@ -141,3 +143,31 @@ func (c *emailCache) DeleteNotifyVerifyCode(ctx context.Context, email string) e
|
|||||||
key := notifyVerifyKey(email)
|
key := notifyVerifyKey(email)
|
||||||
return c.rdb.Del(ctx, key).Err()
|
return c.rdb.Del(ctx, key).Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// User-level rate limiting for notify email verification codes
|
||||||
|
|
||||||
|
func notifyCodeUserRateKey(userID int64) string {
|
||||||
|
return notifyCodeUserRateKeyPrefix + fmt.Sprintf("%d", userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *emailCache) IncrNotifyCodeUserRate(ctx context.Context, userID int64, window time.Duration) (int64, error) {
|
||||||
|
key := notifyCodeUserRateKey(userID)
|
||||||
|
count, err := c.rdb.Incr(ctx, key).Result()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
// Always set TTL (idempotent) to avoid orphan keys if process crashes between INCR and EXPIRE.
|
||||||
|
if err := c.rdb.Expire(ctx, key, window).Err(); err != nil {
|
||||||
|
return count, fmt.Errorf("expire notify code rate key: %w", err)
|
||||||
|
}
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *emailCache) GetNotifyCodeUserRate(ctx context.Context, userID int64) (int64, error) {
|
||||||
|
key := notifyCodeUserRateKey(userID)
|
||||||
|
count, err := c.rdb.Get(ctx, key).Int64()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return count, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -145,14 +145,14 @@ func TestFindPricingForModel(t *testing.T) {
|
|||||||
wantNil: true,
|
wantNil: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "longer wildcard prefix wins over shorter",
|
name: "wildcard matches by config order (first match wins)",
|
||||||
list: []ChannelModelPricing{
|
list: []ChannelModelPricing{
|
||||||
{ID: 10, Models: []string{"claude-*"}},
|
{ID: 10, Models: []string{"claude-*"}},
|
||||||
{ID: 11, Models: []string{"claude-opus-*"}},
|
{ID: 11, Models: []string{"claude-opus-*"}},
|
||||||
},
|
},
|
||||||
platform: "",
|
platform: "",
|
||||||
model: "claude-opus-4",
|
model: "claude-opus-4",
|
||||||
wantID: 11, // "claude-opus-" (12 chars) > "claude-" (7 chars)
|
wantID: 10, // config order: "claude-*" is first and matches, so it wins
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "shorter wildcard used when longer does not match",
|
name: "shorter wildcard used when longer does not match",
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ type APIKeyAuthUserSnapshot struct {
|
|||||||
BalanceNotifyThresholdType string `json:"balance_notify_threshold_type"`
|
BalanceNotifyThresholdType string `json:"balance_notify_threshold_type"`
|
||||||
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"`
|
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"`
|
||||||
BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails,omitempty"`
|
BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails,omitempty"`
|
||||||
|
TotalRecharged float64 `json:"total_recharged"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// APIKeyAuthGroupSnapshot 分组快照
|
// APIKeyAuthGroupSnapshot 分组快照
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
"github.com/dgraph-io/ristretto"
|
"github.com/dgraph-io/ristretto"
|
||||||
)
|
)
|
||||||
|
|
||||||
const apiKeyAuthSnapshotVersion = 4 // v4: added balance notification fields to UserSnapshot
|
const apiKeyAuthSnapshotVersion = 5 // v5: added TotalRecharged for percentage threshold
|
||||||
|
|
||||||
type apiKeyAuthCacheConfig struct {
|
type apiKeyAuthCacheConfig struct {
|
||||||
l1Size int
|
l1Size int
|
||||||
@@ -230,6 +230,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
|||||||
BalanceNotifyThresholdType: apiKey.User.BalanceNotifyThresholdType,
|
BalanceNotifyThresholdType: apiKey.User.BalanceNotifyThresholdType,
|
||||||
BalanceNotifyThreshold: apiKey.User.BalanceNotifyThreshold,
|
BalanceNotifyThreshold: apiKey.User.BalanceNotifyThreshold,
|
||||||
BalanceNotifyExtraEmails: apiKey.User.BalanceNotifyExtraEmails,
|
BalanceNotifyExtraEmails: apiKey.User.BalanceNotifyExtraEmails,
|
||||||
|
TotalRecharged: apiKey.User.TotalRecharged,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if apiKey.Group != nil {
|
if apiKey.Group != nil {
|
||||||
@@ -291,6 +292,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
|
|||||||
BalanceNotifyThresholdType: snapshot.User.BalanceNotifyThresholdType,
|
BalanceNotifyThresholdType: snapshot.User.BalanceNotifyThresholdType,
|
||||||
BalanceNotifyThreshold: snapshot.User.BalanceNotifyThreshold,
|
BalanceNotifyThreshold: snapshot.User.BalanceNotifyThreshold,
|
||||||
BalanceNotifyExtraEmails: snapshot.User.BalanceNotifyExtraEmails,
|
BalanceNotifyExtraEmails: snapshot.User.BalanceNotifyExtraEmails,
|
||||||
|
TotalRecharged: snapshot.User.TotalRecharged,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if snapshot.Group != nil {
|
if snapshot.Group != nil {
|
||||||
|
|||||||
@@ -309,7 +309,7 @@ func (s *BalanceNotifyService) sendBalanceLowEmails(recipients []string, userNam
|
|||||||
if displayName == "" {
|
if displayName == "" {
|
||||||
displayName = userEmail
|
displayName = userEmail
|
||||||
}
|
}
|
||||||
subject := fmt.Sprintf("[%s] 余额不足提醒 / Balance Low Alert", siteName)
|
subject := fmt.Sprintf("[%s] 余额不足提醒 / Balance Low Alert", sanitizeEmailHeader(siteName))
|
||||||
body := s.buildBalanceLowEmailBody(html.EscapeString(displayName), balance, threshold, html.EscapeString(siteName))
|
body := s.buildBalanceLowEmailBody(html.EscapeString(displayName), balance, threshold, html.EscapeString(siteName))
|
||||||
s.sendEmails(recipients, subject, body, "user_email", userEmail, "balance", balance)
|
s.sendEmails(recipients, subject, body, "user_email", userEmail, "balance", balance)
|
||||||
}
|
}
|
||||||
@@ -321,11 +321,16 @@ func (s *BalanceNotifyService) sendQuotaAlertEmails(adminEmails []string, accoun
|
|||||||
dimLabel = dimension
|
dimLabel = dimension
|
||||||
}
|
}
|
||||||
|
|
||||||
subject := fmt.Sprintf("[%s] 账号限额告警 / Account Quota Alert - %s", siteName, accountName)
|
subject := fmt.Sprintf("[%s] 账号限额告警 / Account Quota Alert - %s", sanitizeEmailHeader(siteName), sanitizeEmailHeader(accountName))
|
||||||
body := s.buildQuotaAlertEmailBody(html.EscapeString(accountName), html.EscapeString(dimLabel), used, limit, threshold, html.EscapeString(siteName))
|
body := s.buildQuotaAlertEmailBody(html.EscapeString(accountName), html.EscapeString(dimLabel), used, limit, threshold, html.EscapeString(siteName))
|
||||||
s.sendEmails(adminEmails, subject, body, "account", accountName, "dimension", dimension)
|
s.sendEmails(adminEmails, subject, body, "account", accountName, "dimension", dimension)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sanitizeEmailHeader removes CR/LF characters to prevent SMTP header injection.
|
||||||
|
func sanitizeEmailHeader(s string) string {
|
||||||
|
return strings.NewReplacer("\r", "", "\n", "").Replace(s)
|
||||||
|
}
|
||||||
|
|
||||||
// balanceLowEmailTemplate is the HTML template for balance low notifications.
|
// balanceLowEmailTemplate is the HTML template for balance low notifications.
|
||||||
// Format args: siteName, userName, userName, balance, threshold, threshold.
|
// Format args: siteName, userName, userName, balance, threshold, threshold.
|
||||||
const balanceLowEmailTemplate = `<!DOCTYPE html>
|
const balanceLowEmailTemplate = `<!DOCTYPE html>
|
||||||
|
|||||||
@@ -196,6 +196,9 @@ func (c *Channel) Clone() *Channel {
|
|||||||
cp.ModelMapping[platform] = inner
|
cp.ModelMapping[platform] = inner
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if c.FeaturesConfig != nil {
|
||||||
|
cp.FeaturesConfig = deepCopyFeaturesConfig(c.FeaturesConfig)
|
||||||
|
}
|
||||||
if c.AccountStatsPricingRules != nil {
|
if c.AccountStatsPricingRules != nil {
|
||||||
cp.AccountStatsPricingRules = make([]AccountStatsPricingRule, len(c.AccountStatsPricingRules))
|
cp.AccountStatsPricingRules = make([]AccountStatsPricingRule, len(c.AccountStatsPricingRules))
|
||||||
for i, rule := range c.AccountStatsPricingRules {
|
for i, rule := range c.AccountStatsPricingRules {
|
||||||
@@ -219,6 +222,19 @@ func (c *Channel) Clone() *Channel {
|
|||||||
return &cp
|
return &cp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// deepCopyFeaturesConfig creates a deep copy of FeaturesConfig to prevent cache pollution.
|
||||||
|
func deepCopyFeaturesConfig(src map[string]any) map[string]any {
|
||||||
|
dst := make(map[string]any, len(src))
|
||||||
|
for k, v := range src {
|
||||||
|
if inner, ok := v.(map[string]any); ok {
|
||||||
|
dst[k] = deepCopyFeaturesConfig(inner)
|
||||||
|
} else {
|
||||||
|
dst[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return dst
|
||||||
|
}
|
||||||
|
|
||||||
// ValidateIntervals 校验区间列表的合法性。
|
// ValidateIntervals 校验区间列表的合法性。
|
||||||
// 规则:MinTokens >= 0;MaxTokens 若非 nil 则 > 0 且 > MinTokens;
|
// 规则:MinTokens >= 0;MaxTokens 若非 nil 则 > 0 且 > MinTokens;
|
||||||
// 所有价格字段 >= 0;区间按 MinTokens 排序后无重叠((min, max] 语义);
|
// 所有价格字段 >= 0;区间按 MinTokens 排序后无重叠((min, max] 语义);
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log/slog"
|
||||||
"math/big"
|
"math/big"
|
||||||
"net/smtp"
|
"net/smtp"
|
||||||
"net/url"
|
"net/url"
|
||||||
@@ -292,7 +292,7 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
|
|||||||
if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 {
|
if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 {
|
||||||
data.Attempts++
|
data.Attempts++
|
||||||
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
|
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
|
||||||
log.Printf("[Email] Failed to update verification attempt count: %v", err)
|
slog.Error("failed to update verification attempt count", "email", email, "error", err)
|
||||||
}
|
}
|
||||||
if data.Attempts >= maxVerifyCodeAttempts {
|
if data.Attempts >= maxVerifyCodeAttempts {
|
||||||
return ErrVerifyCodeMaxAttempts
|
return ErrVerifyCodeMaxAttempts
|
||||||
@@ -302,7 +302,7 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
|
|||||||
|
|
||||||
// 验证成功,删除验证码
|
// 验证成功,删除验证码
|
||||||
if err := s.cache.DeleteVerificationCode(ctx, email); err != nil {
|
if err := s.cache.DeleteVerificationCode(ctx, email); err != nil {
|
||||||
log.Printf("[Email] Failed to delete verification code after success: %v", err)
|
slog.Error("failed to delete verification code after success", "email", email, "error", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -452,7 +452,7 @@ func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteNa
|
|||||||
func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, email, siteName, resetURL string) error {
|
func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, email, siteName, resetURL string) error {
|
||||||
// Check email cooldown to prevent email bombing
|
// Check email cooldown to prevent email bombing
|
||||||
if s.cache.IsPasswordResetEmailInCooldown(ctx, email) {
|
if s.cache.IsPasswordResetEmailInCooldown(ctx, email) {
|
||||||
log.Printf("[Email] Password reset email skipped (cooldown): %s", email)
|
slog.Info("password reset email skipped due to cooldown", "email", email)
|
||||||
return nil // Silent success to prevent revealing cooldown to attackers
|
return nil // Silent success to prevent revealing cooldown to attackers
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -463,7 +463,7 @@ func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, e
|
|||||||
|
|
||||||
// Set cooldown marker (Redis TTL handles expiration)
|
// Set cooldown marker (Redis TTL handles expiration)
|
||||||
if err := s.cache.SetPasswordResetEmailCooldown(ctx, email, passwordResetEmailCooldown); err != nil {
|
if err := s.cache.SetPasswordResetEmailCooldown(ctx, email, passwordResetEmailCooldown); err != nil {
|
||||||
log.Printf("[Email] Failed to set password reset cooldown for %s: %v", email, err)
|
slog.Error("failed to set password reset cooldown", "email", email, "error", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -493,7 +493,7 @@ func (s *EmailService) ConsumePasswordResetToken(ctx context.Context, email, tok
|
|||||||
|
|
||||||
// Delete after verification (one-time use)
|
// Delete after verification (one-time use)
|
||||||
if err := s.cache.DeletePasswordResetToken(ctx, email); err != nil {
|
if err := s.cache.DeletePasswordResetToken(ctx, email); err != nil {
|
||||||
log.Printf("[Email] Failed to delete password reset token after consumption: %v", err)
|
slog.Error("failed to delete password reset token after consumption", "email", email, "error", err)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// NotifyEmailEntry represents a notification email with enable/disable and verification state.
|
// NotifyEmailEntry represents a notification email with enable/disable and verification state.
|
||||||
// Email="" is a placeholder for the "primary email" (user's registration email or first admin email).
|
// All emails are user-managed; maximum 3 entries per user.
|
||||||
type NotifyEmailEntry struct {
|
type NotifyEmailEntry struct {
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
Disabled bool `json:"disabled"`
|
Disabled bool `json:"disabled"`
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/subtle"
|
"crypto/subtle"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log/slog"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -13,12 +13,19 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found")
|
ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found")
|
||||||
ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect")
|
ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect")
|
||||||
ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
|
ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
|
||||||
|
ErrNotifyCodeUserRateLimit = infraerrors.TooManyRequests("NOTIFY_CODE_USER_RATE_LIMIT", "too many verification codes requested, please try again later")
|
||||||
)
|
)
|
||||||
|
|
||||||
const maxNotifyEmails = 3 // Total limit: primary (email="") + up to 2 extra
|
const (
|
||||||
|
maxNotifyEmails = 3 // Maximum number of notification emails per user
|
||||||
|
|
||||||
|
// User-level rate limiting for notify email verification codes
|
||||||
|
notifyCodeUserRateLimit = 5
|
||||||
|
notifyCodeUserRateWindow = 10 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
// UserListFilters contains all filter options for listing users
|
// UserListFilters contains all filter options for listing users
|
||||||
type UserListFilters struct {
|
type UserListFilters struct {
|
||||||
@@ -220,7 +227,7 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl
|
|||||||
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
if err := s.billingCache.InvalidateUserBalance(cacheCtx, userID); err != nil {
|
if err := s.billingCache.InvalidateUserBalance(cacheCtx, userID); err != nil {
|
||||||
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err)
|
slog.Error("invalidate user balance cache failed", "user_id", userID, "error", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
@@ -270,21 +277,44 @@ func (s *UserService) Delete(ctx context.Context, userID int64) error {
|
|||||||
|
|
||||||
// SendNotifyEmailCode sends a verification code to the extra notification email.
|
// SendNotifyEmailCode sends a verification code to the extra notification email.
|
||||||
func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, email string, emailService *EmailService, cache EmailCache) error {
|
func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, email string, emailService *EmailService, cache EmailCache) error {
|
||||||
// Check cooldown
|
if err := checkNotifyCodeRateLimit(ctx, cache, userID, email); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
code, err := emailService.GenerateVerifyCode()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("generate code: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := saveNotifyVerifyCode(ctx, cache, email, code); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increment user-level counter after successful save
|
||||||
|
if _, err := cache.IncrNotifyCodeUserRate(ctx, userID, notifyCodeUserRateWindow); err != nil {
|
||||||
|
slog.Error("failed to increment notify code user rate", "user_id", userID, "error", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s.sendNotifyVerifyEmail(ctx, emailService, email, code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkNotifyCodeRateLimit checks both email cooldown and user-level rate limit.
|
||||||
|
func checkNotifyCodeRateLimit(ctx context.Context, cache EmailCache, userID int64, email string) error {
|
||||||
existing, err := cache.GetNotifyVerifyCode(ctx, email)
|
existing, err := cache.GetNotifyVerifyCode(ctx, email)
|
||||||
if err == nil && existing != nil {
|
if err == nil && existing != nil {
|
||||||
if time.Since(existing.CreatedAt) < verifyCodeCooldown {
|
if time.Since(existing.CreatedAt) < verifyCodeCooldown {
|
||||||
return ErrVerifyCodeTooFrequent
|
return ErrVerifyCodeTooFrequent
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
count, err := cache.GetNotifyCodeUserRate(ctx, userID)
|
||||||
// Generate code
|
if err == nil && count >= notifyCodeUserRateLimit {
|
||||||
code, err := emailService.GenerateVerifyCode()
|
return ErrNotifyCodeUserRateLimit
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("generate code: %w", err)
|
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Save to cache
|
// saveNotifyVerifyCode saves the verification code to cache.
|
||||||
|
func saveNotifyVerifyCode(ctx context.Context, cache EmailCache, email, code string) error {
|
||||||
data := &VerificationCodeData{
|
data := &VerificationCodeData{
|
||||||
Code: code,
|
Code: code,
|
||||||
Attempts: 0,
|
Attempts: 0,
|
||||||
@@ -293,16 +323,17 @@ func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, ema
|
|||||||
if err := cache.SetNotifyVerifyCode(ctx, email, data, verifyCodeTTL); err != nil {
|
if err := cache.SetNotifyVerifyCode(ctx, email, data, verifyCodeTTL); err != nil {
|
||||||
return fmt.Errorf("save verify code: %w", err)
|
return fmt.Errorf("save verify code: %w", err)
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Get site name
|
// sendNotifyVerifyEmail builds and sends the verification email.
|
||||||
|
func (s *UserService) sendNotifyVerifyEmail(ctx context.Context, emailService *EmailService, email, code string) error {
|
||||||
siteName := "Sub2API"
|
siteName := "Sub2API"
|
||||||
if s.settingRepo != nil {
|
if s.settingRepo != nil {
|
||||||
if name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName); err == nil && name != "" {
|
if name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName); err == nil && name != "" {
|
||||||
siteName = name
|
siteName = name
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build and send email
|
|
||||||
subject := fmt.Sprintf("[%s] 通知邮箱验证码 / Notification Email Verification", siteName)
|
subject := fmt.Sprintf("[%s] 通知邮箱验证码 / Notification Email Verification", siteName)
|
||||||
body := buildNotifyVerifyEmailBody(code, siteName)
|
body := buildNotifyVerifyEmailBody(code, siteName)
|
||||||
return emailService.SendEmail(ctx, email, subject, body)
|
return emailService.SendEmail(ctx, email, subject, body)
|
||||||
@@ -310,7 +341,15 @@ func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, ema
|
|||||||
|
|
||||||
// VerifyAndAddNotifyEmail verifies the code and adds the email to user's extra emails.
|
// VerifyAndAddNotifyEmail verifies the code and adds the email to user's extra emails.
|
||||||
func (s *UserService) VerifyAndAddNotifyEmail(ctx context.Context, userID int64, email, code string, cache EmailCache) error {
|
func (s *UserService) VerifyAndAddNotifyEmail(ctx context.Context, userID int64, email, code string, cache EmailCache) error {
|
||||||
// Verify code
|
if err := verifyNotifyCode(ctx, cache, email, code); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_ = cache.DeleteNotifyVerifyCode(ctx, email)
|
||||||
|
return s.addOrVerifyNotifyEmail(ctx, userID, email)
|
||||||
|
}
|
||||||
|
|
||||||
|
// verifyNotifyCode validates the verification code against the cached data.
|
||||||
|
func verifyNotifyCode(ctx context.Context, cache EmailCache, email, code string) error {
|
||||||
data, err := cache.GetNotifyVerifyCode(ctx, email)
|
data, err := cache.GetNotifyVerifyCode(ctx, email)
|
||||||
if err != nil || data == nil {
|
if err != nil || data == nil {
|
||||||
return ErrInvalidVerifyCode
|
return ErrInvalidVerifyCode
|
||||||
@@ -326,17 +365,18 @@ func (s *UserService) VerifyAndAddNotifyEmail(ctx context.Context, userID int64,
|
|||||||
}
|
}
|
||||||
return ErrInvalidVerifyCode
|
return ErrInvalidVerifyCode
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Delete code after verification
|
// addOrVerifyNotifyEmail adds the email to user's extra notification emails or marks it as verified.
|
||||||
_ = cache.DeleteNotifyVerifyCode(ctx, email)
|
// Note: concurrent calls for the same user could race on the read-modify-write of
|
||||||
|
// BalanceNotifyExtraEmails. The window is small (requires two verify flows completing
|
||||||
// Add to user's extra emails
|
// simultaneously), and the worst case is a duplicate entry which is harmless.
|
||||||
|
func (s *UserService) addOrVerifyNotifyEmail(ctx context.Context, userID int64, email string) error {
|
||||||
user, err := s.userRepo.GetByID(ctx, userID)
|
user, err := s.userRepo.GetByID(ctx, userID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if already exists — if unverified, mark as verified
|
|
||||||
for i, e := range user.BalanceNotifyExtraEmails {
|
for i, e := range user.BalanceNotifyExtraEmails {
|
||||||
if strings.EqualFold(e.Email, email) {
|
if strings.EqualFold(e.Email, email) {
|
||||||
if !e.Verified {
|
if !e.Verified {
|
||||||
@@ -346,12 +386,9 @@ func (s *UserService) VerifyAndAddNotifyEmail(ctx context.Context, userID int64,
|
|||||||
return nil // Already verified
|
return nil // Already verified
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check limit
|
|
||||||
if len(user.BalanceNotifyExtraEmails) >= maxNotifyEmails {
|
if len(user.BalanceNotifyExtraEmails) >= maxNotifyEmails {
|
||||||
return infraerrors.BadRequest("TOO_MANY_NOTIFY_EMAILS", fmt.Sprintf("maximum %d notification emails allowed", maxNotifyEmails))
|
return infraerrors.BadRequest("TOO_MANY_NOTIFY_EMAILS", fmt.Sprintf("maximum %d notification emails allowed", maxNotifyEmails))
|
||||||
}
|
}
|
||||||
|
|
||||||
user.BalanceNotifyExtraEmails = append(user.BalanceNotifyExtraEmails, NotifyEmailEntry{
|
user.BalanceNotifyExtraEmails = append(user.BalanceNotifyExtraEmails, NotifyEmailEntry{
|
||||||
Email: email,
|
Email: email,
|
||||||
Disabled: false,
|
Disabled: false,
|
||||||
@@ -399,10 +436,9 @@ func (s *UserService) ToggleNotifyEmail(ctx context.Context, userID int64, email
|
|||||||
return s.userRepo.Update(ctx, user)
|
return s.userRepo.Update(ctx, user)
|
||||||
}
|
}
|
||||||
|
|
||||||
// buildNotifyVerifyEmailBody builds the HTML email body for notify email verification.
|
// notifyVerifyEmailTemplate is the HTML template for notify email verification.
|
||||||
func buildNotifyVerifyEmailBody(code, siteName string) string {
|
// Format args: siteName, code.
|
||||||
return fmt.Sprintf(`
|
const notifyVerifyEmailTemplate = `<!DOCTYPE html>
|
||||||
<!DOCTYPE html>
|
|
||||||
<html>
|
<html>
|
||||||
<head>
|
<head>
|
||||||
<meta charset="UTF-8">
|
<meta charset="UTF-8">
|
||||||
@@ -439,6 +475,9 @@ func buildNotifyVerifyEmailBody(code, siteName string) string {
|
|||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</body>
|
</body>
|
||||||
</html>
|
</html>`
|
||||||
`, siteName, code)
|
|
||||||
|
// buildNotifyVerifyEmailBody builds the HTML email body for notify email verification.
|
||||||
|
func buildNotifyVerifyEmailBody(code, siteName string) string {
|
||||||
|
return fmt.Sprintf(notifyVerifyEmailTemplate, siteName, code)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -421,7 +421,7 @@
|
|||||||
</h4>
|
</h4>
|
||||||
<button
|
<button
|
||||||
type="button"
|
type="button"
|
||||||
@click="addAccountStatsRule()"
|
@click="addAccountStatsRule(sIdx)"
|
||||||
class="rounded-lg border border-primary-300 px-3 py-1 text-xs font-medium text-primary-600 hover:bg-primary-50 dark:border-primary-600 dark:text-primary-400 dark:hover:bg-primary-900/20"
|
class="rounded-lg border border-primary-300 px-3 py-1 text-xs font-medium text-primary-600 hover:bg-primary-50 dark:border-primary-600 dark:text-primary-400 dark:hover:bg-primary-900/20"
|
||||||
>
|
>
|
||||||
+ {{ t('admin.channels.form.addRule') }}
|
+ {{ t('admin.channels.form.addRule') }}
|
||||||
@@ -430,14 +430,14 @@
|
|||||||
|
|
||||||
<!-- Filter rules for this platform's groups -->
|
<!-- Filter rules for this platform's groups -->
|
||||||
<p
|
<p
|
||||||
v-if="form.account_stats_pricing_rules.length === 0"
|
v-if="section.account_stats_pricing_rules.length === 0"
|
||||||
class="text-xs italic text-gray-400 dark:text-gray-500"
|
class="text-xs italic text-gray-400 dark:text-gray-500"
|
||||||
>
|
>
|
||||||
{{ t('admin.channels.form.noRulesConfigured') }}
|
{{ t('admin.channels.form.noRulesConfigured') }}
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<div
|
<div
|
||||||
v-for="(rule, ruleIndex) in form.account_stats_pricing_rules"
|
v-for="(rule, ruleIndex) in section.account_stats_pricing_rules"
|
||||||
:key="ruleIndex"
|
:key="ruleIndex"
|
||||||
class="space-y-3 rounded-lg border border-gray-200 p-4 dark:border-dark-600"
|
class="space-y-3 rounded-lg border border-gray-200 p-4 dark:border-dark-600"
|
||||||
>
|
>
|
||||||
@@ -447,7 +447,7 @@
|
|||||||
:placeholder="t('admin.channels.form.ruleName')"
|
:placeholder="t('admin.channels.form.ruleName')"
|
||||||
class="bg-transparent text-sm font-medium text-gray-700 placeholder-gray-400 outline-none dark:text-gray-300"
|
class="bg-transparent text-sm font-medium text-gray-700 placeholder-gray-400 outline-none dark:text-gray-300"
|
||||||
/>
|
/>
|
||||||
<button type="button" @click="removeAccountStatsRule(ruleIndex)" class="text-xs text-red-500 hover:text-red-700">
|
<button type="button" @click="removeAccountStatsRule(sIdx, ruleIndex)" class="text-xs text-red-500 hover:text-red-700">
|
||||||
{{ t('common.delete') }}
|
{{ t('common.delete') }}
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
@@ -524,7 +524,7 @@
|
|||||||
<div>
|
<div>
|
||||||
<div class="mb-1 flex items-center justify-between">
|
<div class="mb-1 flex items-center justify-between">
|
||||||
<label class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.channels.form.ruleModelPricing') }}</label>
|
<label class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.channels.form.ruleModelPricing') }}</label>
|
||||||
<button type="button" @click="addRulePricingEntry(ruleIndex)" class="text-xs text-primary-600 hover:text-primary-700">
|
<button type="button" @click="addRulePricingEntry(sIdx, ruleIndex)" class="text-xs text-primary-600 hover:text-primary-700">
|
||||||
+ {{ t('common.add') }}
|
+ {{ t('common.add') }}
|
||||||
</button>
|
</button>
|
||||||
</div>
|
</div>
|
||||||
@@ -538,7 +538,7 @@
|
|||||||
:entry="entry"
|
:entry="entry"
|
||||||
:platform="section.platform"
|
:platform="section.platform"
|
||||||
@update="rule.pricing.splice(pIdx, 1, $event)"
|
@update="rule.pricing.splice(pIdx, 1, $event)"
|
||||||
@remove="removeRulePricingEntry(ruleIndex, pIdx)"
|
@remove="removeRulePricingEntry(sIdx, ruleIndex, pIdx)"
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
@@ -625,6 +625,14 @@ async function loadWebSearchGlobalState() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Form-level pricing rule type (per-platform) ──
|
||||||
|
interface FormPricingRule {
|
||||||
|
name: string
|
||||||
|
group_ids: number[]
|
||||||
|
account_ids: number[]
|
||||||
|
pricing: PricingFormEntry[]
|
||||||
|
}
|
||||||
|
|
||||||
// ── Platform Section type ──
|
// ── Platform Section type ──
|
||||||
interface PlatformSection {
|
interface PlatformSection {
|
||||||
platform: GroupPlatform
|
platform: GroupPlatform
|
||||||
@@ -634,6 +642,7 @@ interface PlatformSection {
|
|||||||
model_mapping: Record<string, string>
|
model_mapping: Record<string, string>
|
||||||
model_pricing: PricingFormEntry[]
|
model_pricing: PricingFormEntry[]
|
||||||
web_search_emulation: boolean
|
web_search_emulation: boolean
|
||||||
|
account_stats_pricing_rules: FormPricingRule[]
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── Table columns ──
|
// ── Table columns ──
|
||||||
@@ -703,12 +712,6 @@ const form = reactive({
|
|||||||
billing_model_source: 'channel_mapped' as string,
|
billing_model_source: 'channel_mapped' as string,
|
||||||
platforms: [] as PlatformSection[],
|
platforms: [] as PlatformSection[],
|
||||||
apply_pricing_to_account_stats: false,
|
apply_pricing_to_account_stats: false,
|
||||||
account_stats_pricing_rules: [] as Array<{
|
|
||||||
name: string
|
|
||||||
group_ids: number[]
|
|
||||||
account_ids: number[]
|
|
||||||
pricing: PricingFormEntry[]
|
|
||||||
}>
|
|
||||||
})
|
})
|
||||||
|
|
||||||
let abortController: AbortController | null = null
|
let abortController: AbortController | null = null
|
||||||
@@ -754,6 +757,7 @@ function addPlatformSection(platform: GroupPlatform) {
|
|||||||
model_mapping: {},
|
model_mapping: {},
|
||||||
model_pricing: [],
|
model_pricing: [],
|
||||||
web_search_emulation: false,
|
web_search_emulation: false,
|
||||||
|
account_stats_pricing_rules: [],
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -867,8 +871,8 @@ function renameMappingKey(sectionIdx: number, oldKey: string, newKey: string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ── Account Stats Pricing helpers ──
|
// ── Account Stats Pricing helpers ──
|
||||||
function addAccountStatsRule() {
|
function addAccountStatsRule(sectionIdx: number) {
|
||||||
form.account_stats_pricing_rules.push({
|
form.platforms[sectionIdx].account_stats_pricing_rules.push({
|
||||||
name: '',
|
name: '',
|
||||||
group_ids: [],
|
group_ids: [],
|
||||||
account_ids: [],
|
account_ids: [],
|
||||||
@@ -876,8 +880,8 @@ function addAccountStatsRule() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
function addRulePricingEntry(ruleIndex: number) {
|
function addRulePricingEntry(sectionIdx: number, ruleIndex: number) {
|
||||||
form.account_stats_pricing_rules[ruleIndex].pricing.push({
|
form.platforms[sectionIdx].account_stats_pricing_rules[ruleIndex].pricing.push({
|
||||||
models: [],
|
models: [],
|
||||||
billing_mode: 'token',
|
billing_mode: 'token',
|
||||||
input_price: null,
|
input_price: null,
|
||||||
@@ -890,15 +894,15 @@ function addRulePricingEntry(ruleIndex: number) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
function removeAccountStatsRule(ruleIndex: number) {
|
function removeAccountStatsRule(sectionIdx: number, ruleIndex: number) {
|
||||||
form.account_stats_pricing_rules.splice(ruleIndex, 1)
|
form.platforms[sectionIdx].account_stats_pricing_rules.splice(ruleIndex, 1)
|
||||||
// Clear all search state since indices shift after removal
|
// Clear all search state since indices shift after removal
|
||||||
ruleAccountSearchRunner.clearAll()
|
ruleAccountSearchRunner.clearAll()
|
||||||
clearAllRuleAccountSearchState()
|
clearAllRuleAccountSearchState()
|
||||||
}
|
}
|
||||||
|
|
||||||
function removeRulePricingEntry(ruleIndex: number, pricingIndex: number) {
|
function removeRulePricingEntry(sectionIdx: number, ruleIndex: number, pricingIndex: number) {
|
||||||
form.account_stats_pricing_rules[ruleIndex].pricing.splice(pricingIndex, 1)
|
form.platforms[sectionIdx].account_stats_pricing_rules[ruleIndex].pricing.splice(pricingIndex, 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
function getGroupNameById(groupId: number): string {
|
function getGroupNameById(groupId: number): string {
|
||||||
@@ -980,38 +984,33 @@ function clearAllRuleAccountSearchState() {
|
|||||||
showRuleAccountDropdown.value = {}
|
showRuleAccountDropdown.value = {}
|
||||||
}
|
}
|
||||||
|
|
||||||
function inferRulePlatform(groupIds: number[]): string {
|
|
||||||
const platforms = new Set<string>()
|
|
||||||
for (const gid of groupIds) {
|
|
||||||
const group = allGroups.value.find(g => g.id === gid)
|
|
||||||
if (group) platforms.add(group.platform)
|
|
||||||
}
|
|
||||||
return platforms.size === 1 ? [...platforms][0] : ''
|
|
||||||
}
|
|
||||||
|
|
||||||
function accountStatsRulesToAPI(): AccountStatsPricingRule[] {
|
function accountStatsRulesToAPI(): AccountStatsPricingRule[] {
|
||||||
return form.account_stats_pricing_rules.map(rule => {
|
const rules: AccountStatsPricingRule[] = []
|
||||||
const platform = inferRulePlatform(rule.group_ids)
|
for (const section of form.platforms) {
|
||||||
return {
|
if (!section.enabled) continue
|
||||||
name: rule.name,
|
for (const rule of section.account_stats_pricing_rules) {
|
||||||
group_ids: rule.group_ids,
|
rules.push({
|
||||||
account_ids: rule.account_ids,
|
name: rule.name,
|
||||||
pricing: rule.pricing
|
group_ids: rule.group_ids,
|
||||||
.filter(p => p.models.length > 0)
|
account_ids: rule.account_ids,
|
||||||
.map(p => ({
|
pricing: rule.pricing
|
||||||
platform,
|
.filter(p => p.models.length > 0)
|
||||||
models: p.models,
|
.map(p => ({
|
||||||
billing_mode: p.billing_mode,
|
platform: section.platform,
|
||||||
input_price: mTokToPerToken(p.input_price),
|
models: p.models,
|
||||||
output_price: mTokToPerToken(p.output_price),
|
billing_mode: p.billing_mode,
|
||||||
cache_write_price: mTokToPerToken(p.cache_write_price),
|
input_price: mTokToPerToken(p.input_price),
|
||||||
cache_read_price: mTokToPerToken(p.cache_read_price),
|
output_price: mTokToPerToken(p.output_price),
|
||||||
image_output_price: mTokToPerToken(p.image_output_price),
|
cache_write_price: mTokToPerToken(p.cache_write_price),
|
||||||
per_request_price: p.per_request_price != null && p.per_request_price !== '' ? Number(p.per_request_price) : null,
|
cache_read_price: mTokToPerToken(p.cache_read_price),
|
||||||
intervals: formIntervalsToAPI(p.intervals || [])
|
image_output_price: mTokToPerToken(p.image_output_price),
|
||||||
}))
|
per_request_price: p.per_request_price != null && p.per_request_price !== '' ? Number(p.per_request_price) : null,
|
||||||
|
intervals: formIntervalsToAPI(p.intervals || [])
|
||||||
|
}))
|
||||||
|
})
|
||||||
}
|
}
|
||||||
})
|
}
|
||||||
|
return rules
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── Form ↔ API conversion ──
|
// ── Form ↔ API conversion ──
|
||||||
@@ -1120,6 +1119,7 @@ function apiToForm(channel: Channel): PlatformSection[] {
|
|||||||
model_mapping: { ...mapping },
|
model_mapping: { ...mapping },
|
||||||
model_pricing: pricing,
|
model_pricing: pricing,
|
||||||
web_search_emulation: webSearchEnabled,
|
web_search_emulation: webSearchEnabled,
|
||||||
|
account_stats_pricing_rules: [],
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1213,7 +1213,6 @@ function resetForm() {
|
|||||||
form.billing_model_source = 'channel_mapped'
|
form.billing_model_source = 'channel_mapped'
|
||||||
form.platforms = []
|
form.platforms = []
|
||||||
form.apply_pricing_to_account_stats = false
|
form.apply_pricing_to_account_stats = false
|
||||||
form.account_stats_pricing_rules = []
|
|
||||||
activeTab.value = 'basic'
|
activeTab.value = 'basic'
|
||||||
ruleAccountSearchRunner.clearAll()
|
ruleAccountSearchRunner.clearAll()
|
||||||
clearAllRuleAccountSearchState()
|
clearAllRuleAccountSearchState()
|
||||||
@@ -1235,28 +1234,91 @@ async function openEditDialog(channel: Channel) {
|
|||||||
form.restrict_models = channel.restrict_models || false
|
form.restrict_models = channel.restrict_models || false
|
||||||
form.billing_model_source = channel.billing_model_source || 'channel_mapped'
|
form.billing_model_source = channel.billing_model_source || 'channel_mapped'
|
||||||
form.apply_pricing_to_account_stats = channel.apply_pricing_to_account_stats || false
|
form.apply_pricing_to_account_stats = channel.apply_pricing_to_account_stats || false
|
||||||
form.account_stats_pricing_rules = (channel.account_stats_pricing_rules || []).map(rule => ({
|
|
||||||
name: rule.name || '',
|
|
||||||
group_ids: [...(rule.group_ids || [])],
|
|
||||||
account_ids: [...(rule.account_ids || [])],
|
|
||||||
pricing: (rule.pricing || []).map(p => ({
|
|
||||||
models: [...(p.models || [])],
|
|
||||||
billing_mode: p.billing_mode,
|
|
||||||
input_price: perTokenToMTok(p.input_price),
|
|
||||||
output_price: perTokenToMTok(p.output_price),
|
|
||||||
cache_write_price: perTokenToMTok(p.cache_write_price),
|
|
||||||
cache_read_price: perTokenToMTok(p.cache_read_price),
|
|
||||||
image_output_price: perTokenToMTok(p.image_output_price),
|
|
||||||
per_request_price: p.per_request_price,
|
|
||||||
intervals: apiIntervalsToForm(p.intervals || [])
|
|
||||||
} as PricingFormEntry))
|
|
||||||
}))
|
|
||||||
// Must load groups first so apiToForm can map groupID → platform
|
// Must load groups first so apiToForm can map groupID → platform
|
||||||
await Promise.all([loadGroups(), loadAllChannelsForConflict()])
|
await Promise.all([loadGroups(), loadAllChannelsForConflict()])
|
||||||
form.platforms = apiToForm(channel)
|
form.platforms = apiToForm(channel)
|
||||||
|
|
||||||
|
// Distribute channel-level rules into per-platform sections
|
||||||
|
distributeRulesToPlatforms(channel.account_stats_pricing_rules || [])
|
||||||
|
|
||||||
|
// Populate ruleAccountNameCache for existing rule accounts
|
||||||
|
await populateRuleAccountNameCache()
|
||||||
|
|
||||||
showDialog.value = true
|
showDialog.value = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Distribute flat channel-level rules into the matching platform section based on group_ids */
|
||||||
|
function distributeRulesToPlatforms(apiRules: AccountStatsPricingRule[]) {
|
||||||
|
// Build groupID → platform lookup
|
||||||
|
const groupPlatformMap = new Map<number, GroupPlatform>()
|
||||||
|
for (const g of allGroups.value) {
|
||||||
|
groupPlatformMap.set(g.id, g.platform)
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const apiRule of apiRules) {
|
||||||
|
// Infer platform from group_ids
|
||||||
|
const platforms = new Set<GroupPlatform>()
|
||||||
|
for (const gid of apiRule.group_ids || []) {
|
||||||
|
const p = groupPlatformMap.get(gid)
|
||||||
|
if (p) platforms.add(p)
|
||||||
|
}
|
||||||
|
// If pricing has a platform field, use that as fallback
|
||||||
|
if (platforms.size === 0 && apiRule.pricing?.length > 0) {
|
||||||
|
const p = apiRule.pricing[0].platform as GroupPlatform | undefined
|
||||||
|
if (p) platforms.add(p)
|
||||||
|
}
|
||||||
|
const targetPlatform = platforms.size >= 1 ? [...platforms][0] : null
|
||||||
|
if (!targetPlatform) continue
|
||||||
|
|
||||||
|
const section = form.platforms.find(s => s.platform === targetPlatform)
|
||||||
|
if (!section) continue
|
||||||
|
|
||||||
|
const formRule: FormPricingRule = {
|
||||||
|
name: apiRule.name || '',
|
||||||
|
group_ids: [...(apiRule.group_ids || [])],
|
||||||
|
account_ids: [...(apiRule.account_ids || [])],
|
||||||
|
pricing: (apiRule.pricing || []).map(p => ({
|
||||||
|
models: [...(p.models || [])],
|
||||||
|
billing_mode: p.billing_mode,
|
||||||
|
input_price: perTokenToMTok(p.input_price),
|
||||||
|
output_price: perTokenToMTok(p.output_price),
|
||||||
|
cache_write_price: perTokenToMTok(p.cache_write_price),
|
||||||
|
cache_read_price: perTokenToMTok(p.cache_read_price),
|
||||||
|
image_output_price: perTokenToMTok(p.image_output_price),
|
||||||
|
per_request_price: p.per_request_price,
|
||||||
|
intervals: apiIntervalsToForm(p.intervals || [])
|
||||||
|
} as PricingFormEntry))
|
||||||
|
}
|
||||||
|
section.account_stats_pricing_rules.push(formRule)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Populate ruleAccountNameCache by fetching account details for all account_ids in rules */
|
||||||
|
async function populateRuleAccountNameCache() {
|
||||||
|
const allAccountIds = new Set<number>()
|
||||||
|
for (const section of form.platforms) {
|
||||||
|
for (const rule of section.account_stats_pricing_rules) {
|
||||||
|
for (const id of rule.account_ids) {
|
||||||
|
allAccountIds.add(id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (allAccountIds.size === 0) return
|
||||||
|
|
||||||
|
// Fetch account details in parallel (batch of individual getById calls)
|
||||||
|
const ids = [...allAccountIds]
|
||||||
|
const results = await Promise.allSettled(
|
||||||
|
ids.map(id => adminAPI.accounts.getById(id))
|
||||||
|
)
|
||||||
|
for (let i = 0; i < ids.length; i++) {
|
||||||
|
const result = results[i]
|
||||||
|
if (result.status === 'fulfilled') {
|
||||||
|
ruleAccountNameCache.value[ids[i]] = result.value.name
|
||||||
|
}
|
||||||
|
// If rejected, the cache won't have the name, so it'll show "#ID" which is acceptable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
function closeDialog() {
|
function closeDialog() {
|
||||||
showDialog.value = false
|
showDialog.value = false
|
||||||
editingChannel.value = null
|
editingChannel.value = null
|
||||||
|
|||||||
Reference in New Issue
Block a user