Merge remote-tracking branch 'origin/alpha' into refactor/model-pricing

This commit is contained in:
t0ng7u
2025-08-08 02:34:44 +08:00
25 changed files with 126 additions and 153 deletions

View File

@@ -161,7 +161,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
logInfo.ApiKey = "" logInfo.ApiKey = ""
common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo)) common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo))
priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens)) priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.GetMaxTokens()))
if err != nil { if err != nil {
return testResult{ return testResult{
context: c, context: c,
@@ -275,7 +275,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
Quota: quota, Quota: quota,
Content: "模型测试", Content: "模型测试",
UseTimeSeconds: int(consumedTime), UseTimeSeconds: int(consumedTime),
IsStream: false, IsStream: info.IsStream,
Group: info.UsingGroup, Group: info.UsingGroup,
Other: other, Other: other,
}) })

View File

@@ -36,30 +36,11 @@ type OpenAIModel struct {
Parent string `json:"parent"` Parent string `json:"parent"`
} }
type GoogleOpenAICompatibleModels []struct {
Name string `json:"name"`
Version string `json:"version"`
DisplayName string `json:"displayName"`
Description string `json:"description,omitempty"`
InputTokenLimit int `json:"inputTokenLimit"`
OutputTokenLimit int `json:"outputTokenLimit"`
SupportedGenerationMethods []string `json:"supportedGenerationMethods"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK int `json:"topK,omitempty"`
MaxTemperature int `json:"maxTemperature,omitempty"`
}
type OpenAIModelsResponse struct { type OpenAIModelsResponse struct {
Data []OpenAIModel `json:"data"` Data []OpenAIModel `json:"data"`
Success bool `json:"success"` Success bool `json:"success"`
} }
type GoogleOpenAICompatibleResponse struct {
Models []GoogleOpenAICompatibleModels `json:"models"`
NextPageToken string `json:"nextPageToken"`
}
func parseStatusFilter(statusParam string) int { func parseStatusFilter(statusParam string) int {
switch strings.ToLower(statusParam) { switch strings.ToLower(statusParam) {
case "enabled", "1": case "enabled", "1":
@@ -203,7 +184,7 @@ func FetchUpstreamModels(c *gin.Context) {
switch channel.Type { switch channel.Type {
case constant.ChannelTypeGemini: case constant.ChannelTypeGemini:
// curl https://example.com/v1beta/models?key=$GEMINI_API_KEY // curl https://example.com/v1beta/models?key=$GEMINI_API_KEY
url = fmt.Sprintf("%s/v1beta/openai/models?key=%s", baseURL, channel.Key) url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) // Remove key in url since we need to use AuthHeader
case constant.ChannelTypeAli: case constant.ChannelTypeAli:
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL) url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
default: default:
@@ -212,10 +193,11 @@ func FetchUpstreamModels(c *gin.Context) {
// 获取响应体 - 根据渠道类型决定是否添加 AuthHeader // 获取响应体 - 根据渠道类型决定是否添加 AuthHeader
var body []byte var body []byte
key := strings.Split(channel.Key, "\n")[0]
if channel.Type == constant.ChannelTypeGemini { if channel.Type == constant.ChannelTypeGemini {
body, err = GetResponseBody("GET", url, channel, nil) // I don't know why, but Gemini requires no AuthHeader body, err = GetResponseBody("GET", url, channel, GetAuthHeader(key)) // Use AuthHeader since Gemini now forces it
} else { } else {
body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key)) body, err = GetResponseBody("GET", url, channel, GetAuthHeader(key))
} }
if err != nil { if err != nil {
common.ApiError(c, err) common.ApiError(c, err)
@@ -223,34 +205,12 @@ func FetchUpstreamModels(c *gin.Context) {
} }
var result OpenAIModelsResponse var result OpenAIModelsResponse
var parseSuccess bool if err = json.Unmarshal(body, &result); err != nil {
c.JSON(http.StatusOK, gin.H{
// 适配特殊格式 "success": false,
switch channel.Type { "message": fmt.Sprintf("解析响应失败: %s", err.Error()),
case constant.ChannelTypeGemini: })
var googleResult GoogleOpenAICompatibleResponse return
if err = json.Unmarshal(body, &googleResult); err == nil {
// 转换Google格式到OpenAI格式
for _, model := range googleResult.Models {
for _, gModel := range model {
result.Data = append(result.Data, OpenAIModel{
ID: gModel.Name,
})
}
}
parseSuccess = true
}
}
// 如果解析失败尝试OpenAI格式
if !parseSuccess {
if err = json.Unmarshal(body, &result); err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("解析响应失败: %s", err.Error()),
})
return
}
} }
var ids []string var ids []string

View File

@@ -361,7 +361,7 @@ type ClaudeUsage struct {
CacheCreationInputTokens int `json:"cache_creation_input_tokens"` CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
CacheReadInputTokens int `json:"cache_read_input_tokens"` CacheReadInputTokens int `json:"cache_read_input_tokens"`
OutputTokens int `json:"output_tokens"` OutputTokens int `json:"output_tokens"`
ServerToolUse *ClaudeServerToolUse `json:"server_tool_use"` ServerToolUse *ClaudeServerToolUse `json:"server_tool_use,omitempty"`
} }
type ClaudeServerToolUse struct { type ClaudeServerToolUse struct {

View File

@@ -99,8 +99,11 @@ type StreamOptions struct {
IncludeUsage bool `json:"include_usage,omitempty"` IncludeUsage bool `json:"include_usage,omitempty"`
} }
func (r *GeneralOpenAIRequest) GetMaxTokens() int { func (r *GeneralOpenAIRequest) GetMaxTokens() uint {
return int(r.MaxTokens) if r.MaxCompletionTokens != 0 {
return r.MaxCompletionTokens
}
return r.MaxTokens
} }
func (r *GeneralOpenAIRequest) ParseInput() []string { func (r *GeneralOpenAIRequest) ParseInput() []string {

View File

@@ -3,16 +3,17 @@ package ali
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin"
"io" "io"
"net/http" "net/http"
"one-api/dto" "one-api/dto"
"one-api/relay/channel" "one-api/relay/channel"
"one-api/relay/channel/claude"
"one-api/relay/channel/openai" "one-api/relay/channel/openai"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/constant" "one-api/relay/constant"
"one-api/types" "one-api/types"
"strings"
"github.com/gin-gonic/gin"
) )
type Adaptor struct { type Adaptor struct {
@@ -23,10 +24,8 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
//TODO implement me return req, nil
panic("implement me")
return nil, nil
} }
func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
@@ -34,18 +33,24 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
var fullRequestURL string var fullRequestURL string
switch info.RelayMode { switch info.RelayFormat {
case constant.RelayModeEmbeddings: case relaycommon.RelayFormatClaude:
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.BaseUrl) fullRequestURL = fmt.Sprintf("%s/api/v2/apps/claude-code-proxy/v1/messages", info.BaseUrl)
case constant.RelayModeRerank:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl)
case constant.RelayModeImagesGenerations:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl)
case constant.RelayModeCompletions:
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.BaseUrl)
default: default:
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl) switch info.RelayMode {
case constant.RelayModeEmbeddings:
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.BaseUrl)
case constant.RelayModeRerank:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl)
case constant.RelayModeImagesGenerations:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl)
case constant.RelayModeCompletions:
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.BaseUrl)
default:
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl)
}
} }
return fullRequestURL, nil return fullRequestURL, nil
} }
@@ -65,7 +70,13 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
// docs: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2712216
// fix: InternalError.Algo.InvalidParameter: The value of the enable_thinking parameter is restricted to True.
if strings.Contains(request.Model, "thinking") {
request.EnableThinking = true
request.Stream = true
info.IsStream = true
}
// fix: ali parameter.enable_thinking must be set to false for non-streaming calls // fix: ali parameter.enable_thinking must be set to false for non-streaming calls
if !info.IsStream { if !info.IsStream {
request.EnableThinking = false request.EnableThinking = false
@@ -106,18 +117,27 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayMode { switch info.RelayFormat {
case constant.RelayModeImagesGenerations: case relaycommon.RelayFormatClaude:
err, usage = aliImageHandler(c, resp, info)
case constant.RelayModeEmbeddings:
err, usage = aliEmbeddingHandler(c, resp)
case constant.RelayModeRerank:
err, usage = RerankHandler(c, resp, info)
default:
if info.IsStream { if info.IsStream {
usage, err = openai.OaiStreamHandler(c, info, resp) err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
} else { } else {
usage, err = openai.OpenaiHandler(c, info, resp) err, usage = claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
}
default:
switch info.RelayMode {
case constant.RelayModeImagesGenerations:
err, usage = aliImageHandler(c, resp, info)
case constant.RelayModeEmbeddings:
err, usage = aliEmbeddingHandler(c, resp)
case constant.RelayModeRerank:
err, usage = RerankHandler(c, resp, info)
default:
if info.IsStream {
usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
usage, err = openai.OpenaiHandler(c, info, resp)
}
} }
} }
return return

View File

@@ -34,9 +34,9 @@ func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
EnableCitation: false, EnableCitation: false,
UserId: request.User, UserId: request.User,
} }
if request.MaxTokens != 0 { if request.GetMaxTokens() != 0 {
maxTokens := int(request.MaxTokens) maxTokens := int(request.GetMaxTokens())
if request.MaxTokens == 1 { if request.GetMaxTokens() == 1 {
maxTokens = 2 maxTokens = 2
} }
baiduRequest.MaxOutputTokens = &maxTokens baiduRequest.MaxOutputTokens = &maxTokens

View File

@@ -9,6 +9,7 @@ import (
"one-api/relay/channel" "one-api/relay/channel"
"one-api/relay/channel/openai" "one-api/relay/channel/openai"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/types" "one-api/types"
"strings" "strings"
@@ -23,10 +24,9 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
//TODO implement me adaptor := openai.Adaptor{}
panic("implement me") return adaptor.ConvertClaudeRequest(c, info, req)
return nil, nil
} }
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
@@ -43,7 +43,20 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v2/chat/completions", info.BaseUrl), nil switch info.RelayMode {
case constant.RelayModeChatCompletions:
return fmt.Sprintf("%s/v2/chat/completions", info.BaseUrl), nil
case constant.RelayModeEmbeddings:
return fmt.Sprintf("%s/v2/embeddings", info.BaseUrl), nil
case constant.RelayModeImagesGenerations:
return fmt.Sprintf("%s/v2/images/generations", info.BaseUrl), nil
case constant.RelayModeImagesEdits:
return fmt.Sprintf("%s/v2/images/edits", info.BaseUrl), nil
case constant.RelayModeRerank:
return fmt.Sprintf("%s/v2/rerank", info.BaseUrl), nil
default:
}
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
} }
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -99,11 +112,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream { adaptor := openai.Adaptor{}
usage, err = openai.OaiStreamHandler(c, info, resp) usage, err = adaptor.DoResponse(c, resp, info)
} else {
usage, err = openai.OpenaiHandler(c, info, resp)
}
return return
} }

View File

@@ -104,7 +104,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.IsStream { if info.IsStream {
err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode) err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode)
} else { } else {
err, usage = ClaudeHandler(c, resp, a.RequestMode, info) err, usage = ClaudeHandler(c, resp, info, a.RequestMode)
} }
return return
} }

View File

@@ -149,7 +149,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
claudeRequest := dto.ClaudeRequest{ claudeRequest := dto.ClaudeRequest{
Model: textRequest.Model, Model: textRequest.Model,
MaxTokens: textRequest.MaxTokens, MaxTokens: textRequest.GetMaxTokens(),
StopSequences: nil, StopSequences: nil,
Temperature: textRequest.Temperature, Temperature: textRequest.Temperature,
TopP: textRequest.TopP, TopP: textRequest.TopP,
@@ -740,7 +740,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
return nil return nil
} }
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) { func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
defer common.CloseResponseBodyGracefully(resp) defer common.CloseResponseBodyGracefully(resp)
claudeInfo := &ClaudeResponseInfo{ claudeInfo := &ClaudeResponseInfo{

View File

@@ -5,7 +5,7 @@ import "one-api/dto"
type CfRequest struct { type CfRequest struct {
Messages []dto.Message `json:"messages,omitempty"` Messages []dto.Message `json:"messages,omitempty"`
Lora string `json:"lora,omitempty"` Lora string `json:"lora,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"` MaxTokens uint `json:"max_tokens,omitempty"`
Prompt string `json:"prompt,omitempty"` Prompt string `json:"prompt,omitempty"`
Raw bool `json:"raw,omitempty"` Raw bool `json:"raw,omitempty"`
Stream bool `json:"stream,omitempty"` Stream bool `json:"stream,omitempty"`

View File

@@ -7,7 +7,7 @@ type CohereRequest struct {
ChatHistory []ChatHistory `json:"chat_history"` ChatHistory []ChatHistory `json:"chat_history"`
Message string `json:"message"` Message string `json:"message"`
Stream bool `json:"stream"` Stream bool `json:"stream"`
MaxTokens int `json:"max_tokens"` MaxTokens uint `json:"max_tokens"`
SafetyMode string `json:"safety_mode,omitempty"` SafetyMode string `json:"safety_mode,omitempty"`
} }

View File

@@ -24,10 +24,9 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
//TODO implement me adaptor := openai.Adaptor{}
panic("implement me") return adaptor.ConvertClaudeRequest(c, info, req)
return nil, nil
} }
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {

View File

@@ -120,6 +120,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
action := "generateContent" action := "generateContent"
if info.IsStream { if info.IsStream {
action = "streamGenerateContent?alt=sse" action = "streamGenerateContent?alt=sse"
if info.RelayMode == constant.RelayModeGemini {
info.DisablePing = true
}
} }
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
} }
@@ -193,7 +196,6 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == constant.RelayModeGemini { if info.RelayMode == constant.RelayModeGemini {
if info.IsStream { if info.IsStream {
info.DisablePing = true
return GeminiTextGenerationStreamHandler(c, info, resp) return GeminiTextGenerationStreamHandler(c, info, resp)
} else { } else {
return GeminiTextGenerationHandler(c, info, resp) return GeminiTextGenerationHandler(c, info, resp)

View File

@@ -184,7 +184,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
GenerationConfig: dto.GeminiChatGenerationConfig{ GenerationConfig: dto.GeminiChatGenerationConfig{
Temperature: textRequest.Temperature, Temperature: textRequest.Temperature,
TopP: textRequest.TopP, TopP: textRequest.TopP,
MaxOutputTokens: textRequest.MaxTokens, MaxOutputTokens: textRequest.GetMaxTokens(),
Seed: int64(textRequest.Seed), Seed: int64(textRequest.Seed),
}, },
} }

View File

@@ -71,7 +71,7 @@ func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAI
Messages: messages, Messages: messages,
Temperature: request.Temperature, Temperature: request.Temperature,
TopP: request.TopP, TopP: request.TopP,
MaxTokens: request.MaxTokens, MaxTokens: request.GetMaxTokens(),
Tools: request.Tools, Tools: request.Tools,
ToolChoice: request.ToolChoice, ToolChoice: request.ToolChoice,
} }

View File

@@ -60,7 +60,7 @@ func requestOpenAI2Ollama(request *dto.GeneralOpenAIRequest) (*OllamaRequest, er
TopK: request.TopK, TopK: request.TopK,
Stop: Stop, Stop: Stop,
Tools: request.Tools, Tools: request.Tools,
MaxTokens: request.MaxTokens, MaxTokens: request.GetMaxTokens(),
ResponseFormat: request.ResponseFormat, ResponseFormat: request.ResponseFormat,
FrequencyPenalty: request.FrequencyPenalty, FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty, PresencePenalty: request.PresencePenalty,

View File

@@ -18,30 +18,6 @@ import (
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body // https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
func requestOpenAI2PaLM(textRequest dto.GeneralOpenAIRequest) *PaLMChatRequest {
palmRequest := PaLMChatRequest{
Prompt: PaLMPrompt{
Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
},
Temperature: textRequest.Temperature,
CandidateCount: textRequest.N,
TopP: textRequest.TopP,
TopK: textRequest.MaxTokens,
}
for _, message := range textRequest.Messages {
palmMessage := PaLMChatMessage{
Content: message.StringContent(),
}
if message.Role == "user" {
palmMessage.Author = "0"
} else {
palmMessage.Author = "1"
}
palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage)
}
return &palmRequest
}
func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse { func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse {
fullTextResponse := dto.OpenAITextResponse{ fullTextResponse := dto.OpenAITextResponse{
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)), Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),

View File

@@ -16,6 +16,6 @@ func requestOpenAI2Perplexity(request dto.GeneralOpenAIRequest) *dto.GeneralOpen
Messages: messages, Messages: messages,
Temperature: request.Temperature, Temperature: request.Temperature,
TopP: request.TopP, TopP: request.TopP,
MaxTokens: request.MaxTokens, MaxTokens: request.GetMaxTokens(),
} }
} }

View File

@@ -238,7 +238,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
} else { } else {
switch a.RequestMode { switch a.RequestMode {
case RequestModeClaude: case RequestModeClaude:
err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info) err, usage = claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
case RequestModeGemini: case RequestModeGemini:
if info.RelayMode == constant.RelayModeGemini { if info.RelayMode == constant.RelayModeGemini {
usage, err = gemini.GeminiTextGenerationHandler(c, info, resp) usage, err = gemini.GeminiTextGenerationHandler(c, info, resp)

View File

@@ -28,10 +28,9 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
return nil, errors.New("not implemented") return nil, errors.New("not implemented")
} }
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) { func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
//TODO implement me adaptor := openai.Adaptor{}
panic("implement me") return adaptor.ConvertClaudeRequest(c, info, req)
return nil, nil
} }
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
@@ -196,6 +195,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil
case constant.RelayModeImagesGenerations: case constant.RelayModeImagesGenerations:
return fmt.Sprintf("%s/api/v3/images/generations", info.BaseUrl), nil return fmt.Sprintf("%s/api/v3/images/generations", info.BaseUrl), nil
case constant.RelayModeImagesEdits:
return fmt.Sprintf("%s/api/v3/images/edits", info.BaseUrl), nil
case constant.RelayModeRerank:
return fmt.Sprintf("%s/api/v3/rerank", info.BaseUrl), nil
default: default:
} }
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode) return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
@@ -232,18 +235,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
} }
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayMode { adaptor := openai.Adaptor{}
case constant.RelayModeChatCompletions: usage, err = adaptor.DoResponse(c, resp, info)
if info.IsStream {
usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
usage, err = openai.OpenaiHandler(c, info, resp)
}
case constant.RelayModeEmbeddings:
usage, err = openai.OpenaiHandler(c, info, resp)
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
usage, err = openai.OpenaiHandlerWithUsage(c, info, resp)
}
return return
} }

View File

@@ -48,7 +48,7 @@ func requestOpenAI2Xunfei(request dto.GeneralOpenAIRequest, xunfeiAppId string,
xunfeiRequest.Parameter.Chat.Domain = domain xunfeiRequest.Parameter.Chat.Domain = domain
xunfeiRequest.Parameter.Chat.Temperature = request.Temperature xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
xunfeiRequest.Parameter.Chat.TopK = request.N xunfeiRequest.Parameter.Chat.TopK = request.N
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens xunfeiRequest.Parameter.Chat.MaxTokens = request.GetMaxTokens()
xunfeiRequest.Payload.Message.Text = messages xunfeiRequest.Payload.Message.Text = messages
return &xunfeiRequest return &xunfeiRequest
} }

View File

@@ -105,7 +105,7 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq
Messages: messages, Messages: messages,
Temperature: request.Temperature, Temperature: request.Temperature,
TopP: request.TopP, TopP: request.TopP,
MaxTokens: request.MaxTokens, MaxTokens: request.GetMaxTokens(),
Stop: Stop, Stop: Stop,
Tools: request.Tools, Tools: request.Tools,
ToolChoice: request.ToolChoice, ToolChoice: request.ToolChoice,

View File

@@ -225,6 +225,9 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
userId := common.GetContextKeyInt(c, constant.ContextKeyUserId) userId := common.GetContextKeyInt(c, constant.ContextKeyUserId)
tokenUnlimited := common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited) tokenUnlimited := common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited)
startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime) startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime)
if startTime.IsZero() {
startTime = time.Now()
}
// firstResponseTime = time.Now() - 1 second // firstResponseTime = time.Now() - 1 second
apiType, _ := common.ChannelType2APIType(channelType) apiType, _ := common.ChannelType2APIType(channelType)

View File

@@ -283,7 +283,9 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" { if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" {
// should be done // should be done
info.FinishReason = *chosenChoice.FinishReason info.FinishReason = *chosenChoice.FinishReason
return claudeResponses if !info.Done {
return claudeResponses
}
} }
if info.Done { if info.Done {
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index)) claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
@@ -432,6 +434,8 @@ func stopReasonOpenAI2Claude(reason string) string {
return "end_turn" return "end_turn"
case "stop_sequence": case "stop_sequence":
return "stop_sequence" return "stop_sequence"
case "length":
fallthrough
case "max_tokens": case "max_tokens":
return "max_tokens" return "max_tokens"
case "tool_calls": case "tool_calls":

View File

@@ -93,6 +93,9 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *t
if showBodyWhenFail { if showBodyWhenFail {
newApiErr.Err = fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)) newApiErr.Err = fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))
} else { } else {
if common.DebugEnabled {
println(fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)))
}
newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode) newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode)
} }
return return