fix(channel): 全平台渠道映射覆盖 + 公共函数抽取 + 死代码清理
- 4个缺失handler入口添加渠道映射+限制检查(ChatCompletions/Responses/Gemini) - 模型限制错误信息优化,区分"模型不可用"和"无账号" - OpenAI RecordUsage RequestedModel 改用 OriginalModel - ResolveChannelMappingAndRestrict/ReplaceModelInBody 抽取到 ChannelService 消除跨service重复 - validateNoDuplicateModels 按 platform:model 去重 - 删除 Channel.ResolveMappedModel 死代码和 CalculateCostWithChannel Deprecated方法 - 移除冗余nil检查,抽取 validatePricingBillingMode 公共校验
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
package admin
|
package admin
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -224,6 +225,18 @@ func pricingRequestToService(reqs []channelModelPricingRequest) []service.Channe
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// validatePricingBillingMode 校验按次/图片计费模式必须配置 PerRequestPrice 或 Intervals
|
||||||
|
func validatePricingBillingMode(pricing []service.ChannelModelPricing) error {
|
||||||
|
for _, p := range pricing {
|
||||||
|
if p.BillingMode == service.BillingModePerRequest || p.BillingMode == service.BillingModeImage {
|
||||||
|
if p.PerRequestPrice == nil && len(p.Intervals) == 0 {
|
||||||
|
return errors.New("Per-request price or intervals required for per_request/image billing mode")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// --- Handlers ---
|
// --- Handlers ---
|
||||||
|
|
||||||
// List handles listing channels with pagination
|
// List handles listing channels with pagination
|
||||||
@@ -277,13 +290,9 @@ func (h *ChannelHandler) Create(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pricing := pricingRequestToService(req.ModelPricing)
|
pricing := pricingRequestToService(req.ModelPricing)
|
||||||
for _, p := range pricing {
|
if err := validatePricingBillingMode(pricing); err != nil {
|
||||||
if p.BillingMode == service.BillingModePerRequest || p.BillingMode == service.BillingModeImage {
|
response.BadRequest(c, err.Error())
|
||||||
if p.PerRequestPrice == nil && len(p.Intervals) == 0 {
|
return
|
||||||
response.BadRequest(c, "Per-request price or intervals required for per_request/image billing mode")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
|
channel, err := h.channelService.Create(c.Request.Context(), &service.CreateChannelInput{
|
||||||
@@ -329,13 +338,9 @@ func (h *ChannelHandler) Update(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
if req.ModelPricing != nil {
|
if req.ModelPricing != nil {
|
||||||
pricing := pricingRequestToService(*req.ModelPricing)
|
pricing := pricingRequestToService(*req.ModelPricing)
|
||||||
for _, p := range pricing {
|
if err := validatePricingBillingMode(pricing); err != nil {
|
||||||
if p.BillingMode == service.BillingModePerRequest || p.BillingMode == service.BillingModeImage {
|
response.BadRequest(c, err.Error())
|
||||||
if p.PerRequestPrice == nil && len(p.Intervals) == 0 {
|
return
|
||||||
response.BadRequest(c, "Per-request price or intervals required for per_request/image billing mode")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
input.ModelPricing = &pricing
|
input.ModelPricing = &pricing
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -161,7 +161,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
|
|||||||
// 解析渠道级模型映射 + 限制检查
|
// 解析渠道级模型映射 + 限制检查
|
||||||
channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||||
if restricted {
|
if restricted {
|
||||||
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
|
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -80,6 +80,13 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||||
|
|
||||||
|
// 解析渠道级模型映射 + 限制检查
|
||||||
|
channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||||
|
if restricted {
|
||||||
|
h.chatCompletionsErrorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Claude Code only restriction
|
// Claude Code only restriction
|
||||||
if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly {
|
if apiKey.Group != nil && apiKey.Group.ClaudeCodeOnly {
|
||||||
h.chatCompletionsErrorResponse(c, http.StatusForbidden, "permission_error",
|
h.chatCompletionsErrorResponse(c, http.StatusForbidden, "permission_error",
|
||||||
@@ -203,7 +210,11 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
|
|
||||||
// 5. Forward request
|
// 5. Forward request
|
||||||
writerSizeBeforeForward := c.Writer.Size()
|
writerSizeBeforeForward := c.Writer.Size()
|
||||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, parsedReq)
|
forwardBody := body
|
||||||
|
if channelMapping.Mapped {
|
||||||
|
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
||||||
|
}
|
||||||
|
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, parsedReq)
|
||||||
|
|
||||||
if accountReleaseFunc != nil {
|
if accountReleaseFunc != nil {
|
||||||
accountReleaseFunc()
|
accountReleaseFunc()
|
||||||
@@ -255,6 +266,10 @@ func (h *GatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
RequestPayloadHash: requestPayloadHash,
|
RequestPayloadHash: requestPayloadHash,
|
||||||
APIKeyService: h.apiKeyService,
|
APIKeyService: h.apiKeyService,
|
||||||
|
ChannelID: channelMapping.ChannelID,
|
||||||
|
OriginalModel: reqModel,
|
||||||
|
BillingModelSource: channelMapping.BillingModelSource,
|
||||||
|
ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel),
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
reqLog.Error("gateway.cc.record_usage_failed",
|
reqLog.Error("gateway.cc.record_usage_failed",
|
||||||
zap.Int64("account_id", account.ID),
|
zap.Int64("account_id", account.ID),
|
||||||
|
|||||||
@@ -80,6 +80,13 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
|
|||||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||||
|
|
||||||
|
// 解析渠道级模型映射 + 限制检查
|
||||||
|
channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||||
|
if restricted {
|
||||||
|
h.responsesErrorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// Claude Code only restriction:
|
// Claude Code only restriction:
|
||||||
// /v1/responses is never a Claude Code endpoint.
|
// /v1/responses is never a Claude Code endpoint.
|
||||||
// When claude_code_only is enabled, this endpoint is rejected.
|
// When claude_code_only is enabled, this endpoint is rejected.
|
||||||
@@ -208,7 +215,11 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
|
|||||||
|
|
||||||
// 5. Forward request
|
// 5. Forward request
|
||||||
writerSizeBeforeForward := c.Writer.Size()
|
writerSizeBeforeForward := c.Writer.Size()
|
||||||
result, err := h.gatewayService.ForwardAsResponses(c.Request.Context(), c, account, body, parsedReq)
|
forwardBody := body
|
||||||
|
if channelMapping.Mapped {
|
||||||
|
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
||||||
|
}
|
||||||
|
result, err := h.gatewayService.ForwardAsResponses(c.Request.Context(), c, account, forwardBody, parsedReq)
|
||||||
|
|
||||||
if accountReleaseFunc != nil {
|
if accountReleaseFunc != nil {
|
||||||
accountReleaseFunc()
|
accountReleaseFunc()
|
||||||
@@ -261,6 +272,10 @@ func (h *GatewayHandler) Responses(c *gin.Context) {
|
|||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
RequestPayloadHash: requestPayloadHash,
|
RequestPayloadHash: requestPayloadHash,
|
||||||
APIKeyService: h.apiKeyService,
|
APIKeyService: h.apiKeyService,
|
||||||
|
ChannelID: channelMapping.ChannelID,
|
||||||
|
OriginalModel: reqModel,
|
||||||
|
BillingModelSource: channelMapping.BillingModelSource,
|
||||||
|
ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel),
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
reqLog.Error("gateway.responses.record_usage_failed",
|
reqLog.Error("gateway.responses.record_usage_failed",
|
||||||
zap.Int64("account_id", account.ID),
|
zap.Int64("account_id", account.ID),
|
||||||
|
|||||||
@@ -184,6 +184,17 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
setOpsRequestContext(c, modelName, stream, body)
|
setOpsRequestContext(c, modelName, stream, body)
|
||||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
|
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(stream, false)))
|
||||||
|
|
||||||
|
// 解析渠道级模型映射 + 限制检查
|
||||||
|
channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, modelName)
|
||||||
|
if restricted {
|
||||||
|
googleError(c, http.StatusServiceUnavailable, "The requested model is not available for this API key")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
reqModel := modelName // 保存映射前的原始模型名
|
||||||
|
if channelMapping.Mapped {
|
||||||
|
modelName = channelMapping.MappedModel
|
||||||
|
}
|
||||||
|
|
||||||
// Get subscription (may be nil)
|
// Get subscription (may be nil)
|
||||||
subscription, _ := middleware.GetSubscriptionFromContext(c)
|
subscription, _ := middleware.GetSubscriptionFromContext(c)
|
||||||
|
|
||||||
@@ -523,6 +534,10 @@ func (h *GatewayHandler) GeminiV1BetaModels(c *gin.Context) {
|
|||||||
LongContextMultiplier: 2.0, // 超出部分双倍计费
|
LongContextMultiplier: 2.0, // 超出部分双倍计费
|
||||||
ForceCacheBilling: fs.ForceCacheBilling,
|
ForceCacheBilling: fs.ForceCacheBilling,
|
||||||
APIKeyService: h.apiKeyService,
|
APIKeyService: h.apiKeyService,
|
||||||
|
ChannelID: channelMapping.ChannelID,
|
||||||
|
OriginalModel: reqModel,
|
||||||
|
BillingModelSource: channelMapping.BillingModelSource,
|
||||||
|
ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel),
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
zap.String("component", "handler.gemini_v1beta.models"),
|
zap.String("component", "handler.gemini_v1beta.models"),
|
||||||
|
|||||||
@@ -79,6 +79,13 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||||
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
setOpsEndpointContext(c, "", int16(service.RequestTypeFromLegacy(reqStream, false)))
|
||||||
|
|
||||||
|
// 解析渠道级模型映射 + 限制检查
|
||||||
|
channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||||
|
if restricted {
|
||||||
|
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if h.errorPassthroughService != nil {
|
if h.errorPassthroughService != nil {
|
||||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||||
}
|
}
|
||||||
@@ -183,7 +190,11 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
forwardStart := time.Now()
|
forwardStart := time.Now()
|
||||||
|
|
||||||
defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_chat_completions_fallback_model"))
|
defaultMappedModel := resolveOpenAIForwardDefaultMappedModel(apiKey, c.GetString("openai_chat_completions_fallback_model"))
|
||||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
forwardBody := body
|
||||||
|
if channelMapping.Mapped {
|
||||||
|
forwardBody = h.gatewayService.ReplaceModelInBody(body, channelMapping.MappedModel)
|
||||||
|
}
|
||||||
|
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, forwardBody, promptCacheKey, defaultMappedModel)
|
||||||
|
|
||||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||||
if accountReleaseFunc != nil {
|
if accountReleaseFunc != nil {
|
||||||
@@ -257,16 +268,20 @@ func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
|
|
||||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||||
Result: result,
|
Result: result,
|
||||||
APIKey: apiKey,
|
APIKey: apiKey,
|
||||||
User: apiKey.User,
|
User: apiKey.User,
|
||||||
Account: account,
|
Account: account,
|
||||||
Subscription: subscription,
|
Subscription: subscription,
|
||||||
InboundEndpoint: GetInboundEndpoint(c),
|
InboundEndpoint: GetInboundEndpoint(c),
|
||||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||||
UserAgent: userAgent,
|
UserAgent: userAgent,
|
||||||
IPAddress: clientIP,
|
IPAddress: clientIP,
|
||||||
APIKeyService: h.apiKeyService,
|
APIKeyService: h.apiKeyService,
|
||||||
|
ChannelID: channelMapping.ChannelID,
|
||||||
|
OriginalModel: reqModel,
|
||||||
|
BillingModelSource: channelMapping.BillingModelSource,
|
||||||
|
ModelMappingChain: channelMapping.BuildModelMappingChain(reqModel, result.UpstreamModel),
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
logger.L().With(
|
logger.L().With(
|
||||||
zap.String("component", "handler.openai_gateway.chat_completions"),
|
zap.String("component", "handler.openai_gateway.chat_completions"),
|
||||||
|
|||||||
@@ -188,7 +188,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
// 解析渠道级模型映射 + 限制检查
|
// 解析渠道级模型映射 + 限制检查
|
||||||
channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
channelMapping, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||||
if restricted {
|
if restricted {
|
||||||
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
|
h.errorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -568,7 +568,7 @@ func (h *OpenAIGatewayHandler) Messages(c *gin.Context) {
|
|||||||
// 解析渠道级模型映射 + 限制检查
|
// 解析渠道级模型映射 + 限制检查
|
||||||
channelMappingMsg, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
channelMappingMsg, restricted := h.gatewayService.ResolveChannelMappingAndRestrict(c.Request.Context(), apiKey.GroupID, reqModel)
|
||||||
if restricted {
|
if restricted {
|
||||||
h.anthropicErrorResponse(c, http.StatusServiceUnavailable, "api_error", "No available accounts")
|
h.anthropicErrorResponse(c, http.StatusServiceUnavailable, "api_error", "The requested model is not available for this API key")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -402,12 +402,6 @@ func (s *BillingService) GetModelPricingWithChannel(model string, channelPricing
|
|||||||
return pricing, nil
|
return pricing, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// CalculateCostWithChannel 使用渠道定价计算费用
|
|
||||||
// Deprecated: 使用 CalculateCostUnified 代替
|
|
||||||
func (s *BillingService) CalculateCostWithChannel(model string, tokens UsageTokens, rateMultiplier float64, channelPricing *ChannelModelPricing) (*CostBreakdown, error) {
|
|
||||||
return s.calculateCostInternal(model, tokens, rateMultiplier, "", channelPricing)
|
|
||||||
}
|
|
||||||
|
|
||||||
// --- 统一计费入口 ---
|
// --- 统一计费入口 ---
|
||||||
|
|
||||||
// CostInput 统一计费输入
|
// CostInput 统一计费输入
|
||||||
|
|||||||
@@ -82,38 +82,6 @@ type PricingInterval struct {
|
|||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResolveMappedModel 解析渠道级模型映射,返回映射后的模型名。
|
|
||||||
// platform 指定查找哪个平台的映射规则。
|
|
||||||
// 支持通配符(如 "claude-*" → "claude-sonnet-4")。
|
|
||||||
// 如果没有匹配的映射规则,返回原始模型名。
|
|
||||||
func (c *Channel) ResolveMappedModel(platform, requestedModel string) string {
|
|
||||||
if len(c.ModelMapping) == 0 {
|
|
||||||
return requestedModel
|
|
||||||
}
|
|
||||||
platformMapping, ok := c.ModelMapping[platform]
|
|
||||||
if !ok || len(platformMapping) == 0 {
|
|
||||||
return requestedModel
|
|
||||||
}
|
|
||||||
lower := strings.ToLower(requestedModel)
|
|
||||||
// 精确匹配优先
|
|
||||||
for src, dst := range platformMapping {
|
|
||||||
if strings.ToLower(src) == lower {
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// 通配符匹配
|
|
||||||
for src, dst := range platformMapping {
|
|
||||||
srcLower := strings.ToLower(src)
|
|
||||||
if strings.HasSuffix(srcLower, "*") {
|
|
||||||
prefix := strings.TrimSuffix(srcLower, "*")
|
|
||||||
if strings.HasPrefix(lower, prefix) {
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return requestedModel
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsActive 判断渠道是否启用
|
// IsActive 判断渠道是否启用
|
||||||
func (c *Channel) IsActive() bool {
|
func (c *Channel) IsActive() bool {
|
||||||
return c.Status == StatusActive
|
return c.Status == StatusActive
|
||||||
|
|||||||
@@ -11,6 +11,8 @@ import (
|
|||||||
|
|
||||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
"golang.org/x/sync/singleflight"
|
"golang.org/x/sync/singleflight"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -379,6 +381,34 @@ func (s *ChannelService) IsModelRestricted(ctx context.Context, groupID int64, m
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制(组合方法)。
|
||||||
|
// 返回映射结果和是否被限制。groupID 为 nil 时跳过。
|
||||||
|
func (s *ChannelService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
|
||||||
|
var mapping ChannelMappingResult
|
||||||
|
mapping.MappedModel = model
|
||||||
|
if groupID == nil {
|
||||||
|
return mapping, false
|
||||||
|
}
|
||||||
|
mapping = s.ResolveChannelMapping(ctx, *groupID, model)
|
||||||
|
restricted := s.IsModelRestricted(ctx, *groupID, mapping.MappedModel)
|
||||||
|
return mapping, restricted
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReplaceModelInBody 替换请求体 JSON 中的 model 字段。
|
||||||
|
func ReplaceModelInBody(body []byte, newModel string) []byte {
|
||||||
|
if len(body) == 0 {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
if current := gjson.GetBytes(body, "model"); current.Exists() && current.String() == newModel {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
newBody, err := sjson.SetBytes(body, "model", newModel)
|
||||||
|
if err != nil {
|
||||||
|
return body
|
||||||
|
}
|
||||||
|
return newBody
|
||||||
|
}
|
||||||
|
|
||||||
// --- CRUD ---
|
// --- CRUD ---
|
||||||
|
|
||||||
// Create 创建渠道
|
// Create 创建渠道
|
||||||
@@ -539,16 +569,16 @@ func (s *ChannelService) List(ctx context.Context, params pagination.PaginationP
|
|||||||
return s.repo.List(ctx, params, status, search)
|
return s.repo.List(ctx, params, status, search)
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateNoDuplicateModels 检查定价列表中是否有重复模型
|
// validateNoDuplicateModels 检查定价列表中是否有重复模型(同一平台下不允许重复)
|
||||||
func validateNoDuplicateModels(pricingList []ChannelModelPricing) error {
|
func validateNoDuplicateModels(pricingList []ChannelModelPricing) error {
|
||||||
seen := make(map[string]bool)
|
seen := make(map[string]bool)
|
||||||
for _, p := range pricingList {
|
for _, p := range pricingList {
|
||||||
for _, model := range p.Models {
|
for _, model := range p.Models {
|
||||||
lower := strings.ToLower(model)
|
key := p.Platform + ":" + strings.ToLower(model)
|
||||||
if seen[lower] {
|
if seen[key] {
|
||||||
return infraerrors.BadRequest("DUPLICATE_MODEL", fmt.Sprintf("model '%s' appears in multiple pricing entries", model))
|
return infraerrors.BadRequest("DUPLICATE_MODEL", fmt.Sprintf("model '%s' appears in multiple pricing entries for platform '%s'", model, p.Platform))
|
||||||
}
|
}
|
||||||
seen[lower] = true
|
seen[key] = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -872,17 +872,7 @@ type anthropicMetadataPayload struct {
|
|||||||
// replaceModelInBody 替换请求体中的model字段
|
// replaceModelInBody 替换请求体中的model字段
|
||||||
// 优先使用定点修改,尽量保持客户端原始字段顺序。
|
// 优先使用定点修改,尽量保持客户端原始字段顺序。
|
||||||
func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte {
|
func (s *GatewayService) replaceModelInBody(body []byte, newModel string) []byte {
|
||||||
if len(body) == 0 {
|
return ReplaceModelInBody(body, newModel)
|
||||||
return body
|
|
||||||
}
|
|
||||||
if current := gjson.GetBytes(body, "model"); current.Exists() && current.String() == newModel {
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
newBody, err := sjson.SetBytes(body, "model", newModel)
|
|
||||||
if err != nil {
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
return newBody
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type claudeOAuthNormalizeOptions struct {
|
type claudeOAuthNormalizeOptions struct {
|
||||||
@@ -7794,11 +7784,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
|||||||
}
|
}
|
||||||
var err error
|
var err error
|
||||||
if s.resolver != nil && apiKey.Group != nil {
|
if s.resolver != nil && apiKey.Group != nil {
|
||||||
var groupID *int64
|
gid := apiKey.Group.ID
|
||||||
if apiKey.Group != nil {
|
groupID := &gid
|
||||||
gid := apiKey.Group.ID
|
|
||||||
groupID = &gid
|
|
||||||
}
|
|
||||||
cost, err = s.billingService.CalculateCostUnified(CostInput{
|
cost, err = s.billingService.CalculateCostUnified(CostInput{
|
||||||
Ctx: ctx,
|
Ctx: ctx,
|
||||||
Model: billingModel,
|
Model: billingModel,
|
||||||
@@ -8184,7 +8171,7 @@ func (s *GatewayService) ResolveChannelMapping(ctx context.Context, groupID int6
|
|||||||
|
|
||||||
// ReplaceModelInBody 替换请求体中的模型名(导出供 handler 使用)
|
// ReplaceModelInBody 替换请求体中的模型名(导出供 handler 使用)
|
||||||
func (s *GatewayService) ReplaceModelInBody(body []byte, newModel string) []byte {
|
func (s *GatewayService) ReplaceModelInBody(body []byte, newModel string) []byte {
|
||||||
return s.replaceModelInBody(body, newModel)
|
return ReplaceModelInBody(body, newModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsModelRestricted 检查模型是否被渠道限制
|
// IsModelRestricted 检查模型是否被渠道限制
|
||||||
@@ -8198,14 +8185,10 @@ func (s *GatewayService) IsModelRestricted(ctx context.Context, groupID int64, m
|
|||||||
// ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制。
|
// ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制。
|
||||||
// 返回映射结果和是否被限制。
|
// 返回映射结果和是否被限制。
|
||||||
func (s *GatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
|
func (s *GatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
|
||||||
var mapping ChannelMappingResult
|
if s.channelService == nil {
|
||||||
mapping.MappedModel = model
|
return ChannelMappingResult{MappedModel: model}, false
|
||||||
if groupID == nil {
|
|
||||||
return mapping, false
|
|
||||||
}
|
}
|
||||||
mapping = s.ResolveChannelMapping(ctx, *groupID, model)
|
return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model)
|
||||||
restricted := s.IsModelRestricted(ctx, *groupID, mapping.MappedModel)
|
|
||||||
return mapping, restricted
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ForwardCountTokens 转发 count_tokens 请求到上游 API
|
// ForwardCountTokens 转发 count_tokens 请求到上游 API
|
||||||
|
|||||||
@@ -416,29 +416,15 @@ func (s *OpenAIGatewayService) IsModelRestricted(ctx context.Context, groupID in
|
|||||||
// ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制。
|
// ResolveChannelMappingAndRestrict 解析渠道映射并检查模型限制。
|
||||||
// 返回映射结果和是否被限制。
|
// 返回映射结果和是否被限制。
|
||||||
func (s *OpenAIGatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
|
func (s *OpenAIGatewayService) ResolveChannelMappingAndRestrict(ctx context.Context, groupID *int64, model string) (ChannelMappingResult, bool) {
|
||||||
var mapping ChannelMappingResult
|
if s.channelService == nil {
|
||||||
mapping.MappedModel = model
|
return ChannelMappingResult{MappedModel: model}, false
|
||||||
if groupID == nil {
|
|
||||||
return mapping, false
|
|
||||||
}
|
}
|
||||||
mapping = s.ResolveChannelMapping(ctx, *groupID, model)
|
return s.channelService.ResolveChannelMappingAndRestrict(ctx, groupID, model)
|
||||||
restricted := s.IsModelRestricted(ctx, *groupID, mapping.MappedModel)
|
|
||||||
return mapping, restricted
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReplaceModelInBody 替换请求体中的 JSON model 字段(通用 gjson/sjson 实现)。
|
// ReplaceModelInBody 替换请求体中的 JSON model 字段(通用 gjson/sjson 实现)。
|
||||||
func (s *OpenAIGatewayService) ReplaceModelInBody(body []byte, newModel string) []byte {
|
func (s *OpenAIGatewayService) ReplaceModelInBody(body []byte, newModel string) []byte {
|
||||||
if len(body) == 0 {
|
return ReplaceModelInBody(body, newModel)
|
||||||
return body
|
|
||||||
}
|
|
||||||
if current := gjson.GetBytes(body, "model"); current.Exists() && current.String() == newModel {
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
newBody, err := sjson.SetBytes(body, "model", newModel)
|
|
||||||
if err != nil {
|
|
||||||
return body
|
|
||||||
}
|
|
||||||
return newBody
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle {
|
func (s *OpenAIGatewayService) getCodexSnapshotThrottle() *accountWriteThrottle {
|
||||||
@@ -4249,13 +4235,20 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
|||||||
durationMs := int(result.Duration.Milliseconds())
|
durationMs := int(result.Duration.Milliseconds())
|
||||||
accountRateMultiplier := account.BillingRateMultiplier()
|
accountRateMultiplier := account.BillingRateMultiplier()
|
||||||
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
|
requestID := resolveUsageBillingRequestID(ctx, result.RequestID)
|
||||||
|
|
||||||
|
// 确定 RequestedModel(渠道映射前的原始模型)
|
||||||
|
requestedModel := result.Model
|
||||||
|
if input.OriginalModel != "" {
|
||||||
|
requestedModel = input.OriginalModel
|
||||||
|
}
|
||||||
|
|
||||||
usageLog := &UsageLog{
|
usageLog := &UsageLog{
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
APIKeyID: apiKey.ID,
|
APIKeyID: apiKey.ID,
|
||||||
AccountID: account.ID,
|
AccountID: account.ID,
|
||||||
RequestID: requestID,
|
RequestID: requestID,
|
||||||
Model: result.Model,
|
Model: result.Model,
|
||||||
RequestedModel: result.Model,
|
RequestedModel: requestedModel,
|
||||||
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
|
UpstreamModel: optionalNonEqualStringPtr(result.UpstreamModel, result.Model),
|
||||||
ServiceTier: result.ServiceTier,
|
ServiceTier: result.ServiceTier,
|
||||||
ReasoningEffort: result.ReasoningEffort,
|
ReasoningEffort: result.ReasoningEffort,
|
||||||
|
|||||||
Reference in New Issue
Block a user