- Split Update (98→25 lines), buildCache (54→20 lines), Create (51→25 lines) into focused sub-functions: applyUpdateInput, checkGroupConflicts, fetchChannelData, populateChannelCache, storeErrorCache, getOldGroupIDs, invalidateAuthCacheForGroups - Extract validateChannelConfig to eliminate duplicated validation calls between Create and Update - Move validatePricingBillingMode from handler to service layer for proper separation of concerns - Add error logging to IsModelRestricted (was silently swallowing errors) - Add 12 new tests: ToUsageFields, billing mode validation, antigravity wildcard mapping isolation, Create/Update mapping conflict integration
388 lines
14 KiB
Go
388 lines
14 KiB
Go
package admin
|
|
|
|
import (
|
|
"strconv"
|
|
"strings"
|
|
|
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/response"
|
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
// ChannelHandler handles admin channel management
|
|
type ChannelHandler struct {
|
|
channelService *service.ChannelService
|
|
billingService *service.BillingService
|
|
}
|
|
|
|
// NewChannelHandler creates a new admin channel handler
|
|
func NewChannelHandler(channelService *service.ChannelService, billingService *service.BillingService) *ChannelHandler {
|
|
return &ChannelHandler{channelService: channelService, billingService: billingService}
|
|
}
|
|
|
|
// --- Request / Response types ---
|
|
|
|
type createChannelRequest struct {
|
|
Name string `json:"name" binding:"required,max=100"`
|
|
Description string `json:"description"`
|
|
GroupIDs []int64 `json:"group_ids"`
|
|
ModelPricing []channelModelPricingRequest `json:"model_pricing"`
|
|
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
|
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
|
|
RestrictModels bool `json:"restrict_models"`
|
|
}
|
|
|
|
type updateChannelRequest struct {
|
|
Name string `json:"name" binding:"omitempty,max=100"`
|
|
Description *string `json:"description"`
|
|
Status string `json:"status" binding:"omitempty,oneof=active disabled"`
|
|
GroupIDs *[]int64 `json:"group_ids"`
|
|
ModelPricing *[]channelModelPricingRequest `json:"model_pricing"`
|
|
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
|
BillingModelSource string `json:"billing_model_source" binding:"omitempty,oneof=requested upstream channel_mapped"`
|
|
RestrictModels *bool `json:"restrict_models"`
|
|
}
|
|
|
|
type channelModelPricingRequest struct {
|
|
Platform string `json:"platform" binding:"omitempty,max=50"`
|
|
Models []string `json:"models" binding:"required,min=1,max=100"`
|
|
BillingMode string `json:"billing_mode" binding:"omitempty,oneof=token per_request image"`
|
|
InputPrice *float64 `json:"input_price" binding:"omitempty,min=0"`
|
|
OutputPrice *float64 `json:"output_price" binding:"omitempty,min=0"`
|
|
CacheWritePrice *float64 `json:"cache_write_price" binding:"omitempty,min=0"`
|
|
CacheReadPrice *float64 `json:"cache_read_price" binding:"omitempty,min=0"`
|
|
ImageOutputPrice *float64 `json:"image_output_price" binding:"omitempty,min=0"`
|
|
PerRequestPrice *float64 `json:"per_request_price" binding:"omitempty,min=0"`
|
|
Intervals []pricingIntervalRequest `json:"intervals"`
|
|
}
|
|
|
|
type pricingIntervalRequest struct {
|
|
MinTokens int `json:"min_tokens"`
|
|
MaxTokens *int `json:"max_tokens"`
|
|
TierLabel string `json:"tier_label"`
|
|
InputPrice *float64 `json:"input_price"`
|
|
OutputPrice *float64 `json:"output_price"`
|
|
CacheWritePrice *float64 `json:"cache_write_price"`
|
|
CacheReadPrice *float64 `json:"cache_read_price"`
|
|
PerRequestPrice *float64 `json:"per_request_price"`
|
|
SortOrder int `json:"sort_order"`
|
|
}
|
|
|
|
type channelResponse struct {
|
|
ID int64 `json:"id"`
|
|
Name string `json:"name"`
|
|
Description string `json:"description"`
|
|
Status string `json:"status"`
|
|
BillingModelSource string `json:"billing_model_source"`
|
|
RestrictModels bool `json:"restrict_models"`
|
|
GroupIDs []int64 `json:"group_ids"`
|
|
ModelPricing []channelModelPricingResponse `json:"model_pricing"`
|
|
ModelMapping map[string]map[string]string `json:"model_mapping"`
|
|
CreatedAt string `json:"created_at"`
|
|
UpdatedAt string `json:"updated_at"`
|
|
}
|
|
|
|
type channelModelPricingResponse struct {
|
|
ID int64 `json:"id"`
|
|
Platform string `json:"platform"`
|
|
Models []string `json:"models"`
|
|
BillingMode string `json:"billing_mode"`
|
|
InputPrice *float64 `json:"input_price"`
|
|
OutputPrice *float64 `json:"output_price"`
|
|
CacheWritePrice *float64 `json:"cache_write_price"`
|
|
CacheReadPrice *float64 `json:"cache_read_price"`
|
|
ImageOutputPrice *float64 `json:"image_output_price"`
|
|
PerRequestPrice *float64 `json:"per_request_price"`
|
|
Intervals []pricingIntervalResponse `json:"intervals"`
|
|
}
|
|
|
|
type pricingIntervalResponse struct {
|
|
ID int64 `json:"id"`
|
|
MinTokens int `json:"min_tokens"`
|
|
MaxTokens *int `json:"max_tokens"`
|
|
TierLabel string `json:"tier_label,omitempty"`
|
|
InputPrice *float64 `json:"input_price"`
|
|
OutputPrice *float64 `json:"output_price"`
|
|
CacheWritePrice *float64 `json:"cache_write_price"`
|
|
CacheReadPrice *float64 `json:"cache_read_price"`
|
|
PerRequestPrice *float64 `json:"per_request_price"`
|
|
SortOrder int `json:"sort_order"`
|
|
}
|
|
|
|
func channelToResponse(ch *service.Channel) *channelResponse {
|
|
if ch == nil {
|
|
return nil
|
|
}
|
|
resp := &channelResponse{
|
|
ID: ch.ID,
|
|
Name: ch.Name,
|
|
Description: ch.Description,
|
|
Status: ch.Status,
|
|
RestrictModels: ch.RestrictModels,
|
|
GroupIDs: ch.GroupIDs,
|
|
ModelMapping: ch.ModelMapping,
|
|
CreatedAt: ch.CreatedAt.Format("2006-01-02T15:04:05Z"),
|
|
UpdatedAt: ch.UpdatedAt.Format("2006-01-02T15:04:05Z"),
|
|
}
|
|
resp.BillingModelSource = ch.BillingModelSource
|
|
if resp.BillingModelSource == "" {
|
|
resp.BillingModelSource = service.BillingModelSourceChannelMapped
|
|
}
|
|
if resp.GroupIDs == nil {
|
|
resp.GroupIDs = []int64{}
|
|
}
|
|
if resp.ModelMapping == nil {
|
|
resp.ModelMapping = map[string]map[string]string{}
|
|
}
|
|
|
|
resp.ModelPricing = make([]channelModelPricingResponse, 0, len(ch.ModelPricing))
|
|
for _, p := range ch.ModelPricing {
|
|
resp.ModelPricing = append(resp.ModelPricing, pricingToResponse(&p))
|
|
}
|
|
return resp
|
|
}
|
|
|
|
func pricingToResponse(p *service.ChannelModelPricing) channelModelPricingResponse {
|
|
models := p.Models
|
|
if models == nil {
|
|
models = []string{}
|
|
}
|
|
billingMode := string(p.BillingMode)
|
|
if billingMode == "" {
|
|
billingMode = string(service.BillingModeToken)
|
|
}
|
|
platform := p.Platform
|
|
if platform == "" {
|
|
platform = service.PlatformAnthropic
|
|
}
|
|
intervals := make([]pricingIntervalResponse, 0, len(p.Intervals))
|
|
for _, iv := range p.Intervals {
|
|
intervals = append(intervals, intervalToResponse(iv))
|
|
}
|
|
return channelModelPricingResponse{
|
|
ID: p.ID,
|
|
Platform: platform,
|
|
Models: models,
|
|
BillingMode: billingMode,
|
|
InputPrice: p.InputPrice,
|
|
OutputPrice: p.OutputPrice,
|
|
CacheWritePrice: p.CacheWritePrice,
|
|
CacheReadPrice: p.CacheReadPrice,
|
|
ImageOutputPrice: p.ImageOutputPrice,
|
|
PerRequestPrice: p.PerRequestPrice,
|
|
Intervals: intervals,
|
|
}
|
|
}
|
|
|
|
func intervalToResponse(iv service.PricingInterval) pricingIntervalResponse {
|
|
return pricingIntervalResponse{
|
|
ID: iv.ID,
|
|
MinTokens: iv.MinTokens,
|
|
MaxTokens: iv.MaxTokens,
|
|
TierLabel: iv.TierLabel,
|
|
InputPrice: iv.InputPrice,
|
|
OutputPrice: iv.OutputPrice,
|
|
CacheWritePrice: iv.CacheWritePrice,
|
|
CacheReadPrice: iv.CacheReadPrice,
|
|
PerRequestPrice: iv.PerRequestPrice,
|
|
SortOrder: iv.SortOrder,
|
|
}
|
|
}
|
|
|
|
func pricingRequestToService(reqs []channelModelPricingRequest) []service.ChannelModelPricing {
|
|
result := make([]service.ChannelModelPricing, 0, len(reqs))
|
|
for _, r := range reqs {
|
|
billingMode := service.BillingMode(r.BillingMode)
|
|
if billingMode == "" {
|
|
billingMode = service.BillingModeToken
|
|
}
|
|
platform := r.Platform
|
|
if platform == "" {
|
|
platform = service.PlatformAnthropic
|
|
}
|
|
intervals := make([]service.PricingInterval, 0, len(r.Intervals))
|
|
for _, iv := range r.Intervals {
|
|
intervals = append(intervals, service.PricingInterval{
|
|
MinTokens: iv.MinTokens,
|
|
MaxTokens: iv.MaxTokens,
|
|
TierLabel: iv.TierLabel,
|
|
InputPrice: iv.InputPrice,
|
|
OutputPrice: iv.OutputPrice,
|
|
CacheWritePrice: iv.CacheWritePrice,
|
|
CacheReadPrice: iv.CacheReadPrice,
|
|
PerRequestPrice: iv.PerRequestPrice,
|
|
SortOrder: iv.SortOrder,
|
|
})
|
|
}
|
|
result = append(result, service.ChannelModelPricing{
|
|
Platform: platform,
|
|
Models: r.Models,
|
|
BillingMode: billingMode,
|
|
InputPrice: r.InputPrice,
|
|
OutputPrice: r.OutputPrice,
|
|
CacheWritePrice: r.CacheWritePrice,
|
|
CacheReadPrice: r.CacheReadPrice,
|
|
ImageOutputPrice: r.ImageOutputPrice,
|
|
PerRequestPrice: r.PerRequestPrice,
|
|
Intervals: intervals,
|
|
})
|
|
}
|
|
return result
|
|
}
|
|
|
|
// --- Handlers ---
|
|
|
|
// List handles listing channels with pagination
|
|
// GET /api/v1/admin/channels
|
|
func (h *ChannelHandler) List(c *gin.Context) {
|
|
page, pageSize := response.ParsePagination(c)
|
|
status := c.Query("status")
|
|
search := strings.TrimSpace(c.Query("search"))
|
|
if len(search) > 100 {
|
|
search = search[:100]
|
|
}
|
|
|
|
channels, pag, err := h.channelService.List(c.Request.Context(), pagination.PaginationParams{Page: page, PageSize: pageSize}, status, search)
|
|
if err != nil {
|
|
response.ErrorFrom(c, err)
|
|
return
|
|
}
|
|
|
|
out := make([]*channelResponse, 0, len(channels))
|
|
for i := range channels {
|
|
out = append(out, channelToResponse(&channels[i]))
|
|
}
|
|
response.Paginated(c, out, pag.Total, page, pageSize)
|
|
}
|
|
|
|
// GetByID handles getting a channel by ID
|
|
// GET /api/v1/admin/channels/:id
|
|
func (h *ChannelHandler) GetByID(c *gin.Context) {
|
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
|
if err != nil {
|
|
response.ErrorFrom(c, infraerrors.BadRequest("INVALID_CHANNEL_ID", "Invalid channel ID"))
|
|
return
|
|
}
|
|
|
|
channel, err := h.channelService.GetByID(c.Request.Context(), id)
|
|
if err != nil {
|
|
response.ErrorFrom(c, err)
|
|
return
|
|
}
|
|
|
|
response.Success(c, channelToResponse(channel))
|
|
}
|
|
|
|
// Create handles creating a new channel
|
|
// POST /api/v1/admin/channels
|
|
func (h *ChannelHandler) Create(c *gin.Context) {
|
|
var req createChannelRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
|
|
return
|
|
}
|
|
|
|
pricing := pricingRequestToService(req.ModelPricing)
|
|
|
|
channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
|
|
Name: req.Name,
|
|
Description: req.Description,
|
|
GroupIDs: req.GroupIDs,
|
|
ModelPricing: pricing,
|
|
ModelMapping: req.ModelMapping,
|
|
BillingModelSource: req.BillingModelSource,
|
|
RestrictModels: req.RestrictModels,
|
|
})
|
|
if err != nil {
|
|
response.ErrorFrom(c, err)
|
|
return
|
|
}
|
|
|
|
response.Success(c, channelToResponse(channel))
|
|
}
|
|
|
|
// Update handles updating a channel
|
|
// PUT /api/v1/admin/channels/:id
|
|
func (h *ChannelHandler) Update(c *gin.Context) {
|
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
|
if err != nil {
|
|
response.ErrorFrom(c, infraerrors.BadRequest("INVALID_CHANNEL_ID", "Invalid channel ID"))
|
|
return
|
|
}
|
|
|
|
var req updateChannelRequest
|
|
if err := c.ShouldBindJSON(&req); err != nil {
|
|
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
|
|
return
|
|
}
|
|
|
|
input := &service.UpdateChannelInput{
|
|
Name: req.Name,
|
|
Description: req.Description,
|
|
Status: req.Status,
|
|
GroupIDs: req.GroupIDs,
|
|
ModelMapping: req.ModelMapping,
|
|
BillingModelSource: req.BillingModelSource,
|
|
RestrictModels: req.RestrictModels,
|
|
}
|
|
if req.ModelPricing != nil {
|
|
pricing := pricingRequestToService(*req.ModelPricing)
|
|
input.ModelPricing = &pricing
|
|
}
|
|
|
|
channel, err := h.channelService.Update(c.Request.Context(), id, input)
|
|
if err != nil {
|
|
response.ErrorFrom(c, err)
|
|
return
|
|
}
|
|
|
|
response.Success(c, channelToResponse(channel))
|
|
}
|
|
|
|
// Delete handles deleting a channel
|
|
// DELETE /api/v1/admin/channels/:id
|
|
func (h *ChannelHandler) Delete(c *gin.Context) {
|
|
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
|
|
if err != nil {
|
|
response.ErrorFrom(c, infraerrors.BadRequest("INVALID_CHANNEL_ID", "Invalid channel ID"))
|
|
return
|
|
}
|
|
|
|
if err := h.channelService.Delete(c.Request.Context(), id); err != nil {
|
|
response.ErrorFrom(c, err)
|
|
return
|
|
}
|
|
|
|
response.Success(c, gin.H{"message": "Channel deleted successfully"})
|
|
}
|
|
|
|
// GetModelDefaultPricing 获取模型的默认定价(用于前端自动填充)
|
|
// GET /api/v1/admin/channels/model-pricing?model=claude-sonnet-4
|
|
func (h *ChannelHandler) GetModelDefaultPricing(c *gin.Context) {
|
|
model := strings.TrimSpace(c.Query("model"))
|
|
if model == "" {
|
|
response.ErrorFrom(c, infraerrors.BadRequest("MISSING_PARAMETER", "model parameter is required").
|
|
WithMetadata(map[string]string{"param": "model"}))
|
|
return
|
|
}
|
|
|
|
pricing, err := h.billingService.GetModelPricing(model)
|
|
if err != nil {
|
|
// 模型不在定价列表中
|
|
response.Success(c, gin.H{"found": false})
|
|
return
|
|
}
|
|
|
|
response.Success(c, gin.H{
|
|
"found": true,
|
|
"input_price": pricing.InputPricePerToken,
|
|
"output_price": pricing.OutputPricePerToken,
|
|
"cache_write_price": pricing.CacheCreationPricePerToken,
|
|
"cache_read_price": pricing.CacheReadPricePerToken,
|
|
"image_output_price": pricing.ImageOutputPricePerToken,
|
|
})
|
|
}
|