diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go index 563a27ce..b503e5c3 100644 --- a/backend/internal/handler/admin/channel_handler.go +++ b/backend/internal/handler/admin/channel_handler.go @@ -1,8 +1,6 @@ package admin import ( - "errors" - "fmt" "strconv" "strings" @@ -235,61 +233,6 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe return result } -// validatePricingBillingMode 校验计费配置 -func validatePricingBillingMode(pricing []service.ChannelModelPricing) error { - for _, p := range pricing { - // 按次/图片模式必须配置默认价格或区间 - if p.BillingMode == service.BillingModePerRequest || p.BillingMode == service.BillingModeImage { - if p.PerRequestPrice == nil && len(p.Intervals) == 0 { - return errors.New("per-request price or intervals required for per_request/image billing mode") - } - } - // 校验价格不能为负 - if err := validatePriceNotNegative("input_price", p.InputPrice); err != nil { - return err - } - if err := validatePriceNotNegative("output_price", p.OutputPrice); err != nil { - return err - } - if err := validatePriceNotNegative("cache_write_price", p.CacheWritePrice); err != nil { - return err - } - if err := validatePriceNotNegative("cache_read_price", p.CacheReadPrice); err != nil { - return err - } - if err := validatePriceNotNegative("image_output_price", p.ImageOutputPrice); err != nil { - return err - } - if err := validatePriceNotNegative("per_request_price", p.PerRequestPrice); err != nil { - return err - } - // 校验 interval:至少有一个价格字段非空 - for _, iv := range p.Intervals { - if iv.InputPrice == nil && iv.OutputPrice == nil && - iv.CacheWritePrice == nil && iv.CacheReadPrice == nil && - iv.PerRequestPrice == nil { - return fmt.Errorf("interval [%d, %s] has no price fields set for model %v", - iv.MinTokens, formatMaxTokens(iv.MaxTokens), p.Models) - } - } - } - return nil -} - -func validatePriceNotNegative(field string, val *float64) error { - if val != nil && *val < 0 { - return fmt.Errorf("%s must be >= 0", field) - } - return nil -} - -func formatMaxTokens(max *int) string { - if max == nil { - return "∞" - } - return fmt.Sprintf("%d", *max) -} - // --- Handlers --- // List handles listing channels with pagination @@ -343,10 +286,6 @@ func (h *ChannelHandler) Create(c *gin.Context) { } pricing := pricingRequestToService(req.ModelPricing) - if err := validatePricingBillingMode(pricing); err != nil { - response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error())) - return - } channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{ Name: req.Name, @@ -391,10 +330,6 @@ func (h *ChannelHandler) Update(c *gin.Context) { } if req.ModelPricing != nil { pricing := pricingRequestToService(*req.ModelPricing) - if err := validatePricingBillingMode(pricing); err != nil { - response.ErrorFrom(c, infraerrors.BadRequest("VALIDATION_ERROR", err.Error())) - return - } input.ModelPricing = &pricing } diff --git a/backend/internal/handler/admin/channel_handler_test.go b/backend/internal/handler/admin/channel_handler_test.go index 6f6ea526..2f4b4440 100644 --- a/backend/internal/handler/admin/channel_handler_test.go +++ b/backend/internal/handler/admin/channel_handler_test.go @@ -400,103 +400,3 @@ func TestPricingRequestToService_NilPriceFields(t *testing.T) { require.Nil(t, r.ImageOutputPrice) require.Nil(t, r.PerRequestPrice) } - -// --------------------------------------------------------------------------- -// 3. validatePricingBillingMode -// --------------------------------------------------------------------------- - -func TestValidatePricingBillingMode(t *testing.T) { - tests := []struct { - name string - pricing []service.ChannelModelPricing - wantErr bool - }{ - { - name: "token mode - valid", - pricing: []service.ChannelModelPricing{ - {BillingMode: service.BillingModeToken}, - }, - wantErr: false, - }, - { - name: "per_request with price - valid", - pricing: []service.ChannelModelPricing{ - { - BillingMode: service.BillingModePerRequest, - PerRequestPrice: float64Ptr(0.5), - }, - }, - wantErr: false, - }, - { - name: "per_request with intervals - valid", - pricing: []service.ChannelModelPricing{ - { - BillingMode: service.BillingModePerRequest, - Intervals: []service.PricingInterval{ - {MinTokens: 0, MaxTokens: intPtr(1000), PerRequestPrice: float64Ptr(0.1)}, - }, - }, - }, - wantErr: false, - }, - { - name: "per_request no price no intervals - invalid", - pricing: []service.ChannelModelPricing{ - {BillingMode: service.BillingModePerRequest}, - }, - wantErr: true, - }, - { - name: "image with price - valid", - pricing: []service.ChannelModelPricing{ - { - BillingMode: service.BillingModeImage, - PerRequestPrice: float64Ptr(0.2), - }, - }, - wantErr: false, - }, - { - name: "image no price no intervals - invalid", - pricing: []service.ChannelModelPricing{ - {BillingMode: service.BillingModeImage}, - }, - wantErr: true, - }, - { - name: "empty list - valid", - pricing: []service.ChannelModelPricing{}, - wantErr: false, - }, - { - name: "mixed modes with invalid image - invalid", - pricing: []service.ChannelModelPricing{ - { - BillingMode: service.BillingModeToken, - InputPrice: float64Ptr(0.01), - }, - { - BillingMode: service.BillingModePerRequest, - PerRequestPrice: float64Ptr(0.5), - }, - { - BillingMode: service.BillingModeImage, - }, - }, - wantErr: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - err := validatePricingBillingMode(tt.pricing) - if tt.wantErr { - require.Error(t, err) - require.Contains(t, err.Error(), "per-request price or intervals required") - } else { - require.NoError(t, err) - } - }) - } -} diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go index 7b96084d..9667cb98 100644 --- a/backend/internal/service/channel_service.go +++ b/backend/internal/service/channel_service.go @@ -248,40 +248,58 @@ func expandMappingToCache(cache *channelCache, ch *Channel, gid int64, platform } } +// storeErrorCache 存入短 TTL 空缓存,防止 DB 错误后紧密重试。 +// 通过回退 loadedAt 使剩余 TTL = channelErrorTTL。 +func (s *ChannelService) storeErrorCache() { + errorCache := newEmptyChannelCache() + errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) + s.cache.Store(errorCache) +} + // buildCache 从数据库构建渠道缓存。 // 使用独立 context 避免请求取消导致空值被长期缓存。 func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) { - // 断开请求取消链,避免客户端断连导致空值被长期缓存 dbCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), channelCacheDBTimeout) defer cancel() - channels, err := s.repo.ListAll(dbCtx) + channels, groupPlatforms, err := s.fetchChannelData(dbCtx) if err != nil { - // error-TTL:失败时存入短 TTL 空缓存,防止紧密重试 - slog.Warn("failed to build channel cache", "error", err) - errorCache := newEmptyChannelCache() - errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) // 使剩余 TTL = errorTTL - s.cache.Store(errorCache) - return nil, fmt.Errorf("list all channels: %w", err) + return nil, err + } + + cache := populateChannelCache(channels, groupPlatforms) + s.cache.Store(cache) + return cache, nil +} + +// fetchChannelData 从数据库加载渠道列表和分组平台映射。 +func (s *ChannelService) fetchChannelData(ctx context.Context) ([]Channel, map[int64]string, error) { + channels, err := s.repo.ListAll(ctx) + if err != nil { + slog.Warn("failed to build channel cache", "error", err) + s.storeErrorCache() + return nil, nil, fmt.Errorf("list all channels: %w", err) } - // 收集所有 groupID,批量查询 platform var allGroupIDs []int64 for i := range channels { allGroupIDs = append(allGroupIDs, channels[i].GroupIDs...) } + groupPlatforms := make(map[int64]string) if len(allGroupIDs) > 0 { - groupPlatforms, err = s.repo.GetGroupPlatforms(dbCtx, allGroupIDs) + groupPlatforms, err = s.repo.GetGroupPlatforms(ctx, allGroupIDs) if err != nil { slog.Warn("failed to load group platforms for channel cache", "error", err) - errorCache := newEmptyChannelCache() - errorCache.loadedAt = time.Now().Add(-(channelCacheTTL - channelErrorTTL)) - s.cache.Store(errorCache) - return nil, fmt.Errorf("get group platforms: %w", err) + s.storeErrorCache() + return nil, nil, fmt.Errorf("get group platforms: %w", err) } } + return channels, groupPlatforms, nil +} +// populateChannelCache 将渠道列表和分组平台映射填充到缓存快照中。 +func populateChannelCache(channels []Channel, groupPlatforms map[int64]string) *channelCache { cache := newEmptyChannelCache() cache.groupPlatform = groupPlatforms cache.byID = make(map[int64]*Channel, len(channels)) @@ -290,7 +308,6 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) for i := range channels { ch := &channels[i] cache.byID[ch.ID] = ch - for _, gid := range ch.GroupIDs { cache.channelByGroupID[gid] = ch platform := groupPlatforms[gid] @@ -298,11 +315,7 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) expandMappingToCache(cache, ch, gid, platform) } } - - // 通配符条目保持配置顺序(最先匹配到优先) - - s.cache.Store(cache) - return cache, nil + return cache } // invalidateCache 使缓存失效,让下次读取时自然重建 @@ -466,7 +479,10 @@ func (s *ChannelService) ResolveChannelMapping(ctx context.Context, groupID int6 // 返回 true 表示模型被限制(不在允许列表中)。 // 如果渠道未启用模型限制或分组无渠道关联,返回 false。 func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool { - lk, _ := s.lookupGroupChannel(ctx, groupID) + lk, err := s.lookupGroupChannel(ctx, groupID) + if err != nil { + slog.Warn("failed to load channel cache for model restriction check", "group_id", groupID, "error", err) + } if lk == nil { return false } @@ -537,6 +553,91 @@ func ReplaceModelInBody(body []byte, newModel string) []byte { return newBody } +// validateChannelConfig 校验渠道的定价和映射配置(冲突检测 + 区间校验 + 计费模式校验)。 +// Create 和 Update 共用此函数,避免重复。 +func validateChannelConfig(pricing []ChannelModelPricing, mapping map[string]map[string]string) error { + if err := validateNoConflictingModels(pricing); err != nil { + return err + } + if err := validatePricingIntervals(pricing); err != nil { + return err + } + if err := validateNoConflictingMappings(mapping); err != nil { + return err + } + return validatePricingBillingMode(pricing) +} + +// validatePricingBillingMode 校验计费模式配置:按次/图片模式必须配价格或区间,所有价格字段不能为负,区间至少有一个价格字段。 +func validatePricingBillingMode(pricing []ChannelModelPricing) error { + for _, p := range pricing { + if err := checkBillingModeRequirements(p); err != nil { + return err + } + if err := checkPricesNotNegative(p); err != nil { + return err + } + if err := checkIntervalsHavePrices(p); err != nil { + return err + } + } + return nil +} + +func checkBillingModeRequirements(p ChannelModelPricing) error { + if p.BillingMode == BillingModePerRequest || p.BillingMode == BillingModeImage { + if p.PerRequestPrice == nil && len(p.Intervals) == 0 { + return infraerrors.BadRequest( + "BILLING_MODE_MISSING_PRICE", + "per-request price or intervals required for per_request/image billing mode", + ) + } + } + return nil +} + +func checkPricesNotNegative(p ChannelModelPricing) error { + checks := []struct { + field string + val *float64 + }{ + {"input_price", p.InputPrice}, + {"output_price", p.OutputPrice}, + {"cache_write_price", p.CacheWritePrice}, + {"cache_read_price", p.CacheReadPrice}, + {"image_output_price", p.ImageOutputPrice}, + {"per_request_price", p.PerRequestPrice}, + } + for _, c := range checks { + if c.val != nil && *c.val < 0 { + return infraerrors.BadRequest("NEGATIVE_PRICE", fmt.Sprintf("%s must be >= 0", c.field)) + } + } + return nil +} + +func checkIntervalsHavePrices(p ChannelModelPricing) error { + for _, iv := range p.Intervals { + if iv.InputPrice == nil && iv.OutputPrice == nil && + iv.CacheWritePrice == nil && iv.CacheReadPrice == nil && + iv.PerRequestPrice == nil { + return infraerrors.BadRequest( + "INTERVAL_MISSING_PRICE", + fmt.Sprintf("interval [%d, %s] has no price fields set for model %v", + iv.MinTokens, formatMaxTokens(iv.MaxTokens), p.Models), + ) + } + } + return nil +} + +func formatMaxTokens(max *int) string { + if max == nil { + return "∞" + } + return fmt.Sprintf("%d", *max) +} + // --- CRUD --- // Create 创建渠道 @@ -549,15 +650,8 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) return nil, ErrChannelExists } - // 检查分组冲突 - if len(input.GroupIDs) > 0 { - conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, 0, input.GroupIDs) - if err != nil { - return nil, fmt.Errorf("check group conflicts: %w", err) - } - if len(conflicting) > 0 { - return nil, ErrGroupAlreadyInChannel - } + if err := s.checkGroupConflicts(ctx, 0, input.GroupIDs); err != nil { + return nil, err } channel := &Channel{ @@ -574,13 +668,7 @@ func (s *ChannelService) Create(ctx context.Context, input *CreateChannelInput) channel.BillingModelSource = BillingModelSourceChannelMapped } - if err := validateNoConflictingModels(channel.ModelPricing); err != nil { - return nil, err - } - if err := validatePricingIntervals(channel.ModelPricing); err != nil { - return nil, err - } - if err := validateNoConflictingMappings(channel.ModelMapping); err != nil { + if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil { return nil, err } @@ -604,102 +692,112 @@ func (s *ChannelService) Update(ctx context.Context, id int64, input *UpdateChan return nil, fmt.Errorf("get channel: %w", err) } - if input.Name != "" && input.Name != channel.Name { - exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, id) - if err != nil { - return nil, fmt.Errorf("check channel exists: %w", err) - } - if exists { - return nil, ErrChannelExists - } - channel.Name = input.Name - } - - if input.Description != nil { - channel.Description = *input.Description - } - - if input.Status != "" { - channel.Status = input.Status - } - - if input.RestrictModels != nil { - channel.RestrictModels = *input.RestrictModels - } - - // 检查分组冲突 - if input.GroupIDs != nil { - conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, id, *input.GroupIDs) - if err != nil { - return nil, fmt.Errorf("check group conflicts: %w", err) - } - if len(conflicting) > 0 { - return nil, ErrGroupAlreadyInChannel - } - channel.GroupIDs = *input.GroupIDs - } - - if input.ModelPricing != nil { - channel.ModelPricing = *input.ModelPricing - } - - if input.ModelMapping != nil { - channel.ModelMapping = input.ModelMapping - } - - if input.BillingModelSource != "" { - channel.BillingModelSource = input.BillingModelSource - } - - if err := validateNoConflictingModels(channel.ModelPricing); err != nil { - return nil, err - } - if err := validatePricingIntervals(channel.ModelPricing); err != nil { - return nil, err - } - if err := validateNoConflictingMappings(channel.ModelMapping); err != nil { + if err := s.applyUpdateInput(ctx, channel, input); err != nil { return nil, err } - // 先获取旧分组,Update 后旧分组关联已删除,无法再查到 - var oldGroupIDs []int64 - if s.authCacheInvalidator != nil { - var err2 error - oldGroupIDs, err2 = s.repo.GetGroupIDs(ctx, id) - if err2 != nil { - slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", id, "error", err2) - } + if err := validateChannelConfig(channel.ModelPricing, channel.ModelMapping); err != nil { + return nil, err } + oldGroupIDs := s.getOldGroupIDs(ctx, id) + if err := s.repo.Update(ctx, channel); err != nil { return nil, fmt.Errorf("update channel: %w", err) } s.invalidateCache() - - // 失效新旧分组的 auth 缓存 - if s.authCacheInvalidator != nil { - seen := make(map[int64]struct{}, len(oldGroupIDs)+len(channel.GroupIDs)) - for _, gid := range oldGroupIDs { - if _, ok := seen[gid]; !ok { - seen[gid] = struct{}{} - s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid) - } - } - for _, gid := range channel.GroupIDs { - if _, ok := seen[gid]; !ok { - seen[gid] = struct{}{} - s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid) - } - } - } + s.invalidateAuthCacheForGroups(ctx, oldGroupIDs, channel.GroupIDs) return s.repo.GetByID(ctx, id) } +// applyUpdateInput 将更新请求的字段应用到渠道实体上。 +func (s *ChannelService) applyUpdateInput(ctx context.Context, channel *Channel, input *UpdateChannelInput) error { + if input.Name != "" && input.Name != channel.Name { + exists, err := s.repo.ExistsByNameExcluding(ctx, input.Name, channel.ID) + if err != nil { + return fmt.Errorf("check channel exists: %w", err) + } + if exists { + return ErrChannelExists + } + channel.Name = input.Name + } + if input.Description != nil { + channel.Description = *input.Description + } + if input.Status != "" { + channel.Status = input.Status + } + if input.RestrictModels != nil { + channel.RestrictModels = *input.RestrictModels + } + if input.GroupIDs != nil { + if err := s.checkGroupConflicts(ctx, channel.ID, *input.GroupIDs); err != nil { + return err + } + channel.GroupIDs = *input.GroupIDs + } + if input.ModelPricing != nil { + channel.ModelPricing = *input.ModelPricing + } + if input.ModelMapping != nil { + channel.ModelMapping = input.ModelMapping + } + if input.BillingModelSource != "" { + channel.BillingModelSource = input.BillingModelSource + } + return nil +} + +// checkGroupConflicts 检查待关联的分组是否已属于其他渠道。 +// channelID 为当前渠道 ID(Create 时传 0)。 +func (s *ChannelService) checkGroupConflicts(ctx context.Context, channelID int64, groupIDs []int64) error { + if len(groupIDs) == 0 { + return nil + } + conflicting, err := s.repo.GetGroupsInOtherChannels(ctx, channelID, groupIDs) + if err != nil { + return fmt.Errorf("check group conflicts: %w", err) + } + if len(conflicting) > 0 { + return ErrGroupAlreadyInChannel + } + return nil +} + +// getOldGroupIDs 获取渠道更新前的关联分组 ID(用于失效 auth 缓存)。 +func (s *ChannelService) getOldGroupIDs(ctx context.Context, channelID int64) []int64 { + if s.authCacheInvalidator == nil { + return nil + } + oldGroupIDs, err := s.repo.GetGroupIDs(ctx, channelID) + if err != nil { + slog.Warn("failed to get old group IDs for cache invalidation", "channel_id", channelID, "error", err) + } + return oldGroupIDs +} + +// invalidateAuthCacheForGroups 对新旧分组去重后逐个失效 auth 缓存。 +func (s *ChannelService) invalidateAuthCacheForGroups(ctx context.Context, groupIDSets ...[]int64) { + if s.authCacheInvalidator == nil { + return + } + seen := make(map[int64]struct{}) + for _, ids := range groupIDSets { + for _, gid := range ids { + if _, ok := seen[gid]; ok { + continue + } + seen[gid] = struct{}{} + s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid) + } + } +} + // Delete 删除渠道 func (s *ChannelService) Delete(ctx context.Context, id int64) error { - // 先获取关联分组用于失效缓存 groupIDs, err := s.repo.GetGroupIDs(ctx, id) if err != nil { slog.Warn("failed to get group IDs before delete", "channel_id", id, "error", err) @@ -710,12 +808,7 @@ func (s *ChannelService) Delete(ctx context.Context, id int64) error { } s.invalidateCache() - - if s.authCacheInvalidator != nil { - for _, gid := range groupIDs { - s.authCacheInvalidator.InvalidateAuthCacheByGroupID(ctx, gid) - } - } + s.invalidateAuthCacheForGroups(ctx, groupIDs) return nil } diff --git a/backend/internal/service/channel_service_test.go b/backend/internal/service/channel_service_test.go index 3a01fd80..e1345618 100644 --- a/backend/internal/service/channel_service_test.go +++ b/backend/internal/service/channel_service_test.go @@ -2199,3 +2199,207 @@ func TestGetChannelModelPricing_NonAntigravityUnaffected(t *testing.T) { require.Equal(t, int64(601), result.ID) require.InDelta(t, 5e-6, *result.InputPrice, 1e-12) } + +// --------------------------------------------------------------------------- +// 10. ToUsageFields +// --------------------------------------------------------------------------- + +func TestToUsageFields_NoMapping(t *testing.T) { + r := ChannelMappingResult{ + MappedModel: "claude-opus-4", + ChannelID: 1, + Mapped: false, + BillingModelSource: BillingModelSourceRequested, + } + fields := r.ToUsageFields("claude-opus-4", "claude-opus-4") + require.Equal(t, int64(1), fields.ChannelID) + require.Equal(t, "claude-opus-4", fields.OriginalModel) + require.Equal(t, "claude-opus-4", fields.ChannelMappedModel) + require.Equal(t, BillingModelSourceRequested, fields.BillingModelSource) + require.Empty(t, fields.ModelMappingChain) +} + +func TestToUsageFields_WithChannelMapping(t *testing.T) { + r := ChannelMappingResult{ + MappedModel: "claude-sonnet-4-20250514", + ChannelID: 2, + Mapped: true, + BillingModelSource: BillingModelSourceChannelMapped, + } + fields := r.ToUsageFields("claude-sonnet-4", "claude-sonnet-4-20250514") + require.Equal(t, int64(2), fields.ChannelID) + require.Equal(t, "claude-sonnet-4", fields.OriginalModel) + require.Equal(t, "claude-sonnet-4-20250514", fields.ChannelMappedModel) + require.Equal(t, "claude-sonnet-4→claude-sonnet-4-20250514", fields.ModelMappingChain) +} + +func TestToUsageFields_WithUpstreamDifference(t *testing.T) { + r := ChannelMappingResult{ + MappedModel: "claude-sonnet-4", + ChannelID: 3, + Mapped: true, + BillingModelSource: BillingModelSourceUpstream, + } + fields := r.ToUsageFields("my-alias", "claude-sonnet-4-20250514") + require.Equal(t, "my-alias", fields.OriginalModel) + require.Equal(t, "claude-sonnet-4", fields.ChannelMappedModel) + require.Equal(t, "my-alias→claude-sonnet-4→claude-sonnet-4-20250514", fields.ModelMappingChain) +} + +// --------------------------------------------------------------------------- +// 11. validatePricingBillingMode (moved from handler tests) +// --------------------------------------------------------------------------- + +func TestValidatePricingBillingMode(t *testing.T) { + tests := []struct { + name string + pricing []ChannelModelPricing + wantErr bool + errMsg string + }{ + { + name: "token mode - valid", + pricing: []ChannelModelPricing{{BillingMode: BillingModeToken}}, + }, + { + name: "per_request with price - valid", + pricing: []ChannelModelPricing{{ + BillingMode: BillingModePerRequest, + PerRequestPrice: testPtrFloat64(0.5), + }}, + }, + { + name: "per_request with intervals - valid", + pricing: []ChannelModelPricing{{ + BillingMode: BillingModePerRequest, + Intervals: []PricingInterval{{MinTokens: 0, MaxTokens: testPtrInt(1000), PerRequestPrice: testPtrFloat64(0.1)}}, + }}, + }, + { + name: "per_request no price no intervals - invalid", + pricing: []ChannelModelPricing{{BillingMode: BillingModePerRequest}}, + wantErr: true, + errMsg: "per-request price or intervals required", + }, + { + name: "image no price no intervals - invalid", + pricing: []ChannelModelPricing{{BillingMode: BillingModeImage}}, + wantErr: true, + errMsg: "per-request price or intervals required", + }, + { + name: "empty list - valid", + pricing: []ChannelModelPricing{}, + }, + { + name: "negative input_price - invalid", + pricing: []ChannelModelPricing{{ + BillingMode: BillingModeToken, + InputPrice: testPtrFloat64(-0.01), + }}, + wantErr: true, + errMsg: "input_price must be >= 0", + }, + { + name: "interval with no price fields - invalid", + pricing: []ChannelModelPricing{{ + BillingMode: BillingModePerRequest, + PerRequestPrice: testPtrFloat64(0.5), + Intervals: []PricingInterval{{MinTokens: 0, MaxTokens: testPtrInt(1000)}}, + }}, + wantErr: true, + errMsg: "has no price fields set", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validatePricingBillingMode(tt.pricing) + if tt.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errMsg) + } else { + require.NoError(t, err) + } + }) + } +} + +// --------------------------------------------------------------------------- +// 12. Antigravity wildcard mapping isolation +// --------------------------------------------------------------------------- + +func TestResolveChannelMapping_AntigravityDoesNotSeeWildcardMappingFromOtherPlatforms(t *testing.T) { + ch := Channel{ + ID: 1, + Status: StatusActive, + GroupIDs: []int64{10, 20}, + ModelMapping: map[string]map[string]string{ + PlatformAnthropic: {"claude-*": "claude-override"}, + PlatformGemini: {"gemini-*": "gemini-override"}, + }, + } + repo := makeStandardRepo(ch, map[int64]string{10: PlatformAntigravity, 20: PlatformAnthropic}) + svc := newTestChannelService(repo) + + // antigravity 分组不应看到 anthropic/gemini 的通配符映射 + result := svc.ResolveChannelMapping(context.Background(), 10, "claude-opus-4") + require.False(t, result.Mapped) + require.Equal(t, "claude-opus-4", result.MappedModel) + + result = svc.ResolveChannelMapping(context.Background(), 10, "gemini-2.5-pro") + require.False(t, result.Mapped) + require.Equal(t, "gemini-2.5-pro", result.MappedModel) + + // anthropic 分组应该能看到 anthropic 的通配符映射 + result = svc.ResolveChannelMapping(context.Background(), 20, "claude-opus-4") + require.True(t, result.Mapped) + require.Equal(t, "claude-override", result.MappedModel) +} + +// --------------------------------------------------------------------------- +// 13. Create/Update with mapping conflict validation +// --------------------------------------------------------------------------- + +func TestCreate_MappingConflict(t *testing.T) { + repo := &mockChannelRepository{} + svc := newTestChannelService(repo) + + _, err := svc.Create(context.Background(), &CreateChannelInput{ + Name: "test", + ModelMapping: map[string]map[string]string{ + PlatformAnthropic: { + "claude-*": "target-a", + "claude-opus-*": "target-b", + }, + }, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "MAPPING_PATTERN_CONFLICT") +} + +func TestUpdate_MappingConflict(t *testing.T) { + existingChannel := &Channel{ + ID: 1, + Name: "existing", + Status: StatusActive, + } + repo := &mockChannelRepository{ + getByIDFn: func(_ context.Context, _ int64) (*Channel, error) { + return existingChannel, nil + }, + } + svc := newTestChannelService(repo) + + conflictMapping := map[string]map[string]string{ + PlatformAnthropic: { + "claude-*": "target-a", + "claude-opus-*": "target-b", + }, + } + _, err := svc.Update(context.Background(), 1, &UpdateChannelInput{ + ModelMapping: conflictMapping, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "MAPPING_PATTERN_CONFLICT") +}