feat(channel): 渠道管理系统 — 多模式定价 + 统一计费解析
Cherry-picked from release/custom-0.1.106: a9117600
This commit is contained in:
392
backend/internal/repository/channel_repo.go
Normal file
392
backend/internal/repository/channel_repo.go
Normal file
@@ -0,0 +1,392 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
type channelRepository struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
// NewChannelRepository 创建渠道数据访问实例
|
||||
func NewChannelRepository(db *sql.DB) service.ChannelRepository {
|
||||
return &channelRepository{db: db}
|
||||
}
|
||||
|
||||
// runInTx 在事务中执行 fn,成功 commit,失败 rollback。
|
||||
func (r *channelRepository) runInTx(ctx context.Context, fn func(tx *sql.Tx) error) error {
|
||||
tx, err := r.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin tx: %w", err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
if err := fn(tx); err != nil {
|
||||
return err
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func (r *channelRepository) Create(ctx context.Context, channel *service.Channel) error {
|
||||
return r.runInTx(ctx, func(tx *sql.Tx) error {
|
||||
err := tx.QueryRowContext(ctx,
|
||||
`INSERT INTO channels (name, description, status) VALUES ($1, $2, $3)
|
||||
RETURNING id, created_at, updated_at`,
|
||||
channel.Name, channel.Description, channel.Status,
|
||||
).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt)
|
||||
if err != nil {
|
||||
if isUniqueViolation(err) {
|
||||
return service.ErrChannelExists
|
||||
}
|
||||
return fmt.Errorf("insert channel: %w", err)
|
||||
}
|
||||
|
||||
// 设置分组关联
|
||||
if len(channel.GroupIDs) > 0 {
|
||||
if err := setGroupIDsTx(ctx, tx, channel.ID, channel.GroupIDs); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 设置模型定价
|
||||
if len(channel.ModelPricing) > 0 {
|
||||
if err := replaceModelPricingTx(ctx, tx, channel.ID, channel.ModelPricing); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) {
|
||||
ch := &service.Channel{}
|
||||
err := r.db.QueryRowContext(ctx,
|
||||
`SELECT id, name, description, status, created_at, updated_at
|
||||
FROM channels WHERE id = $1`, id,
|
||||
).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &ch.CreatedAt, &ch.UpdatedAt)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, service.ErrChannelNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get channel: %w", err)
|
||||
}
|
||||
|
||||
groupIDs, err := r.GetGroupIDs(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ch.GroupIDs = groupIDs
|
||||
|
||||
pricing, err := r.ListModelPricing(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ch.ModelPricing = pricing
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (r *channelRepository) Update(ctx context.Context, channel *service.Channel) error {
|
||||
return r.runInTx(ctx, func(tx *sql.Tx) error {
|
||||
result, err := tx.ExecContext(ctx,
|
||||
`UPDATE channels SET name = $1, description = $2, status = $3, updated_at = NOW()
|
||||
WHERE id = $4`,
|
||||
channel.Name, channel.Description, channel.Status, channel.ID,
|
||||
)
|
||||
if err != nil {
|
||||
if isUniqueViolation(err) {
|
||||
return service.ErrChannelExists
|
||||
}
|
||||
return fmt.Errorf("update channel: %w", err)
|
||||
}
|
||||
rows, _ := result.RowsAffected()
|
||||
if rows == 0 {
|
||||
return service.ErrChannelNotFound
|
||||
}
|
||||
|
||||
// 更新分组关联
|
||||
if channel.GroupIDs != nil {
|
||||
if err := setGroupIDsTx(ctx, tx, channel.ID, channel.GroupIDs); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// 更新模型定价
|
||||
if channel.ModelPricing != nil {
|
||||
if err := replaceModelPricingTx(ctx, tx, channel.ID, channel.ModelPricing); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (r *channelRepository) Delete(ctx context.Context, id int64) error {
|
||||
result, err := r.db.ExecContext(ctx, `DELETE FROM channels WHERE id = $1`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete channel: %w", err)
|
||||
}
|
||||
rows, _ := result.RowsAffected()
|
||||
if rows == 0 {
|
||||
return service.ErrChannelNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *channelRepository) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]service.Channel, *pagination.PaginationResult, error) {
|
||||
where := []string{"1=1"}
|
||||
args := []any{}
|
||||
argIdx := 1
|
||||
|
||||
if status != "" {
|
||||
where = append(where, fmt.Sprintf("c.status = $%d", argIdx))
|
||||
args = append(args, status)
|
||||
argIdx++
|
||||
}
|
||||
if search != "" {
|
||||
where = append(where, fmt.Sprintf("(c.name ILIKE $%d OR c.description ILIKE $%d)", argIdx, argIdx))
|
||||
args = append(args, "%"+escapeLike(search)+"%")
|
||||
argIdx++
|
||||
}
|
||||
|
||||
whereClause := strings.Join(where, " AND ")
|
||||
|
||||
// 计数
|
||||
var total int64
|
||||
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM channels c WHERE %s", whereClause)
|
||||
if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
|
||||
return nil, nil, fmt.Errorf("count channels: %w", err)
|
||||
}
|
||||
|
||||
pageSize := params.Limit() // 约束在 [1, 100]
|
||||
page := params.Page
|
||||
if page < 1 {
|
||||
page = 1
|
||||
}
|
||||
offset := (page - 1) * pageSize
|
||||
|
||||
// 查询 channel 列表
|
||||
dataQuery := fmt.Sprintf(
|
||||
`SELECT c.id, c.name, c.description, c.status, c.created_at, c.updated_at
|
||||
FROM channels c WHERE %s ORDER BY c.id DESC LIMIT $%d OFFSET $%d`,
|
||||
whereClause, argIdx, argIdx+1,
|
||||
)
|
||||
args = append(args, pageSize, offset)
|
||||
|
||||
rows, err := r.db.QueryContext(ctx, dataQuery, args...)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("query channels: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var channels []service.Channel
|
||||
var channelIDs []int64
|
||||
for rows.Next() {
|
||||
var ch service.Channel
|
||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
||||
return nil, nil, fmt.Errorf("scan channel: %w", err)
|
||||
}
|
||||
channels = append(channels, ch)
|
||||
channelIDs = append(channelIDs, ch.ID)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, nil, fmt.Errorf("iterate channels: %w", err)
|
||||
}
|
||||
|
||||
// 批量加载分组 ID 和模型定价(避免 N+1)
|
||||
if len(channelIDs) > 0 {
|
||||
groupMap, err := r.batchLoadGroupIDs(ctx, channelIDs)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
pricingMap, err := r.batchLoadModelPricing(ctx, channelIDs)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
for i := range channels {
|
||||
channels[i].GroupIDs = groupMap[channels[i].ID]
|
||||
channels[i].ModelPricing = pricingMap[channels[i].ID]
|
||||
}
|
||||
}
|
||||
|
||||
pages := 0
|
||||
if total > 0 {
|
||||
pages = int((total + int64(pageSize) - 1) / int64(pageSize))
|
||||
}
|
||||
|
||||
paginationResult := &pagination.PaginationResult{
|
||||
Total: total,
|
||||
Page: page,
|
||||
PageSize: pageSize,
|
||||
Pages: pages,
|
||||
}
|
||||
|
||||
return channels, paginationResult, nil
|
||||
}
|
||||
|
||||
func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT id, name, description, status, created_at, updated_at FROM channels ORDER BY id`,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("query all channels: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var channels []service.Channel
|
||||
var channelIDs []int64
|
||||
for rows.Next() {
|
||||
var ch service.Channel
|
||||
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan channel: %w", err)
|
||||
}
|
||||
channels = append(channels, ch)
|
||||
channelIDs = append(channelIDs, ch.ID)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate channels: %w", err)
|
||||
}
|
||||
|
||||
if len(channelIDs) == 0 {
|
||||
return channels, nil
|
||||
}
|
||||
|
||||
// 批量加载分组 ID
|
||||
groupMap, err := r.batchLoadGroupIDs(ctx, channelIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 批量加载模型定价
|
||||
pricingMap, err := r.batchLoadModelPricing(ctx, channelIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i := range channels {
|
||||
channels[i].GroupIDs = groupMap[channels[i].ID]
|
||||
channels[i].ModelPricing = pricingMap[channels[i].ID]
|
||||
}
|
||||
|
||||
return channels, nil
|
||||
}
|
||||
|
||||
// --- 批量加载辅助方法 ---
|
||||
|
||||
// batchLoadGroupIDs 批量加载多个渠道的分组 ID
|
||||
func (r *channelRepository) batchLoadGroupIDs(ctx context.Context, channelIDs []int64) (map[int64][]int64, error) {
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT channel_id, group_id FROM channel_groups
|
||||
WHERE channel_id = ANY($1) ORDER BY channel_id, group_id`,
|
||||
pq.Array(channelIDs),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("batch load group ids: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
groupMap := make(map[int64][]int64, len(channelIDs))
|
||||
for rows.Next() {
|
||||
var channelID, groupID int64
|
||||
if err := rows.Scan(&channelID, &groupID); err != nil {
|
||||
return nil, fmt.Errorf("scan group id: %w", err)
|
||||
}
|
||||
groupMap[channelID] = append(groupMap[channelID], groupID)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate group ids: %w", err)
|
||||
}
|
||||
return groupMap, nil
|
||||
}
|
||||
|
||||
func (r *channelRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||
var exists bool
|
||||
err := r.db.QueryRowContext(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM channels WHERE name = $1)`, name,
|
||||
).Scan(&exists)
|
||||
return exists, err
|
||||
}
|
||||
|
||||
func (r *channelRepository) ExistsByNameExcluding(ctx context.Context, name string, excludeID int64) (bool, error) {
|
||||
var exists bool
|
||||
err := r.db.QueryRowContext(ctx,
|
||||
`SELECT EXISTS(SELECT 1 FROM channels WHERE name = $1 AND id != $2)`, name, excludeID,
|
||||
).Scan(&exists)
|
||||
return exists, err
|
||||
}
|
||||
|
||||
// --- 分组关联 ---
|
||||
|
||||
func (r *channelRepository) GetGroupIDs(ctx context.Context, channelID int64) ([]int64, error) {
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT group_id FROM channel_groups WHERE channel_id = $1 ORDER BY group_id`, channelID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group ids: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var ids []int64
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, fmt.Errorf("scan group id: %w", err)
|
||||
}
|
||||
ids = append(ids, id)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate group ids: %w", err)
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (r *channelRepository) SetGroupIDs(ctx context.Context, channelID int64, groupIDs []int64) error {
|
||||
return setGroupIDsTx(ctx, r.db, channelID, groupIDs)
|
||||
}
|
||||
|
||||
func (r *channelRepository) GetChannelIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
var channelID int64
|
||||
err := r.db.QueryRowContext(ctx,
|
||||
`SELECT channel_id FROM channel_groups WHERE group_id = $1`, groupID,
|
||||
).Scan(&channelID)
|
||||
if err == sql.ErrNoRows {
|
||||
return 0, nil
|
||||
}
|
||||
return channelID, err
|
||||
}
|
||||
|
||||
func (r *channelRepository) GetGroupsInOtherChannels(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error) {
|
||||
if len(groupIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT group_id FROM channel_groups WHERE group_id = ANY($1) AND channel_id != $2`,
|
||||
pq.Array(groupIDs), channelID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get groups in other channels: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var conflicting []int64
|
||||
for rows.Next() {
|
||||
var id int64
|
||||
if err := rows.Scan(&id); err != nil {
|
||||
return nil, fmt.Errorf("scan conflicting group id: %w", err)
|
||||
}
|
||||
conflicting = append(conflicting, id)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate conflicting group ids: %w", err)
|
||||
}
|
||||
return conflicting, nil
|
||||
}
|
||||
285
backend/internal/repository/channel_repo_pricing.go
Normal file
285
backend/internal/repository/channel_repo_pricing.go
Normal file
@@ -0,0 +1,285 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
// --- 模型定价 ---
|
||||
|
||||
func (r *channelRepository) ListModelPricing(ctx context.Context, channelID int64) ([]service.ChannelModelPricing, error) {
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT id, channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, created_at, updated_at
|
||||
FROM channel_model_pricing WHERE channel_id = $1 ORDER BY id`, channelID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list model pricing: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
result, pricingIDs, err := scanModelPricingRows(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(pricingIDs) > 0 {
|
||||
intervalMap, err := r.batchLoadIntervals(ctx, pricingIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range result {
|
||||
result[i].Intervals = intervalMap[result[i].ID]
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *channelRepository) CreateModelPricing(ctx context.Context, pricing *service.ChannelModelPricing) error {
|
||||
return createModelPricingExec(ctx, r.db, pricing)
|
||||
}
|
||||
|
||||
func (r *channelRepository) UpdateModelPricing(ctx context.Context, pricing *service.ChannelModelPricing) error {
|
||||
modelsJSON, err := json.Marshal(pricing.Models)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal models: %w", err)
|
||||
}
|
||||
billingMode := pricing.BillingMode
|
||||
if billingMode == "" {
|
||||
billingMode = service.BillingModeToken
|
||||
}
|
||||
result, err := r.db.ExecContext(ctx,
|
||||
`UPDATE channel_model_pricing
|
||||
SET models = $1, billing_mode = $2, input_price = $3, output_price = $4, cache_write_price = $5, cache_read_price = $6, image_output_price = $7, updated_at = NOW()
|
||||
WHERE id = $8`,
|
||||
modelsJSON, billingMode, pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
|
||||
pricing.ImageOutputPrice, pricing.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update model pricing: %w", err)
|
||||
}
|
||||
rows, _ := result.RowsAffected()
|
||||
if rows == 0 {
|
||||
return fmt.Errorf("pricing entry not found: %d", pricing.ID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *channelRepository) DeleteModelPricing(ctx context.Context, id int64) error {
|
||||
_, err := r.db.ExecContext(ctx, `DELETE FROM channel_model_pricing WHERE id = $1`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete model pricing: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *channelRepository) ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []service.ChannelModelPricing) error {
|
||||
return r.runInTx(ctx, func(tx *sql.Tx) error {
|
||||
return replaceModelPricingTx(ctx, tx, channelID, pricingList)
|
||||
})
|
||||
}
|
||||
|
||||
// --- 批量加载辅助方法 ---
|
||||
|
||||
// batchLoadModelPricing 批量加载多个渠道的模型定价(含区间)
|
||||
func (r *channelRepository) batchLoadModelPricing(ctx context.Context, channelIDs []int64) (map[int64][]service.ChannelModelPricing, error) {
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT id, channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, created_at, updated_at
|
||||
FROM channel_model_pricing WHERE channel_id = ANY($1) ORDER BY channel_id, id`,
|
||||
pq.Array(channelIDs),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("batch load model pricing: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
allPricing, allPricingIDs, err := scanModelPricingRows(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 按 channelID 分组
|
||||
pricingMap := make(map[int64][]service.ChannelModelPricing, len(channelIDs))
|
||||
for _, p := range allPricing {
|
||||
pricingMap[p.ChannelID] = append(pricingMap[p.ChannelID], p)
|
||||
}
|
||||
|
||||
// 批量加载所有区间
|
||||
if len(allPricingIDs) > 0 {
|
||||
intervalMap, err := r.batchLoadIntervals(ctx, allPricingIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for chID := range pricingMap {
|
||||
for i := range pricingMap[chID] {
|
||||
pricingMap[chID][i].Intervals = intervalMap[pricingMap[chID][i].ID]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return pricingMap, nil
|
||||
}
|
||||
|
||||
// batchLoadIntervals 批量加载多个定价条目的区间
|
||||
func (r *channelRepository) batchLoadIntervals(ctx context.Context, pricingIDs []int64) (map[int64][]service.PricingInterval, error) {
|
||||
rows, err := r.db.QueryContext(ctx,
|
||||
`SELECT id, pricing_id, min_tokens, max_tokens, tier_label,
|
||||
input_price, output_price, cache_write_price, cache_read_price,
|
||||
per_request_price, sort_order, created_at, updated_at
|
||||
FROM channel_pricing_intervals
|
||||
WHERE pricing_id = ANY($1) ORDER BY pricing_id, sort_order, id`,
|
||||
pq.Array(pricingIDs),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("batch load intervals: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
intervalMap := make(map[int64][]service.PricingInterval, len(pricingIDs))
|
||||
for rows.Next() {
|
||||
var iv service.PricingInterval
|
||||
if err := rows.Scan(
|
||||
&iv.ID, &iv.PricingID, &iv.MinTokens, &iv.MaxTokens, &iv.TierLabel,
|
||||
&iv.InputPrice, &iv.OutputPrice, &iv.CacheWritePrice, &iv.CacheReadPrice,
|
||||
&iv.PerRequestPrice, &iv.SortOrder, &iv.CreatedAt, &iv.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, fmt.Errorf("scan interval: %w", err)
|
||||
}
|
||||
intervalMap[iv.PricingID] = append(intervalMap[iv.PricingID], iv)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, fmt.Errorf("iterate intervals: %w", err)
|
||||
}
|
||||
return intervalMap, nil
|
||||
}
|
||||
|
||||
// --- 共享 scan 辅助 ---
|
||||
|
||||
// scanModelPricingRows 扫描 model pricing 行,返回结果列表和 ID 列表
|
||||
func scanModelPricingRows(rows *sql.Rows) ([]service.ChannelModelPricing, []int64, error) {
|
||||
var result []service.ChannelModelPricing
|
||||
var pricingIDs []int64
|
||||
for rows.Next() {
|
||||
var p service.ChannelModelPricing
|
||||
var modelsJSON []byte
|
||||
if err := rows.Scan(
|
||||
&p.ID, &p.ChannelID, &modelsJSON, &p.BillingMode,
|
||||
&p.InputPrice, &p.OutputPrice, &p.CacheWritePrice, &p.CacheReadPrice,
|
||||
&p.ImageOutputPrice, &p.CreatedAt, &p.UpdatedAt,
|
||||
); err != nil {
|
||||
return nil, nil, fmt.Errorf("scan model pricing: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(modelsJSON, &p.Models); err != nil {
|
||||
p.Models = []string{}
|
||||
}
|
||||
pricingIDs = append(pricingIDs, p.ID)
|
||||
result = append(result, p)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, nil, fmt.Errorf("iterate model pricing: %w", err)
|
||||
}
|
||||
return result, pricingIDs, nil
|
||||
}
|
||||
|
||||
// --- 事务内辅助方法 ---
|
||||
|
||||
// dbExec 是 *sql.DB 和 *sql.Tx 共享的最小 SQL 执行接口
|
||||
type dbExec interface {
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
|
||||
}
|
||||
|
||||
func setGroupIDsTx(ctx context.Context, exec dbExec, channelID int64, groupIDs []int64) error {
|
||||
if _, err := exec.ExecContext(ctx, `DELETE FROM channel_groups WHERE channel_id = $1`, channelID); err != nil {
|
||||
return fmt.Errorf("delete old group associations: %w", err)
|
||||
}
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
_, err := exec.ExecContext(ctx,
|
||||
`INSERT INTO channel_groups (channel_id, group_id)
|
||||
SELECT $1, unnest($2::bigint[])`,
|
||||
channelID, pq.Array(groupIDs),
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert group associations: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func createModelPricingExec(ctx context.Context, exec dbExec, pricing *service.ChannelModelPricing) error {
|
||||
modelsJSON, err := json.Marshal(pricing.Models)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal models: %w", err)
|
||||
}
|
||||
billingMode := pricing.BillingMode
|
||||
if billingMode == "" {
|
||||
billingMode = service.BillingModeToken
|
||||
}
|
||||
err = exec.QueryRowContext(ctx,
|
||||
`INSERT INTO channel_model_pricing (channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id, created_at, updated_at`,
|
||||
pricing.ChannelID, modelsJSON, billingMode,
|
||||
pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
|
||||
pricing.ImageOutputPrice,
|
||||
).Scan(&pricing.ID, &pricing.CreatedAt, &pricing.UpdatedAt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("insert model pricing: %w", err)
|
||||
}
|
||||
|
||||
for i := range pricing.Intervals {
|
||||
pricing.Intervals[i].PricingID = pricing.ID
|
||||
if err := createIntervalExec(ctx, exec, &pricing.Intervals[i]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func createIntervalExec(ctx context.Context, exec dbExec, iv *service.PricingInterval) error {
|
||||
return exec.QueryRowContext(ctx,
|
||||
`INSERT INTO channel_pricing_intervals
|
||||
(pricing_id, min_tokens, max_tokens, tier_label, input_price, output_price, cache_write_price, cache_read_price, per_request_price, sort_order)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
|
||||
iv.PricingID, iv.MinTokens, iv.MaxTokens, iv.TierLabel,
|
||||
iv.InputPrice, iv.OutputPrice, iv.CacheWritePrice, iv.CacheReadPrice,
|
||||
iv.PerRequestPrice, iv.SortOrder,
|
||||
).Scan(&iv.ID, &iv.CreatedAt, &iv.UpdatedAt)
|
||||
}
|
||||
|
||||
func replaceModelPricingTx(ctx context.Context, exec dbExec, channelID int64, pricingList []service.ChannelModelPricing) error {
|
||||
if _, err := exec.ExecContext(ctx, `DELETE FROM channel_model_pricing WHERE channel_id = $1`, channelID); err != nil {
|
||||
return fmt.Errorf("delete old model pricing: %w", err)
|
||||
}
|
||||
for i := range pricingList {
|
||||
pricingList[i].ChannelID = channelID
|
||||
if err := createModelPricingExec(ctx, exec, &pricingList[i]); err != nil {
|
||||
return fmt.Errorf("insert model pricing: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// isUniqueViolation 检查 pq 唯一约束违反错误
|
||||
func isUniqueViolation(err error) bool {
|
||||
if pqErr, ok := err.(*pq.Error); ok {
|
||||
return pqErr.Code == "23505"
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// escapeLike 转义 LIKE/ILIKE 模式中的特殊字符
|
||||
func escapeLike(s string) string {
|
||||
s = strings.ReplaceAll(s, `\`, `\\`)
|
||||
s = strings.ReplaceAll(s, `%`, `\%`)
|
||||
s = strings.ReplaceAll(s, `_`, `\_`)
|
||||
return s
|
||||
}
|
||||
@@ -74,6 +74,7 @@ var ProviderSet = wire.NewSet(
|
||||
NewUserGroupRateRepository,
|
||||
NewErrorPassthroughRepository,
|
||||
NewTLSFingerprintProfileRepository,
|
||||
NewChannelRepository,
|
||||
|
||||
// Cache implementations
|
||||
NewGatewayCache,
|
||||
|
||||
Reference in New Issue
Block a user