package relay import ( "bytes" "fmt" "io" "net/http" "strings" "github.com/QuantumNous/new-api/common" "github.com/QuantumNous/new-api/constant" "github.com/QuantumNous/new-api/dto" "github.com/QuantumNous/new-api/logger" relaycommon "github.com/QuantumNous/new-api/relay/common" relayconstant "github.com/QuantumNous/new-api/relay/constant" "github.com/QuantumNous/new-api/relay/helper" "github.com/QuantumNous/new-api/service" "github.com/QuantumNous/new-api/setting/model_setting" "github.com/QuantumNous/new-api/setting/ratio_setting" "github.com/QuantumNous/new-api/types" "github.com/samber/lo" "github.com/gin-gonic/gin" ) func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) { info.InitChannelMeta(c) textReq, ok := info.Request.(*dto.GeneralOpenAIRequest) if !ok { return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.GeneralOpenAIRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } request, err := common.DeepCopy(textReq) if err != nil { return types.NewError(fmt.Errorf("failed to copy request to GeneralOpenAIRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) } if request.WebSearchOptions != nil { c.Set("chat_completion_web_search_context_size", request.WebSearchOptions.SearchContextSize) } err = helper.ModelMappedHelper(c, info, request) if err != nil { return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry()) } includeUsage := true // 判断用户是否需要返回使用情况 if request.StreamOptions != nil { includeUsage = request.StreamOptions.IncludeUsage } // 如果不支持StreamOptions,将StreamOptions设置为nil if !info.SupportStreamOptions || !lo.FromPtrOr(request.Stream, false) { request.StreamOptions = nil } else { // 如果支持StreamOptions,且请求中没有设置StreamOptions,根据配置文件设置StreamOptions if constant.ForceStreamOption { request.StreamOptions = &dto.StreamOptions{ IncludeUsage: true, } } } info.ShouldIncludeUsage = includeUsage adaptor := GetAdaptor(info.ApiType) if adaptor == nil { return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry()) } adaptor.Init(info) passThroughGlobal := model_setting.GetGlobalSettings().PassThroughRequestEnabled if info.RelayMode == relayconstant.RelayModeChatCompletions && !passThroughGlobal && !info.ChannelSetting.PassThroughBodyEnabled && service.ShouldChatCompletionsUseResponsesGlobal(info.ChannelId, info.ChannelType, info.OriginModelName) { applySystemPromptIfNeeded(c, info, request) usage, newApiErr := chatCompletionsViaResponses(c, info, adaptor, request) if newApiErr != nil { return newApiErr } var containAudioTokens = usage.CompletionTokenDetails.AudioTokens > 0 || usage.PromptTokensDetails.AudioTokens > 0 var containsAudioRatios = ratio_setting.ContainsAudioRatio(info.OriginModelName) || ratio_setting.ContainsAudioCompletionRatio(info.OriginModelName) if containAudioTokens && containsAudioRatios { service.PostAudioConsumeQuota(c, info, usage, "") } else { service.PostTextConsumeQuota(c, info, usage, nil) } return nil } var requestBody io.Reader if passThroughGlobal || info.ChannelSetting.PassThroughBodyEnabled { storage, err := common.GetBodyStorage(c) if err != nil { return types.NewErrorWithStatusCode(err, types.ErrorCodeReadRequestBodyFailed, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) } if common.DebugEnabled { if debugBytes, bErr := storage.Bytes(); bErr == nil { println("requestBody: ", string(debugBytes)) } } requestBody = common.ReaderOnly(storage) } else { convertedRequest, err := adaptor.ConvertOpenAIRequest(c, info, request) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } relaycommon.AppendRequestConversionFromRequest(info, convertedRequest) if info.ChannelSetting.SystemPrompt != "" { // 如果有系统提示,则将其添加到请求中 request, ok := convertedRequest.(*dto.GeneralOpenAIRequest) if ok { containSystemPrompt := false for _, message := range request.Messages { if message.Role == request.GetSystemRoleName() { containSystemPrompt = true break } } if !containSystemPrompt { // 如果没有系统提示,则添加系统提示 systemMessage := dto.Message{ Role: request.GetSystemRoleName(), Content: info.ChannelSetting.SystemPrompt, } request.Messages = append([]dto.Message{systemMessage}, request.Messages...) } else if info.ChannelSetting.SystemPromptOverride { common.SetContextKey(c, constant.ContextKeySystemPromptOverride, true) // 如果有系统提示,且允许覆盖,则拼接到前面 for i, message := range request.Messages { if message.Role == request.GetSystemRoleName() { if message.IsStringContent() { request.Messages[i].SetStringContent(info.ChannelSetting.SystemPrompt + "\n" + message.StringContent()) } else { contents := message.ParseContent() contents = append([]dto.MediaContent{ { Type: dto.ContentTypeText, Text: info.ChannelSetting.SystemPrompt, }, }, contents...) request.Messages[i].Content = contents } break } } } } } jsonData, err := common.Marshal(convertedRequest) if err != nil { return types.NewError(err, types.ErrorCodeJsonMarshalFailed, types.ErrOptionWithSkipRetry()) } // remove disabled fields for OpenAI API jsonData, err = relaycommon.RemoveDisabledFields(jsonData, info.ChannelOtherSettings, info.ChannelSetting.PassThroughBodyEnabled) if err != nil { return types.NewError(err, types.ErrorCodeConvertRequestFailed, types.ErrOptionWithSkipRetry()) } // apply param override if len(info.ParamOverride) > 0 { jsonData, err = relaycommon.ApplyParamOverrideWithRelayInfo(jsonData, info) if err != nil { return newAPIErrorFromParamOverride(err) } } logger.LogDebug(c, fmt.Sprintf("text request body: %s", string(jsonData))) requestBody = bytes.NewBuffer(jsonData) } var httpResp *http.Response resp, err := adaptor.DoRequest(c, info, requestBody) if err != nil { return types.NewOpenAIError(err, types.ErrorCodeDoRequestFailed, http.StatusInternalServerError) } statusCodeMappingStr := c.GetString("status_code_mapping") if resp != nil { httpResp = resp.(*http.Response) info.IsStream = info.IsStream || strings.HasPrefix(httpResp.Header.Get("Content-Type"), "text/event-stream") if httpResp.StatusCode != http.StatusOK { newApiErr := service.RelayErrorHandler(c.Request.Context(), httpResp, false) // reset status code 重置状态码 service.ResetStatusCode(newApiErr, statusCodeMappingStr) return newApiErr } } usage, newApiErr := adaptor.DoResponse(c, httpResp, info) if newApiErr != nil { // reset status code 重置状态码 service.ResetStatusCode(newApiErr, statusCodeMappingStr) return newApiErr } var containAudioTokens = usage.(*dto.Usage).CompletionTokenDetails.AudioTokens > 0 || usage.(*dto.Usage).PromptTokensDetails.AudioTokens > 0 var containsAudioRatios = ratio_setting.ContainsAudioRatio(info.OriginModelName) || ratio_setting.ContainsAudioCompletionRatio(info.OriginModelName) if containAudioTokens && containsAudioRatios { service.PostAudioConsumeQuota(c, info, usage.(*dto.Usage), "") } else { service.PostTextConsumeQuota(c, info, usage.(*dto.Usage), nil) } return nil }