diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go index 9cefc792..2d4cd56a 100644 --- a/backend/internal/handler/admin/channel_handler.go +++ b/backend/internal/handler/admin/channel_handler.go @@ -26,28 +26,30 @@ func NewChannelHandler(channelService *service.ChannelService, billingService *s // --- Request / Response types --- type createChannelRequest struct { - Name string `json:"name" binding:"required,max=100"` - Description string `json:"description"` - GroupIDs []int64 `json:"group_ids"` - ModelPricing []channelModelPricingRequest `json:"model_pricing"` - ModelMapping map[string]map[string]string `json:"model_mapping"` - BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"` - RestrictModels bool `json:"restrict_models"` - Features string `json:"features"` - FeaturesConfig map[string]any `json:"features_config"` + Name string `json:"name" binding:"required,max=100"` + Description string `json:"description"` + GroupIDs []int64 `json:"group_ids"` + ModelPricing []channelModelPricingRequest `json:"model_pricing"` + ModelMapping map[string]map[string]string `json:"model_mapping"` + BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"` + RestrictModels bool `json:"restrict_models"` + Features string `json:"features"` + ApplyPricingToAccountStats bool `json:"apply_pricing_to_account_stats"` + AccountStatsPricingRules []accountStatsPricingRuleRequest `json:"account_stats_pricing_rules"` } type updateChannelRequest struct { - Name string `json:"name" binding:"omitempty,max=100"` - Description *string `json:"description"` - Status string `json:"status" binding:"omitempty,oneof=active disabled"` - GroupIDs *[]int64 `json:"group_ids"` - ModelPricing *[]channelModelPricingRequest `json:"model_pricing"` - ModelMapping map[string]map[string]string `json:"model_mapping"` - BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"` - RestrictModels *bool `json:"restrict_models"` - Features *string `json:"features"` - FeaturesConfig map[string]any `json:"features_config"` + Name string `json:"name" binding:"omitempty,max=100"` + Description *string `json:"description"` + Status string `json:"status" binding:"omitempty,oneof=active disabled"` + GroupIDs *[]int64 `json:"group_ids"` + ModelPricing *[]channelModelPricingRequest `json:"model_pricing"` + ModelMapping map[string]map[string]string `json:"model_mapping"` + BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"` + RestrictModels *bool `json:"restrict_models"` + Features *string `json:"features"` + ApplyPricingToAccountStats *bool `json:"apply_pricing_to_account_stats"` + AccountStatsPricingRules *[]accountStatsPricingRuleRequest `json:"account_stats_pricing_rules"` } type channelModelPricingRequest struct { @@ -75,20 +77,28 @@ type pricingIntervalRequest struct { SortOrder int `json:"sort_order"` } +type accountStatsPricingRuleRequest struct { + Name string `json:"name"` + GroupIDs []int64 `json:"group_ids"` + AccountIDs []int64 `json:"account_ids"` + Pricing []channelModelPricingRequest `json:"pricing"` +} + type channelResponse struct { - ID int64 `json:"id"` - Name string `json:"name"` - Description string `json:"description"` - Status string `json:"status"` - BillingModelSource string `json:"billing_model_source"` - RestrictModels bool `json:"restrict_models"` - Features string `json:"features"` - FeaturesConfig map[string]any `json:"features_config"` - GroupIDs []int64 `json:"group_ids"` - ModelPricing []channelModelPricingResponse `json:"model_pricing"` - ModelMapping map[string]map[string]string `json:"model_mapping"` - CreatedAt string `json:"created_at"` - UpdatedAt string `json:"updated_at"` + ID int64 `json:"id"` + Name string `json:"name"` + Description string `json:"description"` + Status string `json:"status"` + BillingModelSource string `json:"billing_model_source"` + RestrictModels bool `json:"restrict_models"` + Features string `json:"features"` + GroupIDs []int64 `json:"group_ids"` + ModelPricing []channelModelPricingResponse `json:"model_pricing"` + ModelMapping map[string]map[string]string `json:"model_mapping"` + ApplyPricingToAccountStats bool `json:"apply_pricing_to_account_stats"` + AccountStatsPricingRules []accountStatsPricingRuleResponse `json:"account_stats_pricing_rules"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` } type channelModelPricingResponse struct { @@ -118,6 +128,14 @@ type pricingIntervalResponse struct { SortOrder int `json:"sort_order"` } +type accountStatsPricingRuleResponse struct { + ID int64 `json:"id"` + Name string `json:"name"` + GroupIDs []int64 `json:"group_ids"` + AccountIDs []int64 `json:"account_ids"` + Pricing []channelModelPricingResponse `json:"pricing"` +} + func channelToResponse(ch *service.Channel) *channelResponse { if ch == nil { return nil @@ -129,7 +147,6 @@ func channelToResponse(ch *service.Channel) *channelResponse { Status: ch.Status, RestrictModels: ch.RestrictModels, Features: ch.Features, - FeaturesConfig: ch.FeaturesConfig, GroupIDs: ch.GroupIDs, ModelMapping: ch.ModelMapping, CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"), @@ -150,6 +167,29 @@ func channelToResponse(ch *service.Channel) *channelResponse { for _, p := range ch.ModelPricing { resp.ModelPricing = append(resp.ModelPricing, pricingToResponse(&p)) } + + resp.ApplyPricingToAccountStats = ch.ApplyPricingToAccountStats + resp.AccountStatsPricingRules = make([]accountStatsPricingRuleResponse, 0, len(ch.AccountStatsPricingRules)) + for _, rule := range ch.AccountStatsPricingRules { + ruleResp := accountStatsPricingRuleResponse{ + ID: rule.ID, + Name: rule.Name, + GroupIDs: rule.GroupIDs, + AccountIDs: rule.AccountIDs, + Pricing: make([]channelModelPricingResponse, 0, len(rule.Pricing)), + } + if ruleResp.GroupIDs == nil { + ruleResp.GroupIDs = []int64{} + } + if ruleResp.AccountIDs == nil { + ruleResp.AccountIDs = []int64{} + } + for i := range rule.Pricing { + ruleResp.Pricing = append(ruleResp.Pricing, pricingToResponse(&rule.Pricing[i])) + } + resp.AccountStatsPricingRules = append(resp.AccountStatsPricingRules, ruleResp) + } + return resp } @@ -241,6 +281,15 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe return result } +func accountStatsPricingRuleRequestToService(r accountStatsPricingRuleRequest) service.AccountStatsPricingRule { + return service.AccountStatsPricingRule{ + Name: r.Name, + GroupIDs: r.GroupIDs, + AccountIDs: r.AccountIDs, + Pricing: pricingRequestToService(r.Pricing), + } +} + // --- Handlers --- // List handles listing channels with pagination @@ -300,16 +349,24 @@ func (h *ChannelHandler) Create(c *gin.Context) { pricing := pricingRequestToService(req.ModelPricing) + var statsRules []service.AccountStatsPricingRule + for i, r := range req.AccountStatsPricingRules { + rule := accountStatsPricingRuleRequestToService(r) + rule.SortOrder = i + statsRules = append(statsRules, rule) + } + channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{ - Name: req.Name, - Description: req.Description, - GroupIDs: req.GroupIDs, - ModelPricing: pricing, - ModelMapping: req.ModelMapping, - BillingModelSource: req.BillingModelSource, - RestrictModels: req.RestrictModels, - Features: req.Features, - FeaturesConfig: req.FeaturesConfig, + Name: req.Name, + Description: req.Description, + GroupIDs: req.GroupIDs, + ModelPricing: pricing, + ModelMapping: req.ModelMapping, + BillingModelSource: req.BillingModelSource, + RestrictModels: req.RestrictModels, + Features: req.Features, + ApplyPricingToAccountStats: req.ApplyPricingToAccountStats, + AccountStatsPricingRules: statsRules, }) if err != nil { response.ErrorFrom(c, err) @@ -335,20 +392,29 @@ func (h *ChannelHandler) Update(c *gin.Context) { } input := &service.UpdateChannelInput{ - Name: req.Name, - Description: req.Description, - Status: req.Status, - GroupIDs: req.GroupIDs, - ModelMapping: req.ModelMapping, - BillingModelSource: req.BillingModelSource, - RestrictModels: req.RestrictModels, - Features: req.Features, - FeaturesConfig: req.FeaturesConfig, + Name: req.Name, + Description: req.Description, + Status: req.Status, + GroupIDs: req.GroupIDs, + ModelMapping: req.ModelMapping, + BillingModelSource: req.BillingModelSource, + RestrictModels: req.RestrictModels, + Features: req.Features, + ApplyPricingToAccountStats: req.ApplyPricingToAccountStats, } if req.ModelPricing != nil { pricing := pricingRequestToService(*req.ModelPricing) input.ModelPricing = &pricing } + if req.AccountStatsPricingRules != nil { + statsRules := make([]service.AccountStatsPricingRule, 0, len(*req.AccountStatsPricingRules)) + for i, r := range *req.AccountStatsPricingRules { + rule := accountStatsPricingRuleRequestToService(r) + rule.SortOrder = i + statsRules = append(statsRules, rule) + } + input.AccountStatsPricingRules = &statsRules + } channel, err := h.channelService.Update(c.Request.Context(), id, input) if err != nil { diff --git a/backend/internal/repository/channel_repo.go b/backend/internal/repository/channel_repo.go index 56b5cc71..583ce895 100644 --- a/backend/internal/repository/channel_repo.go +++ b/backend/internal/repository/channel_repo.go @@ -41,14 +41,10 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel if err != nil { return err } - featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig) - if err != nil { - return err - } err = tx.QueryRowContext(ctx, - `INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + `INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features, apply_pricing_to_account_stats) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id, created_at, updated_at`, - channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON, + channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, channel.ApplyPricingToAccountStats, ).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt) if err != nil { if isUniqueViolation(err) { @@ -71,17 +67,24 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel } } + // 设置账号统计定价规则 + if len(channel.AccountStatsPricingRules) > 0 { + if err := replaceAccountStatsPricingRulesTx(ctx, tx, channel.ID, channel.AccountStatsPricingRules); err != nil { + return err + } + } + return nil }) } func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) { ch := &service.Channel{} - var modelMappingJSON, featuresConfigJSON []byte + var modelMappingJSON []byte err := r.db.QueryRowContext(ctx, - `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, created_at, updated_at + `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, apply_pricing_to_account_stats, created_at, updated_at FROM channels WHERE id = $1`, id, - ).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.CreatedAt, &ch.UpdatedAt) + ).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt) if err == sql.ErrNoRows { return nil, service.ErrChannelNotFound } @@ -89,7 +92,6 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha return nil, fmt.Errorf("get channel: %w", err) } ch.ModelMapping = unmarshalModelMapping(modelMappingJSON) - ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON) groupIDs, err := r.GetGroupIDs(ctx, id) if err != nil { @@ -103,6 +105,12 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha } ch.ModelPricing = pricing + statsPricingRules, err := r.loadAccountStatsPricingRules(ctx, id) + if err != nil { + return nil, err + } + ch.AccountStatsPricingRules = statsPricingRules + return ch, nil } @@ -112,14 +120,10 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel if err != nil { return err } - featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig) - if err != nil { - return err - } result, err := tx.ExecContext(ctx, - `UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, features_config = $8, updated_at = NOW() + `UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, apply_pricing_to_account_stats = $8, updated_at = NOW() WHERE id = $9`, - channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON, channel.ID, + channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, channel.ApplyPricingToAccountStats, channel.ID, ) if err != nil { if isUniqueViolation(err) { @@ -146,6 +150,13 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel } } + // 更新账号统计定价规则 + if channel.AccountStatsPricingRules != nil { + if err := replaceAccountStatsPricingRulesTx(ctx, tx, channel.ID, channel.AccountStatsPricingRules); err != nil { + return err + } + } + return nil }) } @@ -196,7 +207,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati // 查询 channel 列表 dataQuery := fmt.Sprintf( - `SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.features_config, c.created_at, c.updated_at + `SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.apply_pricing_to_account_stats, c.created_at, c.updated_at FROM channels c WHERE %s ORDER BY %s LIMIT $%d OFFSET $%d`, whereClause, channelListOrderBy(params), argIdx, argIdx+1, ) @@ -212,12 +223,11 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati var channelIDs []int64 for rows.Next() { var ch service.Channel - var modelMappingJSON, featuresConfigJSON []byte - if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.CreatedAt, &ch.UpdatedAt); err != nil { + var modelMappingJSON []byte + if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil { return nil, nil, fmt.Errorf("scan channel: %w", err) } ch.ModelMapping = unmarshalModelMapping(modelMappingJSON) - ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON) channels = append(channels, ch) channelIDs = append(channelIDs, ch.ID) } @@ -235,9 +245,14 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati if err != nil { return nil, nil, err } + statsRulesMap, err := r.batchLoadAccountStatsPricingRules(ctx, channelIDs) + if err != nil { + return nil, nil, err + } for i := range channels { channels[i].GroupIDs = groupMap[channels[i].ID] channels[i].ModelPricing = pricingMap[channels[i].ID] + channels[i].AccountStatsPricingRules = statsRulesMap[channels[i].ID] } } @@ -283,7 +298,7 @@ func channelListOrderBy(params pagination.PaginationParams) string { func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) { rows, err := r.db.QueryContext(ctx, - `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, created_at, updated_at FROM channels ORDER BY id`, + `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, apply_pricing_to_account_stats, created_at, updated_at FROM channels ORDER BY id`, ) if err != nil { return nil, fmt.Errorf("query all channels: %w", err) @@ -294,12 +309,11 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err var channelIDs []int64 for rows.Next() { var ch service.Channel - var modelMappingJSON, featuresConfigJSON []byte - if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.CreatedAt, &ch.UpdatedAt); err != nil { + var modelMappingJSON []byte + if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil { return nil, fmt.Errorf("scan channel: %w", err) } ch.ModelMapping = unmarshalModelMapping(modelMappingJSON) - ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON) channels = append(channels, ch) channelIDs = append(channelIDs, ch.ID) } @@ -323,9 +337,16 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err return nil, err } + // 批量加载账号统计定价规则 + statsRulesMap, err := r.batchLoadAccountStatsPricingRules(ctx, channelIDs) + if err != nil { + return nil, err + } + for i := range channels { channels[i].GroupIDs = groupMap[channels[i].ID] channels[i].ModelPricing = pricingMap[channels[i].ID] + channels[i].AccountStatsPricingRules = statsRulesMap[channels[i].ID] } return channels, nil @@ -467,28 +488,6 @@ func unmarshalModelMapping(data []byte) map[string]map[string]string { return m } -func marshalFeaturesConfig(m map[string]any) ([]byte, error) { - if len(m) == 0 { - return []byte("{}"), nil - } - data, err := json.Marshal(m) - if err != nil { - return nil, fmt.Errorf("marshal features_config: %w", err) - } - return data, nil -} - -func unmarshalFeaturesConfig(data []byte) map[string]any { - if len(data) == 0 { - return nil - } - var m map[string]any - if err := json.Unmarshal(data, &m); err != nil { - return nil - } - return m -} - // GetGroupPlatforms 批量查询分组 ID 对应的平台 func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) { if len(groupIDs) == 0 { diff --git a/backend/internal/repository/channel_repo_account_stats_pricing.go b/backend/internal/repository/channel_repo_account_stats_pricing.go new file mode 100644 index 00000000..ef8f5177 --- /dev/null +++ b/backend/internal/repository/channel_repo_account_stats_pricing.go @@ -0,0 +1,170 @@ +package repository + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/lib/pq" +) + +// --- 账号统计定价规则 --- + +// batchLoadAccountStatsPricingRules 批量加载多个渠道的账号统计定价规则(含模型定价) +func (r *channelRepository) batchLoadAccountStatsPricingRules(ctx context.Context, channelIDs []int64) (map[int64][]service.AccountStatsPricingRule, error) { + // 1. 查询规则 + rows, err := r.db.QueryContext(ctx, + `SELECT id, channel_id, name, group_ids, account_ids, sort_order, created_at, updated_at + FROM channel_account_stats_pricing_rules WHERE channel_id = ANY($1) ORDER BY channel_id, sort_order, id`, + pq.Array(channelIDs), + ) + if err != nil { + return nil, fmt.Errorf("batch load account stats pricing rules: %w", err) + } + defer func() { _ = rows.Close() }() + + var allRules []service.AccountStatsPricingRule + var ruleIDs []int64 + for rows.Next() { + var rule service.AccountStatsPricingRule + if err := rows.Scan( + &rule.ID, &rule.ChannelID, &rule.Name, + pq.Array(&rule.GroupIDs), pq.Array(&rule.AccountIDs), + &rule.SortOrder, &rule.CreatedAt, &rule.UpdatedAt, + ); err != nil { + return nil, fmt.Errorf("scan account stats pricing rule: %w", err) + } + ruleIDs = append(ruleIDs, rule.ID) + allRules = append(allRules, rule) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate account stats pricing rules: %w", err) + } + + // 2. 批量加载规则的模型定价 + pricingMap, err := r.batchLoadAccountStatsModelPricing(ctx, ruleIDs) + if err != nil { + return nil, err + } + + // 3. 按 channelID 分组并关联定价 + result := make(map[int64][]service.AccountStatsPricingRule, len(channelIDs)) + for i := range allRules { + allRules[i].Pricing = pricingMap[allRules[i].ID] + result[allRules[i].ChannelID] = append(result[allRules[i].ChannelID], allRules[i]) + } + + return result, nil +} + +// batchLoadAccountStatsModelPricing 批量加载规则的模型定价 +func (r *channelRepository) batchLoadAccountStatsModelPricing(ctx context.Context, ruleIDs []int64) (map[int64][]service.ChannelModelPricing, error) { + if len(ruleIDs) == 0 { + return make(map[int64][]service.ChannelModelPricing), nil + } + + rows, err := r.db.QueryContext(ctx, + `SELECT id, rule_id, platform, models, billing_mode, input_price, output_price, + cache_write_price, cache_read_price, image_output_price, per_request_price, created_at, updated_at + FROM channel_account_stats_model_pricing WHERE rule_id = ANY($1) ORDER BY rule_id, id`, + pq.Array(ruleIDs), + ) + if err != nil { + return nil, fmt.Errorf("batch load account stats model pricing: %w", err) + } + defer func() { _ = rows.Close() }() + + pricingMap := make(map[int64][]service.ChannelModelPricing, len(ruleIDs)) + for rows.Next() { + var p service.ChannelModelPricing + var ruleID int64 + var modelsJSON []byte + if err := rows.Scan( + &p.ID, &ruleID, &p.Platform, &modelsJSON, &p.BillingMode, + &p.InputPrice, &p.OutputPrice, &p.CacheWritePrice, &p.CacheReadPrice, + &p.ImageOutputPrice, &p.PerRequestPrice, &p.CreatedAt, &p.UpdatedAt, + ); err != nil { + return nil, fmt.Errorf("scan account stats model pricing: %w", err) + } + if err := json.Unmarshal(modelsJSON, &p.Models); err != nil { + p.Models = []string{} + } + pricingMap[ruleID] = append(pricingMap[ruleID], p) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("iterate account stats model pricing: %w", err) + } + return pricingMap, nil +} + +// loadAccountStatsPricingRules 加载单个渠道的账号统计定价规则(供 GetByID 使用) +func (r *channelRepository) loadAccountStatsPricingRules(ctx context.Context, channelID int64) ([]service.AccountStatsPricingRule, error) { + result, err := r.batchLoadAccountStatsPricingRules(ctx, []int64{channelID}) + if err != nil { + return nil, err + } + return result[channelID], nil +} + +// replaceAccountStatsPricingRulesTx 在事务中替换渠道的账号统计定价规则(删除旧的 + 插入新的) +func replaceAccountStatsPricingRulesTx(ctx context.Context, tx *sql.Tx, channelID int64, rules []service.AccountStatsPricingRule) error { + // CASCADE 会自动删除关联的 model_pricing + if _, err := tx.ExecContext(ctx, + `DELETE FROM channel_account_stats_pricing_rules WHERE channel_id = $1`, channelID, + ); err != nil { + return fmt.Errorf("delete old account stats pricing rules: %w", err) + } + + for i := range rules { + rules[i].ChannelID = channelID + if err := createAccountStatsPricingRuleTx(ctx, tx, &rules[i]); err != nil { + return fmt.Errorf("insert account stats pricing rule: %w", err) + } + } + return nil +} + +// createAccountStatsPricingRuleTx 在事务中创建单条账号统计定价规则及其模型定价 +func createAccountStatsPricingRuleTx(ctx context.Context, tx *sql.Tx, rule *service.AccountStatsPricingRule) error { + err := tx.QueryRowContext(ctx, + `INSERT INTO channel_account_stats_pricing_rules (channel_id, name, group_ids, account_ids, sort_order) + VALUES ($1, $2, $3, $4, $5) RETURNING id, created_at, updated_at`, + rule.ChannelID, rule.Name, pq.Array(rule.GroupIDs), pq.Array(rule.AccountIDs), rule.SortOrder, + ).Scan(&rule.ID, &rule.CreatedAt, &rule.UpdatedAt) + if err != nil { + return fmt.Errorf("insert account stats pricing rule: %w", err) + } + + for j := range rule.Pricing { + if err := createAccountStatsModelPricingTx(ctx, tx, rule.ID, &rule.Pricing[j]); err != nil { + return err + } + } + return nil +} + +// createAccountStatsModelPricingTx 在事务中创建单条账号统计模型定价 +func createAccountStatsModelPricingTx(ctx context.Context, tx *sql.Tx, ruleID int64, pricing *service.ChannelModelPricing) error { + modelsJSON, err := json.Marshal(pricing.Models) + if err != nil { + return fmt.Errorf("marshal models: %w", err) + } + billingMode := pricing.BillingMode + if billingMode == "" { + billingMode = service.BillingModeToken + } + platform := pricing.Platform + err = tx.QueryRowContext(ctx, + `INSERT INTO channel_account_stats_model_pricing (rule_id, platform, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, per_request_price) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`, + ruleID, platform, modelsJSON, billingMode, + pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice, + pricing.ImageOutputPrice, pricing.PerRequestPrice, + ).Scan(&pricing.ID, &pricing.CreatedAt, &pricing.UpdatedAt) + if err != nil { + return fmt.Errorf("insert account stats model pricing: %w", err) + } + return nil +} diff --git a/backend/internal/repository/usage_log_repo.go b/backend/internal/repository/usage_log_repo.go index 3ba2191e..f942a8e1 100644 --- a/backend/internal/repository/usage_log_repo.go +++ b/backend/internal/repository/usage_log_repo.go @@ -28,7 +28,7 @@ import ( gocache "github.com/patrickmn/go-cache" ) -const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, created_at" +const usageLogSelectColumns = "id, user_id, api_key_id, account_id, request_id, model, requested_model, upstream_model, group_id, subscription_id, input_tokens, output_tokens, cache_creation_tokens, cache_read_tokens, cache_creation_5m_tokens, cache_creation_1h_tokens, image_output_tokens, image_output_cost, input_cost, output_cost, cache_creation_cost, cache_read_cost, total_cost, actual_cost, rate_multiplier, account_rate_multiplier, billing_type, request_type, stream, openai_ws_mode, duration_ms, first_token_ms, user_agent, ip_address, image_count, image_size, service_tier, reasoning_effort, inbound_endpoint, upstream_endpoint, cache_ttl_overridden, channel_id, model_mapping_chain, billing_tier, billing_mode, account_stats_cost, created_at" // usageLogInsertArgTypes must stay in the same order as: // 1. prepareUsageLogInsert().args @@ -82,6 +82,7 @@ var usageLogInsertArgTypes = [...]string{ "text", // model_mapping_chain "text", // billing_tier "text", // billing_mode + "numeric", // account_stats_cost "timestamptz", // created_at } @@ -360,6 +361,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, model_mapping_chain, billing_tier, billing_mode, + account_stats_cost, created_at ) VALUES ( $1, $2, $3, $4, $5, $6, $7, @@ -367,7 +369,7 @@ func (r *usageLogRepository) createSingle(ctx context.Context, sqlq sqlExecutor, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, - $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45 + $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46 ) ON CONFLICT (request_id, api_key_id) DO NOTHING RETURNING id, created_at @@ -797,6 +799,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage model_mapping_chain, billing_tier, billing_mode, + account_stats_cost, created_at ) AS (VALUES `) @@ -873,6 +876,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage model_mapping_chain, billing_tier, billing_mode, + account_stats_cost, created_at ) SELECT @@ -920,6 +924,7 @@ func buildUsageLogBatchInsertQuery(keys []string, preparedByKey map[string]usage model_mapping_chain, billing_tier, billing_mode, + account_stats_cost, created_at FROM input ON CONFLICT (request_id, api_key_id) DO NOTHING @@ -1007,10 +1012,11 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( model_mapping_chain, billing_tier, billing_mode, + account_stats_cost, created_at ) AS (VALUES `) - args := make([]any, 0, len(preparedList)*45) + args := make([]any, 0, len(preparedList)*46) argPos := 1 for idx, prepared := range preparedList { if idx > 0 { @@ -1080,6 +1086,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( model_mapping_chain, billing_tier, billing_mode, + account_stats_cost, created_at ) SELECT @@ -1127,6 +1134,7 @@ func buildUsageLogBestEffortInsertQuery(preparedList []usageLogInsertPrepared) ( model_mapping_chain, billing_tier, billing_mode, + account_stats_cost, created_at FROM input ON CONFLICT (request_id, api_key_id) DO NOTHING @@ -1182,6 +1190,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared model_mapping_chain, billing_tier, billing_mode, + account_stats_cost, created_at ) VALUES ( $1, $2, $3, $4, $5, $6, $7, @@ -1189,7 +1198,7 @@ func execUsageLogInsertNoResult(ctx context.Context, sqlq sqlExecutor, prepared $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, - $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45 + $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41, $42, $43, $44, $45, $46 ) ON CONFLICT (request_id, api_key_id) DO NOTHING `, prepared.args...) @@ -1285,6 +1294,7 @@ func prepareUsageLogInsert(log *service.UsageLog) usageLogInsertPrepared { modelMappingChain, billingTier, billingMode, + log.AccountStatsCost, // account_stats_cost createdAt, }, } @@ -1959,7 +1969,7 @@ func (r *usageLogRepository) GetAccountTodayStats(ctx context.Context, accountID SELECT COUNT(*) as requests, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens, - COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost, + COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost, COALESCE(SUM(total_cost), 0) as standard_cost, COALESCE(SUM(actual_cost), 0) as user_cost FROM usage_logs @@ -1989,7 +1999,7 @@ func (r *usageLogRepository) GetAccountWindowStats(ctx context.Context, accountI SELECT COUNT(*) as requests, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens, - COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost, + COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost, COALESCE(SUM(total_cost), 0) as standard_cost, COALESCE(SUM(actual_cost), 0) as user_cost FROM usage_logs @@ -2026,7 +2036,7 @@ func (r *usageLogRepository) GetAccountWindowStatsBatch(ctx context.Context, acc account_id, COUNT(*) as requests, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens, - COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as cost, + COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as cost, COALESCE(SUM(total_cost), 0) as standard_cost, COALESCE(SUM(actual_cost), 0) as user_cost FROM usage_logs @@ -2990,7 +3000,7 @@ func (r *usageLogRepository) getModelStatsWithFiltersBySource(ctx context.Contex actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost" // 当仅按 account_id 聚合时,实际费用使用账号倍率(total_cost * account_rate_multiplier)。 if accountID > 0 && userID == 0 && apiKeyID == 0 { - actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost" + actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost" } modelExpr := resolveModelDimensionExpression(source) @@ -3358,7 +3368,7 @@ func (r *usageLogRepository) GetStatsWithFilters(ctx context.Context, filters Us COALESCE(SUM(cache_creation_tokens + cache_read_tokens), 0) as total_cache_tokens, COALESCE(SUM(total_cost), 0) as total_cost, COALESCE(SUM(actual_cost), 0) as total_actual_cost, - COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost, + COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as total_account_cost, COALESCE(AVG(duration_ms), 0) as avg_duration_ms FROM usage_logs %s @@ -3433,7 +3443,7 @@ type EndpointStat = usagestats.EndpointStat func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Context, endpointColumn string, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) { actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost" if accountID > 0 && userID == 0 && apiKeyID == 0 { - actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost" + actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost" } query := fmt.Sprintf(` @@ -3500,7 +3510,7 @@ func (r *usageLogRepository) getEndpointStatsByColumnWithFilters(ctx context.Con func (r *usageLogRepository) getEndpointPathStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) (results []EndpointStat, err error) { actualCostExpr := "COALESCE(SUM(actual_cost), 0) as actual_cost" if accountID > 0 && userID == 0 && apiKeyID == 0 { - actualCostExpr = "COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost" + actualCostExpr = "COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost" } query := fmt.Sprintf(` @@ -3591,7 +3601,7 @@ func (r *usageLogRepository) GetAccountUsageStats(ctx context.Context, accountID COUNT(*) as requests, COALESCE(SUM(input_tokens + output_tokens + cache_creation_tokens + cache_read_tokens), 0) as tokens, COALESCE(SUM(total_cost), 0) as cost, - COALESCE(SUM(total_cost * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost, + COALESCE(SUM(COALESCE(account_stats_cost, total_cost) * COALESCE(account_rate_multiplier, 1)), 0) as actual_cost, COALESCE(SUM(actual_cost), 0) as user_cost FROM usage_logs WHERE account_id = $1 AND created_at >= $2 AND created_at < $3 @@ -4069,6 +4079,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e modelMappingChain sql.NullString billingTier sql.NullString billingMode sql.NullString + accountStatsCost sql.NullFloat64 createdAt time.Time ) @@ -4118,6 +4129,7 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e &modelMappingChain, &billingTier, &billingMode, + &accountStatsCost, &createdAt, ); err != nil { return nil, err @@ -4214,6 +4226,9 @@ func scanUsageLog(scanner interface{ Scan(...any) error }) (*service.UsageLog, e if billingMode.Valid { log.BillingMode = &billingMode.String } + if accountStatsCost.Valid { + log.AccountStatsCost = &accountStatsCost.Float64 + } return log, nil } diff --git a/backend/internal/repository/usage_log_repo_request_type_test.go b/backend/internal/repository/usage_log_repo_request_type_test.go index b9cb6a13..acdd6e62 100644 --- a/backend/internal/repository/usage_log_repo_request_type_test.go +++ b/backend/internal/repository/usage_log_repo_request_type_test.go @@ -85,6 +85,7 @@ func TestUsageLogRepositoryCreateSyncRequestTypeAndLegacyFields(t *testing.T) { sqlmock.AnyArg(), // model_mapping_chain sqlmock.AnyArg(), // billing_tier sqlmock.AnyArg(), // billing_mode + sqlmock.AnyArg(), // account_stats_cost createdAt, ). WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(99), createdAt)) @@ -163,6 +164,7 @@ func TestUsageLogRepositoryCreate_PersistsServiceTier(t *testing.T) { sqlmock.AnyArg(), // model_mapping_chain sqlmock.AnyArg(), // billing_tier sqlmock.AnyArg(), // billing_mode + sqlmock.AnyArg(), // account_stats_cost createdAt, ). WillReturnRows(sqlmock.NewRows([]string{"id", "created_at"}).AddRow(int64(100), createdAt)) @@ -483,10 +485,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { sql.NullString{}, sql.NullString{}, false, - sql.NullInt64{}, // channel_id - sql.NullString{}, // model_mapping_chain - sql.NullString{}, // billing_tier - sql.NullString{}, // billing_mode + sql.NullInt64{}, // channel_id + sql.NullString{}, // model_mapping_chain + sql.NullString{}, // billing_tier + sql.NullString{}, // billing_mode + sql.NullFloat64{}, // account_stats_cost now, }}) require.NoError(t, err) @@ -530,10 +533,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { sql.NullString{}, sql.NullString{}, false, - sql.NullInt64{}, // channel_id - sql.NullString{}, // model_mapping_chain - sql.NullString{}, // billing_tier - sql.NullString{}, // billing_mode + sql.NullInt64{}, // channel_id + sql.NullString{}, // model_mapping_chain + sql.NullString{}, // billing_tier + sql.NullString{}, // billing_mode + sql.NullFloat64{}, // account_stats_cost now, }}) require.NoError(t, err) @@ -577,10 +581,11 @@ func TestScanUsageLogRequestTypeAndLegacyFallback(t *testing.T) { sql.NullString{}, sql.NullString{}, false, - sql.NullInt64{}, // channel_id - sql.NullString{}, // model_mapping_chain - sql.NullString{}, // billing_tier - sql.NullString{}, // billing_mode + sql.NullInt64{}, // channel_id + sql.NullString{}, // model_mapping_chain + sql.NullString{}, // billing_tier + sql.NullString{}, // billing_mode + sql.NullFloat64{}, // account_stats_cost now, }}) require.NoError(t, err) diff --git a/backend/internal/service/account_stats_pricing.go b/backend/internal/service/account_stats_pricing.go new file mode 100644 index 00000000..86f98a12 --- /dev/null +++ b/backend/internal/service/account_stats_pricing.go @@ -0,0 +1,192 @@ +package service + +import ( + "context" + "sort" + "strings" +) + +// resolveAccountStatsCost 计算账号统计定价费用。 +// 返回 nil 表示不覆盖,使用默认公式(total_cost × account_rate_multiplier)。 +// +// 匹配优先级(先命中为准): +// 1. 自定义规则(AccountStatsPricingRules,按数组顺序遍历) +// 2. 渠道已有的模型定价(ApplyPricingToAccountStats 开启时) +// 3. nil → 走默认公式 +func resolveAccountStatsCost( + ctx context.Context, + channelService *ChannelService, + billingService *BillingService, + accountID int64, + groupID int64, + billingModel string, + tokens UsageTokens, + requestCount int, + serviceTier string, +) *float64 { + if channelService == nil || billingService == nil { + return nil + } + channel, err := channelService.GetChannelForGroup(ctx, groupID) + if err != nil || channel == nil || !channel.ApplyPricingToAccountStats { + return nil + } + + platform := channelService.GetGroupPlatform(ctx, groupID) + modelLower := strings.ToLower(billingModel) + + // 优先级 1:自定义规则 + if cost := tryCustomRules(channel, accountID, groupID, platform, modelLower, tokens, requestCount); cost != nil { + return cost + } + + // 优先级 2:渠道已有模型定价 + return tryChannelPricing(ctx, channelService, groupID, billingModel, tokens, requestCount) +} + +// tryCustomRules 遍历自定义规则,按数组顺序先命中为准。 +func tryCustomRules( + channel *Channel, accountID, groupID int64, + platform, modelLower string, tokens UsageTokens, requestCount int, +) *float64 { + for _, rule := range channel.AccountStatsPricingRules { + if !matchAccountStatsRule(&rule, accountID, groupID) { + continue + } + pricing := findPricingForModel(rule.Pricing, platform, modelLower) + if pricing == nil { + continue // 规则匹配但模型不在规则定价中,继续下一条 + } + return calculateStatsCost(pricing, tokens, requestCount) + } + return nil +} + +// tryChannelPricing 使用渠道已有的模型定价计算账号统计费用。 +func tryChannelPricing( + ctx context.Context, channelService *ChannelService, + groupID int64, billingModel string, tokens UsageTokens, requestCount int, +) *float64 { + pricing := channelService.GetChannelModelPricing(ctx, groupID, billingModel) + if pricing == nil { + return nil + } + return calculateStatsCost(pricing, tokens, requestCount) +} + +// matchAccountStatsRule 检查规则是否匹配指定的 accountID 和 groupID。 +// 匹配条件:accountID ∈ rule.AccountIDs 或 groupID ∈ rule.GroupIDs。 +// 如果规则的 AccountIDs 和 GroupIDs 都为空,视为不匹配。 +func matchAccountStatsRule(rule *AccountStatsPricingRule, accountID, groupID int64) bool { + if len(rule.AccountIDs) == 0 && len(rule.GroupIDs) == 0 { + return false + } + for _, id := range rule.AccountIDs { + if id == accountID { + return true + } + } + for _, id := range rule.GroupIDs { + if id == groupID { + return true + } + } + return false +} + +// wildcardMatch 通配符匹配候选项(用于排序) +type wildcardMatch struct { + prefixLen int + pricing *ChannelModelPricing +} + +// findPricingForModel 在定价列表中查找匹配的模型定价。 +// 先精确匹配,再通配符匹配(前缀越长优先级越高)。 +func findPricingForModel(pricingList []ChannelModelPricing, platform, modelLower string) *ChannelModelPricing { + // 精确匹配优先 + for i := range pricingList { + p := &pricingList[i] + if !isPlatformMatch(platform, p.Platform) { + continue + } + for _, m := range p.Models { + if strings.ToLower(m) == modelLower { + return p + } + } + } + // 通配符匹配:收集所有匹配项,按前缀长度降序取最长 + var matches []wildcardMatch + for i := range pricingList { + p := &pricingList[i] + if !isPlatformMatch(platform, p.Platform) { + continue + } + for _, m := range p.Models { + ml := strings.ToLower(m) + if !strings.HasSuffix(ml, "*") { + continue + } + prefix := strings.TrimSuffix(ml, "*") + if strings.HasPrefix(modelLower, prefix) { + matches = append(matches, wildcardMatch{prefixLen: len(prefix), pricing: p}) + } + } + } + if len(matches) == 0 { + return nil + } + sort.Slice(matches, func(i, j int) bool { + return matches[i].prefixLen > matches[j].prefixLen + }) + return matches[0].pricing +} + +// isPlatformMatch 判断平台是否匹配(空平台视为不限平台)。 +func isPlatformMatch(queryPlatform, pricingPlatform string) bool { + if queryPlatform == "" || pricingPlatform == "" { + return true + } + return queryPlatform == pricingPlatform +} + +// calculateStatsCost 使用给定的定价计算费用(不含任何倍率,原始费用)。 +func calculateStatsCost(pricing *ChannelModelPricing, tokens UsageTokens, requestCount int) *float64 { + if pricing == nil { + return nil + } + switch pricing.BillingMode { + case BillingModePerRequest, BillingModeImage: + return calculatePerRequestStatsCost(pricing, requestCount) + default: + return calculateTokenStatsCost(pricing, tokens) + } +} + +// calculatePerRequestStatsCost 按次/图片计费。 +func calculatePerRequestStatsCost(pricing *ChannelModelPricing, requestCount int) *float64 { + if pricing.PerRequestPrice == nil || *pricing.PerRequestPrice <= 0 { + return nil + } + cost := *pricing.PerRequestPrice * float64(requestCount) + return &cost +} + +// calculateTokenStatsCost Token 计费。 +func calculateTokenStatsCost(pricing *ChannelModelPricing, tokens UsageTokens) *float64 { + deref := func(p *float64) float64 { + if p == nil { + return 0 + } + return *p + } + cost := float64(tokens.InputTokens)*deref(pricing.InputPrice) + + float64(tokens.OutputTokens)*deref(pricing.OutputPrice) + + float64(tokens.CacheCreationTokens)*deref(pricing.CacheWritePrice) + + float64(tokens.CacheReadTokens)*deref(pricing.CacheReadPrice) + + float64(tokens.ImageOutputTokens)*deref(pricing.ImageOutputPrice) + if cost == 0 { + return nil + } + return &cost +} diff --git a/backend/internal/service/account_stats_pricing_test.go b/backend/internal/service/account_stats_pricing_test.go new file mode 100644 index 00000000..bc3db251 --- /dev/null +++ b/backend/internal/service/account_stats_pricing_test.go @@ -0,0 +1,430 @@ +//go:build unit + +package service + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// --------------------------------------------------------------------------- +// matchAccountStatsRule +// --------------------------------------------------------------------------- + +func TestMatchAccountStatsRule_BothEmpty_NoMatch(t *testing.T) { + rule := &AccountStatsPricingRule{} + require.False(t, matchAccountStatsRule(rule, 1, 10)) +} + +func TestMatchAccountStatsRule_AccountIDMatch(t *testing.T) { + rule := &AccountStatsPricingRule{AccountIDs: []int64{1, 2, 3}} + require.True(t, matchAccountStatsRule(rule, 2, 999)) +} + +func TestMatchAccountStatsRule_GroupIDMatch(t *testing.T) { + rule := &AccountStatsPricingRule{GroupIDs: []int64{10, 20}} + require.True(t, matchAccountStatsRule(rule, 999, 20)) +} + +func TestMatchAccountStatsRule_BothConfigured_AccountMatch(t *testing.T) { + rule := &AccountStatsPricingRule{ + AccountIDs: []int64{1, 2}, + GroupIDs: []int64{10, 20}, + } + require.True(t, matchAccountStatsRule(rule, 2, 999)) +} + +func TestMatchAccountStatsRule_BothConfigured_GroupMatch(t *testing.T) { + rule := &AccountStatsPricingRule{ + AccountIDs: []int64{1, 2}, + GroupIDs: []int64{10, 20}, + } + require.True(t, matchAccountStatsRule(rule, 999, 10)) +} + +func TestMatchAccountStatsRule_BothConfigured_NeitherMatch(t *testing.T) { + rule := &AccountStatsPricingRule{ + AccountIDs: []int64{1, 2}, + GroupIDs: []int64{10, 20}, + } + require.False(t, matchAccountStatsRule(rule, 999, 999)) +} + +// --------------------------------------------------------------------------- +// findPricingForModel +// --------------------------------------------------------------------------- + +func TestFindPricingForModel(t *testing.T) { + exactPricing := ChannelModelPricing{ + ID: 1, + Models: []string{"claude-opus-4"}, + } + wildcardPricing := ChannelModelPricing{ + ID: 2, + Models: []string{"claude-*"}, + } + platformPricing := ChannelModelPricing{ + ID: 3, + Platform: "openai", + Models: []string{"gpt-4o"}, + } + emptyPlatformPricing := ChannelModelPricing{ + ID: 4, + Models: []string{"gemini-2.5-pro"}, + } + + tests := []struct { + name string + list []ChannelModelPricing + platform string + model string + wantID int64 + wantNil bool + }{ + { + name: "exact match", + list: []ChannelModelPricing{exactPricing}, + platform: "anthropic", + model: "claude-opus-4", + wantID: 1, + }, + { + name: "exact match case insensitive", + list: []ChannelModelPricing{{ID: 5, Models: []string{"Claude-Opus-4"}}}, + platform: "", + model: "claude-opus-4", + wantID: 5, + }, + { + name: "wildcard match", + list: []ChannelModelPricing{wildcardPricing}, + platform: "anthropic", + model: "claude-opus-4", + wantID: 2, + }, + { + name: "exact match takes priority over wildcard", + list: []ChannelModelPricing{wildcardPricing, exactPricing}, + platform: "anthropic", + model: "claude-opus-4", + wantID: 1, + }, + { + name: "platform mismatch skipped", + list: []ChannelModelPricing{platformPricing}, + platform: "anthropic", + model: "gpt-4o", + wantNil: true, + }, + { + name: "empty platform in pricing matches any", + list: []ChannelModelPricing{emptyPlatformPricing}, + platform: "gemini", + model: "gemini-2.5-pro", + wantID: 4, + }, + { + name: "empty platform in query matches any pricing platform", + list: []ChannelModelPricing{platformPricing}, + platform: "", + model: "gpt-4o", + wantID: 3, + }, + { + name: "no match at all", + list: []ChannelModelPricing{exactPricing, wildcardPricing}, + platform: "anthropic", + model: "gpt-4o", + wantNil: true, + }, + { + name: "empty list returns nil", + list: nil, + model: "claude-opus-4", + wantNil: true, + }, + { + name: "longer wildcard prefix wins over shorter", + list: []ChannelModelPricing{ + {ID: 10, Models: []string{"claude-*"}}, + {ID: 11, Models: []string{"claude-opus-*"}}, + }, + platform: "", + model: "claude-opus-4", + wantID: 11, // "claude-opus-" (12 chars) > "claude-" (7 chars) + }, + { + name: "shorter wildcard used when longer does not match", + list: []ChannelModelPricing{ + {ID: 10, Models: []string{"claude-*"}}, + {ID: 11, Models: []string{"claude-opus-*"}}, + }, + platform: "", + model: "claude-sonnet-4", + wantID: 10, // only "claude-*" matches + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := findPricingForModel(tt.list, tt.platform, tt.model) + if tt.wantNil { + require.Nil(t, result) + return + } + require.NotNil(t, result) + require.Equal(t, tt.wantID, result.ID) + }) + } +} + +// --------------------------------------------------------------------------- +// calculateStatsCost +// --------------------------------------------------------------------------- + +func TestCalculateStatsCost_NilPricing(t *testing.T) { + result := calculateStatsCost(nil, UsageTokens{}, 1) + require.Nil(t, result) +} + +func TestCalculateStatsCost_TokenBilling(t *testing.T) { + pricing := &ChannelModelPricing{ + BillingMode: BillingModeToken, + InputPrice: testPtrFloat64(0.001), + OutputPrice: testPtrFloat64(0.002), + } + tokens := UsageTokens{ + InputTokens: 100, + OutputTokens: 50, + } + result := calculateStatsCost(pricing, tokens, 1) + require.NotNil(t, result) + // 100*0.001 + 50*0.002 = 0.1 + 0.1 = 0.2 + require.InDelta(t, 0.2, *result, 1e-12) +} + +func TestCalculateStatsCost_TokenBilling_WithCache(t *testing.T) { + pricing := &ChannelModelPricing{ + BillingMode: BillingModeToken, + InputPrice: testPtrFloat64(0.001), + OutputPrice: testPtrFloat64(0.002), + CacheWritePrice: testPtrFloat64(0.003), + CacheReadPrice: testPtrFloat64(0.0005), + } + tokens := UsageTokens{ + InputTokens: 100, + OutputTokens: 50, + CacheCreationTokens: 200, + CacheReadTokens: 300, + } + result := calculateStatsCost(pricing, tokens, 1) + require.NotNil(t, result) + // 100*0.001 + 50*0.002 + 200*0.003 + 300*0.0005 + // = 0.1 + 0.1 + 0.6 + 0.15 = 0.95 + require.InDelta(t, 0.95, *result, 1e-12) +} + +func TestCalculateStatsCost_TokenBilling_WithImageOutput(t *testing.T) { + pricing := &ChannelModelPricing{ + BillingMode: BillingModeToken, + InputPrice: testPtrFloat64(0.001), + OutputPrice: testPtrFloat64(0.002), + ImageOutputPrice: testPtrFloat64(0.01), + } + tokens := UsageTokens{ + InputTokens: 100, + OutputTokens: 50, + ImageOutputTokens: 10, + } + result := calculateStatsCost(pricing, tokens, 1) + require.NotNil(t, result) + // 100*0.001 + 50*0.002 + 10*0.01 = 0.1 + 0.1 + 0.1 = 0.3 + require.InDelta(t, 0.3, *result, 1e-12) +} + +func TestCalculateStatsCost_TokenBilling_PartialPricesNil(t *testing.T) { + pricing := &ChannelModelPricing{ + BillingMode: BillingModeToken, + InputPrice: testPtrFloat64(0.001), + // OutputPrice, CacheWritePrice, etc. are all nil → treated as 0 + } + tokens := UsageTokens{ + InputTokens: 100, + OutputTokens: 50, + CacheCreationTokens: 200, + } + result := calculateStatsCost(pricing, tokens, 1) + require.NotNil(t, result) + // Only input contributes: 100*0.001 = 0.1 + require.InDelta(t, 0.1, *result, 1e-12) +} + +func TestCalculateStatsCost_TokenBilling_AllTokensZero(t *testing.T) { + pricing := &ChannelModelPricing{ + BillingMode: BillingModeToken, + InputPrice: testPtrFloat64(0.001), + OutputPrice: testPtrFloat64(0.002), + } + tokens := UsageTokens{} // all zeros + result := calculateStatsCost(pricing, tokens, 1) + // totalCost == 0 → returns nil (does not override, falls back to default formula) + require.Nil(t, result) +} + +func TestCalculateStatsCost_PerRequestBilling(t *testing.T) { + pricing := &ChannelModelPricing{ + BillingMode: BillingModePerRequest, + PerRequestPrice: testPtrFloat64(0.05), + } + tokens := UsageTokens{InputTokens: 999, OutputTokens: 999} + result := calculateStatsCost(pricing, tokens, 3) + require.NotNil(t, result) + // 0.05 * 3 = 0.15 + require.InDelta(t, 0.15, *result, 1e-12) +} + +func TestCalculateStatsCost_PerRequestBilling_PriceNil(t *testing.T) { + pricing := &ChannelModelPricing{ + BillingMode: BillingModePerRequest, + // PerRequestPrice is nil + } + result := calculateStatsCost(pricing, UsageTokens{}, 1) + require.Nil(t, result) +} + +func TestCalculateStatsCost_PerRequestBilling_PriceZero(t *testing.T) { + pricing := &ChannelModelPricing{ + BillingMode: BillingModePerRequest, + PerRequestPrice: testPtrFloat64(0), + } + result := calculateStatsCost(pricing, UsageTokens{}, 1) + // price == 0 → condition *pricing.PerRequestPrice > 0 is false → returns nil + require.Nil(t, result) +} + +func TestCalculateStatsCost_ImageBilling(t *testing.T) { + pricing := &ChannelModelPricing{ + BillingMode: BillingModeImage, + PerRequestPrice: testPtrFloat64(0.10), + } + result := calculateStatsCost(pricing, UsageTokens{}, 2) + require.NotNil(t, result) + // 0.10 * 2 = 0.20 + require.InDelta(t, 0.20, *result, 1e-12) +} + +func TestCalculateStatsCost_ImageBilling_PriceNil(t *testing.T) { + pricing := &ChannelModelPricing{ + BillingMode: BillingModeImage, + // PerRequestPrice is nil + } + result := calculateStatsCost(pricing, UsageTokens{}, 1) + require.Nil(t, result) +} + +func TestCalculateStatsCost_DefaultBillingMode_FallsToToken(t *testing.T) { + // BillingMode is empty string (default) → falls into token billing + pricing := &ChannelModelPricing{ + InputPrice: testPtrFloat64(0.001), + OutputPrice: testPtrFloat64(0.002), + } + tokens := UsageTokens{ + InputTokens: 100, + OutputTokens: 50, + } + result := calculateStatsCost(pricing, tokens, 1) + require.NotNil(t, result) + require.InDelta(t, 0.2, *result, 1e-12) +} + +// --------------------------------------------------------------------------- +// tryCustomRules — 多规则顺序测试 +// --------------------------------------------------------------------------- + +func TestTryCustomRules_FirstMatchWins(t *testing.T) { + channel := &Channel{ + AccountStatsPricingRules: []AccountStatsPricingRule{ + { + GroupIDs: []int64{1}, + Pricing: []ChannelModelPricing{ + {ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.01), OutputPrice: testPtrFloat64(0.02)}, + }, + }, + { + GroupIDs: []int64{1}, + Pricing: []ChannelModelPricing{ + {ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.99), OutputPrice: testPtrFloat64(0.99)}, + }, + }, + }, + } + tokens := UsageTokens{InputTokens: 100, OutputTokens: 50} + result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1) + require.NotNil(t, result) + // 应使用第一条规则的价格:100*0.01 + 50*0.02 = 2.0 + require.InDelta(t, 2.0, *result, 1e-12) +} + +func TestTryCustomRules_SkipsNonMatchingRules(t *testing.T) { + channel := &Channel{ + AccountStatsPricingRules: []AccountStatsPricingRule{ + { + AccountIDs: []int64{888}, // 不匹配 + Pricing: []ChannelModelPricing{ + {ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.99)}, + }, + }, + { + GroupIDs: []int64{1}, // 匹配 + Pricing: []ChannelModelPricing{ + {ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.05)}, + }, + }, + }, + } + tokens := UsageTokens{InputTokens: 100} + result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1) + require.NotNil(t, result) + // 跳过规则1(账号不匹配),使用规则2:100*0.05 = 5.0 + require.InDelta(t, 5.0, *result, 1e-12) +} + +func TestTryCustomRules_NoMatch_ReturnsNil(t *testing.T) { + channel := &Channel{ + AccountStatsPricingRules: []AccountStatsPricingRule{ + { + AccountIDs: []int64{888}, + Pricing: []ChannelModelPricing{ + {ID: 100, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.01)}, + }, + }, + }, + } + tokens := UsageTokens{InputTokens: 100} + result := tryCustomRules(channel, 999, 2, "", "claude-opus-4", tokens, 1) + require.Nil(t, result) // 账号和分组都不匹配 +} + +func TestTryCustomRules_RuleMatchesButModelNot_ContinuesToNext(t *testing.T) { + channel := &Channel{ + AccountStatsPricingRules: []AccountStatsPricingRule{ + { + GroupIDs: []int64{1}, + Pricing: []ChannelModelPricing{ + {ID: 100, Models: []string{"gpt-4o"}, InputPrice: testPtrFloat64(0.01)}, // 模型不匹配 + }, + }, + { + GroupIDs: []int64{1}, + Pricing: []ChannelModelPricing{ + {ID: 200, Models: []string{"claude-opus-4"}, InputPrice: testPtrFloat64(0.05)}, // 模型匹配 + }, + }, + }, + } + tokens := UsageTokens{InputTokens: 100} + result := tryCustomRules(channel, 999, 1, "", "claude-opus-4", tokens, 1) + require.NotNil(t, result) + require.InDelta(t, 5.0, *result, 1e-12) // 使用规则2 +} diff --git a/backend/internal/service/channel.go b/backend/internal/service/channel.go index baf5c839..3867f2a0 100644 --- a/backend/internal/service/channel.go +++ b/backend/internal/service/channel.go @@ -49,21 +49,25 @@ type Channel struct { ModelPricing []ChannelModelPricing // 渠道级模型映射(按平台分组:platform → {src→dst}) ModelMapping map[string]map[string]string - // 渠道特性配置(如 {"web_search_emulation": {"anthropic": true}}) - FeaturesConfig map[string]any + + // 账号统计定价 + ApplyPricingToAccountStats bool // 是否应用渠道模型定价到账号统计 + AccountStatsPricingRules []AccountStatsPricingRule // 自定义账号统计定价规则(按 SortOrder 排序,先命中为准) } -// IsWebSearchEmulationEnabled 返回该渠道是否为指定平台启用了 web search 模拟。 -func (c *Channel) IsWebSearchEmulationEnabled(platform string) bool { - if c == nil || c.FeaturesConfig == nil { - return false - } - wse, ok := c.FeaturesConfig[featureKeyWebSearchEmulation].(map[string]any) - if !ok { - return false - } - enabled, ok := wse[platform].(bool) - return ok && enabled +// AccountStatsPricingRule 账号统计定价规则 +// 每条规则包含匹配条件(分组/账号)和独立的模型定价。 +// 多条规则按 SortOrder 排序,先命中为准。 +type AccountStatsPricingRule struct { + ID int64 + ChannelID int64 + Name string + GroupIDs []int64 + AccountIDs []int64 + SortOrder int + Pricing []ChannelModelPricing // 规则内的模型定价(复用现有定价结构) + CreatedAt time.Time + UpdatedAt time.Time } // ChannelModelPricing 渠道模型定价条目 @@ -192,6 +196,26 @@ func (c *Channel) Clone() *Channel { cp.ModelMapping[platform] = inner } } + if c.AccountStatsPricingRules != nil { + cp.AccountStatsPricingRules = make([]AccountStatsPricingRule, len(c.AccountStatsPricingRules)) + for i, rule := range c.AccountStatsPricingRules { + cp.AccountStatsPricingRules[i] = rule + if rule.GroupIDs != nil { + cp.AccountStatsPricingRules[i].GroupIDs = make([]int64, len(rule.GroupIDs)) + copy(cp.AccountStatsPricingRules[i].GroupIDs, rule.GroupIDs) + } + if rule.AccountIDs != nil { + cp.AccountStatsPricingRules[i].AccountIDs = make([]int64, len(rule.AccountIDs)) + copy(cp.AccountStatsPricingRules[i].AccountIDs, rule.AccountIDs) + } + if rule.Pricing != nil { + cp.AccountStatsPricingRules[i].Pricing = make([]ChannelModelPricing, len(rule.Pricing)) + for j := range rule.Pricing { + cp.AccountStatsPricingRules[i].Pricing[j] = rule.Pricing[j].Clone() + } + } + } + } return &cp } diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go index 7b28662b..d0698f0f 100644 --- a/backend/internal/service/channel_service.go +++ b/backend/internal/service/channel_service.go @@ -416,6 +416,15 @@ func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64) return ch.Clone(), nil } +// GetGroupPlatform 获取分组的平台标识(从缓存) +func (s *ChannelService) GetGroupPlatform(ctx context.Context, groupID int64) string { + cache, err := s.loadCache(ctx) + if err != nil { + return "" + } + return cache.groupPlatform[groupID] +} + // channelLookup 热路径公共查找结果 type channelLookup struct { cache *channelCache @@ -656,16 +665,17 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) } channel := &Channel{ - Name: input.Name, - Description: input.Description, - Status: StatusActive, - BillingModelSource: input.BillingModelSource, - RestrictModels: input.RestrictModels, - GroupIDs: input.GroupIDs, - ModelPricing: input.ModelPricing, - ModelMapping: input.ModelMapping, - Features: input.Features, - FeaturesConfig: input.FeaturesConfig, + Name: input.Name, + Description: input.Description, + Status: StatusActive, + BillingModelSource: input.BillingModelSource, + RestrictModels: input.RestrictModels, + GroupIDs: input.GroupIDs, + ModelPricing: input.ModelPricing, + ModelMapping: input.ModelMapping, + Features: input.Features, + ApplyPricingToAccountStats: input.ApplyPricingToAccountStats, + AccountStatsPricingRules: input.AccountStatsPricingRules, } if channel.BillingModelSource == "" { channel.BillingModelSource = BillingModelSourceChannelMapped @@ -754,8 +764,11 @@ func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel, if input.BillingModelSource != "" { channel.BillingModelSource = input.BillingModelSource } - if input.FeaturesConfig != nil { - channel.FeaturesConfig = input.FeaturesConfig + if input.ApplyPricingToAccountStats != nil { + channel.ApplyPricingToAccountStats = *input.ApplyPricingToAccountStats + } + if input.AccountStatsPricingRules != nil { + channel.AccountStatsPricingRules = *input.AccountStatsPricingRules } return nil } @@ -922,27 +935,29 @@ func detectConflicts(entries []modelEntry, platform, errCode, label string) erro // CreateChannelInput 创建渠道输入 type CreateChannelInput struct { - Name string - Description string - GroupIDs []int64 - ModelPricing []ChannelModelPricing - ModelMapping map[string]map[string]string // platform → {src→dst} - BillingModelSource string - RestrictModels bool - Features string - FeaturesConfig map[string]any + Name string + Description string + GroupIDs []int64 + ModelPricing []ChannelModelPricing + ModelMapping map[string]map[string]string // platform → {src→dst} + BillingModelSource string + RestrictModels bool + Features string + ApplyPricingToAccountStats bool + AccountStatsPricingRules []AccountStatsPricingRule } // UpdateChannelInput 更新渠道输入 type UpdateChannelInput struct { - Name string - Description *string - Status string - GroupIDs *[]int64 - ModelPricing *[]ChannelModelPricing - ModelMapping map[string]map[string]string // platform → {src→dst} - BillingModelSource string - RestrictModels *bool - Features *string - FeaturesConfig map[string]any + Name string + Description *string + Status string + GroupIDs *[]int64 + ModelPricing *[]ChannelModelPricing + ModelMapping map[string]map[string]string // platform → {src→dst} + BillingModelSource string + RestrictModels *bool + Features *string + ApplyPricingToAccountStats *bool + AccountStatsPricingRules *[]AccountStatsPricingRule } diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 77e9b8c8..1d6d0a08 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -7559,6 +7559,23 @@ func (s *GatewayService) recordUsageCore(ctx context.Context, input *recordUsage usageLog := s.buildRecordUsageLog(ctx, input, result, apiKey, user, account, subscription, requestedModel, multiplier, accountRateMultiplier, billingType, cacheTTLOverridden, cost, opts) + // 计算账号统计定价费用 + if apiKey.GroupID != nil { + usageLog.AccountStatsCost = resolveAccountStatsCost( + ctx, s.channelService, s.billingService, + account.ID, *apiKey.GroupID, billingModel, + UsageTokens{ + InputTokens: result.Usage.InputTokens, + OutputTokens: result.Usage.OutputTokens, + CacheCreationTokens: result.Usage.CacheCreationInputTokens, + CacheReadTokens: result.Usage.CacheReadInputTokens, + ImageOutputTokens: result.Usage.ImageOutputTokens, + }, + 1, // requestCount + "", // serviceTier: Anthropic 平台不使用 service tier + ) + } + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.gateway") logger.LegacyPrintf("service.gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index dbc53869..3daa8756 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -4569,6 +4569,15 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec usageLog.SubscriptionID = &subscription.ID } + // 计算账号统计定价费用 + if apiKey.GroupID != nil { + usageLog.AccountStatsCost = resolveAccountStatsCost( + ctx, s.channelService, s.billingService, + account.ID, *apiKey.GroupID, billingModel, + tokens, 1, serviceTier, + ) + } + if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple { writeUsageLogBestEffort(ctx, s.usageLogRepo, usageLog, "service.openai_gateway") logger.LegacyPrintf("service.openai_gateway", "[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens()) diff --git a/backend/internal/service/usage_log.go b/backend/internal/service/usage_log.go index 3218f3db..e29d282e 100644 --- a/backend/internal/service/usage_log.go +++ b/backend/internal/service/usage_log.go @@ -146,6 +146,8 @@ type UsageLog struct { RateMultiplier float64 // AccountRateMultiplier 账号计费倍率快照(nil 表示历史数据,按 1.0 处理) AccountRateMultiplier *float64 + // AccountStatsCost 账号统计定价预计算费用(nil = 使用默认公式 total_cost × account_rate_multiplier) + AccountStatsCost *float64 BillingType int8 RequestType RequestType diff --git a/backend/migrations/101_add_account_stats_pricing.sql b/backend/migrations/101_add_account_stats_pricing.sql new file mode 100644 index 00000000..a61d0c26 --- /dev/null +++ b/backend/migrations/101_add_account_stats_pricing.sql @@ -0,0 +1,38 @@ +-- Account statistics pricing: allow channels to configure custom pricing for account cost tracking. + +-- 1. Channel-level toggle +ALTER TABLE channels ADD COLUMN IF NOT EXISTS apply_pricing_to_account_stats BOOLEAN NOT NULL DEFAULT FALSE; + +-- 2. Account stats pricing rules (ordered list per channel) +CREATE TABLE IF NOT EXISTS channel_account_stats_pricing_rules ( + id BIGSERIAL PRIMARY KEY, + channel_id BIGINT NOT NULL REFERENCES channels(id) ON DELETE CASCADE, + name VARCHAR(100) NOT NULL DEFAULT '', + group_ids BIGINT[] NOT NULL DEFAULT '{}', + account_ids BIGINT[] NOT NULL DEFAULT '{}', + sort_order INT NOT NULL DEFAULT 0, + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); +CREATE INDEX IF NOT EXISTS idx_cas_pricing_rules_channel_id ON channel_account_stats_pricing_rules(channel_id); + +-- 3. Model pricing for each rule (same structure as channel_model_pricing) +CREATE TABLE IF NOT EXISTS channel_account_stats_model_pricing ( + id BIGSERIAL PRIMARY KEY, + rule_id BIGINT NOT NULL REFERENCES channel_account_stats_pricing_rules(id) ON DELETE CASCADE, + platform VARCHAR(50) NOT NULL DEFAULT '', + models JSONB NOT NULL DEFAULT '[]', + billing_mode VARCHAR(20) NOT NULL DEFAULT 'token', + input_price NUMERIC(20,10), + output_price NUMERIC(20,10), + cache_write_price NUMERIC(20,10), + cache_read_price NUMERIC(20,10), + image_output_price NUMERIC(20,10), + per_request_price NUMERIC(20,10), + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); +CREATE INDEX IF NOT EXISTS idx_cas_model_pricing_rule_id ON channel_account_stats_model_pricing(rule_id); + +-- 4. Usage logs: pre-computed account stats cost (NULL = use default formula) +ALTER TABLE usage_logs ADD COLUMN IF NOT EXISTS account_stats_cost NUMERIC(20,10); diff --git a/frontend/src/api/admin/channels.ts b/frontend/src/api/admin/channels.ts index d49982aa..a13eb3e1 100644 --- a/frontend/src/api/admin/channels.ts +++ b/frontend/src/api/admin/channels.ts @@ -34,6 +34,14 @@ export interface ChannelModelPricing { intervals: PricingInterval[] } +export interface AccountStatsPricingRule { + id?: number + name: string + group_ids: number[] + account_ids: number[] + pricing: ChannelModelPricing[] +} + export interface Channel { id: number name: string @@ -41,10 +49,11 @@ export interface Channel { status: string billing_model_source: string // "requested" | "upstream" restrict_models: boolean - features_config?: Record group_ids: number[] model_pricing: ChannelModelPricing[] model_mapping: Record> // platform → {src→dst} + apply_pricing_to_account_stats: boolean + account_stats_pricing_rules: AccountStatsPricingRule[] created_at: string updated_at: string } @@ -57,7 +66,8 @@ export interface CreateChannelRequest { model_mapping?: Record> billing_model_source?: string restrict_models?: boolean - features_config?: Record + apply_pricing_to_account_stats?: boolean + account_stats_pricing_rules?: AccountStatsPricingRule[] } export interface UpdateChannelRequest { @@ -69,7 +79,8 @@ export interface UpdateChannelRequest { model_mapping?: Record> billing_model_source?: string restrict_models?: boolean - features_config?: Record + apply_pricing_to_account_stats?: boolean + account_stats_pricing_rules?: AccountStatsPricingRule[] } interface PaginatedResponse { diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index 99f8d535..dd45ea17 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -1844,7 +1844,18 @@ export default { noPlatforms: 'Click "Add Platform" to start configuring the channel', mappingCount: 'mappings', pricingEntry: 'Pricing Entry', - noModels: 'No models added' + noModels: 'No models added', + applyPricingToAccountStats: 'Apply Pricing to Account Stats', + applyPricingToAccountStatsDesc: 'When enabled, account statistics cost will use channel model pricing. Account rate multiplier still applies.', + accountStatsPricingRules: 'Custom Account Stats Pricing Rules', + addRule: 'Add Rule', + noRulesConfigured: 'No custom rules configured. Channel model pricing above will be used.', + ruleName: 'Rule name (optional)', + ruleGroups: 'Groups', + ruleAccounts: 'Account IDs', + ruleAccountsPlaceholder: 'Enter account IDs, comma-separated', + ruleModelPricing: 'Model Pricing', + noGroupsInChannel: 'No groups selected in platform tabs above' } }, diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 7ef7ead0..bbfc7971 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -1923,7 +1923,18 @@ export default { noPlatforms: '点击"添加平台"开始配置渠道', mappingCount: '条映射', pricingEntry: '定价配置', - noModels: '未添加模型' + noModels: '未添加模型', + applyPricingToAccountStats: '应用模型定价到账号统计', + applyPricingToAccountStatsDesc: '启用后,账号统计费用将使用渠道模型定价计算。账号自身的统计倍率仍然生效。', + accountStatsPricingRules: '自定义账号统计定价规则', + addRule: '添加规则', + noRulesConfigured: '未配置自定义规则,将使用上方的模型定价。', + ruleName: '规则名称(可选)', + ruleGroups: '分组', + ruleAccounts: '账号 ID', + ruleAccountsPlaceholder: '输入账号 ID,逗号分隔', + ruleModelPricing: '模型定价', + noGroupsInChannel: '上方平台标签页中未选择分组' } }, diff --git a/frontend/src/views/admin/ChannelsView.vue b/frontend/src/views/admin/ChannelsView.vue index ce8d3c9c..a49e1694 100644 --- a/frontend/src/views/admin/ChannelsView.vue +++ b/frontend/src/views/admin/ChannelsView.vue @@ -306,24 +306,6 @@ - -
-
-
- -

- {{ t('admin.channels.form.webSearchEmulationHint') }} -

-

- {{ t('admin.channels.form.webSearchEmulationGlobalDisabled') }} -

-
- -
-
-
@@ -398,6 +380,143 @@
+ + +
+ +
+
+ +

+ {{ t('admin.channels.form.applyPricingToAccountStatsDesc', 'When enabled, account statistics cost will use channel model pricing. Account rate multiplier still applies.') }} +

+
+ +
+ + +
+
+

+ {{ t('admin.channels.form.accountStatsPricingRules', 'Custom Account Stats Pricing Rules') }} +

+ +
+ +

+ {{ t('admin.channels.form.noRulesConfigured', 'No custom rules configured. Channel model pricing above will be used.') }} +

+ + +
+
+ + +
+ + +
+ +
+ +
+

+ {{ t('admin.channels.form.noGroupsInChannel', 'No groups selected in platform tabs above') }} +

+
+ + +
+ + +
+ + +
+
+ + +
+
+ {{ t('admin.channels.form.noPricingRules', 'No pricing rules yet. Click "Add" to create one.') }} +
+
+ +
+
+
+
+
@@ -441,9 +560,8 @@ import { ref, reactive, computed, onMounted, onUnmounted } from 'vue' import { useI18n } from 'vue-i18n' import { useAppStore } from '@/stores/app' -import { extractApiErrorMessage } from '@/utils/apiError' import { adminAPI } from '@/api/admin' -import type { Channel, ChannelModelPricing, CreateChannelRequest, UpdateChannelRequest } from '@/api/admin/channels' +import type { Channel, ChannelModelPricing, CreateChannelRequest, UpdateChannelRequest, AccountStatsPricingRule } from '@/api/admin/channels' import type { PricingFormEntry } from '@/components/admin/channel/types' import { mTokToPerToken, perTokenToMTok, apiIntervalsToForm, formIntervalsToAPI, findModelConflict, validateIntervals } from '@/components/admin/channel/types' import type { AdminGroup, GroupPlatform } from '@/types' @@ -465,18 +583,6 @@ import { getPersistedPageSize } from '@/composables/usePersistedPageSize' const { t } = useI18n() const appStore = useAppStore() -// Web Search global enabled state (loaded once on mount) -const webSearchGlobalEnabled = ref(false) -async function loadWebSearchGlobalState() { - try { - const cfg = await adminAPI.settings.getWebSearchEmulationConfig() - webSearchGlobalEnabled.value = cfg?.enabled === true && (cfg?.providers?.length ?? 0) > 0 - } catch (err: unknown) { - console.warn('Failed to load web search global state:', err) - webSearchGlobalEnabled.value = false - } -} - // ── Platform Section type ── interface PlatformSection { platform: GroupPlatform @@ -485,7 +591,6 @@ interface PlatformSection { group_ids: number[] model_mapping: Record model_pricing: PricingFormEntry[] - web_search_emulation: boolean } // ── Table columns ── @@ -553,7 +658,14 @@ const form = reactive({ status: 'active', restrict_models: false, billing_model_source: 'channel_mapped' as string, - platforms: [] as PlatformSection[] + platforms: [] as PlatformSection[], + apply_pricing_to_account_stats: false, + account_stats_pricing_rules: [] as Array<{ + name: string + group_ids: number[] + account_ids: number[] + pricing: PricingFormEntry[] + }> }) let abortController: AbortController | null = null @@ -597,8 +709,7 @@ function addPlatformSection(platform: GroupPlatform) { collapsed: false, group_ids: [], model_mapping: {}, - model_pricing: [], - web_search_emulation: false, + model_pricing: [] }) } @@ -711,15 +822,89 @@ function renameMappingKey(sectionIdx: number, oldKey: string, newKey: string) { mapping[newKey] = value } +// ── Account Stats Pricing helpers ── +function addAccountStatsRule() { + form.account_stats_pricing_rules.push({ + name: '', + group_ids: [], + account_ids: [], + pricing: [] + }) +} + +function addRulePricingEntry(ruleIndex: number) { + form.account_stats_pricing_rules[ruleIndex].pricing.push({ + models: [], + billing_mode: 'token', + input_price: null, + output_price: null, + cache_write_price: null, + cache_read_price: null, + image_output_price: null, + per_request_price: null, + intervals: [] + }) +} + +function removeAccountStatsRule(ruleIndex: number) { + form.account_stats_pricing_rules.splice(ruleIndex, 1) +} + +function removeRulePricingEntry(ruleIndex: number, pricingIndex: number) { + form.account_stats_pricing_rules[ruleIndex].pricing.splice(pricingIndex, 1) +} + +function getGroupNameById(groupId: number): string { + const group = allGroups.value.find(g => g.id === groupId) + return group ? group.name : `#${groupId}` +} + +/** Collect all group_ids from enabled platform sections */ +const allFormGroupIds = computed(() => { + const ids = new Set() + for (const section of form.platforms) { + if (!section.enabled) continue + for (const gid of section.group_ids) { + ids.add(gid) + } + } + return [...ids] +}) + +function parseAccountIdsInput(value: string): number[] { + return value + .split(',') + .map(s => parseInt(s.trim())) + .filter(n => !isNaN(n) && n > 0) +} + +function accountStatsRulesToAPI(): AccountStatsPricingRule[] { + return form.account_stats_pricing_rules.map(rule => ({ + name: rule.name, + group_ids: rule.group_ids, + account_ids: rule.account_ids, + pricing: rule.pricing + .filter(p => p.models.length > 0) + .map(p => ({ + platform: '', + models: p.models, + billing_mode: p.billing_mode, + input_price: mTokToPerToken(p.input_price), + output_price: mTokToPerToken(p.output_price), + cache_write_price: mTokToPerToken(p.cache_write_price), + cache_read_price: mTokToPerToken(p.cache_read_price), + image_output_price: mTokToPerToken(p.image_output_price), + per_request_price: p.per_request_price != null && p.per_request_price !== '' ? Number(p.per_request_price) : null, + intervals: formIntervalsToAPI(p.intervals || []) + })) + })) +} + // ── Form ↔ API conversion ── -function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[], model_mapping: Record>, features_config: Record } { +function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[], model_mapping: Record> } { const group_ids: number[] = [] const model_pricing: ChannelModelPricing[] = [] const model_mapping: Record> = {} - // Preserve existing features_config fields not managed by the form - const featuresConfig: Record = editingChannel.value?.features_config - ? { ...editingChannel.value.features_config } - : {} for (const section of form.platforms) { if (!section.enabled) continue @@ -748,19 +933,7 @@ function formToAPI(): { group_ids: number[], model_pricing: ChannelModelPricing[ } } - // Collect web_search_emulation (only anthropic platform supports it) - const wsEmulation: Record = {} - for (const section of form.platforms) { - if (!section.enabled) continue - if (section.web_search_emulation && section.platform === 'anthropic') { - wsEmulation[section.platform] = true - } - } - if (Object.keys(wsEmulation).length > 0) { - featuresConfig.web_search_emulation = wsEmulation - } - - return { group_ids, model_pricing, model_mapping, features_config: featuresConfig } + return { group_ids, model_pricing, model_mapping } } function apiToForm(channel: Channel): PlatformSection[] { @@ -804,19 +977,13 @@ function apiToForm(channel: Channel): PlatformSection[] { intervals: apiIntervalsToForm(p.intervals || []) } as PricingFormEntry)) - // Read web_search_emulation from features_config - const fc = channel.features_config - const wsEmulation = fc?.web_search_emulation as Record | undefined - const webSearchEnabled = wsEmulation?.[platform] === true - sections.push({ platform, enabled: true, collapsed: false, group_ids: groupIds, model_mapping: { ...mapping }, - model_pricing: pricing, - web_search_emulation: webSearchEnabled, + model_pricing: pricing }) } @@ -841,10 +1008,10 @@ async function loadChannels() { if (ctrl.signal.aborted || abortController !== ctrl) return channels.value = response.items || [] pagination.total = response.total - } catch (error: unknown) { - const e = error as { name?: string; code?: string } - if (e?.name === 'AbortError' || e?.code === 'ERR_CANCELED') return - appStore.showError(extractApiErrorMessage(error, t('admin.channels.loadError', 'Failed to load channels'))) + } catch (error: any) { + if (error?.name === 'AbortError' || error?.code === 'ERR_CANCELED') return + appStore.showError(t('admin.channels.loadError', 'Failed to load channels')) + console.error('Error loading channels:', error) } finally { if (abortController === ctrl) { loading.value = false @@ -909,6 +1076,8 @@ function resetForm() { form.restrict_models = false form.billing_model_source = 'channel_mapped' form.platforms = [] + form.apply_pricing_to_account_stats = false + form.account_stats_pricing_rules = [] activeTab.value = 'basic' } @@ -926,6 +1095,23 @@ async function openEditDialog(channel: Channel) { form.status = channel.status form.restrict_models = channel.restrict_models || false form.billing_model_source = channel.billing_model_source || 'channel_mapped' + form.apply_pricing_to_account_stats = channel.apply_pricing_to_account_stats || false + form.account_stats_pricing_rules = (channel.account_stats_pricing_rules || []).map(rule => ({ + name: rule.name || '', + group_ids: [...(rule.group_ids || [])], + account_ids: [...(rule.account_ids || [])], + pricing: (rule.pricing || []).map(p => ({ + models: [...(p.models || [])], + billing_mode: p.billing_mode, + input_price: perTokenToMTok(p.input_price), + output_price: perTokenToMTok(p.output_price), + cache_write_price: perTokenToMTok(p.cache_write_price), + cache_read_price: perTokenToMTok(p.cache_read_price), + image_output_price: perTokenToMTok(p.image_output_price), + per_request_price: p.per_request_price, + intervals: apiIntervalsToForm(p.intervals || []) + } as PricingFormEntry)) + })) // Must load groups first so apiToForm can map groupID → platform await Promise.all([loadGroups(), loadAllChannelsForConflict()]) form.platforms = apiToForm(channel) @@ -1024,7 +1210,7 @@ async function handleSubmit() { } } - const { group_ids, model_pricing, model_mapping, features_config } = formToAPI() + const { group_ids, model_pricing, model_mapping } = formToAPI() submitting.value = true try { @@ -1038,7 +1224,8 @@ async function handleSubmit() { model_mapping: Object.keys(model_mapping).length > 0 ? model_mapping : {}, billing_model_source: form.billing_model_source, restrict_models: form.restrict_models, - features_config, + apply_pricing_to_account_stats: form.apply_pricing_to_account_stats, + account_stats_pricing_rules: accountStatsRulesToAPI() } await adminAPI.channels.update(editingChannel.value.id, req) appStore.showSuccess(t('admin.channels.updateSuccess', 'Channel updated')) @@ -1051,17 +1238,20 @@ async function handleSubmit() { model_mapping: Object.keys(model_mapping).length > 0 ? model_mapping : {}, billing_model_source: form.billing_model_source, restrict_models: form.restrict_models, - features_config, + apply_pricing_to_account_stats: form.apply_pricing_to_account_stats, + account_stats_pricing_rules: accountStatsRulesToAPI() } await adminAPI.channels.create(req) appStore.showSuccess(t('admin.channels.createSuccess', 'Channel created')) } closeDialog() loadChannels() - } catch (error: unknown) { - appStore.showError(extractApiErrorMessage(error, editingChannel.value + } catch (error: any) { + const msg = error.response?.data?.detail || (editingChannel.value ? t('admin.channels.updateError', 'Failed to update channel') - : t('admin.channels.createError', 'Failed to create channel'))) + : t('admin.channels.createError', 'Failed to create channel')) + appStore.showError(msg) + console.error('Error saving channel:', error) } finally { submitting.value = false } @@ -1099,8 +1289,9 @@ async function confirmDelete() { showDeleteDialog.value = false deletingChannel.value = null loadChannels() - } catch (error: unknown) { - appStore.showError(extractApiErrorMessage(error, t('admin.channels.deleteError', 'Failed to delete channel'))) + } catch (error: any) { + appStore.showError(error.response?.data?.detail || t('admin.channels.deleteError', 'Failed to delete channel')) + console.error('Error deleting channel:', error) } } @@ -1108,7 +1299,6 @@ async function confirmDelete() { onMounted(() => { loadChannels() loadGroups() - loadWebSearchGlobalState() }) onUnmounted(() => {