diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 3beaef33..0705494f 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -180,7 +180,7 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { modelPricingResolver := service.NewModelPricingResolver(channelService, billingService) gatewayService := service.NewGatewayService(accountRepository, groupRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, identityService, httpUpstream, deferredService, claudeTokenProvider, sessionLimitCache, rpmCache, digestSessionStore, settingService, tlsFingerprintProfileService, channelService, modelPricingResolver) openAITokenProvider := service.ProvideOpenAITokenProvider(accountRepository, geminiTokenCache, openAIOAuthService, oauthRefreshAPI) - openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver) + openAIGatewayService := service.NewOpenAIGatewayService(accountRepository, usageLogRepository, usageBillingRepository, userRepository, userSubscriptionRepository, userGroupRateRepository, gatewayCache, configConfig, schedulerSnapshotService, concurrencyService, billingService, rateLimitService, billingCacheService, httpUpstream, deferredService, openAITokenProvider, modelPricingResolver, channelService) geminiMessagesCompatService := service.NewGeminiMessagesCompatService(accountRepository, groupRepository, gatewayCache, schedulerSnapshotService, geminiTokenProvider, rateLimitService, httpUpstream, antigravityGatewayService, configConfig) opsSystemLogSink := service.ProvideOpsSystemLogSink(opsRepository) opsService := service.NewOpsService(opsRepository, settingRepository, configConfig, accountRepository, userRepository, concurrencyService, gatewayService, openAIGatewayService, geminiMessagesCompatService, antigravityGatewayService, opsSystemLogSink) diff --git a/backend/internal/handler/admin/channel_handler.go b/backend/internal/handler/admin/channel_handler.go index 77540d3d..e5a3eac5 100644 --- a/backend/internal/handler/admin/channel_handler.go +++ b/backend/internal/handler/admin/channel_handler.go @@ -276,11 +276,21 @@ func (h *ChannelHandler) Create(c *gin.Context) { return } + pricing := pricingRequestToService(req.ModelPricing) + for _, p := range pricing { + if p.BillingMode == service.BillingModePerRequest || p.BillingMode == service.BillingModeImage { + if p.PerRequestPrice == nil && len(p.Intervals) == 0 { + response.BadRequest(c, "Per-request price or intervals required for per_request/image billing mode") + return + } + } + } + channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{ Name: req.Name, Description: req.Description, GroupIDs: req.GroupIDs, - ModelPricing: pricingRequestToService(req.ModelPricing), + ModelPricing: pricing, ModelMapping: req.ModelMapping, BillingModelSource: req.BillingModelSource, RestrictModels: req.RestrictModels, @@ -319,6 +329,14 @@ func (h *ChannelHandler) Update(c *gin.Context) { } if req.ModelPricing != nil { pricing := pricingRequestToService(*req.ModelPricing) + for _, p := range pricing { + if p.BillingMode == service.BillingModePerRequest || p.BillingMode == service.BillingModeImage { + if p.PerRequestPrice == nil && len(p.Intervals) == 0 { + response.BadRequest(c, "Per-request price or intervals required for per_request/image billing mode") + return + } + } + } input.ModelPricing = &pricing } diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index ae70cee4..7f68a56b 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -185,6 +185,20 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) + // 解析渠道级模型映射 + var channelMapping service.ChannelMappingResult + if apiKey.GroupID != nil { + channelMapping = h.gatewayService.ResolveChannelMapping(c.Request.Context(), *apiKey.GroupID, reqModel) + } + + // 渠道模型限制检查 + if apiKey.GroupID != nil { + if h.gatewayService.IsModelRestricted(c.Request.Context(), *apiKey.GroupID, reqModel) { + h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts") + return + } + } + // 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。 if !h.validateFunctionCallOutputRequest(c, body, reqLog) { return @@ -379,6 +393,21 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { IPAddress: clientIP, RequestPayloadHash: requestPayloadHash, APIKeyService: h.apiKeyService, + ChannelID: channelMapping.ChannelID, + OriginalModel: reqModel, + BillingModelSource: channelMapping.BillingModelSource, + ModelMappingChain: func() string { + if !channelMapping.Mapped { + if result.UpstreamModel != "" && result.UpstreamModel != result.Model { + return reqModel + "→" + result.UpstreamModel + } + return "" + } + if result.UpstreamModel != "" && result.UpstreamModel != channelMapping.MappedModel { + return reqModel + "→" + channelMapping.MappedModel + "→" + result.UpstreamModel + } + return reqModel + "→" + channelMapping.MappedModel + }(), }); err != nil { logger.L().With( zap.String("component", "handler.openai_gateway.responses"), @@ -549,6 +578,20 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { setOpsRequestContext(c, reqModel, reqStream, body) setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false))) + // 解析渠道级模型映射 + var channelMappingMsg service.ChannelMappingResult + if apiKey.GroupID != nil { + channelMappingMsg = h.gatewayService.ResolveChannelMapping(c.Request.Context(), *apiKey.GroupID, reqModel) + } + + // 渠道模型限制检查 + if apiKey.GroupID != nil { + if h.gatewayService.IsModelRestricted(c.Request.Context(), *apiKey.GroupID, reqModel) { + h.anthropicErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts") + return + } + } + // 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。 if h.errorPassthroughService != nil { service.BindErrorPassthroughService(c, h.errorPassthroughService) @@ -759,6 +802,21 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) { IPAddress: clientIP, RequestPayloadHash: requestPayloadHash, APIKeyService: h.apiKeyService, + ChannelID: channelMappingMsg.ChannelID, + OriginalModel: reqModel, + BillingModelSource: channelMappingMsg.BillingModelSource, + ModelMappingChain: func() string { + if !channelMappingMsg.Mapped { + if result.UpstreamModel != "" && result.UpstreamModel != result.Model { + return reqModel + "→" + result.UpstreamModel + } + return "" + } + if result.UpstreamModel != "" && result.UpstreamModel != channelMappingMsg.MappedModel { + return reqModel + "→" + channelMappingMsg.MappedModel + "→" + result.UpstreamModel + } + return reqModel + "→" + channelMappingMsg.MappedModel + }(), }); err != nil { logger.L().With( zap.String("component", "handler.openai_gateway.messages"), @@ -1101,6 +1159,20 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { setOpsRequestContext(c, reqModel, true, firstMessage) setOpsEndpointContext(c, "", int16(service.RequestTypeWSV2)) + // 解析渠道级模型映射 + var channelMappingWS service.ChannelMappingResult + if apiKey.GroupID != nil { + channelMappingWS = h.gatewayService.ResolveChannelMapping(ctx, *apiKey.GroupID, reqModel) + } + + // 渠道模型限制检查 + if apiKey.GroupID != nil { + if h.gatewayService.IsModelRestricted(ctx, *apiKey.GroupID, reqModel) { + closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "model not allowed") + return + } + } + var currentUserRelease func() var currentAccountRelease func() releaseTurnSlots := func() { @@ -1259,6 +1331,21 @@ func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) { IPAddress: clientIP, RequestPayloadHash: service.HashUsageRequestPayload(firstMessage), APIKeyService: h.apiKeyService, + ChannelID: channelMappingWS.ChannelID, + OriginalModel: reqModel, + BillingModelSource: channelMappingWS.BillingModelSource, + ModelMappingChain: func() string { + if !channelMappingWS.Mapped { + if result.UpstreamModel != "" && result.UpstreamModel != result.Model { + return reqModel + "→" + result.UpstreamModel + } + return "" + } + if result.UpstreamModel != "" && result.UpstreamModel != channelMappingWS.MappedModel { + return reqModel + "→" + channelMappingWS.MappedModel + "→" + result.UpstreamModel + } + return reqModel + "→" + channelMappingWS.MappedModel + }(), }); err != nil { reqLog.Error("openai.websocket_record_usage_failed", zap.Int64("account_id", account.ID), diff --git a/backend/internal/service/channel_service.go b/backend/internal/service/channel_service.go index 6025ffcf..29f4b615 100644 --- a/backend/internal/service/channel_service.go +++ b/backend/internal/service/channel_service.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "log/slog" + "sort" "strings" "sync/atomic" "time" @@ -57,13 +58,26 @@ type channelModelKey struct { model string // lowercase } +// channelGroupPlatformKey 通配符定价缓存键 +type channelGroupPlatformKey struct { + groupID int64 + platform string +} + +// wildcardPricingEntry 通配符定价条目 +type wildcardPricingEntry struct { + prefix string + pricing *ChannelModelPricing +} + // channelCache 渠道缓存快照(扁平化哈希结构,热路径 O(1) 查找) type channelCache struct { // 热路径查找 - pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价 - mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标 - channelByGroupID map[int64]*Channel // groupID → 渠道 - groupPlatform map[int64]string // groupID → platform + pricingByGroupModel map[channelModelKey]*ChannelModelPricing // (groupID, platform, model) → 定价 + wildcardByGroupPlatform map[channelGroupPlatformKey][]*wildcardPricingEntry // (groupID, platform) → 通配符定价(前缀长度降序) + mappingByGroupModel map[channelModelKey]string // (groupID, platform, model) → 映射目标 + channelByGroupID map[int64]*Channel // groupID → 渠道 + groupPlatform map[int64]string // groupID → platform // 冷路径(CRUD 操作) byID map[int64]*Channel @@ -137,12 +151,13 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) // error-TTL:失败时存入短 TTL 空缓存,防止紧密重试 slog.Warn("failed to build channel cache", "error", err) errorCache := &channelCache{ - pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing), - mappingByGroupModel: make(map[channelModelKey]string), - channelByGroupID: make(map[int64]*Channel), - groupPlatform: make(map[int64]string), - byID: make(map[int64]*Channel), - loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL + pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing), + wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry), + mappingByGroupModel: make(map[channelModelKey]string), + channelByGroupID: make(map[int64]*Channel), + groupPlatform: make(map[int64]string), + byID: make(map[int64]*Channel), + loadedAt: time.Now().Add(channelCacheTTL - channelErrorTTL), // 使剩余 TTL = errorTTL } s.cache.Store(errorCache) return nil, fmt.Errorf("list all channels: %w", err) @@ -163,12 +178,13 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) } cache := &channelCache{ - pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing), - mappingByGroupModel: make(map[channelModelKey]string), - channelByGroupID: make(map[int64]*Channel), - groupPlatform: groupPlatforms, - byID: make(map[int64]*Channel, len(channels)), - loadedAt: time.Now(), + pricingByGroupModel: make(map[channelModelKey]*ChannelModelPricing), + wildcardByGroupPlatform: make(map[channelGroupPlatformKey][]*wildcardPricingEntry), + mappingByGroupModel: make(map[channelModelKey]string), + channelByGroupID: make(map[int64]*Channel), + groupPlatform: groupPlatforms, + byID: make(map[int64]*Channel, len(channels)), + loadedAt: time.Now(), } for i := range channels { @@ -187,8 +203,18 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) continue // 跳过非本平台的定价 } for _, model := range pricing.Models { - key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(model)} - cache.pricingByGroupModel[key] = pricing + if strings.HasSuffix(model, "*") { + // 通配符模型 → 存入 wildcardByGroupPlatform + prefix := strings.ToLower(strings.TrimSuffix(model, "*")) + gpKey := channelGroupPlatformKey{groupID: gid, platform: platform} + cache.wildcardByGroupPlatform[gpKey] = append(cache.wildcardByGroupPlatform[gpKey], &wildcardPricingEntry{ + prefix: prefix, + pricing: pricing, + }) + } else { + key := channelModelKey{groupID: gid, platform: platform, model: strings.ToLower(model)} + cache.pricingByGroupModel[key] = pricing + } } } @@ -202,6 +228,14 @@ func (s *ChannelService) buildCache(ctx context.Context) (*channelCache, error) } } + // 通配符条目按前缀长度降序排列(最长前缀优先匹配) + for gpKey, entries := range cache.wildcardByGroupPlatform { + sort.Slice(entries, func(i, j int) bool { + return len(entries[i].prefix) > len(entries[j].prefix) + }) + cache.wildcardByGroupPlatform[gpKey] = entries + } + s.cache.Store(cache) return cache, nil } @@ -212,6 +246,18 @@ func (s *ChannelService) invalidateCache() { s.cacheSF.Forget("channel_cache") } +// matchWildcard 在通配符定价中查找匹配项(最长前缀优先) +func (c *channelCache) matchWildcard(groupID int64, platform, modelLower string) *ChannelModelPricing { + gpKey := channelGroupPlatformKey{groupID: groupID, platform: platform} + wildcards := c.wildcardByGroupPlatform[gpKey] + for _, wc := range wildcards { + if strings.HasPrefix(modelLower, wc.prefix) { + return wc.pricing + } + } + return nil +} + // GetChannelForGroup 获取分组关联的渠道(热路径 O(1)) func (s *ChannelService) GetChannelForGroup(ctx context.Context, groupID int64) (*Channel, error) { cache, err := s.loadCache(ctx) @@ -245,7 +291,11 @@ func (s *ChannelService) GetChannelModelPricing(ctx context.Context, groupID int key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)} pricing, ok := cache.pricingByGroupModel[key] if !ok { - return nil + // 精确查找失败,尝试通配符匹配 + pricing = cache.matchWildcard(groupID, platform, strings.ToLower(model)) + if pricing == nil { + return nil + } } cp := pricing.Clone() @@ -302,7 +352,14 @@ func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, m platform := cache.groupPlatform[groupID] key := channelModelKey{groupID: groupID, platform: platform, model: strings.ToLower(model)} _, exists := cache.pricingByGroupModel[key] - return !exists + if exists { + return false + } + // 精确查找失败,尝试通配符匹配 + if cache.matchWildcard(groupID, platform, strings.ToLower(model)) != nil { + return false + } + return true } // --- CRUD --- diff --git a/backend/internal/service/openai_gateway_record_usage_test.go b/backend/internal/service/openai_gateway_record_usage_test.go index bfcff444..e2b164c0 100644 --- a/backend/internal/service/openai_gateway_record_usage_test.go +++ b/backend/internal/service/openai_gateway_record_usage_test.go @@ -146,6 +146,7 @@ func newOpenAIRecordUsageServiceForTest(usageRepo UsageLogRepository, userRepo U &DeferredService{}, nil, nil, + nil, ) svc.userGroupRateResolver = newUserGroupRateResolver( rateRepo, diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 06c36b0f..3818af02 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -323,6 +323,7 @@ type OpenAIGatewayService struct { toolCorrector *CodexToolCorrector openaiWSResolver OpenAIWSProtocolResolver resolver *ModelPricingResolver + channelService *ChannelService openaiWSPoolOnce sync.Once openaiWSStateStoreOnce sync.Once @@ -359,6 +360,7 @@ func NewOpenAIGatewayService( deferredService *DeferredService, openAITokenProvider *OpenAITokenProvider, resolver *ModelPricingResolver, + channelService *ChannelService, ) *OpenAIGatewayService { svc := &OpenAIGatewayService{ accountRepo: accountRepo, @@ -387,6 +389,7 @@ func NewOpenAIGatewayService( toolCorrector: NewCodexToolCorrector(), openaiWSResolver: NewOpenAIWSProtocolResolver(cfg), resolver: resolver, + channelService: channelService, responseHeaderFilter: compileResponseHeaderFilter(cfg), codexSnapshotThrottle: newAccountWriteThrottle(openAICodexSnapshotPersistMinInterval), } @@ -394,6 +397,22 @@ func NewOpenAIGatewayService( return svc } +// ResolveChannelMapping 解析渠道级模型映射(代理到 ChannelService) +func (s *OpenAIGatewayService) ResolveChannelMapping(ctx context.Context, groupID int64, model string) ChannelMappingResult { + if s.channelService == nil { + return ChannelMappingResult{MappedModel: model} + } + return s.channelService.ResolveChannelMapping(ctx, groupID, model) +} + +// IsModelRestricted 检查模型是否被渠道限制(代理到 ChannelService) +func (s *OpenAIGatewayService) IsModelRestricted(ctx context.Context, groupID int64, model string) bool { + if s.channelService == nil { + return false + } + return s.channelService.IsModelRestricted(ctx, groupID, model) +} + func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle { if s != nil && s.codexSnapshotThrottle != nil { return s.codexSnapshotThrottle @@ -4113,6 +4132,10 @@ type OpenAIRecordUsageInput struct { IPAddress string // 请求的客户端 IP 地址 RequestPayloadHash string APIKeyService APIKeyQuotaUpdater + ChannelID int64 + OriginalModel string + BillingModelSource string + ModelMappingChain string } // RecordUsage records usage and deducts balance @@ -4158,6 +4181,12 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec var cost *CostBreakdown var err error billingModel := forwardResultBillingModel(result.Model, result.UpstreamModel) + if result.BillingModel != "" { + billingModel = strings.TrimSpace(result.BillingModel) + } + if input.BillingModelSource == "requested" && input.OriginalModel != "" { + billingModel = input.OriginalModel + } serviceTier := "" if result.ServiceTier != nil { serviceTier = strings.TrimSpace(*result.ServiceTier) @@ -4223,6 +4252,9 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec FirstTokenMs: result.FirstTokenMs, CreatedAt: time.Now(), } + // 设置渠道信息 + usageLog.ChannelID = optionalInt64Ptr(input.ChannelID) + usageLog.ModelMappingChain = optionalTrimmedStringPtr(input.ModelMappingChain) // 设置计费模式 if cost != nil && cost.BillingMode != "" { billingMode := cost.BillingMode diff --git a/backend/internal/service/openai_ws_protocol_forward_test.go b/backend/internal/service/openai_ws_protocol_forward_test.go index e8d9f8f7..3834dcb7 100644 --- a/backend/internal/service/openai_ws_protocol_forward_test.go +++ b/backend/internal/service/openai_ws_protocol_forward_test.go @@ -616,6 +616,7 @@ func TestNewOpenAIGatewayService_InitializesOpenAIWSResolver(t *testing.T) { nil, nil, nil, + nil, ) decision := svc.getOpenAIWSProtocolResolver().Resolve(nil) diff --git a/frontend/src/i18n/locales/en.ts b/frontend/src/i18n/locales/en.ts index ee66db90..73df077f 100644 --- a/frontend/src/i18n/locales/en.ts +++ b/frontend/src/i18n/locales/en.ts @@ -1789,6 +1789,7 @@ export default { noTiersYet: 'No tiers yet. Click add to configure per-request pricing.', noPricingRules: 'No pricing rules yet. Click "Add" to create one.', perRequestPrice: 'Price per Request', + perRequestPriceRequired: 'Per-request price or billing tiers required for per-request/image billing mode', tierLabel: 'Tier', resolution: 'Resolution', modelMapping: 'Model Mapping', diff --git a/frontend/src/i18n/locales/zh.ts b/frontend/src/i18n/locales/zh.ts index 60ddb0d1..d5dd769e 100644 --- a/frontend/src/i18n/locales/zh.ts +++ b/frontend/src/i18n/locales/zh.ts @@ -1869,6 +1869,7 @@ export default { noTiersYet: '暂无层级,点击添加配置按次计费价格', noPricingRules: '暂无定价规则,点击"添加"创建', perRequestPrice: '单次价格', + perRequestPriceRequired: '按次/图片计费模式必须设置默认价格或至少一个计费层级', tierLabel: '层级', resolution: '分辨率', modelMapping: '模型映射', diff --git a/frontend/src/views/admin/ChannelsView.vue b/frontend/src/views/admin/ChannelsView.vue index 940b6d3a..c26a6fcf 100644 --- a/frontend/src/views/admin/ChannelsView.vue +++ b/frontend/src/views/admin/ChannelsView.vue @@ -876,6 +876,19 @@ async function handleSubmit() { return } + // 校验 per_request/image 模式必须有价格 + for (const section of form.platforms) { + for (const entry of section.model_pricing) { + if (entry.models.length === 0) continue + if ((entry.billing_mode === 'per_request' || entry.billing_mode === 'image') && + (entry.per_request_price == null || entry.per_request_price === '') && + (!entry.intervals || entry.intervals.length === 0)) { + appStore.showError(t('admin.channels.perRequestPriceRequired', '按次/图片计费模式必须设置默认价格或至少一个计费层级')) + return + } + } + } + const { group_ids, model_pricing, model_mapping } = formToAPI() console.log('[handleSubmit] model_pricing to send:', JSON.stringify(model_pricing)) diff --git a/frontend/src/views/user/UsageView.vue b/frontend/src/views/user/UsageView.vue index 3b8ef2e0..bbf5163b 100644 --- a/frontend/src/views/user/UsageView.vue +++ b/frontend/src/views/user/UsageView.vue @@ -181,6 +181,13 @@ + +