refactor(channel): split long functions, extract shared validation, move billing validation to service
- Split Update (98→25 lines), buildCache (54→20 lines), Create (51→25 lines) into focused sub-functions: applyUpdateInput, checkGroupConflicts, fetchChannelData, populateChannelCache, storeErrorCache, getOldGroupIDs, invalidateAuthCacheForGroups - Extract validateChannelConfig to eliminate duplicated validation calls between Create and Update - Move validatePricingBillingMode from handler to service layer for proper separation of concerns - Add error logging to IsModelRestricted (was silently swallowing errors) - Add 12 new tests: ToUsageFields, billing mode validation, antigravity wildcard mapping isolation, Create/Update mapping conflict integration
This commit is contained in:
@@ -1,8 +1,6 @@
|
||||
package admin
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -235,61 +233,6 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe
|
||||
return result
|
||||
}
|
||||
|
||||
// validatePricingBillingMode 校验计费配置
|
||||
func validatePricingBillingMode(pricing []service.ChannelModelPricing) error {
|
||||
for _, p := range pricing {
|
||||
// 按次/图片模式必须配置默认价格或区间
|
||||
if p.BillingMode == service.BillingModePerRequest || p.BillingMode == service.BillingModeImage {
|
||||
if p.PerRequestPrice == nil && len(p.Intervals) == 0 {
|
||||
return errors.New("per-request price or intervals required for per_request/image billing mode")
|
||||
}
|
||||
}
|
||||
// 校验价格不能为负
|
||||
if err := validatePriceNotNegative("input_price", p.InputPrice); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validatePriceNotNegative("output_price", p.OutputPrice); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validatePriceNotNegative("cache_write_price", p.CacheWritePrice); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validatePriceNotNegative("cache_read_price", p.CacheReadPrice); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validatePriceNotNegative("image_output_price", p.ImageOutputPrice); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := validatePriceNotNegative("per_request_price", p.PerRequestPrice); err != nil {
|
||||
return err
|
||||
}
|
||||
// 校验 interval:至少有一个价格字段非空
|
||||
for _, iv := range p.Intervals {
|
||||
if iv.InputPrice == nil && iv.OutputPrice == nil &&
|
||||
iv.CacheWritePrice == nil && iv.CacheReadPrice == nil &&
|
||||
iv.PerRequestPrice == nil {
|
||||
return fmt.Errorf("interval [%d, %s] has no price fields set for model %v",
|
||||
iv.MinTokens, formatMaxTokens(iv.MaxTokens), p.Models)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validatePriceNotNegative(field string, val *float64) error {
|
||||
if val != nil && *val < 0 {
|
||||
return fmt.Errorf("%s must be >= 0", field)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func formatMaxTokens(max *int) string {
|
||||
if max == nil {
|
||||
return "∞"
|
||||
}
|
||||
return fmt.Sprintf("%d", *max)
|
||||
}
|
||||
|
||||
// --- Handlers ---
|
||||
|
||||
// List handles listing channels with pagination
|
||||
@@ -343,10 +286,6 @@ func (h *ChannelHandler) Create(c *gin.Context) {
|
||||
}
|
||||
|
||||
pricing := pricingRequestToService(req.ModelPricing)
|
||||
if err := validatePricingBillingMode(pricing); err != nil {
|
||||
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
|
||||
Name: req.Name,
|
||||
@@ -391,10 +330,6 @@ func (h *ChannelHandler) Update(c *gin.Context) {
|
||||
}
|
||||
if req.ModelPricing != nil {
|
||||
pricing := pricingRequestToService(*req.ModelPricing)
|
||||
if err := validatePricingBillingMode(pricing); err != nil {
|
||||
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
|
||||
return
|
||||
}
|
||||
input.ModelPricing = &pricing
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user