feat(channel): 渠道管理系统 — 多模式定价 + 统一计费解析

Cherry-picked from release/custom-0.1.106: a9117600
This commit is contained in:
erio
2026-04-04 11:00:55 +08:00
parent b384570de3
commit 91c9b8d062
27 changed files with 3682 additions and 8 deletions

View File

@@ -0,0 +1,392 @@
package repository
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
type channelRepository struct {
db *sql.DB
}
// NewChannelRepository 创建渠道数据访问实例
func NewChannelRepository(db *sql.DB) service.ChannelRepository {
return &channelRepository{db: db}
}
// runInTx 在事务中执行 fn成功 commit失败 rollback。
func (r *channelRepository) runInTx(ctx context.Context, fn func(tx *sql.Tx) error) error {
tx, err := r.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("begin tx: %w", err)
}
defer func() { _ = tx.Rollback() }()
if err := fn(tx); err != nil {
return err
}
return tx.Commit()
}
func (r *channelRepository) Create(ctx context.Context, channel *service.Channel) error {
return r.runInTx(ctx, func(tx *sql.Tx) error {
err := tx.QueryRowContext(ctx,
`INSERT INTO channels (name, description, status) VALUES ($1, $2, $3)
RETURNING id, created_at, updated_at`,
channel.Name, channel.Description, channel.Status,
).Scan(&channel.ID, &channel.CreatedAt, &channel.UpdatedAt)
if err != nil {
if isUniqueViolation(err) {
return service.ErrChannelExists
}
return fmt.Errorf("insert channel: %w", err)
}
// 设置分组关联
if len(channel.GroupIDs) > 0 {
if err := setGroupIDsTx(ctx, tx, channel.ID, channel.GroupIDs); err != nil {
return err
}
}
// 设置模型定价
if len(channel.ModelPricing) > 0 {
if err := replaceModelPricingTx(ctx, tx, channel.ID, channel.ModelPricing); err != nil {
return err
}
}
return nil
})
}
func (r *channelRepository) GetByID(ctx context.Context, id int64) (*service.Channel, error) {
ch := &service.Channel{}
err := r.db.QueryRowContext(ctx,
`SELECT id, name, description, status, created_at, updated_at
FROM channels WHERE id = $1`, id,
).Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &ch.CreatedAt, &ch.UpdatedAt)
if err == sql.ErrNoRows {
return nil, service.ErrChannelNotFound
}
if err != nil {
return nil, fmt.Errorf("get channel: %w", err)
}
groupIDs, err := r.GetGroupIDs(ctx, id)
if err != nil {
return nil, err
}
ch.GroupIDs = groupIDs
pricing, err := r.ListModelPricing(ctx, id)
if err != nil {
return nil, err
}
ch.ModelPricing = pricing
return ch, nil
}
func (r *channelRepository) Update(ctx context.Context, channel *service.Channel) error {
return r.runInTx(ctx, func(tx *sql.Tx) error {
result, err := tx.ExecContext(ctx,
`UPDATE channels SET name = $1, description = $2, status = $3, updated_at = NOW()
WHERE id = $4`,
channel.Name, channel.Description, channel.Status, channel.ID,
)
if err != nil {
if isUniqueViolation(err) {
return service.ErrChannelExists
}
return fmt.Errorf("update channel: %w", err)
}
rows, _ := result.RowsAffected()
if rows == 0 {
return service.ErrChannelNotFound
}
// 更新分组关联
if channel.GroupIDs != nil {
if err := setGroupIDsTx(ctx, tx, channel.ID, channel.GroupIDs); err != nil {
return err
}
}
// 更新模型定价
if channel.ModelPricing != nil {
if err := replaceModelPricingTx(ctx, tx, channel.ID, channel.ModelPricing); err != nil {
return err
}
}
return nil
})
}
func (r *channelRepository) Delete(ctx context.Context, id int64) error {
result, err := r.db.ExecContext(ctx, `DELETE FROM channels WHERE id = $1`, id)
if err != nil {
return fmt.Errorf("delete channel: %w", err)
}
rows, _ := result.RowsAffected()
if rows == 0 {
return service.ErrChannelNotFound
}
return nil
}
func (r *channelRepository) List(ctx context.Context, params pagination.PaginationParams, status, search string) ([]service.Channel, *pagination.PaginationResult, error) {
where := []string{"1=1"}
args := []any{}
argIdx := 1
if status != "" {
where = append(where, fmt.Sprintf("c.status = $%d", argIdx))
args = append(args, status)
argIdx++
}
if search != "" {
where = append(where, fmt.Sprintf("(c.name ILIKE $%d OR c.description ILIKE $%d)", argIdx, argIdx))
args = append(args, "%"+escapeLike(search)+"%")
argIdx++
}
whereClause := strings.Join(where, " AND ")
// 计数
var total int64
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM channels c WHERE %s", whereClause)
if err := r.db.QueryRowContext(ctx, countQuery, args...).Scan(&total); err != nil {
return nil, nil, fmt.Errorf("count channels: %w", err)
}
pageSize := params.Limit() // 约束在 [1, 100]
page := params.Page
if page < 1 {
page = 1
}
offset := (page - 1) * pageSize
// 查询 channel 列表
dataQuery := fmt.Sprintf(
`SELECT c.id, c.name, c.description, c.status, c.created_at, c.updated_at
FROM channels c WHERE %s ORDER BY c.id DESC LIMIT $%d OFFSET $%d`,
whereClause, argIdx, argIdx+1,
)
args = append(args, pageSize, offset)
rows, err := r.db.QueryContext(ctx, dataQuery, args...)
if err != nil {
return nil, nil, fmt.Errorf("query channels: %w", err)
}
defer rows.Close()
var channels []service.Channel
var channelIDs []int64
for rows.Next() {
var ch service.Channel
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
return nil, nil, fmt.Errorf("scan channel: %w", err)
}
channels = append(channels, ch)
channelIDs = append(channelIDs, ch.ID)
}
if err := rows.Err(); err != nil {
return nil, nil, fmt.Errorf("iterate channels: %w", err)
}
// 批量加载分组 ID 和模型定价(避免 N+1
if len(channelIDs) > 0 {
groupMap, err := r.batchLoadGroupIDs(ctx, channelIDs)
if err != nil {
return nil, nil, err
}
pricingMap, err := r.batchLoadModelPricing(ctx, channelIDs)
if err != nil {
return nil, nil, err
}
for i := range channels {
channels[i].GroupIDs = groupMap[channels[i].ID]
channels[i].ModelPricing = pricingMap[channels[i].ID]
}
}
pages := 0
if total > 0 {
pages = int((total + int64(pageSize) - 1) / int64(pageSize))
}
paginationResult := &pagination.PaginationResult{
Total: total,
Page: page,
PageSize: pageSize,
Pages: pages,
}
return channels, paginationResult, nil
}
func (r *channelRepository) ListAll(ctx context.Context) ([]service.Channel, error) {
rows, err := r.db.QueryContext(ctx,
`SELECT id, name, description, status, created_at, updated_at FROM channels ORDER BY id`,
)
if err != nil {
return nil, fmt.Errorf("query all channels: %w", err)
}
defer rows.Close()
var channels []service.Channel
var channelIDs []int64
for rows.Next() {
var ch service.Channel
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Description, &ch.Status, &ch.CreatedAt, &ch.UpdatedAt); err != nil {
return nil, fmt.Errorf("scan channel: %w", err)
}
channels = append(channels, ch)
channelIDs = append(channelIDs, ch.ID)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate channels: %w", err)
}
if len(channelIDs) == 0 {
return channels, nil
}
// 批量加载分组 ID
groupMap, err := r.batchLoadGroupIDs(ctx, channelIDs)
if err != nil {
return nil, err
}
// 批量加载模型定价
pricingMap, err := r.batchLoadModelPricing(ctx, channelIDs)
if err != nil {
return nil, err
}
for i := range channels {
channels[i].GroupIDs = groupMap[channels[i].ID]
channels[i].ModelPricing = pricingMap[channels[i].ID]
}
return channels, nil
}
// --- 批量加载辅助方法 ---
// batchLoadGroupIDs 批量加载多个渠道的分组 ID
func (r *channelRepository) batchLoadGroupIDs(ctx context.Context, channelIDs []int64) (map[int64][]int64, error) {
rows, err := r.db.QueryContext(ctx,
`SELECT channel_id, group_id FROM channel_groups
WHERE channel_id = ANY($1) ORDER BY channel_id, group_id`,
pq.Array(channelIDs),
)
if err != nil {
return nil, fmt.Errorf("batch load group ids: %w", err)
}
defer rows.Close()
groupMap := make(map[int64][]int64, len(channelIDs))
for rows.Next() {
var channelID, groupID int64
if err := rows.Scan(&channelID, &groupID); err != nil {
return nil, fmt.Errorf("scan group id: %w", err)
}
groupMap[channelID] = append(groupMap[channelID], groupID)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate group ids: %w", err)
}
return groupMap, nil
}
func (r *channelRepository) ExistsByName(ctx context.Context, name string) (bool, error) {
var exists bool
err := r.db.QueryRowContext(ctx,
`SELECT EXISTS(SELECT 1 FROM channels WHERE name = $1)`, name,
).Scan(&exists)
return exists, err
}
func (r *channelRepository) ExistsByNameExcluding(ctx context.Context, name string, excludeID int64) (bool, error) {
var exists bool
err := r.db.QueryRowContext(ctx,
`SELECT EXISTS(SELECT 1 FROM channels WHERE name = $1 AND id != $2)`, name, excludeID,
).Scan(&exists)
return exists, err
}
// --- 分组关联 ---
func (r *channelRepository) GetGroupIDs(ctx context.Context, channelID int64) ([]int64, error) {
rows, err := r.db.QueryContext(ctx,
`SELECT group_id FROM channel_groups WHERE channel_id = $1 ORDER BY group_id`, channelID,
)
if err != nil {
return nil, fmt.Errorf("get group ids: %w", err)
}
defer rows.Close()
var ids []int64
for rows.Next() {
var id int64
if err := rows.Scan(&id); err != nil {
return nil, fmt.Errorf("scan group id: %w", err)
}
ids = append(ids, id)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate group ids: %w", err)
}
return ids, nil
}
func (r *channelRepository) SetGroupIDs(ctx context.Context, channelID int64, groupIDs []int64) error {
return setGroupIDsTx(ctx, r.db, channelID, groupIDs)
}
func (r *channelRepository) GetChannelIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
var channelID int64
err := r.db.QueryRowContext(ctx,
`SELECT channel_id FROM channel_groups WHERE group_id = $1`, groupID,
).Scan(&channelID)
if err == sql.ErrNoRows {
return 0, nil
}
return channelID, err
}
func (r *channelRepository) GetGroupsInOtherChannels(ctx context.Context, channelID int64, groupIDs []int64) ([]int64, error) {
if len(groupIDs) == 0 {
return nil, nil
}
rows, err := r.db.QueryContext(ctx,
`SELECT group_id FROM channel_groups WHERE group_id = ANY($1) AND channel_id != $2`,
pq.Array(groupIDs), channelID,
)
if err != nil {
return nil, fmt.Errorf("get groups in other channels: %w", err)
}
defer rows.Close()
var conflicting []int64
for rows.Next() {
var id int64
if err := rows.Scan(&id); err != nil {
return nil, fmt.Errorf("scan conflicting group id: %w", err)
}
conflicting = append(conflicting, id)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate conflicting group ids: %w", err)
}
return conflicting, nil
}

View File

@@ -0,0 +1,285 @@
package repository
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/lib/pq"
)
// --- 模型定价 ---
func (r *channelRepository) ListModelPricing(ctx context.Context, channelID int64) ([]service.ChannelModelPricing, error) {
rows, err := r.db.QueryContext(ctx,
`SELECT id, channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, created_at, updated_at
FROM channel_model_pricing WHERE channel_id = $1 ORDER BY id`, channelID,
)
if err != nil {
return nil, fmt.Errorf("list model pricing: %w", err)
}
defer rows.Close()
result, pricingIDs, err := scanModelPricingRows(rows)
if err != nil {
return nil, err
}
if len(pricingIDs) > 0 {
intervalMap, err := r.batchLoadIntervals(ctx, pricingIDs)
if err != nil {
return nil, err
}
for i := range result {
result[i].Intervals = intervalMap[result[i].ID]
}
}
return result, nil
}
func (r *channelRepository) CreateModelPricing(ctx context.Context, pricing *service.ChannelModelPricing) error {
return createModelPricingExec(ctx, r.db, pricing)
}
func (r *channelRepository) UpdateModelPricing(ctx context.Context, pricing *service.ChannelModelPricing) error {
modelsJSON, err := json.Marshal(pricing.Models)
if err != nil {
return fmt.Errorf("marshal models: %w", err)
}
billingMode := pricing.BillingMode
if billingMode == "" {
billingMode = service.BillingModeToken
}
result, err := r.db.ExecContext(ctx,
`UPDATE channel_model_pricing
SET models = $1, billing_mode = $2, input_price = $3, output_price = $4, cache_write_price = $5, cache_read_price = $6, image_output_price = $7, updated_at = NOW()
WHERE id = $8`,
modelsJSON, billingMode, pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
pricing.ImageOutputPrice, pricing.ID,
)
if err != nil {
return fmt.Errorf("update model pricing: %w", err)
}
rows, _ := result.RowsAffected()
if rows == 0 {
return fmt.Errorf("pricing entry not found: %d", pricing.ID)
}
return nil
}
func (r *channelRepository) DeleteModelPricing(ctx context.Context, id int64) error {
_, err := r.db.ExecContext(ctx, `DELETE FROM channel_model_pricing WHERE id = $1`, id)
if err != nil {
return fmt.Errorf("delete model pricing: %w", err)
}
return nil
}
func (r *channelRepository) ReplaceModelPricing(ctx context.Context, channelID int64, pricingList []service.ChannelModelPricing) error {
return r.runInTx(ctx, func(tx *sql.Tx) error {
return replaceModelPricingTx(ctx, tx, channelID, pricingList)
})
}
// --- 批量加载辅助方法 ---
// batchLoadModelPricing 批量加载多个渠道的模型定价(含区间)
func (r *channelRepository) batchLoadModelPricing(ctx context.Context, channelIDs []int64) (map[int64][]service.ChannelModelPricing, error) {
rows, err := r.db.QueryContext(ctx,
`SELECT id, channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price, created_at, updated_at
FROM channel_model_pricing WHERE channel_id = ANY($1) ORDER BY channel_id, id`,
pq.Array(channelIDs),
)
if err != nil {
return nil, fmt.Errorf("batch load model pricing: %w", err)
}
defer rows.Close()
allPricing, allPricingIDs, err := scanModelPricingRows(rows)
if err != nil {
return nil, err
}
// 按 channelID 分组
pricingMap := make(map[int64][]service.ChannelModelPricing, len(channelIDs))
for _, p := range allPricing {
pricingMap[p.ChannelID] = append(pricingMap[p.ChannelID], p)
}
// 批量加载所有区间
if len(allPricingIDs) > 0 {
intervalMap, err := r.batchLoadIntervals(ctx, allPricingIDs)
if err != nil {
return nil, err
}
for chID := range pricingMap {
for i := range pricingMap[chID] {
pricingMap[chID][i].Intervals = intervalMap[pricingMap[chID][i].ID]
}
}
}
return pricingMap, nil
}
// batchLoadIntervals 批量加载多个定价条目的区间
func (r *channelRepository) batchLoadIntervals(ctx context.Context, pricingIDs []int64) (map[int64][]service.PricingInterval, error) {
rows, err := r.db.QueryContext(ctx,
`SELECT id, pricing_id, min_tokens, max_tokens, tier_label,
input_price, output_price, cache_write_price, cache_read_price,
per_request_price, sort_order, created_at, updated_at
FROM channel_pricing_intervals
WHERE pricing_id = ANY($1) ORDER BY pricing_id, sort_order, id`,
pq.Array(pricingIDs),
)
if err != nil {
return nil, fmt.Errorf("batch load intervals: %w", err)
}
defer rows.Close()
intervalMap := make(map[int64][]service.PricingInterval, len(pricingIDs))
for rows.Next() {
var iv service.PricingInterval
if err := rows.Scan(
&iv.ID, &iv.PricingID, &iv.MinTokens, &iv.MaxTokens, &iv.TierLabel,
&iv.InputPrice, &iv.OutputPrice, &iv.CacheWritePrice, &iv.CacheReadPrice,
&iv.PerRequestPrice, &iv.SortOrder, &iv.CreatedAt, &iv.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan interval: %w", err)
}
intervalMap[iv.PricingID] = append(intervalMap[iv.PricingID], iv)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate intervals: %w", err)
}
return intervalMap, nil
}
// --- 共享 scan 辅助 ---
// scanModelPricingRows 扫描 model pricing 行,返回结果列表和 ID 列表
func scanModelPricingRows(rows *sql.Rows) ([]service.ChannelModelPricing, []int64, error) {
var result []service.ChannelModelPricing
var pricingIDs []int64
for rows.Next() {
var p service.ChannelModelPricing
var modelsJSON []byte
if err := rows.Scan(
&p.ID, &p.ChannelID, &modelsJSON, &p.BillingMode,
&p.InputPrice, &p.OutputPrice, &p.CacheWritePrice, &p.CacheReadPrice,
&p.ImageOutputPrice, &p.CreatedAt, &p.UpdatedAt,
); err != nil {
return nil, nil, fmt.Errorf("scan model pricing: %w", err)
}
if err := json.Unmarshal(modelsJSON, &p.Models); err != nil {
p.Models = []string{}
}
pricingIDs = append(pricingIDs, p.ID)
result = append(result, p)
}
if err := rows.Err(); err != nil {
return nil, nil, fmt.Errorf("iterate model pricing: %w", err)
}
return result, pricingIDs, nil
}
// --- 事务内辅助方法 ---
// dbExec 是 *sql.DB 和 *sql.Tx 共享的最小 SQL 执行接口
type dbExec interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
}
func setGroupIDsTx(ctx context.Context, exec dbExec, channelID int64, groupIDs []int64) error {
if _, err := exec.ExecContext(ctx, `DELETE FROM channel_groups WHERE channel_id = $1`, channelID); err != nil {
return fmt.Errorf("delete old group associations: %w", err)
}
if len(groupIDs) == 0 {
return nil
}
_, err := exec.ExecContext(ctx,
`INSERT INTO channel_groups (channel_id, group_id)
SELECT $1, unnest($2::bigint[])`,
channelID, pq.Array(groupIDs),
)
if err != nil {
return fmt.Errorf("insert group associations: %w", err)
}
return nil
}
func createModelPricingExec(ctx context.Context, exec dbExec, pricing *service.ChannelModelPricing) error {
modelsJSON, err := json.Marshal(pricing.Models)
if err != nil {
return fmt.Errorf("marshal models: %w", err)
}
billingMode := pricing.BillingMode
if billingMode == "" {
billingMode = service.BillingModeToken
}
err = exec.QueryRowContext(ctx,
`INSERT INTO channel_model_pricing (channel_id, models, billing_mode, input_price, output_price, cache_write_price, cache_read_price, image_output_price)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8) RETURNING id, created_at, updated_at`,
pricing.ChannelID, modelsJSON, billingMode,
pricing.InputPrice, pricing.OutputPrice, pricing.CacheWritePrice, pricing.CacheReadPrice,
pricing.ImageOutputPrice,
).Scan(&pricing.ID, &pricing.CreatedAt, &pricing.UpdatedAt)
if err != nil {
return fmt.Errorf("insert model pricing: %w", err)
}
for i := range pricing.Intervals {
pricing.Intervals[i].PricingID = pricing.ID
if err := createIntervalExec(ctx, exec, &pricing.Intervals[i]); err != nil {
return err
}
}
return nil
}
func createIntervalExec(ctx context.Context, exec dbExec, iv *service.PricingInterval) error {
return exec.QueryRowContext(ctx,
`INSERT INTO channel_pricing_intervals
(pricing_id, min_tokens, max_tokens, tier_label, input_price, output_price, cache_write_price, cache_read_price, per_request_price, sort_order)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at`,
iv.PricingID, iv.MinTokens, iv.MaxTokens, iv.TierLabel,
iv.InputPrice, iv.OutputPrice, iv.CacheWritePrice, iv.CacheReadPrice,
iv.PerRequestPrice, iv.SortOrder,
).Scan(&iv.ID, &iv.CreatedAt, &iv.UpdatedAt)
}
func replaceModelPricingTx(ctx context.Context, exec dbExec, channelID int64, pricingList []service.ChannelModelPricing) error {
if _, err := exec.ExecContext(ctx, `DELETE FROM channel_model_pricing WHERE channel_id = $1`, channelID); err != nil {
return fmt.Errorf("delete old model pricing: %w", err)
}
for i := range pricingList {
pricingList[i].ChannelID = channelID
if err := createModelPricingExec(ctx, exec, &pricingList[i]); err != nil {
return fmt.Errorf("insert model pricing: %w", err)
}
}
return nil
}
// isUniqueViolation 检查 pq 唯一约束违反错误
func isUniqueViolation(err error) bool {
if pqErr, ok := err.(*pq.Error); ok {
return pqErr.Code == "23505"
}
return false
}
// escapeLike 转义 LIKE/ILIKE 模式中的特殊字符
func escapeLike(s string) string {
s = strings.ReplaceAll(s, `\`, `\\`)
s = strings.ReplaceAll(s, `%`, `\%`)
s = strings.ReplaceAll(s, `_`, `\_`)
return s
}

View File

@@ -74,6 +74,7 @@ var ProviderSet = wire.NewSet(
NewUserGroupRateRepository,
NewErrorPassthroughRepository,
NewTLSFingerprintProfileRepository,
NewChannelRepository,
// Cache implementations
NewGatewayCache,