refactor(channel): split long functions, extract shared validation, move billing validation to service
- Split Update (98→25 lines), buildCache (54→20 lines), Create (51→25 lines) into focused sub-functions: applyUpdateInput, checkGroupConflicts, fetchChannelData, populateChannelCache, storeErrorCache, getOldGroupIDs, invalidateAuthCacheForGroups - Extract validateChannelConfig to eliminate duplicated validation calls between Create and Update - Move validatePricingBillingMode from handler to service layer for proper separation of concerns - Add error logging to IsModelRestricted (was silently swallowing errors) - Add 12 new tests: ToUsageFields, billing mode validation, antigravity wildcard mapping isolation, Create/Update mapping conflict integration
This commit is contained in:
@@ -1,8 +1,6 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -235,61 +233,6 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
// validatePricingBillingMode 校验计费配置
|
|
||||||
func validatePricingBillingMode(pricing []service.ChannelModelPricing) error {
|
|
||||||
for _, p := range pricing {
|
|
||||||
// 按次/图片模式必须配置默认价格或区间
|
|
||||||
if p.BillingMode == service.BillingModePerRequest || p.BillingMode == service.BillingModeImage {
|
|
||||||
if p.PerRequestPrice == nil && len(p.Intervals) == 0 {
|
|
||||||
return errors.New("per-request price or intervals required for per_request/image billing mode")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// 校验价格不能为负
|
|
||||||
if err := validatePriceNotNegative("input_price", p.InputPrice); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := validatePriceNotNegative("output_price", p.OutputPrice); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := validatePriceNotNegative("cache_write_price", p.CacheWritePrice); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := validatePriceNotNegative("cache_read_price", p.CacheReadPrice); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := validatePriceNotNegative("image_output_price", p.ImageOutputPrice); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err := validatePriceNotNegative("per_request_price", p.PerRequestPrice); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// 校验 interval:至少有一个价格字段非空
|
|
||||||
for _, iv := range p.Intervals {
|
|
||||||
if iv.InputPrice == nil && iv.OutputPrice == nil &&
|
|
||||||
iv.CacheWritePrice == nil && iv.CacheReadPrice == nil &&
|
|
||||||
iv.PerRequestPrice == nil {
|
|
||||||
return fmt.Errorf("interval [%d, %s] has no price fields set for model %v",
|
|
||||||
iv.MinTokens, formatMaxTokens(iv.MaxTokens), p.Models)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func validatePriceNotNegative(field string, val *float64) error {
|
|
||||||
if val != nil && *val < 0 {
|
|
||||||
return fmt.Errorf("%s must be >= 0", field)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func formatMaxTokens(max *int) string {
|
|
||||||
if max == nil {
|
|
||||||
return "∞"
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("%d", *max)
|
|
||||||
}
|
|
||||||
|
|
||||||
// --- Handlers ---
|
// --- Handlers ---
|
||||||
|
|
||||||
// List handles listing channels with pagination
|
// List handles listing channels with pagination
|
||||||
@@ -343,10 +286,6 @@ func (h *ChannelHandler) Create(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pricing := pricingRequestToService(req.ModelPricing)
|
pricing := pricingRequestToService(req.ModelPricing)
|
||||||
if err := validatePricingBillingMode(pricing); err != nil {
|
|
||||||
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
|
channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
|
||||||
Name: req.Name,
|
Name: req.Name,
|
||||||
@@ -391,10 +330,6 @@ func (h *ChannelHandler) Update(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
if req.ModelPricing != nil {
|
if req.ModelPricing != nil {
|
||||||
pricing := pricingRequestToService(*req.ModelPricing)
|
pricing := pricingRequestToService(*req.ModelPricing)
|
||||||
if err := validatePricingBillingMode(pricing); err != nil {
|
|
||||||
response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error()))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
input.ModelPricing = &pricing
|
input.ModelPricing = &pricing
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -400,103 +400,3 @@ func TestPricingRequestToService_NilPriceFields(t *testing.T) {
|
|||||||
require.Nil(t, r.ImageOutputPrice)
|
require.Nil(t, r.ImageOutputPrice)
|
||||||
require.Nil(t, r.PerRequestPrice)
|
require.Nil(t, r.PerRequestPrice)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---------------------------------------------------------------------------
|
|
||||||
// 3. validatePricingBillingMode
|
|
||||||
// ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
func TestValidatePricingBillingMode(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
pricing []service.ChannelModelPricing
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "token mode - valid",
|
|
||||||
pricing: []service.ChannelModelPricing{
|
|
||||||
{BillingMode: service.BillingModeToken},
|
|
||||||
},
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "per_request with price - valid",
|
|
||||||
pricing: []service.ChannelModelPricing{
|
|
||||||
{
|
|
||||||
BillingMode: service.BillingModePerRequest,
|
|
||||||
PerRequestPrice: float64Ptr(0.5),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "per_request with intervals - valid",
|
|
||||||
pricing: []service.ChannelModelPricing{
|
|
||||||
{
|
|
||||||
BillingMode: service.BillingModePerRequest,
|
|
||||||
Intervals: []service.PricingInterval{
|
|
||||||
{MinTokens: 0, MaxTokens: intPtr(1000), PerRequestPrice: float64Ptr(0.1)},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "per_request no price no intervals - invalid",
|
|
||||||
pricing: []service.ChannelModelPricing{
|
|
||||||
{BillingMode: service.BillingModePerRequest},
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "image with price - valid",
|
|
||||||
pricing: []service.ChannelModelPricing{
|
|
||||||
{
|
|
||||||
BillingMode: service.BillingModeImage,
|
|
||||||
PerRequestPrice: float64Ptr(0.2),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "image no price no intervals - invalid",
|
|
||||||
pricing: []service.ChannelModelPricing{
|
|
||||||
{BillingMode: service.BillingModeImage},
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty list - valid",
|
|
||||||
pricing: []service.ChannelModelPricing{},
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "mixed modes with invalid image - invalid",
|
|
||||||
pricing: []service.ChannelModelPricing{
|
|
||||||
{
|
|
||||||
BillingMode: service.BillingModeToken,
|
|
||||||
InputPrice: float64Ptr(0.01),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
BillingMode: service.BillingModePerRequest,
|
|
||||||
PerRequestPrice: float64Ptr(0.5),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
BillingMode: service.BillingModeImage,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
err := validatePricingBillingMode(tt.pricing)
|
|
||||||
if tt.wantErr {
|
|
||||||
require.Error(t, err)
|
|
||||||
require.Contains(t, err.Error(), "per-request price or intervals required")
|
|
||||||
} else {
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -248,40 +248,58 @@ func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// storeErrorCache 存入短 TTL 空缓存,防止 DB 错误后紧密重试。
|
||||||
|
// 通过回退 loadedAt 使剩余 TTL = channelErrorTTL。
|
||||||
|
func (s *ChannelService) storeErrorCache() {
|
||||||
|
errorCache := newEmptyChannelCache()
|
||||||
|
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL))
|
||||||
|
s.cache.Store(errorCache)
|
||||||
|
}
|
||||||
|
|
||||||
// buildCache 从数据库构建渠道缓存。
|
// buildCache 从数据库构建渠道缓存。
|
||||||
// 使用独立 context 避免请求取消导致空值被长期缓存。
|
// 使用独立 context 避免请求取消导致空值被长期缓存。
|
||||||
func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) {
|
func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) {
|
||||||
// 断开请求取消链,避免客户端断连导致空值被长期缓存
|
|
||||||
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), channelCacheDBTimeout)
|
dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), channelCacheDBTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
channels, err := s.repo.ListAll(dbCtx)
|
channels, groupPlatforms, err := s.fetchChannelData(dbCtx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// error-TTL:失败时存入短 TTL 空缓存,防止紧密重试
|
return nil, err
|
||||||
slog.Warn("failed to build channel cache", "error", err)
|
}
|
||||||
errorCache := newEmptyChannelCache()
|
|
||||||
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) // 使剩余 TTL = errorTTL
|
cache := populateChannelCache(channels, groupPlatforms)
|
||||||
s.cache.Store(errorCache)
|
s.cache.Store(cache)
|
||||||
return nil, fmt.Errorf("list all channels: %w", err)
|
return cache, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// fetchChannelData 从数据库加载渠道列表和分组平台映射。
|
||||||
|
func (s *ChannelService) fetchChannelData(ctx context.Context) ([]Channel, map[int64]string, error) {
|
||||||
|
channels, err := s.repo.ListAll(ctx)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to build channel cache", "error", err)
|
||||||
|
s.storeErrorCache()
|
||||||
|
return nil, nil, fmt.Errorf("list all channels: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 收集所有 groupID,批量查询 platform
|
|
||||||
var allGroupIDs []int64
|
var allGroupIDs []int64
|
||||||
for i := range channels {
|
for i := range channels {
|
||||||
allGroupIDs = append(allGroupIDs, channels[i].GroupIDs...)
|
allGroupIDs = append(allGroupIDs, channels[i].GroupIDs...)
|
||||||
}
|
}
|
||||||
|
|
||||||
groupPlatforms := make(map[int64]string)
|
groupPlatforms := make(map[int64]string)
|
||||||
if len(allGroupIDs) > 0 {
|
if len(allGroupIDs) > 0 {
|
||||||
groupPlatforms, err = s.repo.GetGroupPlatforms(dbCtx, allGroupIDs)
|
groupPlatforms, err = s.repo.GetGroupPlatforms(ctx, allGroupIDs)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("failed to load group platforms for channel cache", "error", err)
|
slog.Warn("failed to load group platforms for channel cache", "error", err)
|
||||||
errorCache := newEmptyChannelCache()
|
s.storeErrorCache()
|
||||||
errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL))
|
return nil, nil, fmt.Errorf("get group platforms: %w", err)
|
||||||
s.cache.Store(errorCache)
|
|
||||||
return nil, fmt.Errorf("get group platforms: %w", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return channels, groupPlatforms, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// populateChannelCache 将渠道列表和分组平台映射填充到缓存快照中。
|
||||||
|
func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) *channelCache {
|
||||||
cache := newEmptyChannelCache()
|
cache := newEmptyChannelCache()
|
||||||
cache.groupPlatform = groupPlatforms
|
cache.groupPlatform = groupPlatforms
|
||||||
cache.byID = make(map[int64]*Channel, len(channels))
|
cache.byID = make(map[int64]*Channel, len(channels))
|
||||||
@@ -290,7 +308,6 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
|||||||
for i := range channels {
|
for i := range channels {
|
||||||
ch := &channels[i]
|
ch := &channels[i]
|
||||||
cache.byID[ch.ID] = ch
|
cache.byID[ch.ID] = ch
|
||||||
|
|
||||||
for _, gid := range ch.GroupIDs {
|
for _, gid := range ch.GroupIDs {
|
||||||
cache.channelByGroupID[gid] = ch
|
cache.channelByGroupID[gid] = ch
|
||||||
platform := groupPlatforms[gid]
|
platform := groupPlatforms[gid]
|
||||||
@@ -298,11 +315,7 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error)
|
|||||||
expandMappingToCache(cache, ch, gid, platform)
|
expandMappingToCache(cache, ch, gid, platform)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return cache
|
||||||
// 通配符条目保持配置顺序(最先匹配到优先)
|
|
||||||
|
|
||||||
s.cache.Store(cache)
|
|
||||||
return cache, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// invalidateCache 使缓存失效,让下次读取时自然重建
|
// invalidateCache 使缓存失效,让下次读取时自然重建
|
||||||
@@ -466,7 +479,10 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6
|
|||||||
// 返回 true 表示模型被限制(不在允许列表中)。
|
// 返回 true 表示模型被限制(不在允许列表中)。
|
||||||
// 如果渠道未启用模型限制或分组无渠道关联,返回 false。
|
// 如果渠道未启用模型限制或分组无渠道关联,返回 false。
|
||||||
func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
|
func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool {
|
||||||
lk, _ := s.lookupGroupChannel(ctx, groupID)
|
lk, err := s.lookupGroupChannel(ctx, groupID)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to load channel cache for model restriction check", "group_id", groupID, "error", err)
|
||||||
|
}
|
||||||
if lk == nil {
|
if lk == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -537,6 +553,91 @@ func ReplaceModelInBody(body []byte, newModel string) []byte {
|
|||||||
return newBody
|
return newBody
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validateChannelConfig 校验渠道的定价和映射配置(冲突检测 + 区间校验 + 计费模式校验)。
|
||||||
|
// Create 和 Update 共用此函数,避免重复。
|
||||||
|
func validateChannelConfig(pricing []ChannelModelPricing, mapping map[string]map[string]string) error {
|
||||||
|
if err := validateNoConflictingModels(pricing); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := validatePricingIntervals(pricing); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := validateNoConflictingMappings(mapping); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return validatePricingBillingMode(pricing)
|
||||||
|
}
|
||||||
|
|
||||||
|
// validatePricingBillingMode 校验计费模式配置:按次/图片模式必须配价格或区间,所有价格字段不能为负,区间至少有一个价格字段。
|
||||||
|
func validatePricingBillingMode(pricing []ChannelModelPricing) error {
|
||||||
|
for _, p := range pricing {
|
||||||
|
if err := checkBillingModeRequirements(p); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := checkPricesNotNegative(p); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := checkIntervalsHavePrices(p); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkBillingModeRequirements(p ChannelModelPricing) error {
|
||||||
|
if p.BillingMode == BillingModePerRequest || p.BillingMode == BillingModeImage {
|
||||||
|
if p.PerRequestPrice == nil && len(p.Intervals) == 0 {
|
||||||
|
return infraerrors.BadRequest(
|
||||||
|
"BILLING_MODE_MISSING_PRICE",
|
||||||
|
"per-request price or intervals required for per_request/image billing mode",
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkPricesNotNegative(p ChannelModelPricing) error {
|
||||||
|
checks := []struct {
|
||||||
|
field string
|
||||||
|
val *float64
|
||||||
|
}{
|
||||||
|
{"input_price", p.InputPrice},
|
||||||
|
{"output_price", p.OutputPrice},
|
||||||
|
{"cache_write_price", p.CacheWritePrice},
|
||||||
|
{"cache_read_price", p.CacheReadPrice},
|
||||||
|
{"image_output_price", p.ImageOutputPrice},
|
||||||
|
{"per_request_price", p.PerRequestPrice},
|
||||||
|
}
|
||||||
|
for _, c := range checks {
|
||||||
|
if c.val != nil && *c.val < 0 {
|
||||||
|
return infraerrors.BadRequest("NEGATIVE_PRICE", fmt.Sprintf("%s must be >= 0", c.field))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkIntervalsHavePrices(p ChannelModelPricing) error {
|
||||||
|
for _, iv := range p.Intervals {
|
||||||
|
if iv.InputPrice == nil && iv.OutputPrice == nil &&
|
||||||
|
iv.CacheWritePrice == nil && iv.CacheReadPrice == nil &&
|
||||||
|
iv.PerRequestPrice == nil {
|
||||||
|
return infraerrors.BadRequest(
|
||||||
|
"INTERVAL_MISSING_PRICE",
|
||||||
|
fmt.Sprintf("interval [%d, %s] has no price fields set for model %v",
|
||||||
|
iv.MinTokens, formatMaxTokens(iv.MaxTokens), p.Models),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func formatMaxTokens(max *int) string {
|
||||||
|
if max == nil {
|
||||||
|
return "∞"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%d", *max)
|
||||||
|
}
|
||||||
|
|
||||||
// --- CRUD ---
|
// --- CRUD ---
|
||||||
|
|
||||||
// Create 创建渠道
|
// Create 创建渠道
|
||||||
@@ -549,15 +650,8 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
|
|||||||
return nil, ErrChannelExists
|
return nil, ErrChannelExists
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查分组冲突
|
if err := s.checkGroupConflicts(ctx, 0, input.GroupIDs); err != nil {
|
||||||
if len(input.GroupIDs) > 0 {
|
return nil, err
|
||||||
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, 0, input.GroupIDs)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("check group conflicts: %w", err)
|
|
||||||
}
|
|
||||||
if len(conflicting) > 0 {
|
|
||||||
return nil, ErrGroupAlreadyInChannel
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
channel := &Channel{
|
channel := &Channel{
|
||||||
@@ -574,13 +668,7 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput)
|
|||||||
channel.BillingModelSource = BillingModelSourceChannelMapped
|
channel.BillingModelSource = BillingModelSourceChannelMapped
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
|
if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := validatePricingIntervals(channel.ModelPricing); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -604,102 +692,112 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan
|
|||||||
return nil, fmt.Errorf("get channel: %w", err)
|
return nil, fmt.Errorf("get channel: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.Name != "" && input.Name != channel.Name {
|
if err := s.applyUpdateInput(ctx, channel, input); err != nil {
|
||||||
exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, id)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("check channel exists: %w", err)
|
|
||||||
}
|
|
||||||
if exists {
|
|
||||||
return nil, ErrChannelExists
|
|
||||||
}
|
|
||||||
channel.Name = input.Name
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.Description != nil {
|
|
||||||
channel.Description = *input.Description
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.Status != "" {
|
|
||||||
channel.Status = input.Status
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.RestrictModels != nil {
|
|
||||||
channel.RestrictModels = *input.RestrictModels
|
|
||||||
}
|
|
||||||
|
|
||||||
// 检查分组冲突
|
|
||||||
if input.GroupIDs != nil {
|
|
||||||
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, id, *input.GroupIDs)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("check group conflicts: %w", err)
|
|
||||||
}
|
|
||||||
if len(conflicting) > 0 {
|
|
||||||
return nil, ErrGroupAlreadyInChannel
|
|
||||||
}
|
|
||||||
channel.GroupIDs = *input.GroupIDs
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.ModelPricing != nil {
|
|
||||||
channel.ModelPricing = *input.ModelPricing
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.ModelMapping != nil {
|
|
||||||
channel.ModelMapping = input.ModelMapping
|
|
||||||
}
|
|
||||||
|
|
||||||
if input.BillingModelSource != "" {
|
|
||||||
channel.BillingModelSource = input.BillingModelSource
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := validateNoConflictingModels(channel.ModelPricing); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := validatePricingIntervals(channel.ModelPricing); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if err := validateNoConflictingMappings(channel.ModelMapping); err != nil {
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 先获取旧分组,Update 后旧分组关联已删除,无法再查到
|
if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil {
|
||||||
var oldGroupIDs []int64
|
return nil, err
|
||||||
if s.authCacheInvalidator != nil {
|
|
||||||
var err2 error
|
|
||||||
oldGroupIDs, err2 = s.repo.GetGroupIDs(ctx, id)
|
|
||||||
if err2 != nil {
|
|
||||||
slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", id, "error", err2)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
oldGroupIDs := s.getOldGroupIDs(ctx, id)
|
||||||
|
|
||||||
if err := s.repo.Update(ctx, channel); err != nil {
|
if err := s.repo.Update(ctx, channel); err != nil {
|
||||||
return nil, fmt.Errorf("update channel: %w", err)
|
return nil, fmt.Errorf("update channel: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.invalidateCache()
|
s.invalidateCache()
|
||||||
|
s.invalidateAuthCacheForGroups(ctx, oldGroupIDs, channel.GroupIDs)
|
||||||
// 失效新旧分组的 auth 缓存
|
|
||||||
if s.authCacheInvalidator != nil {
|
|
||||||
seen := make(map[int64]struct{}, len(oldGroupIDs)+len(channel.GroupIDs))
|
|
||||||
for _, gid := range oldGroupIDs {
|
|
||||||
if _, ok := seen[gid]; !ok {
|
|
||||||
seen[gid] = struct{}{}
|
|
||||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, gid := range channel.GroupIDs {
|
|
||||||
if _, ok := seen[gid]; !ok {
|
|
||||||
seen[gid] = struct{}{}
|
|
||||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.repo.GetByID(ctx, id)
|
return s.repo.GetByID(ctx, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// applyUpdateInput 将更新请求的字段应用到渠道实体上。
|
||||||
|
func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel, input *UpdateChannelInput) error {
|
||||||
|
if input.Name != "" && input.Name != channel.Name {
|
||||||
|
exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, channel.ID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("check channel exists: %w", err)
|
||||||
|
}
|
||||||
|
if exists {
|
||||||
|
return ErrChannelExists
|
||||||
|
}
|
||||||
|
channel.Name = input.Name
|
||||||
|
}
|
||||||
|
if input.Description != nil {
|
||||||
|
channel.Description = *input.Description
|
||||||
|
}
|
||||||
|
if input.Status != "" {
|
||||||
|
channel.Status = input.Status
|
||||||
|
}
|
||||||
|
if input.RestrictModels != nil {
|
||||||
|
channel.RestrictModels = *input.RestrictModels
|
||||||
|
}
|
||||||
|
if input.GroupIDs != nil {
|
||||||
|
if err := s.checkGroupConflicts(ctx, channel.ID, *input.GroupIDs); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
channel.GroupIDs = *input.GroupIDs
|
||||||
|
}
|
||||||
|
if input.ModelPricing != nil {
|
||||||
|
channel.ModelPricing = *input.ModelPricing
|
||||||
|
}
|
||||||
|
if input.ModelMapping != nil {
|
||||||
|
channel.ModelMapping = input.ModelMapping
|
||||||
|
}
|
||||||
|
if input.BillingModelSource != "" {
|
||||||
|
channel.BillingModelSource = input.BillingModelSource
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkGroupConflicts 检查待关联的分组是否已属于其他渠道。
|
||||||
|
// channelID 为当前渠道 ID(Create 时传 0)。
|
||||||
|
func (s *ChannelService) checkGroupConflicts(ctx context.Context, channelID int64, groupIDs []int64) error {
|
||||||
|
if len(groupIDs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, channelID, groupIDs)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("check group conflicts: %w", err)
|
||||||
|
}
|
||||||
|
if len(conflicting) > 0 {
|
||||||
|
return ErrGroupAlreadyInChannel
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getOldGroupIDs 获取渠道更新前的关联分组 ID(用于失效 auth 缓存)。
|
||||||
|
func (s *ChannelService) getOldGroupIDs(ctx context.Context, channelID int64) []int64 {
|
||||||
|
if s.authCacheInvalidator == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
oldGroupIDs, err := s.repo.GetGroupIDs(ctx, channelID)
|
||||||
|
if err != nil {
|
||||||
|
slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", channelID, "error", err)
|
||||||
|
}
|
||||||
|
return oldGroupIDs
|
||||||
|
}
|
||||||
|
|
||||||
|
// invalidateAuthCacheForGroups 对新旧分组去重后逐个失效 auth 缓存。
|
||||||
|
func (s *ChannelService) invalidateAuthCacheForGroups(ctx context.Context, groupIDSets ...[]int64) {
|
||||||
|
if s.authCacheInvalidator == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
seen := make(map[int64]struct{})
|
||||||
|
for _, ids := range groupIDSets {
|
||||||
|
for _, gid := range ids {
|
||||||
|
if _, ok := seen[gid]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[gid] = struct{}{}
|
||||||
|
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Delete 删除渠道
|
// Delete 删除渠道
|
||||||
func (s *ChannelService) Delete(ctx context.Context, id int64) error {
|
func (s *ChannelService) Delete(ctx context.Context, id int64) error {
|
||||||
// 先获取关联分组用于失效缓存
|
|
||||||
groupIDs, err := s.repo.GetGroupIDs(ctx, id)
|
groupIDs, err := s.repo.GetGroupIDs(ctx, id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("failed to get group IDs before delete", "channel_id", id, "error", err)
|
slog.Warn("failed to get group IDs before delete", "channel_id", id, "error", err)
|
||||||
@@ -710,12 +808,7 @@ func (s *ChannelService) Delete(ctx context.Context, id int64) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.invalidateCache()
|
s.invalidateCache()
|
||||||
|
s.invalidateAuthCacheForGroups(ctx, groupIDs)
|
||||||
if s.authCacheInvalidator != nil {
|
|
||||||
for _, gid := range groupIDs {
|
|
||||||
s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2199,3 +2199,207 @@ func TestGetChannelModelPricing_NonAntigravityUnaffected(t *testing.T) {
|
|||||||
require.Equal(t, int64(601), result.ID)
|
require.Equal(t, int64(601), result.ID)
|
||||||
require.InDelta(t, 5e-6, *result.InputPrice, 1e-12)
|
require.InDelta(t, 5e-6, *result.InputPrice, 1e-12)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// 10. ToUsageFields
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestToUsageFields_NoMapping(t *testing.T) {
|
||||||
|
r := ChannelMappingResult{
|
||||||
|
MappedModel: "claude-opus-4",
|
||||||
|
ChannelID: 1,
|
||||||
|
Mapped: false,
|
||||||
|
BillingModelSource: BillingModelSourceRequested,
|
||||||
|
}
|
||||||
|
fields := r.ToUsageFields("claude-opus-4", "claude-opus-4")
|
||||||
|
require.Equal(t, int64(1), fields.ChannelID)
|
||||||
|
require.Equal(t, "claude-opus-4", fields.OriginalModel)
|
||||||
|
require.Equal(t, "claude-opus-4", fields.ChannelMappedModel)
|
||||||
|
require.Equal(t, BillingModelSourceRequested, fields.BillingModelSource)
|
||||||
|
require.Empty(t, fields.ModelMappingChain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToUsageFields_WithChannelMapping(t *testing.T) {
|
||||||
|
r := ChannelMappingResult{
|
||||||
|
MappedModel: "claude-sonnet-4-20250514",
|
||||||
|
ChannelID: 2,
|
||||||
|
Mapped: true,
|
||||||
|
BillingModelSource: BillingModelSourceChannelMapped,
|
||||||
|
}
|
||||||
|
fields := r.ToUsageFields("claude-sonnet-4", "claude-sonnet-4-20250514")
|
||||||
|
require.Equal(t, int64(2), fields.ChannelID)
|
||||||
|
require.Equal(t, "claude-sonnet-4", fields.OriginalModel)
|
||||||
|
require.Equal(t, "claude-sonnet-4-20250514", fields.ChannelMappedModel)
|
||||||
|
require.Equal(t, "claude-sonnet-4→claude-sonnet-4-20250514", fields.ModelMappingChain)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToUsageFields_WithUpstreamDifference(t *testing.T) {
|
||||||
|
r := ChannelMappingResult{
|
||||||
|
MappedModel: "claude-sonnet-4",
|
||||||
|
ChannelID: 3,
|
||||||
|
Mapped: true,
|
||||||
|
BillingModelSource: BillingModelSourceUpstream,
|
||||||
|
}
|
||||||
|
fields := r.ToUsageFields("my-alias", "claude-sonnet-4-20250514")
|
||||||
|
require.Equal(t, "my-alias", fields.OriginalModel)
|
||||||
|
require.Equal(t, "claude-sonnet-4", fields.ChannelMappedModel)
|
||||||
|
require.Equal(t, "my-alias→claude-sonnet-4→claude-sonnet-4-20250514", fields.ModelMappingChain)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// 11. validatePricingBillingMode (moved from handler tests)
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestValidatePricingBillingMode(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
pricing []ChannelModelPricing
|
||||||
|
wantErr bool
|
||||||
|
errMsg string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "token mode - valid",
|
||||||
|
pricing: []ChannelModelPricing{{BillingMode: BillingModeToken}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "per_request with price - valid",
|
||||||
|
pricing: []ChannelModelPricing{{
|
||||||
|
BillingMode: BillingModePerRequest,
|
||||||
|
PerRequestPrice: testPtrFloat64(0.5),
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "per_request with intervals - valid",
|
||||||
|
pricing: []ChannelModelPricing{{
|
||||||
|
BillingMode: BillingModePerRequest,
|
||||||
|
Intervals: []PricingInterval{{MinTokens: 0, MaxTokens: testPtrInt(1000), PerRequestPrice: testPtrFloat64(0.1)}},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "per_request no price no intervals - invalid",
|
||||||
|
pricing: []ChannelModelPricing{{BillingMode: BillingModePerRequest}},
|
||||||
|
wantErr: true,
|
||||||
|
errMsg: "per-request price or intervals required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "image no price no intervals - invalid",
|
||||||
|
pricing: []ChannelModelPricing{{BillingMode: BillingModeImage}},
|
||||||
|
wantErr: true,
|
||||||
|
errMsg: "per-request price or intervals required",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty list - valid",
|
||||||
|
pricing: []ChannelModelPricing{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "negative input_price - invalid",
|
||||||
|
pricing: []ChannelModelPricing{{
|
||||||
|
BillingMode: BillingModeToken,
|
||||||
|
InputPrice: testPtrFloat64(-0.01),
|
||||||
|
}},
|
||||||
|
wantErr: true,
|
||||||
|
errMsg: "input_price must be >= 0",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "interval with no price fields - invalid",
|
||||||
|
pricing: []ChannelModelPricing{{
|
||||||
|
BillingMode: BillingModePerRequest,
|
||||||
|
PerRequestPrice: testPtrFloat64(0.5),
|
||||||
|
Intervals: []PricingInterval{{MinTokens: 0, MaxTokens: testPtrInt(1000)}},
|
||||||
|
}},
|
||||||
|
wantErr: true,
|
||||||
|
errMsg: "has no price fields set",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
err := validatePricingBillingMode(tt.pricing)
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), tt.errMsg)
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// 12. Antigravity wildcard mapping isolation
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestResolveChannelMapping_AntigravityDoesNotSeeWildcardMappingFromOtherPlatforms(t *testing.T) {
|
||||||
|
ch := Channel{
|
||||||
|
ID: 1,
|
||||||
|
Status: StatusActive,
|
||||||
|
GroupIDs: []int64{10, 20},
|
||||||
|
ModelMapping: map[string]map[string]string{
|
||||||
|
PlatformAnthropic: {"claude-*": "claude-override"},
|
||||||
|
PlatformGemini: {"gemini-*": "gemini-override"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity, 20: PlatformAnthropic})
|
||||||
|
svc := newTestChannelService(repo)
|
||||||
|
|
||||||
|
// antigravity 分组不应看到 anthropic/gemini 的通配符映射
|
||||||
|
result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4")
|
||||||
|
require.False(t, result.Mapped)
|
||||||
|
require.Equal(t, "claude-opus-4", result.MappedModel)
|
||||||
|
|
||||||
|
result = svc.ResolveChannelMapping(context.Background(), 10, "gemini-2.5-pro")
|
||||||
|
require.False(t, result.Mapped)
|
||||||
|
require.Equal(t, "gemini-2.5-pro", result.MappedModel)
|
||||||
|
|
||||||
|
// anthropic 分组应该能看到 anthropic 的通配符映射
|
||||||
|
result = svc.ResolveChannelMapping(context.Background(), 20, "claude-opus-4")
|
||||||
|
require.True(t, result.Mapped)
|
||||||
|
require.Equal(t, "claude-override", result.MappedModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// 13. Create/Update with mapping conflict validation
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestCreate_MappingConflict(t *testing.T) {
|
||||||
|
repo := &mockChannelRepository{}
|
||||||
|
svc := newTestChannelService(repo)
|
||||||
|
|
||||||
|
_, err := svc.Create(context.Background(), &CreateChannelInput{
|
||||||
|
Name: "test",
|
||||||
|
ModelMapping: map[string]map[string]string{
|
||||||
|
PlatformAnthropic: {
|
||||||
|
"claude-*": "target-a",
|
||||||
|
"claude-opus-*": "target-b",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "MAPPING_PATTERN_CONFLICT")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdate_MappingConflict(t *testing.T) {
|
||||||
|
existingChannel := &Channel{
|
||||||
|
ID: 1,
|
||||||
|
Name: "existing",
|
||||||
|
Status: StatusActive,
|
||||||
|
}
|
||||||
|
repo := &mockChannelRepository{
|
||||||
|
getByIDFn: func(_ context.Context, _ int64) (*Channel, error) {
|
||||||
|
return existingChannel, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
svc := newTestChannelService(repo)
|
||||||
|
|
||||||
|
conflictMapping := map[string]map[string]string{
|
||||||
|
PlatformAnthropic: {
|
||||||
|
"claude-*": "target-a",
|
||||||
|
"claude-opus-*": "target-b",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
_, err := svc.Update(context.Background(), 1, &UpdateChannelInput{
|
||||||
|
ModelMapping: conflictMapping,
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "MAPPING_PATTERN_CONFLICT")
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user