diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go index 88d27c47..9151d018 100644 --- a/backend/internal/handler/admin/channel_handler.go +++ b/backend/internal/handler/admin/channel_handler.go @@ -35,6 +35,7 @@ type createChannelRequest struct { BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"` RestrictModels bool `json:"restrict_models"` Features string `json:"features"` + FeaturesConfig map[string]any `json:"features_config"` ApplyPricingToAccountStats bool `json:"apply_pricing_to_account_stats"` AccountStatsPricingRules []accountStatsPricingRuleRequest `json:"account_stats_pricing_rules"` } @@ -49,6 +50,7 @@ type updateChannelRequest struct { BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"` RestrictModels *bool `json:"restrict_models"` Features *string `json:"features"` + FeaturesConfig map[string]any `json:"features_config"` ApplyPricingToAccountStats *bool `json:"apply_pricing_to_account_stats"` AccountStatsPricingRules *[]accountStatsPricingRuleRequest `json:"account_stats_pricing_rules"` } @@ -93,6 +95,7 @@ type channelResponse struct { BillingModelSource string `json:"billing_model_source"` RestrictModels bool `json:"restrict_models"` Features string `json:"features"` + FeaturesConfig map[string]any `json:"features_config"` GroupIDs []int64 `json:"group_ids"` ModelPricing []channelModelPricingResponse `json:"model_pricing"` ModelMapping map[string]map[string]string `json:"model_mapping"` @@ -148,6 +151,7 @@ func channelToResponse(ch *service.Channel) *channelResponse { Status: ch.Status, RestrictModels: ch.RestrictModels, Features: ch.Features, + FeaturesConfig: ch.FeaturesConfig, GroupIDs: ch.GroupIDs, ModelMapping: ch.ModelMapping, CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"), @@ -379,6 +383,7 @@ func (h *ChannelHandler) Create(c *gin.Context) { BillingModelSource: req.BillingModelSource, RestrictModels: req.RestrictModels, Features: req.Features, + FeaturesConfig: req.FeaturesConfig, ApplyPricingToAccountStats: req.ApplyPricingToAccountStats, AccountStatsPricingRules: statsRules, }) @@ -414,6 +419,7 @@ func (h *ChannelHandler) Update(c *gin.Context) { BillingModelSource: req.BillingModelSource, RestrictModels: req.RestrictModels, Features: req.Features, + FeaturesConfig: req.FeaturesConfig, ApplyPricingToAccountStats: req.ApplyPricingToAccountStats, } if req.ModelPricing != nil { diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 8ec54420..30065463 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -473,6 +473,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, + ParsedRequest: parsedReq, APIKey: apiKey, User: apiKey.User, Account: account, @@ -675,6 +676,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { } // 转发请求 - 根据账号平台分流 + c.Set("parsed_request", parsedReq) var result *service.ForwardResult requestCtx := c.Request.Context() if fs.SwitchCount > 0 { @@ -813,6 +815,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) { h.submitUsageRecordTask(func(ctx context.Context) { if err := h.gatewayService.RecordUsage(ctx, &service.RecordUsageInput{ Result: result, + ParsedRequest: parsedReq, APIKey: currentAPIKey, User: currentAPIKey.User, Account: account, diff --git a/backend/internal/repository/channel_repo.go b/backend/internal/repository/channel_repo.go index 583ce895..2cb90aab 100644 --- a/backend/internal/repository/channel_repo.go +++ b/backend/internal/repository/channel_repo.go @@ -41,10 +41,14 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel if err != nil { return err } + featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig) + if err != nil { + return err + } err = tx.QueryRowContext(ctx, - `INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features, apply_pricing_to_account_stats) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + `INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, apply_pricing_to_account_stats) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id, created_at, updated_at`, - channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, channel.ApplyPricingToAccountStats, + channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON, channel.ApplyPricingToAccountStats, ).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt) if err != nil { if isUniqueViolation(err) { @@ -80,11 +84,11 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) { ch := &service.Channel{} - var modelMappingJSON []byte + var modelMappingJSON, featuresConfigJSON []byte err := r.db.QueryRowContext(ctx, - `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, apply_pricing_to_account_stats, created_at, updated_at + `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, apply_pricing_to_account_stats, created_at, updated_at FROM channels WHERE id = $1`, id, - ).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt) + ).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt) if err == sql.ErrNoRows { return nil, service.ErrChannelNotFound } @@ -92,6 +96,7 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha return nil, fmt.Errorf("get channel: %w", err) } ch.ModelMapping = unmarshalModelMapping(modelMappingJSON) + ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON) groupIDs, err := r.GetGroupIDs(ctx, id) if err != nil { @@ -120,10 +125,14 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel if err != nil { return err } + featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig) + if err != nil { + return err + } result, err := tx.ExecContext(ctx, - `UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, apply_pricing_to_account_stats = $8, updated_at = NOW() - WHERE id = $9`, - channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, channel.ApplyPricingToAccountStats, channel.ID, + `UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, features_config = $8, apply_pricing_to_account_stats = $9, updated_at = NOW() + WHERE id = $10`, + channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON, channel.ApplyPricingToAccountStats, channel.ID, ) if err != nil { if isUniqueViolation(err) { @@ -207,7 +216,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati // 查询 channel 列表 dataQuery := fmt.Sprintf( - `SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.apply_pricing_to_account_stats, c.created_at, c.updated_at + `SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.features_config, c.apply_pricing_to_account_stats, c.created_at, c.updated_at FROM channels c WHERE %s ORDER BY %s LIMIT $%d OFFSET $%d`, whereClause, channelListOrderBy(params), argIdx, argIdx+1, ) @@ -223,11 +232,12 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati var channelIDs []int64 for rows.Next() { var ch service.Channel - var modelMappingJSON []byte - if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil { + var modelMappingJSON, featuresConfigJSON []byte + if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil { return nil, nil, fmt.Errorf("scan channel: %w", err) } ch.ModelMapping = unmarshalModelMapping(modelMappingJSON) + ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON) channels = append(channels, ch) channelIDs = append(channelIDs, ch.ID) } @@ -298,7 +308,7 @@ func channelListOrderBy(params pagination.PaginationParams) string { func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) { rows, err := r.db.QueryContext(ctx, - `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, apply_pricing_to_account_stats, created_at, updated_at FROM channels ORDER BY id`, + `SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, apply_pricing_to_account_stats, created_at, updated_at FROM channels ORDER BY id`, ) if err != nil { return nil, fmt.Errorf("query all channels: %w", err) @@ -309,11 +319,12 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err var channelIDs []int64 for rows.Next() { var ch service.Channel - var modelMappingJSON []byte - if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil { + var modelMappingJSON, featuresConfigJSON []byte + if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil { return nil, fmt.Errorf("scan channel: %w", err) } ch.ModelMapping = unmarshalModelMapping(modelMappingJSON) + ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON) channels = append(channels, ch) channelIDs = append(channelIDs, ch.ID) } @@ -488,6 +499,28 @@ func unmarshalModelMapping(data []byte) map[string]map[string]string { return m } +func marshalFeaturesConfig(m map[string]any) ([]byte, error) { + if len(m) == 0 { + return []byte("{}"), nil + } + data, err := json.Marshal(m) + if err != nil { + return nil, fmt.Errorf("marshal features_config: %w", err) + } + return data, nil +} + +func unmarshalFeaturesConfig(data []byte) map[string]any { + if len(data) == 0 { + return nil + } + var m map[string]any + if err := json.Unmarshal(data, &m); err != nil { + return nil + } + return m +} + // GetGroupPlatforms 批量查询分组 ID 对应的平台 func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) { if len(groupIDs) == 0 { diff --git a/backend/internal/server/middleware/security_headers.go b/backend/internal/server/middleware/security_headers.go index 73210bfc..7021ab2e 100644 --- a/backend/internal/server/middleware/security_headers.go +++ b/backend/internal/server/middleware/security_headers.go @@ -18,6 +18,8 @@ const ( NonceTemplate = "__CSP_NONCE__" // CloudflareInsightsDomain is the domain for Cloudflare Web Analytics CloudflareInsightsDomain = "https://static.cloudflareinsights.com" + // StripeDomain is the domain for Stripe.js SDK + StripeDomain = "https://*.stripe.com" ) // GenerateNonce generates a cryptographically secure random nonce. @@ -97,8 +99,9 @@ func isAPIRoutePath(c *gin.Context) bool { strings.HasPrefix(path, "/responses") } -// enhanceCSPPolicy ensures the CSP policy includes nonce support and Cloudflare Insights domain. -// This allows the application to work correctly even if the config file has an older CSP policy. +// enhanceCSPPolicy ensures the CSP policy includes nonce support, Cloudflare Insights, +// and Stripe.js domains. This allows the application to work correctly even if the +// config file has an older CSP policy. func enhanceCSPPolicy(policy string) string { // Add nonce placeholder to script-src if not present if !strings.Contains(policy, NonceTemplate) && !strings.Contains(policy, "'nonce-") { @@ -110,6 +113,12 @@ func enhanceCSPPolicy(policy string) string { policy = addToDirective(policy, "script-src", CloudflareInsightsDomain) } + // Add Stripe.js domain to script-src and frame-src if not present + if !strings.Contains(policy, "stripe.com") { + policy = addToDirective(policy, "script-src", StripeDomain) + policy = addToDirective(policy, "frame-src", StripeDomain) + } + return policy }