fix: merge 5 PR-related improvements
- gateway_handler: pass ParsedRequest to RecordUsage + set in gin.Context - channel_handler: add FeaturesConfig to CRUD (WebSearch channel toggle) - channel_repo: features_config JSONB persistence (Create/Get/Update/List) - security_headers: add Stripe CSP domains (script-src + frame-src)
This commit is contained in:
@@ -41,10 +41,14 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = tx.QueryRowContext(ctx,
|
||||
`INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features, apply_pricing_to_account_stats) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
`INSERT INTO channels (name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, apply_pricing_to_account_stats) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
RETURNING id, created_at, updated_at`,
|
||||
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, channel.ApplyPricingToAccountStats,
|
||||
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON, channel.ApplyPricingToAccountStats,
|
||||
).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt)
|
||||
if err != nil {
|
||||
if isUniqueViolation(err) {
|
||||
@@ -80,11 +84,11 @@ func (r *channelRepository) Create(ctx context.Context, channel *service.Channel
|
||||
|
||||
func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) {
|
||||
ch := &service.Channel{}
|
||||
var modelMappingJSON []byte
|
||||
var modelMappingJSON, featuresConfigJSON []byte
|
||||
err := r.db.QueryRowContext(ctx,
|
||||
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, apply_pricing_to_account_stats, created_at, updated_at
|
||||
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, apply_pricing_to_account_stats, created_at, updated_at
|
||||
FROM channels WHERE id = $1`, id,
|
||||
).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt)
|
||||
).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, service.ErrChannelNotFound
|
||||
}
|
||||
@@ -92,6 +96,7 @@ func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Cha
|
||||
return nil, fmt.Errorf("get channel: %w", err)
|
||||
}
|
||||
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
||||
ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
|
||||
|
||||
groupIDs, err := r.GetGroupIDs(ctx, id)
|
||||
if err != nil {
|
||||
@@ -120,10 +125,14 @@ func (r *channelRepository) Update(ctx context.Context, channel *service.Channel
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
featuresConfigJSON, err := marshalFeaturesConfig(channel.FeaturesConfig)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result, err := tx.ExecContext(ctx,
|
||||
`UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, apply_pricing_to_account_stats = $8, updated_at = NOW()
|
||||
WHERE id = $9`,
|
||||
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, channel.ApplyPricingToAccountStats, channel.ID,
|
||||
`UPDATE channels SET name = $1, description = $2, status = $3, model_mapping = $4, billing_model_source = $5, restrict_models = $6, features = $7, features_config = $8, apply_pricing_to_account_stats = $9, updated_at = NOW()
|
||||
WHERE id = $10`,
|
||||
channel.Name, channel.Description, channel.Status, modelMappingJSON, channel.BillingModelSource, channel.RestrictModels, channel.Features, featuresConfigJSON, channel.ApplyPricingToAccountStats, channel.ID,
|
||||
)
|
||||
if err != nil {
|
||||
if isUniqueViolation(err) {
|
||||
@@ -207,7 +216,7 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
|
||||
|
||||
// 查询 channel 列表
|
||||
dataQuery := fmt.Sprintf(
|
||||
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.apply_pricing_to_account_stats, c.created_at, c.updated_at
|
||||
`SELECT c.id, c.name, c.description, c.status, c.model_mapping, c.billing_model_source, c.restrict_models, c.features, c.features_config, c.apply_pricing_to_account_stats, c.created_at, c.updated_at
|
||||
FROM channels c WHERE %s ORDER BY %s LIMIT $%d OFFSET $%d`,
|
||||
whereClause, channelListOrderBy(params), argIdx, argIdx+1,
|
||||
)
|
||||
@@ -223,11 +232,12 @@ func (r *channelRepository) List(ctx context.Context, params pagination.Paginati
|
||||
var channelIDs []int64
|
||||
for rows.Next() {
|
||||
var ch service.Channel
|
||||
var modelMappingJSON []byte
|
||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
||||
var modelMappingJSON, featuresConfigJSON []byte
|
||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
||||
return nil, nil, fmt.Errorf("scan channel: %w", err)
|
||||
}
|
||||
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
||||
ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
|
||||
channels = append(channels, ch)
|
||||
channelIDs = append(channelIDs, ch.ID)
|
||||
}
|
||||
@@ -298,7 +308,7 @@ func channelListOrderBy(params pagination.PaginationParams) string {
|
||||
|
||||
func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, apply_pricing_to_account_stats, created_at, updated_at FROM channels ORDER BY id`,
|
||||
`SELECT id, name, description, status, model_mapping, billing_model_source, restrict_models, features, features_config, apply_pricing_to_account_stats, created_at, updated_at FROM channels ORDER BY id`,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query all channels: %w", err)
|
||||
@@ -309,11 +319,12 @@ func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, err
|
||||
var channelIDs []int64
|
||||
for rows.Next() {
|
||||
var ch service.Channel
|
||||
var modelMappingJSON []byte
|
||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
||||
var modelMappingJSON, featuresConfigJSON []byte
|
||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &modelMappingJSON, &ch.BillingModelSource, &ch.RestrictModels, &ch.Features, &featuresConfigJSON, &ch.ApplyPricingToAccountStats, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan channel: %w", err)
|
||||
}
|
||||
ch.ModelMapping = unmarshalModelMapping(modelMappingJSON)
|
||||
ch.FeaturesConfig = unmarshalFeaturesConfig(featuresConfigJSON)
|
||||
channels = append(channels, ch)
|
||||
channelIDs = append(channelIDs, ch.ID)
|
||||
}
|
||||
@@ -488,6 +499,28 @@ func unmarshalModelMapping(data []byte) map[string]map[string]string {
|
||||
return m
|
||||
}
|
||||
|
||||
func marshalFeaturesConfig(m map[string]any) ([]byte, error) {
|
||||
if len(m) == 0 {
|
||||
return []byte("{}"), nil
|
||||
}
|
||||
data, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal features_config: %w", err)
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func unmarshalFeaturesConfig(data []byte) map[string]any {
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
var m map[string]any
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
return nil
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// GetGroupPlatforms 批量查询分组 ID 对应的平台
|
||||
func (r *channelRepository) GetGroupPlatforms(ctx context.Context, groupIDs []int64) (map[int64]string, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
|
||||
Reference in New Issue
Block a user