fix: audit fixes for websearch, notifications, and channel pricing

P0: fix wildcard matching test assertion (config order, not longest prefix)
P0: add TotalRecharged to auth cache snapshot (v5) for percentage threshold
P1: move pricing rules into per-platform sections in ChannelsView
P1: populate account name cache when editing existing channel rules
P1: sanitize email subject headers to prevent SMTP injection
P1: make Redis INCR+EXPIRE idempotent for rate limiting
P1: deep copy FeaturesConfig in Channel.Clone()
P2: clean up stale email="" placeholder comments
P2: replace log.Printf with slog in email_service.go
This commit is contained in:
erio
2026-04-13 13:59:35 +08:00
parent a68df457d8
commit b7fb2e4387
13 changed files with 273 additions and 118 deletions

View File

@@ -1 +1 @@
0.1.110.11 0.1.110.20

View File

@@ -3,7 +3,7 @@ package dto
import "github.com/Wei-Shaw/sub2api/internal/service" import "github.com/Wei-Shaw/sub2api/internal/service"
// NotifyEmailEntry represents a notification email with enable/disable and verification state. // NotifyEmailEntry represents a notification email with enable/disable and verification state.
// Email="" is a placeholder for the "primary email" (user's registration email or first admin email). // All emails are user-managed; maximum 3 entries per user.
type NotifyEmailEntry struct { type NotifyEmailEntry struct {
Email string `json:"email"` Email string `json:"email"`
Disabled bool `json:"disabled"` Disabled bool `json:"disabled"`

View File

@@ -217,7 +217,7 @@ func (h *UserHandler) RemoveNotifyEmail(c *gin.Context) {
// ToggleNotifyEmailRequest represents the request to toggle a notify email's disabled state // ToggleNotifyEmailRequest represents the request to toggle a notify email's disabled state
type ToggleNotifyEmailRequest struct { type ToggleNotifyEmailRequest struct {
Email string `json:"email"` // empty string for primary email placeholder Email string `json:"email" binding:"required,email"`
Disabled bool `json:"disabled"` Disabled bool `json:"disabled"`
} }

View File

@@ -3,6 +3,7 @@ package repository
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"fmt"
"time" "time"
"github.com/Wei-Shaw/sub2api/internal/service" "github.com/Wei-Shaw/sub2api/internal/service"
@@ -10,10 +11,11 @@ import (
) )
const ( const (
verifyCodeKeyPrefix = "verify_code:" verifyCodeKeyPrefix = "verify_code:"
notifyVerifyKeyPrefix = "notify_verify:" notifyVerifyKeyPrefix = "notify_verify:"
passwordResetKeyPrefix = "password_reset:" passwordResetKeyPrefix = "password_reset:"
passwordResetSentAtKeyPrefix = "password_reset_sent:" passwordResetSentAtKeyPrefix = "password_reset_sent:"
notifyCodeUserRateKeyPrefix = "notify_code_user_rate:"
) )
// verifyCodeKey generates the Redis key for email verification code. // verifyCodeKey generates the Redis key for email verification code.
@@ -141,3 +143,31 @@ func (c *emailCache) DeleteNotifyVerifyCode(ctx context.Context, email string) e
key := notifyVerifyKey(email) key := notifyVerifyKey(email)
return c.rdb.Del(ctx, key).Err() return c.rdb.Del(ctx, key).Err()
} }
// User-level rate limiting for notify email verification codes
func notifyCodeUserRateKey(userID int64) string {
return notifyCodeUserRateKeyPrefix + fmt.Sprintf("%d", userID)
}
func (c *emailCache) IncrNotifyCodeUserRate(ctx context.Context, userID int64, window time.Duration) (int64, error) {
key := notifyCodeUserRateKey(userID)
count, err := c.rdb.Incr(ctx, key).Result()
if err != nil {
return 0, err
}
// Always set TTL (idempotent) to avoid orphan keys if process crashes between INCR and EXPIRE.
if err := c.rdb.Expire(ctx, key, window).Err(); err != nil {
return count, fmt.Errorf("expire notify code rate key: %w", err)
}
return count, nil
}
func (c *emailCache) GetNotifyCodeUserRate(ctx context.Context, userID int64) (int64, error) {
key := notifyCodeUserRateKey(userID)
count, err := c.rdb.Get(ctx, key).Int64()
if err != nil {
return 0, err
}
return count, nil
}

View File

@@ -145,14 +145,14 @@ func TestFindPricingForModel(t *testing.T) {
wantNil: true, wantNil: true,
}, },
{ {
name: "longer wildcard prefix wins over shorter", name: "wildcard matches by config order (first match wins)",
list: []ChannelModelPricing{ list: []ChannelModelPricing{
{ID: 10, Models: []string{"claude-*"}}, {ID: 10, Models: []string{"claude-*"}},
{ID: 11, Models: []string{"claude-opus-*"}}, {ID: 11, Models: []string{"claude-opus-*"}},
}, },
platform: "", platform: "",
model: "claude-opus-4", model: "claude-opus-4",
wantID: 11, // "claude-opus-" (12 chars) > "claude-" (7 chars) wantID: 10, // config order: "claude-*" is first and matches, so it wins
}, },
{ {
name: "shorter wildcard used when longer does not match", name: "shorter wildcard used when longer does not match",

View File

@@ -42,6 +42,7 @@ type APIKeyAuthUserSnapshot struct {
BalanceNotifyThresholdType string `json:"balance_notify_threshold_type"` BalanceNotifyThresholdType string `json:"balance_notify_threshold_type"`
BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"` BalanceNotifyThreshold *float64 `json:"balance_notify_threshold,omitempty"`
BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails,omitempty"` BalanceNotifyExtraEmails []NotifyEmailEntry `json:"balance_notify_extra_emails,omitempty"`
TotalRecharged float64 `json:"total_recharged"`
} }
// APIKeyAuthGroupSnapshot 分组快照 // APIKeyAuthGroupSnapshot 分组快照

View File

@@ -13,7 +13,7 @@ import (
"github.com/dgraph-io/ristretto" "github.com/dgraph-io/ristretto"
) )
const apiKeyAuthSnapshotVersion = 4 // v4: added balance notification fields to UserSnapshot const apiKeyAuthSnapshotVersion = 5 // v5: added TotalRecharged for percentage threshold
type apiKeyAuthCacheConfig struct { type apiKeyAuthCacheConfig struct {
l1Size int l1Size int
@@ -230,6 +230,7 @@ func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
BalanceNotifyThresholdType: apiKey.User.BalanceNotifyThresholdType, BalanceNotifyThresholdType: apiKey.User.BalanceNotifyThresholdType,
BalanceNotifyThreshold: apiKey.User.BalanceNotifyThreshold, BalanceNotifyThreshold: apiKey.User.BalanceNotifyThreshold,
BalanceNotifyExtraEmails: apiKey.User.BalanceNotifyExtraEmails, BalanceNotifyExtraEmails: apiKey.User.BalanceNotifyExtraEmails,
TotalRecharged: apiKey.User.TotalRecharged,
}, },
} }
if apiKey.Group != nil { if apiKey.Group != nil {
@@ -291,6 +292,7 @@ func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapsho
BalanceNotifyThresholdType: snapshot.User.BalanceNotifyThresholdType, BalanceNotifyThresholdType: snapshot.User.BalanceNotifyThresholdType,
BalanceNotifyThreshold: snapshot.User.BalanceNotifyThreshold, BalanceNotifyThreshold: snapshot.User.BalanceNotifyThreshold,
BalanceNotifyExtraEmails: snapshot.User.BalanceNotifyExtraEmails, BalanceNotifyExtraEmails: snapshot.User.BalanceNotifyExtraEmails,
TotalRecharged: snapshot.User.TotalRecharged,
}, },
} }
if snapshot.Group != nil { if snapshot.Group != nil {

View File

@@ -309,7 +309,7 @@ func (s *BalanceNotifyService) sendBalanceLowEmails(recipients []string, userNam
if displayName == "" { if displayName == "" {
displayName = userEmail displayName = userEmail
} }
subject := fmt.Sprintf("[%s] 余额不足提醒 / Balance Low Alert", siteName) subject := fmt.Sprintf("[%s] 余额不足提醒 / Balance Low Alert", sanitizeEmailHeader(siteName))
body := s.buildBalanceLowEmailBody(html.EscapeString(displayName), balance, threshold, html.EscapeString(siteName)) body := s.buildBalanceLowEmailBody(html.EscapeString(displayName), balance, threshold, html.EscapeString(siteName))
s.sendEmails(recipients, subject, body, "user_email", userEmail, "balance", balance) s.sendEmails(recipients, subject, body, "user_email", userEmail, "balance", balance)
} }
@@ -321,11 +321,16 @@ func (s *BalanceNotifyService) sendQuotaAlertEmails(adminEmails []string, accoun
dimLabel = dimension dimLabel = dimension
} }
subject := fmt.Sprintf("[%s] 账号限额告警 / Account Quota Alert - %s", siteName, accountName) subject := fmt.Sprintf("[%s] 账号限额告警 / Account Quota Alert - %s", sanitizeEmailHeader(siteName), sanitizeEmailHeader(accountName))
body := s.buildQuotaAlertEmailBody(html.EscapeString(accountName), html.EscapeString(dimLabel), used, limit, threshold, html.EscapeString(siteName)) body := s.buildQuotaAlertEmailBody(html.EscapeString(accountName), html.EscapeString(dimLabel), used, limit, threshold, html.EscapeString(siteName))
s.sendEmails(adminEmails, subject, body, "account", accountName, "dimension", dimension) s.sendEmails(adminEmails, subject, body, "account", accountName, "dimension", dimension)
} }
// sanitizeEmailHeader removes CR/LF characters to prevent SMTP header injection.
func sanitizeEmailHeader(s string) string {
return strings.NewReplacer("\r", "", "\n", "").Replace(s)
}
// balanceLowEmailTemplate is the HTML template for balance low notifications. // balanceLowEmailTemplate is the HTML template for balance low notifications.
// Format args: siteName, userName, userName, balance, threshold, threshold. // Format args: siteName, userName, userName, balance, threshold, threshold.
const balanceLowEmailTemplate = `<!DOCTYPE html> const balanceLowEmailTemplate = `<!DOCTYPE html>

View File

@@ -196,6 +196,9 @@ func (c *Channel) Clone() *Channel {
cp.ModelMapping[platform] = inner cp.ModelMapping[platform] = inner
} }
} }
if c.FeaturesConfig != nil {
cp.FeaturesConfig = deepCopyFeaturesConfig(c.FeaturesConfig)
}
if c.AccountStatsPricingRules != nil { if c.AccountStatsPricingRules != nil {
cp.AccountStatsPricingRules = make([]AccountStatsPricingRule, len(c.AccountStatsPricingRules)) cp.AccountStatsPricingRules = make([]AccountStatsPricingRule, len(c.AccountStatsPricingRules))
for i, rule := range c.AccountStatsPricingRules { for i, rule := range c.AccountStatsPricingRules {
@@ -219,6 +222,19 @@ func (c *Channel) Clone() *Channel {
return &cp return &cp
} }
// deepCopyFeaturesConfig creates a deep copy of FeaturesConfig to prevent cache pollution.
func deepCopyFeaturesConfig(src map[string]any) map[string]any {
dst := make(map[string]any, len(src))
for k, v := range src {
if inner, ok := v.(map[string]any); ok {
dst[k] = deepCopyFeaturesConfig(inner)
} else {
dst[k] = v
}
}
return dst
}
// ValidateIntervals 校验区间列表的合法性。 // ValidateIntervals 校验区间列表的合法性。
// 规则MinTokens >= 0MaxTokens 若非 nil 则 > 0 且 > MinTokens // 规则MinTokens >= 0MaxTokens 若非 nil 则 > 0 且 > MinTokens
// 所有价格字段 >= 0区间按 MinTokens 排序后无重叠((min, max] 语义); // 所有价格字段 >= 0区间按 MinTokens 排序后无重叠((min, max] 语义);

View File

@@ -7,7 +7,7 @@ import (
"crypto/tls" "crypto/tls"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"log" "log/slog"
"math/big" "math/big"
"net/smtp" "net/smtp"
"net/url" "net/url"
@@ -292,7 +292,7 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 { if subtle.ConstantTimeCompare([]byte(data.Code), []byte(code)) != 1 {
data.Attempts++ data.Attempts++
if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil { if err := s.cache.SetVerificationCode(ctx, email, data, verifyCodeTTL); err != nil {
log.Printf("[Email] Failed to update verification attempt count: %v", err) slog.Error("failed to update verification attempt count", "email", email, "error", err)
} }
if data.Attempts >= maxVerifyCodeAttempts { if data.Attempts >= maxVerifyCodeAttempts {
return ErrVerifyCodeMaxAttempts return ErrVerifyCodeMaxAttempts
@@ -302,7 +302,7 @@ func (s *EmailService) VerifyCode(ctx context.Context, email, code string) error
// 验证成功,删除验证码 // 验证成功,删除验证码
if err := s.cache.DeleteVerificationCode(ctx, email); err != nil { if err := s.cache.DeleteVerificationCode(ctx, email); err != nil {
log.Printf("[Email] Failed to delete verification code after success: %v", err) slog.Error("failed to delete verification code after success", "email", email, "error", err)
} }
return nil return nil
} }
@@ -452,7 +452,7 @@ func (s *EmailService) SendPasswordResetEmail(ctx context.Context, email, siteNa
func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, email, siteName, resetURL string) error { func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, email, siteName, resetURL string) error {
// Check email cooldown to prevent email bombing // Check email cooldown to prevent email bombing
if s.cache.IsPasswordResetEmailInCooldown(ctx, email) { if s.cache.IsPasswordResetEmailInCooldown(ctx, email) {
log.Printf("[Email] Password reset email skipped (cooldown): %s", email) slog.Info("password reset email skipped due to cooldown", "email", email)
return nil // Silent success to prevent revealing cooldown to attackers return nil // Silent success to prevent revealing cooldown to attackers
} }
@@ -463,7 +463,7 @@ func (s *EmailService) SendPasswordResetEmailWithCooldown(ctx context.Context, e
// Set cooldown marker (Redis TTL handles expiration) // Set cooldown marker (Redis TTL handles expiration)
if err := s.cache.SetPasswordResetEmailCooldown(ctx, email, passwordResetEmailCooldown); err != nil { if err := s.cache.SetPasswordResetEmailCooldown(ctx, email, passwordResetEmailCooldown); err != nil {
log.Printf("[Email] Failed to set password reset cooldown for %s: %v", email, err) slog.Error("failed to set password reset cooldown", "email", email, "error", err)
} }
return nil return nil
@@ -493,7 +493,7 @@ func (s *EmailService) ConsumePasswordResetToken(ctx context.Context, email, tok
// Delete after verification (one-time use) // Delete after verification (one-time use)
if err := s.cache.DeletePasswordResetToken(ctx, email); err != nil { if err := s.cache.DeletePasswordResetToken(ctx, email); err != nil {
log.Printf("[Email] Failed to delete password reset token after consumption: %v", err) slog.Error("failed to delete password reset token after consumption", "email", email, "error", err)
} }
return nil return nil
} }

View File

@@ -6,7 +6,7 @@ import (
) )
// NotifyEmailEntry represents a notification email with enable/disable and verification state. // NotifyEmailEntry represents a notification email with enable/disable and verification state.
// Email="" is a placeholder for the "primary email" (user's registration email or first admin email). // All emails are user-managed; maximum 3 entries per user.
type NotifyEmailEntry struct { type NotifyEmailEntry struct {
Email string `json:"email"` Email string `json:"email"`
Disabled bool `json:"disabled"` Disabled bool `json:"disabled"`

View File

@@ -4,7 +4,7 @@ import (
"context" "context"
"crypto/subtle" "crypto/subtle"
"fmt" "fmt"
"log" "log/slog"
"strings" "strings"
"time" "time"
@@ -13,12 +13,19 @@ import (
) )
var ( var (
ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found") ErrUserNotFound = infraerrors.NotFound("USER_NOT_FOUND", "user not found")
ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect") ErrPasswordIncorrect = infraerrors.BadRequest("PASSWORD_INCORRECT", "current password is incorrect")
ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions") ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
ErrNotifyCodeUserRateLimit = infraerrors.TooManyRequests("NOTIFY_CODE_USER_RATE_LIMIT", "too many verification codes requested, please try again later")
) )
const maxNotifyEmails = 3 // Total limit: primary (email="") + up to 2 extra const (
maxNotifyEmails = 3 // Maximum number of notification emails per user
// User-level rate limiting for notify email verification codes
notifyCodeUserRateLimit = 5
notifyCodeUserRateWindow = 10 * time.Minute
)
// UserListFilters contains all filter options for listing users // UserListFilters contains all filter options for listing users
type UserListFilters struct { type UserListFilters struct {
@@ -220,7 +227,7 @@ func (s *UserService) UpdateBalance(ctx context.Context, userID int64, amount fl
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
if err := s.billingCache.InvalidateUserBalance(cacheCtx, userID); err != nil { if err := s.billingCache.InvalidateUserBalance(cacheCtx, userID); err != nil {
log.Printf("invalidate user balance cache failed: user_id=%d err=%v", userID, err) slog.Error("invalidate user balance cache failed", "user_id", userID, "error", err)
} }
}() }()
} }
@@ -270,21 +277,44 @@ func (s *UserService) Delete(ctx context.Context, userID int64) error {
// SendNotifyEmailCode sends a verification code to the extra notification email. // SendNotifyEmailCode sends a verification code to the extra notification email.
func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, email string, emailService *EmailService, cache EmailCache) error { func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, email string, emailService *EmailService, cache EmailCache) error {
// Check cooldown if err := checkNotifyCodeRateLimit(ctx, cache, userID, email); err != nil {
return err
}
code, err := emailService.GenerateVerifyCode()
if err != nil {
return fmt.Errorf("generate code: %w", err)
}
if err := saveNotifyVerifyCode(ctx, cache, email, code); err != nil {
return err
}
// Increment user-level counter after successful save
if _, err := cache.IncrNotifyCodeUserRate(ctx, userID, notifyCodeUserRateWindow); err != nil {
slog.Error("failed to increment notify code user rate", "user_id", userID, "error", err)
}
return s.sendNotifyVerifyEmail(ctx, emailService, email, code)
}
// checkNotifyCodeRateLimit checks both email cooldown and user-level rate limit.
func checkNotifyCodeRateLimit(ctx context.Context, cache EmailCache, userID int64, email string) error {
existing, err := cache.GetNotifyVerifyCode(ctx, email) existing, err := cache.GetNotifyVerifyCode(ctx, email)
if err == nil && existing != nil { if err == nil && existing != nil {
if time.Since(existing.CreatedAt) < verifyCodeCooldown { if time.Since(existing.CreatedAt) < verifyCodeCooldown {
return ErrVerifyCodeTooFrequent return ErrVerifyCodeTooFrequent
} }
} }
count, err := cache.GetNotifyCodeUserRate(ctx, userID)
// Generate code if err == nil && count >= notifyCodeUserRateLimit {
code, err := emailService.GenerateVerifyCode() return ErrNotifyCodeUserRateLimit
if err != nil {
return fmt.Errorf("generate code: %w", err)
} }
return nil
}
// Save to cache // saveNotifyVerifyCode saves the verification code to cache.
func saveNotifyVerifyCode(ctx context.Context, cache EmailCache, email, code string) error {
data := &VerificationCodeData{ data := &VerificationCodeData{
Code: code, Code: code,
Attempts: 0, Attempts: 0,
@@ -293,16 +323,17 @@ func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, ema
if err := cache.SetNotifyVerifyCode(ctx, email, data, verifyCodeTTL); err != nil { if err := cache.SetNotifyVerifyCode(ctx, email, data, verifyCodeTTL); err != nil {
return fmt.Errorf("save verify code: %w", err) return fmt.Errorf("save verify code: %w", err)
} }
return nil
}
// Get site name // sendNotifyVerifyEmail builds and sends the verification email.
func (s *UserService) sendNotifyVerifyEmail(ctx context.Context, emailService *EmailService, email, code string) error {
siteName := "Sub2API" siteName := "Sub2API"
if s.settingRepo != nil { if s.settingRepo != nil {
if name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName); err == nil && name != "" { if name, err := s.settingRepo.GetValue(ctx, SettingKeySiteName); err == nil && name != "" {
siteName = name siteName = name
} }
} }
// Build and send email
subject := fmt.Sprintf("[%s] 通知邮箱验证码 / Notification Email Verification", siteName) subject := fmt.Sprintf("[%s] 通知邮箱验证码 / Notification Email Verification", siteName)
body := buildNotifyVerifyEmailBody(code, siteName) body := buildNotifyVerifyEmailBody(code, siteName)
return emailService.SendEmail(ctx, email, subject, body) return emailService.SendEmail(ctx, email, subject, body)
@@ -310,7 +341,15 @@ func (s *UserService) SendNotifyEmailCode(ctx context.Context, userID int64, ema
// VerifyAndAddNotifyEmail verifies the code and adds the email to user's extra emails. // VerifyAndAddNotifyEmail verifies the code and adds the email to user's extra emails.
func (s *UserService) VerifyAndAddNotifyEmail(ctx context.Context, userID int64, email, code string, cache EmailCache) error { func (s *UserService) VerifyAndAddNotifyEmail(ctx context.Context, userID int64, email, code string, cache EmailCache) error {
// Verify code if err := verifyNotifyCode(ctx, cache, email, code); err != nil {
return err
}
_ = cache.DeleteNotifyVerifyCode(ctx, email)
return s.addOrVerifyNotifyEmail(ctx, userID, email)
}
// verifyNotifyCode validates the verification code against the cached data.
func verifyNotifyCode(ctx context.Context, cache EmailCache, email, code string) error {
data, err := cache.GetNotifyVerifyCode(ctx, email) data, err := cache.GetNotifyVerifyCode(ctx, email)
if err != nil || data == nil { if err != nil || data == nil {
return ErrInvalidVerifyCode return ErrInvalidVerifyCode
@@ -326,17 +365,18 @@ func (s *UserService) VerifyAndAddNotifyEmail(ctx context.Context, userID int64,
} }
return ErrInvalidVerifyCode return ErrInvalidVerifyCode
} }
return nil
}
// Delete code after verification // addOrVerifyNotifyEmail adds the email to user's extra notification emails or marks it as verified.
_ = cache.DeleteNotifyVerifyCode(ctx, email) // Note: concurrent calls for the same user could race on the read-modify-write of
// BalanceNotifyExtraEmails. The window is small (requires two verify flows completing
// Add to user's extra emails // simultaneously), and the worst case is a duplicate entry which is harmless.
func (s *UserService) addOrVerifyNotifyEmail(ctx context.Context, userID int64, email string) error {
user, err := s.userRepo.GetByID(ctx, userID) user, err := s.userRepo.GetByID(ctx, userID)
if err != nil { if err != nil {
return err return err
} }
// Check if already exists — if unverified, mark as verified
for i, e := range user.BalanceNotifyExtraEmails { for i, e := range user.BalanceNotifyExtraEmails {
if strings.EqualFold(e.Email, email) { if strings.EqualFold(e.Email, email) {
if !e.Verified { if !e.Verified {
@@ -346,12 +386,9 @@ func (s *UserService) VerifyAndAddNotifyEmail(ctx context.Context, userID int64,
return nil // Already verified return nil // Already verified
} }
} }
// Check limit
if len(user.BalanceNotifyExtraEmails) >= maxNotifyEmails { if len(user.BalanceNotifyExtraEmails) >= maxNotifyEmails {
return infraerrors.BadRequest("TOO_MANY_NOTIFY_EMAILS", fmt.Sprintf("maximum %d notification emails allowed", maxNotifyEmails)) return infraerrors.BadRequest("TOO_MANY_NOTIFY_EMAILS", fmt.Sprintf("maximum %d notification emails allowed", maxNotifyEmails))
} }
user.BalanceNotifyExtraEmails = append(user.BalanceNotifyExtraEmails, NotifyEmailEntry{ user.BalanceNotifyExtraEmails = append(user.BalanceNotifyExtraEmails, NotifyEmailEntry{
Email: email, Email: email,
Disabled: false, Disabled: false,
@@ -399,10 +436,9 @@ func (s *UserService) ToggleNotifyEmail(ctx context.Context, userID int64, email
return s.userRepo.Update(ctx, user) return s.userRepo.Update(ctx, user)
} }
// buildNotifyVerifyEmailBody builds the HTML email body for notify email verification. // notifyVerifyEmailTemplate is the HTML template for notify email verification.
func buildNotifyVerifyEmailBody(code, siteName string) string { // Format args: siteName, code.
return fmt.Sprintf(` const notifyVerifyEmailTemplate = `<!DOCTYPE html>
<!DOCTYPE html>
<html> <html>
<head> <head>
<meta charset="UTF-8"> <meta charset="UTF-8">
@@ -439,6 +475,9 @@ func buildNotifyVerifyEmailBody(code, siteName string) string {
</div> </div>
</div> </div>
</body> </body>
</html> </html>`
`, siteName, code)
// buildNotifyVerifyEmailBody builds the HTML email body for notify email verification.
func buildNotifyVerifyEmailBody(code, siteName string) string {
return fmt.Sprintf(notifyVerifyEmailTemplate, siteName, code)
} }

View File

@@ -421,7 +421,7 @@
</h4> </h4>
<button <button
type="button" type="button"
@click="addAccountStatsRule()" @click="addAccountStatsRule(sIdx)"
class="rounded-lg border border-primary-300 px-3 py-1 text-xs font-medium text-primary-600 hover:bg-primary-50 dark:border-primary-600 dark:text-primary-400 dark:hover:bg-primary-900/20" class="rounded-lg border border-primary-300 px-3 py-1 text-xs font-medium text-primary-600 hover:bg-primary-50 dark:border-primary-600 dark:text-primary-400 dark:hover:bg-primary-900/20"
> >
+ {{ t('admin.channels.form.addRule') }} + {{ t('admin.channels.form.addRule') }}
@@ -430,14 +430,14 @@
<!-- Filter rules for this platform's groups --> <!-- Filter rules for this platform's groups -->
<p <p
v-if="form.account_stats_pricing_rules.length === 0" v-if="section.account_stats_pricing_rules.length === 0"
class="text-xs italic text-gray-400 dark:text-gray-500" class="text-xs italic text-gray-400 dark:text-gray-500"
> >
{{ t('admin.channels.form.noRulesConfigured') }} {{ t('admin.channels.form.noRulesConfigured') }}
</p> </p>
<div <div
v-for="(rule, ruleIndex) in form.account_stats_pricing_rules" v-for="(rule, ruleIndex) in section.account_stats_pricing_rules"
:key="ruleIndex" :key="ruleIndex"
class="space-y-3 rounded-lg border border-gray-200 p-4 dark:border-dark-600" class="space-y-3 rounded-lg border border-gray-200 p-4 dark:border-dark-600"
> >
@@ -447,7 +447,7 @@
:placeholder="t('admin.channels.form.ruleName')" :placeholder="t('admin.channels.form.ruleName')"
class="bg-transparent text-sm font-medium text-gray-700 placeholder-gray-400 outline-none dark:text-gray-300" class="bg-transparent text-sm font-medium text-gray-700 placeholder-gray-400 outline-none dark:text-gray-300"
/> />
<button type="button" @click="removeAccountStatsRule(ruleIndex)" class="text-xs text-red-500 hover:text-red-700"> <button type="button" @click="removeAccountStatsRule(sIdx, ruleIndex)" class="text-xs text-red-500 hover:text-red-700">
{{ t('common.delete') }} {{ t('common.delete') }}
</button> </button>
</div> </div>
@@ -524,7 +524,7 @@
<div> <div>
<div class="mb-1 flex items-center justify-between"> <div class="mb-1 flex items-center justify-between">
<label class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.channels.form.ruleModelPricing') }}</label> <label class="text-xs text-gray-500 dark:text-gray-400">{{ t('admin.channels.form.ruleModelPricing') }}</label>
<button type="button" @click="addRulePricingEntry(ruleIndex)" class="text-xs text-primary-600 hover:text-primary-700"> <button type="button" @click="addRulePricingEntry(sIdx, ruleIndex)" class="text-xs text-primary-600 hover:text-primary-700">
+ {{ t('common.add') }} + {{ t('common.add') }}
</button> </button>
</div> </div>
@@ -538,7 +538,7 @@
:entry="entry" :entry="entry"
:platform="section.platform" :platform="section.platform"
@update="rule.pricing.splice(pIdx, 1, $event)" @update="rule.pricing.splice(pIdx, 1, $event)"
@remove="removeRulePricingEntry(ruleIndex, pIdx)" @remove="removeRulePricingEntry(sIdx, ruleIndex, pIdx)"
/> />
</div> </div>
</div> </div>
@@ -625,6 +625,14 @@ async function loadWebSearchGlobalState() {
} }
} }
// ── Form-level pricing rule type (per-platform) ──
interface FormPricingRule {
name: string
group_ids: number[]
account_ids: number[]
pricing: PricingFormEntry[]
}
// ── Platform Section type ── // ── Platform Section type ──
interface PlatformSection { interface PlatformSection {
platform: GroupPlatform platform: GroupPlatform
@@ -634,6 +642,7 @@ interface PlatformSection {
model_mapping: Record<string, string> model_mapping: Record<string, string>
model_pricing: PricingFormEntry[] model_pricing: PricingFormEntry[]
web_search_emulation: boolean web_search_emulation: boolean
account_stats_pricing_rules: FormPricingRule[]
} }
// ── Table columns ── // ── Table columns ──
@@ -703,12 +712,6 @@ const form = reactive({
billing_model_source: 'channel_mapped' as string, billing_model_source: 'channel_mapped' as string,
platforms: [] as PlatformSection[], platforms: [] as PlatformSection[],
apply_pricing_to_account_stats: false, apply_pricing_to_account_stats: false,
account_stats_pricing_rules: [] as Array<{
name: string
group_ids: number[]
account_ids: number[]
pricing: PricingFormEntry[]
}>
}) })
let abortController: AbortController | null = null let abortController: AbortController | null = null
@@ -754,6 +757,7 @@ function addPlatformSection(platform: GroupPlatform) {
model_mapping: {}, model_mapping: {},
model_pricing: [], model_pricing: [],
web_search_emulation: false, web_search_emulation: false,
account_stats_pricing_rules: [],
}) })
} }
@@ -867,8 +871,8 @@ function renameMappingKey(sectionIdx: number, oldKey: string, newKey: string) {
} }
// ── Account Stats Pricing helpers ── // ── Account Stats Pricing helpers ──
function addAccountStatsRule() { function addAccountStatsRule(sectionIdx: number) {
form.account_stats_pricing_rules.push({ form.platforms[sectionIdx].account_stats_pricing_rules.push({
name: '', name: '',
group_ids: [], group_ids: [],
account_ids: [], account_ids: [],
@@ -876,8 +880,8 @@ function addAccountStatsRule() {
}) })
} }
function addRulePricingEntry(ruleIndex: number) { function addRulePricingEntry(sectionIdx: number, ruleIndex: number) {
form.account_stats_pricing_rules[ruleIndex].pricing.push({ form.platforms[sectionIdx].account_stats_pricing_rules[ruleIndex].pricing.push({
models: [], models: [],
billing_mode: 'token', billing_mode: 'token',
input_price: null, input_price: null,
@@ -890,15 +894,15 @@ function addRulePricingEntry(ruleIndex: number) {
}) })
} }
function removeAccountStatsRule(ruleIndex: number) { function removeAccountStatsRule(sectionIdx: number, ruleIndex: number) {
form.account_stats_pricing_rules.splice(ruleIndex, 1) form.platforms[sectionIdx].account_stats_pricing_rules.splice(ruleIndex, 1)
// Clear all search state since indices shift after removal // Clear all search state since indices shift after removal
ruleAccountSearchRunner.clearAll() ruleAccountSearchRunner.clearAll()
clearAllRuleAccountSearchState() clearAllRuleAccountSearchState()
} }
function removeRulePricingEntry(ruleIndex: number, pricingIndex: number) { function removeRulePricingEntry(sectionIdx: number, ruleIndex: number, pricingIndex: number) {
form.account_stats_pricing_rules[ruleIndex].pricing.splice(pricingIndex, 1) form.platforms[sectionIdx].account_stats_pricing_rules[ruleIndex].pricing.splice(pricingIndex, 1)
} }
function getGroupNameById(groupId: number): string { function getGroupNameById(groupId: number): string {
@@ -980,38 +984,33 @@ function clearAllRuleAccountSearchState() {
showRuleAccountDropdown.value = {} showRuleAccountDropdown.value = {}
} }
function inferRulePlatform(groupIds: number[]): string {
const platforms = new Set<string>()
for (const gid of groupIds) {
const group = allGroups.value.find(g => g.id === gid)
if (group) platforms.add(group.platform)
}
return platforms.size === 1 ? [...platforms][0] : ''
}
function accountStatsRulesToAPI(): AccountStatsPricingRule[] { function accountStatsRulesToAPI(): AccountStatsPricingRule[] {
return form.account_stats_pricing_rules.map(rule => { const rules: AccountStatsPricingRule[] = []
const platform = inferRulePlatform(rule.group_ids) for (const section of form.platforms) {
return { if (!section.enabled) continue
name: rule.name, for (const rule of section.account_stats_pricing_rules) {
group_ids: rule.group_ids, rules.push({
account_ids: rule.account_ids, name: rule.name,
pricing: rule.pricing group_ids: rule.group_ids,
.filter(p => p.models.length > 0) account_ids: rule.account_ids,
.map(p => ({ pricing: rule.pricing
platform, .filter(p => p.models.length > 0)
models: p.models, .map(p => ({
billing_mode: p.billing_mode, platform: section.platform,
input_price: mTokToPerToken(p.input_price), models: p.models,
output_price: mTokToPerToken(p.output_price), billing_mode: p.billing_mode,
cache_write_price: mTokToPerToken(p.cache_write_price), input_price: mTokToPerToken(p.input_price),
cache_read_price: mTokToPerToken(p.cache_read_price), output_price: mTokToPerToken(p.output_price),
image_output_price: mTokToPerToken(p.image_output_price), cache_write_price: mTokToPerToken(p.cache_write_price),
per_request_price: p.per_request_price != null && p.per_request_price !== '' ? Number(p.per_request_price) : null, cache_read_price: mTokToPerToken(p.cache_read_price),
intervals: formIntervalsToAPI(p.intervals || []) image_output_price: mTokToPerToken(p.image_output_price),
})) per_request_price: p.per_request_price != null && p.per_request_price !== '' ? Number(p.per_request_price) : null,
intervals: formIntervalsToAPI(p.intervals || [])
}))
})
} }
}) }
return rules
} }
// ── Form ↔ API conversion ── // ── Form ↔ API conversion ──
@@ -1120,6 +1119,7 @@ function apiToForm(channel: Channel): PlatformSection[] {
model_mapping: { ...mapping }, model_mapping: { ...mapping },
model_pricing: pricing, model_pricing: pricing,
web_search_emulation: webSearchEnabled, web_search_emulation: webSearchEnabled,
account_stats_pricing_rules: [],
}) })
} }
@@ -1213,7 +1213,6 @@ function resetForm() {
form.billing_model_source = 'channel_mapped' form.billing_model_source = 'channel_mapped'
form.platforms = [] form.platforms = []
form.apply_pricing_to_account_stats = false form.apply_pricing_to_account_stats = false
form.account_stats_pricing_rules = []
activeTab.value = 'basic' activeTab.value = 'basic'
ruleAccountSearchRunner.clearAll() ruleAccountSearchRunner.clearAll()
clearAllRuleAccountSearchState() clearAllRuleAccountSearchState()
@@ -1235,28 +1234,91 @@ async function openEditDialog(channel: Channel) {
form.restrict_models = channel.restrict_models || false form.restrict_models = channel.restrict_models || false
form.billing_model_source = channel.billing_model_source || 'channel_mapped' form.billing_model_source = channel.billing_model_source || 'channel_mapped'
form.apply_pricing_to_account_stats = channel.apply_pricing_to_account_stats || false form.apply_pricing_to_account_stats = channel.apply_pricing_to_account_stats || false
form.account_stats_pricing_rules = (channel.account_stats_pricing_rules || []).map(rule => ({
name: rule.name || '',
group_ids: [...(rule.group_ids || [])],
account_ids: [...(rule.account_ids || [])],
pricing: (rule.pricing || []).map(p => ({
models: [...(p.models || [])],
billing_mode: p.billing_mode,
input_price: perTokenToMTok(p.input_price),
output_price: perTokenToMTok(p.output_price),
cache_write_price: perTokenToMTok(p.cache_write_price),
cache_read_price: perTokenToMTok(p.cache_read_price),
image_output_price: perTokenToMTok(p.image_output_price),
per_request_price: p.per_request_price,
intervals: apiIntervalsToForm(p.intervals || [])
} as PricingFormEntry))
}))
// Must load groups first so apiToForm can map groupID → platform // Must load groups first so apiToForm can map groupID → platform
await Promise.all([loadGroups(), loadAllChannelsForConflict()]) await Promise.all([loadGroups(), loadAllChannelsForConflict()])
form.platforms = apiToForm(channel) form.platforms = apiToForm(channel)
// Distribute channel-level rules into per-platform sections
distributeRulesToPlatforms(channel.account_stats_pricing_rules || [])
// Populate ruleAccountNameCache for existing rule accounts
await populateRuleAccountNameCache()
showDialog.value = true showDialog.value = true
} }
/** Distribute flat channel-level rules into the matching platform section based on group_ids */
function distributeRulesToPlatforms(apiRules: AccountStatsPricingRule[]) {
// Build groupID → platform lookup
const groupPlatformMap = new Map<number, GroupPlatform>()
for (const g of allGroups.value) {
groupPlatformMap.set(g.id, g.platform)
}
for (const apiRule of apiRules) {
// Infer platform from group_ids
const platforms = new Set<GroupPlatform>()
for (const gid of apiRule.group_ids || []) {
const p = groupPlatformMap.get(gid)
if (p) platforms.add(p)
}
// If pricing has a platform field, use that as fallback
if (platforms.size === 0 && apiRule.pricing?.length > 0) {
const p = apiRule.pricing[0].platform as GroupPlatform | undefined
if (p) platforms.add(p)
}
const targetPlatform = platforms.size >= 1 ? [...platforms][0] : null
if (!targetPlatform) continue
const section = form.platforms.find(s => s.platform === targetPlatform)
if (!section) continue
const formRule: FormPricingRule = {
name: apiRule.name || '',
group_ids: [...(apiRule.group_ids || [])],
account_ids: [...(apiRule.account_ids || [])],
pricing: (apiRule.pricing || []).map(p => ({
models: [...(p.models || [])],
billing_mode: p.billing_mode,
input_price: perTokenToMTok(p.input_price),
output_price: perTokenToMTok(p.output_price),
cache_write_price: perTokenToMTok(p.cache_write_price),
cache_read_price: perTokenToMTok(p.cache_read_price),
image_output_price: perTokenToMTok(p.image_output_price),
per_request_price: p.per_request_price,
intervals: apiIntervalsToForm(p.intervals || [])
} as PricingFormEntry))
}
section.account_stats_pricing_rules.push(formRule)
}
}
/** Populate ruleAccountNameCache by fetching account details for all account_ids in rules */
async function populateRuleAccountNameCache() {
const allAccountIds = new Set<number>()
for (const section of form.platforms) {
for (const rule of section.account_stats_pricing_rules) {
for (const id of rule.account_ids) {
allAccountIds.add(id)
}
}
}
if (allAccountIds.size === 0) return
// Fetch account details in parallel (batch of individual getById calls)
const ids = [...allAccountIds]
const results = await Promise.allSettled(
ids.map(id => adminAPI.accounts.getById(id))
)
for (let i = 0; i < ids.length; i++) {
const result = results[i]
if (result.status === 'fulfilled') {
ruleAccountNameCache.value[ids[i]] = result.value.name
}
// If rejected, the cache won't have the name, so it'll show "#ID" which is acceptable
}
}
function closeDialog() { function closeDialog() {
showDialog.value = false showDialog.value = false
editingChannel.value = null editingChannel.value = null