diff --git a/controller/channel-test.go b/controller/channel-test.go index 3a7c582b..026a863b 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -161,7 +161,7 @@ func testChannel(channel *model.Channel, testModel string) testResult { logInfo.ApiKey = "" common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo)) - priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens)) + priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.GetMaxTokens())) if err != nil { return testResult{ context: c, @@ -275,7 +275,7 @@ func testChannel(channel *model.Channel, testModel string) testResult { Quota: quota, Content: "模型测试", UseTimeSeconds: int(consumedTime), - IsStream: false, + IsStream: info.IsStream, Group: info.UsingGroup, Other: other, }) diff --git a/controller/channel.go b/controller/channel.go index 9f46ca35..020a3327 100644 --- a/controller/channel.go +++ b/controller/channel.go @@ -36,30 +36,11 @@ type OpenAIModel struct { Parent string `json:"parent"` } -type GoogleOpenAICompatibleModels []struct { - Name string `json:"name"` - Version string `json:"version"` - DisplayName string `json:"displayName"` - Description string `json:"description,omitempty"` - InputTokenLimit int `json:"inputTokenLimit"` - OutputTokenLimit int `json:"outputTokenLimit"` - SupportedGenerationMethods []string `json:"supportedGenerationMethods"` - Temperature float64 `json:"temperature,omitempty"` - TopP float64 `json:"topP,omitempty"` - TopK int `json:"topK,omitempty"` - MaxTemperature int `json:"maxTemperature,omitempty"` -} - type OpenAIModelsResponse struct { Data []OpenAIModel `json:"data"` Success bool `json:"success"` } -type GoogleOpenAICompatibleResponse struct { - Models []GoogleOpenAICompatibleModels `json:"models"` - NextPageToken string `json:"nextPageToken"` -} - func parseStatusFilter(statusParam string) int { switch strings.ToLower(statusParam) { case "enabled", "1": @@ -203,7 +184,7 @@ func FetchUpstreamModels(c *gin.Context) { switch channel.Type { case constant.ChannelTypeGemini: // curl https://example.com/v1beta/models?key=$GEMINI_API_KEY - url = fmt.Sprintf("%s/v1beta/openai/models?key=%s", baseURL, channel.Key) + url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) // Remove key in url since we need to use AuthHeader case constant.ChannelTypeAli: url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL) default: @@ -212,10 +193,11 @@ func FetchUpstreamModels(c *gin.Context) { // 获取响应体 - 根据渠道类型决定是否添加 AuthHeader var body []byte + key := strings.Split(channel.Key, "\n")[0] if channel.Type == constant.ChannelTypeGemini { - body, err = GetResponseBody("GET", url, channel, nil) // I don't know why, but Gemini requires no AuthHeader + body, err = GetResponseBody("GET", url, channel, GetAuthHeader(key)) // Use AuthHeader since Gemini now forces it } else { - body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) + body, err = GetResponseBody("GET", url, channel, GetAuthHeader(key)) } if err != nil { common.ApiError(c, err) @@ -223,34 +205,12 @@ func FetchUpstreamModels(c *gin.Context) { } var result OpenAIModelsResponse - var parseSuccess bool - - // 适配特殊格式 - switch channel.Type { - case constant.ChannelTypeGemini: - var googleResult GoogleOpenAICompatibleResponse - if err = json.Unmarshal(body, &googleResult); err == nil { - // 转换Google格式到OpenAI格式 - for _, model := range googleResult.Models { - for _, gModel := range model { - result.Data = append(result.Data, OpenAIModel{ - ID: gModel.Name, - }) - } - } - parseSuccess = true - } - } - - // 如果解析失败,尝试OpenAI格式 - if !parseSuccess { - if err = json.Unmarshal(body, &result); err != nil { - c.JSON(http.StatusOK, gin.H{ - "success": false, - "message": fmt.Sprintf("解析响应失败: %s", err.Error()), - }) - return - } + if err = json.Unmarshal(body, &result); err != nil { + c.JSON(http.StatusOK, gin.H{ + "success": false, + "message": fmt.Sprintf("解析响应失败: %s", err.Error()), + }) + return } var ids []string diff --git a/dto/claude.go b/dto/claude.go index ea099df4..7b5f348e 100644 --- a/dto/claude.go +++ b/dto/claude.go @@ -361,7 +361,7 @@ type ClaudeUsage struct { CacheCreationInputTokens int `json:"cache_creation_input_tokens"` CacheReadInputTokens int `json:"cache_read_input_tokens"` OutputTokens int `json:"output_tokens"` - ServerToolUse *ClaudeServerToolUse `json:"server_tool_use"` + ServerToolUse *ClaudeServerToolUse `json:"server_tool_use,omitempty"` } type ClaudeServerToolUse struct { diff --git a/dto/openai_request.go b/dto/openai_request.go index 29076ef6..fcd47d07 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -99,8 +99,11 @@ type StreamOptions struct { IncludeUsage bool `json:"include_usage,omitempty"` } -func (r *GeneralOpenAIRequest) GetMaxTokens() int { - return int(r.MaxTokens) +func (r *GeneralOpenAIRequest) GetMaxTokens() uint { + if r.MaxCompletionTokens != 0 { + return r.MaxCompletionTokens + } + return r.MaxTokens } func (r *GeneralOpenAIRequest) ParseInput() []string { diff --git a/relay/channel/ali/adaptor.go b/relay/channel/ali/adaptor.go index 067fac37..f3c5cee6 100644 --- a/relay/channel/ali/adaptor.go +++ b/relay/channel/ali/adaptor.go @@ -3,16 +3,17 @@ package ali import ( "errors" "fmt" + "github.com/gin-gonic/gin" "io" "net/http" "one-api/dto" "one-api/relay/channel" + "one-api/relay/channel/claude" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" "one-api/relay/constant" "one-api/types" - - "github.com/gin-gonic/gin" + "strings" ) type Adaptor struct { @@ -23,10 +24,8 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt return nil, errors.New("not implemented") } -func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { - //TODO implement me - panic("implement me") - return nil, nil +func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) { + return req, nil } func (a *Adaptor) Init(info *relaycommon.RelayInfo) { @@ -34,18 +33,24 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { var fullRequestURL string - switch info.RelayMode { - case constant.RelayModeEmbeddings: - fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.BaseUrl) - case constant.RelayModeRerank: - fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl) - case constant.RelayModeImagesGenerations: - fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl) - case constant.RelayModeCompletions: - fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.BaseUrl) + switch info.RelayFormat { + case relaycommon.RelayFormatClaude: + fullRequestURL = fmt.Sprintf("%s/api/v2/apps/claude-code-proxy/v1/messages", info.BaseUrl) default: - fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl) + switch info.RelayMode { + case constant.RelayModeEmbeddings: + fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.BaseUrl) + case constant.RelayModeRerank: + fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl) + case constant.RelayModeImagesGenerations: + fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl) + case constant.RelayModeCompletions: + fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.BaseUrl) + default: + fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl) + } } + return fullRequestURL, nil } @@ -65,7 +70,13 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if request == nil { return nil, errors.New("request is nil") } - + // docs: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2712216 + // fix: InternalError.Algo.InvalidParameter: The value of the enable_thinking parameter is restricted to True. + if strings.Contains(request.Model, "thinking") { + request.EnableThinking = true + request.Stream = true + info.IsStream = true + } // fix: ali parameter.enable_thinking must be set to false for non-streaming calls if !info.IsStream { request.EnableThinking = false @@ -106,18 +117,27 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { - switch info.RelayMode { - case constant.RelayModeImagesGenerations: - err, usage = aliImageHandler(c, resp, info) - case constant.RelayModeEmbeddings: - err, usage = aliEmbeddingHandler(c, resp) - case constant.RelayModeRerank: - err, usage = RerankHandler(c, resp, info) - default: + switch info.RelayFormat { + case relaycommon.RelayFormatClaude: if info.IsStream { - usage, err = openai.OaiStreamHandler(c, info, resp) + err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage) } else { - usage, err = openai.OpenaiHandler(c, info, resp) + err, usage = claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage) + } + default: + switch info.RelayMode { + case constant.RelayModeImagesGenerations: + err, usage = aliImageHandler(c, resp, info) + case constant.RelayModeEmbeddings: + err, usage = aliEmbeddingHandler(c, resp) + case constant.RelayModeRerank: + err, usage = RerankHandler(c, resp, info) + default: + if info.IsStream { + usage, err = openai.OaiStreamHandler(c, info, resp) + } else { + usage, err = openai.OpenaiHandler(c, info, resp) + } } } return diff --git a/relay/channel/baidu/relay-baidu.go b/relay/channel/baidu/relay-baidu.go index 06b48c20..a7cd5996 100644 --- a/relay/channel/baidu/relay-baidu.go +++ b/relay/channel/baidu/relay-baidu.go @@ -34,9 +34,9 @@ func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest { EnableCitation: false, UserId: request.User, } - if request.MaxTokens != 0 { - maxTokens := int(request.MaxTokens) - if request.MaxTokens == 1 { + if request.GetMaxTokens() != 0 { + maxTokens := int(request.GetMaxTokens()) + if request.GetMaxTokens() == 1 { maxTokens = 2 } baiduRequest.MaxOutputTokens = &maxTokens diff --git a/relay/channel/baidu_v2/adaptor.go b/relay/channel/baidu_v2/adaptor.go index b8a4ac2f..ba59e307 100644 --- a/relay/channel/baidu_v2/adaptor.go +++ b/relay/channel/baidu_v2/adaptor.go @@ -9,6 +9,7 @@ import ( "one-api/relay/channel" "one-api/relay/channel/openai" relaycommon "one-api/relay/common" + "one-api/relay/constant" "one-api/types" "strings" @@ -23,10 +24,9 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt return nil, errors.New("not implemented") } -func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { - //TODO implement me - panic("implement me") - return nil, nil +func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) { + adaptor := openai.Adaptor{} + return adaptor.ConvertClaudeRequest(c, info, req) } func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { @@ -43,7 +43,20 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) { } func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { - return fmt.Sprintf("%s/v2/chat/completions", info.BaseUrl), nil + switch info.RelayMode { + case constant.RelayModeChatCompletions: + return fmt.Sprintf("%s/v2/chat/completions", info.BaseUrl), nil + case constant.RelayModeEmbeddings: + return fmt.Sprintf("%s/v2/embeddings", info.BaseUrl), nil + case constant.RelayModeImagesGenerations: + return fmt.Sprintf("%s/v2/images/generations", info.BaseUrl), nil + case constant.RelayModeImagesEdits: + return fmt.Sprintf("%s/v2/images/edits", info.BaseUrl), nil + case constant.RelayModeRerank: + return fmt.Sprintf("%s/v2/rerank", info.BaseUrl), nil + default: + } + return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode) } func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { @@ -99,11 +112,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { - if info.IsStream { - usage, err = openai.OaiStreamHandler(c, info, resp) - } else { - usage, err = openai.OpenaiHandler(c, info, resp) - } + adaptor := openai.Adaptor{} + usage, err = adaptor.DoResponse(c, resp, info) return } diff --git a/relay/channel/claude/adaptor.go b/relay/channel/claude/adaptor.go index 0f7a9414..39b8ce2f 100644 --- a/relay/channel/claude/adaptor.go +++ b/relay/channel/claude/adaptor.go @@ -104,7 +104,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom if info.IsStream { err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode) } else { - err, usage = ClaudeHandler(c, resp, a.RequestMode, info) + err, usage = ClaudeHandler(c, resp, info, a.RequestMode) } return } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 64739aa9..e4d3975e 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -149,7 +149,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla claudeRequest := dto.ClaudeRequest{ Model: textRequest.Model, - MaxTokens: textRequest.MaxTokens, + MaxTokens: textRequest.GetMaxTokens(), StopSequences: nil, Temperature: textRequest.Temperature, TopP: textRequest.TopP, @@ -740,7 +740,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud return nil } -func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) { +func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) { defer common.CloseResponseBodyGracefully(resp) claudeInfo := &ClaudeResponseInfo{ diff --git a/relay/channel/cloudflare/dto.go b/relay/channel/cloudflare/dto.go index 62a45c40..72b40615 100644 --- a/relay/channel/cloudflare/dto.go +++ b/relay/channel/cloudflare/dto.go @@ -5,7 +5,7 @@ import "one-api/dto" type CfRequest struct { Messages []dto.Message `json:"messages,omitempty"` Lora string `json:"lora,omitempty"` - MaxTokens int `json:"max_tokens,omitempty"` + MaxTokens uint `json:"max_tokens,omitempty"` Prompt string `json:"prompt,omitempty"` Raw bool `json:"raw,omitempty"` Stream bool `json:"stream,omitempty"` diff --git a/relay/channel/cohere/dto.go b/relay/channel/cohere/dto.go index 410540c0..d5127963 100644 --- a/relay/channel/cohere/dto.go +++ b/relay/channel/cohere/dto.go @@ -7,7 +7,7 @@ type CohereRequest struct { ChatHistory []ChatHistory `json:"chat_history"` Message string `json:"message"` Stream bool `json:"stream"` - MaxTokens int `json:"max_tokens"` + MaxTokens uint `json:"max_tokens"` SafetyMode string `json:"safety_mode,omitempty"` } diff --git a/relay/channel/deepseek/adaptor.go b/relay/channel/deepseek/adaptor.go index ac8ea18f..be8de0c8 100644 --- a/relay/channel/deepseek/adaptor.go +++ b/relay/channel/deepseek/adaptor.go @@ -24,10 +24,9 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt return nil, errors.New("not implemented") } -func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { - //TODO implement me - panic("implement me") - return nil, nil +func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) { + adaptor := openai.Adaptor{} + return adaptor.ConvertClaudeRequest(c, info, req) } func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { diff --git a/relay/channel/gemini/adaptor.go b/relay/channel/gemini/adaptor.go index 14fd278d..01dfea2c 100644 --- a/relay/channel/gemini/adaptor.go +++ b/relay/channel/gemini/adaptor.go @@ -120,6 +120,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { action := "generateContent" if info.IsStream { action = "streamGenerateContent?alt=sse" + if info.RelayMode == constant.RelayModeGemini { + info.DisablePing = true + } } return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil } @@ -193,7 +196,6 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { if info.RelayMode == constant.RelayModeGemini { if info.IsStream { - info.DisablePing = true return GeminiTextGenerationStreamHandler(c, info, resp) } else { return GeminiTextGenerationHandler(c, info, resp) diff --git a/relay/channel/gemini/relay-gemini.go b/relay/channel/gemini/relay-gemini.go index 18524afb..698a972c 100644 --- a/relay/channel/gemini/relay-gemini.go +++ b/relay/channel/gemini/relay-gemini.go @@ -184,7 +184,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon GenerationConfig: dto.GeminiChatGenerationConfig{ Temperature: textRequest.Temperature, TopP: textRequest.TopP, - MaxOutputTokens: textRequest.MaxTokens, + MaxOutputTokens: textRequest.GetMaxTokens(), Seed: int64(textRequest.Seed), }, } diff --git a/relay/channel/mistral/text.go b/relay/channel/mistral/text.go index e26c6101..aa925781 100644 --- a/relay/channel/mistral/text.go +++ b/relay/channel/mistral/text.go @@ -71,7 +71,7 @@ func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAI Messages: messages, Temperature: request.Temperature, TopP: request.TopP, - MaxTokens: request.MaxTokens, + MaxTokens: request.GetMaxTokens(), Tools: request.Tools, ToolChoice: request.ToolChoice, } diff --git a/relay/channel/ollama/relay-ollama.go b/relay/channel/ollama/relay-ollama.go index f98dfc73..d4686ce3 100644 --- a/relay/channel/ollama/relay-ollama.go +++ b/relay/channel/ollama/relay-ollama.go @@ -60,7 +60,7 @@ func requestOpenAI2Ollama(request *dto.GeneralOpenAIRequest) (*OllamaRequest, er TopK: request.TopK, Stop: Stop, Tools: request.Tools, - MaxTokens: request.MaxTokens, + MaxTokens: request.GetMaxTokens(), ResponseFormat: request.ResponseFormat, FrequencyPenalty: request.FrequencyPenalty, PresencePenalty: request.PresencePenalty, diff --git a/relay/channel/palm/relay-palm.go b/relay/channel/palm/relay-palm.go index cbd60f5e..9b8bce7d 100644 --- a/relay/channel/palm/relay-palm.go +++ b/relay/channel/palm/relay-palm.go @@ -18,30 +18,6 @@ import ( // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body -func requestOpenAI2PaLM(textRequest dto.GeneralOpenAIRequest) *PaLMChatRequest { - palmRequest := PaLMChatRequest{ - Prompt: PaLMPrompt{ - Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)), - }, - Temperature: textRequest.Temperature, - CandidateCount: textRequest.N, - TopP: textRequest.TopP, - TopK: textRequest.MaxTokens, - } - for _, message := range textRequest.Messages { - palmMessage := PaLMChatMessage{ - Content: message.StringContent(), - } - if message.Role == "user" { - palmMessage.Author = "0" - } else { - palmMessage.Author = "1" - } - palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage) - } - return &palmRequest -} - func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse { fullTextResponse := dto.OpenAITextResponse{ Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)), diff --git a/relay/channel/perplexity/relay-perplexity.go b/relay/channel/perplexity/relay-perplexity.go index 9772aead..7ebadd0f 100644 --- a/relay/channel/perplexity/relay-perplexity.go +++ b/relay/channel/perplexity/relay-perplexity.go @@ -16,6 +16,6 @@ func requestOpenAI2Perplexity(request dto.GeneralOpenAIRequest) *dto.GeneralOpen Messages: messages, Temperature: request.Temperature, TopP: request.TopP, - MaxTokens: request.MaxTokens, + MaxTokens: request.GetMaxTokens(), } } diff --git a/relay/channel/vertex/adaptor.go b/relay/channel/vertex/adaptor.go index 4648a384..35e4490b 100644 --- a/relay/channel/vertex/adaptor.go +++ b/relay/channel/vertex/adaptor.go @@ -238,7 +238,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom } else { switch a.RequestMode { case RequestModeClaude: - err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info) + err, usage = claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage) case RequestModeGemini: if info.RelayMode == constant.RelayModeGemini { usage, err = gemini.GeminiTextGenerationHandler(c, info, resp) diff --git a/relay/channel/volcengine/adaptor.go b/relay/channel/volcengine/adaptor.go index 225b3895..2cc4f663 100644 --- a/relay/channel/volcengine/adaptor.go +++ b/relay/channel/volcengine/adaptor.go @@ -28,10 +28,9 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt return nil, errors.New("not implemented") } -func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { - //TODO implement me - panic("implement me") - return nil, nil +func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) { + adaptor := openai.Adaptor{} + return adaptor.ConvertClaudeRequest(c, info, req) } func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { @@ -196,6 +195,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil case constant.RelayModeImagesGenerations: return fmt.Sprintf("%s/api/v3/images/generations", info.BaseUrl), nil + case constant.RelayModeImagesEdits: + return fmt.Sprintf("%s/api/v3/images/edits", info.BaseUrl), nil + case constant.RelayModeRerank: + return fmt.Sprintf("%s/api/v3/rerank", info.BaseUrl), nil default: } return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode) @@ -232,18 +235,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request } func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { - switch info.RelayMode { - case constant.RelayModeChatCompletions: - if info.IsStream { - usage, err = openai.OaiStreamHandler(c, info, resp) - } else { - usage, err = openai.OpenaiHandler(c, info, resp) - } - case constant.RelayModeEmbeddings: - usage, err = openai.OpenaiHandler(c, info, resp) - case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits: - usage, err = openai.OpenaiHandlerWithUsage(c, info, resp) - } + adaptor := openai.Adaptor{} + usage, err = adaptor.DoResponse(c, resp, info) return } diff --git a/relay/channel/xunfei/relay-xunfei.go b/relay/channel/xunfei/relay-xunfei.go index 373ad605..1a426d50 100644 --- a/relay/channel/xunfei/relay-xunfei.go +++ b/relay/channel/xunfei/relay-xunfei.go @@ -48,7 +48,7 @@ func requestOpenAI2Xunfei(request dto.GeneralOpenAIRequest, xunfeiAppId string, xunfeiRequest.Parameter.Chat.Domain = domain xunfeiRequest.Parameter.Chat.Temperature = request.Temperature xunfeiRequest.Parameter.Chat.TopK = request.N - xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens + xunfeiRequest.Parameter.Chat.MaxTokens = request.GetMaxTokens() xunfeiRequest.Payload.Message.Text = messages return &xunfeiRequest } diff --git a/relay/channel/zhipu_4v/relay-zhipu_v4.go b/relay/channel/zhipu_4v/relay-zhipu_v4.go index 271dda8f..98a852f5 100644 --- a/relay/channel/zhipu_4v/relay-zhipu_v4.go +++ b/relay/channel/zhipu_4v/relay-zhipu_v4.go @@ -105,7 +105,7 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq Messages: messages, Temperature: request.Temperature, TopP: request.TopP, - MaxTokens: request.MaxTokens, + MaxTokens: request.GetMaxTokens(), Stop: Stop, Tools: request.Tools, ToolChoice: request.ToolChoice, diff --git a/relay/common/relay_info.go b/relay/common/relay_info.go index 266486c4..743070ca 100644 --- a/relay/common/relay_info.go +++ b/relay/common/relay_info.go @@ -225,6 +225,9 @@ func GenRelayInfo(c *gin.Context) *RelayInfo { userId := common.GetContextKeyInt(c, constant.ContextKeyUserId) tokenUnlimited := common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited) startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime) + if startTime.IsZero() { + startTime = time.Now() + } // firstResponseTime = time.Now() - 1 second apiType, _ := common.ChannelType2APIType(channelType) diff --git a/service/convert.go b/service/convert.go index ee8ecee5..967e4682 100644 --- a/service/convert.go +++ b/service/convert.go @@ -283,7 +283,9 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" { // should be done info.FinishReason = *chosenChoice.FinishReason - return claudeResponses + if !info.Done { + return claudeResponses + } } if info.Done { claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index)) @@ -432,6 +434,8 @@ func stopReasonOpenAI2Claude(reason string) string { return "end_turn" case "stop_sequence": return "stop_sequence" + case "length": + fallthrough case "max_tokens": return "max_tokens" case "tool_calls": diff --git a/service/error.go b/service/error.go index ad28c90f..9672402d 100644 --- a/service/error.go +++ b/service/error.go @@ -93,6 +93,9 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *t if showBodyWhenFail { newApiErr.Err = fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)) } else { + if common.DebugEnabled { + println(fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))) + } newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode) } return