Merge remote-tracking branch 'origin/alpha' into refactor/model-pricing
This commit is contained in:
@@ -161,7 +161,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
|||||||
logInfo.ApiKey = ""
|
logInfo.ApiKey = ""
|
||||||
common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo))
|
common.SysLog(fmt.Sprintf("testing channel %d with model %s , info %+v ", channel.Id, testModel, logInfo))
|
||||||
|
|
||||||
priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.MaxTokens))
|
priceData, err := helper.ModelPriceHelper(c, info, 0, int(request.GetMaxTokens()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return testResult{
|
return testResult{
|
||||||
context: c,
|
context: c,
|
||||||
@@ -275,7 +275,7 @@ func testChannel(channel *model.Channel, testModel string) testResult {
|
|||||||
Quota: quota,
|
Quota: quota,
|
||||||
Content: "模型测试",
|
Content: "模型测试",
|
||||||
UseTimeSeconds: int(consumedTime),
|
UseTimeSeconds: int(consumedTime),
|
||||||
IsStream: false,
|
IsStream: info.IsStream,
|
||||||
Group: info.UsingGroup,
|
Group: info.UsingGroup,
|
||||||
Other: other,
|
Other: other,
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -36,30 +36,11 @@ type OpenAIModel struct {
|
|||||||
Parent string `json:"parent"`
|
Parent string `json:"parent"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GoogleOpenAICompatibleModels []struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Version string `json:"version"`
|
|
||||||
DisplayName string `json:"displayName"`
|
|
||||||
Description string `json:"description,omitempty"`
|
|
||||||
InputTokenLimit int `json:"inputTokenLimit"`
|
|
||||||
OutputTokenLimit int `json:"outputTokenLimit"`
|
|
||||||
SupportedGenerationMethods []string `json:"supportedGenerationMethods"`
|
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
|
||||||
TopP float64 `json:"topP,omitempty"`
|
|
||||||
TopK int `json:"topK,omitempty"`
|
|
||||||
MaxTemperature int `json:"maxTemperature,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OpenAIModelsResponse struct {
|
type OpenAIModelsResponse struct {
|
||||||
Data []OpenAIModel `json:"data"`
|
Data []OpenAIModel `json:"data"`
|
||||||
Success bool `json:"success"`
|
Success bool `json:"success"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GoogleOpenAICompatibleResponse struct {
|
|
||||||
Models []GoogleOpenAICompatibleModels `json:"models"`
|
|
||||||
NextPageToken string `json:"nextPageToken"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseStatusFilter(statusParam string) int {
|
func parseStatusFilter(statusParam string) int {
|
||||||
switch strings.ToLower(statusParam) {
|
switch strings.ToLower(statusParam) {
|
||||||
case "enabled", "1":
|
case "enabled", "1":
|
||||||
@@ -203,7 +184,7 @@ func FetchUpstreamModels(c *gin.Context) {
|
|||||||
switch channel.Type {
|
switch channel.Type {
|
||||||
case constant.ChannelTypeGemini:
|
case constant.ChannelTypeGemini:
|
||||||
// curl https://example.com/v1beta/models?key=$GEMINI_API_KEY
|
// curl https://example.com/v1beta/models?key=$GEMINI_API_KEY
|
||||||
url = fmt.Sprintf("%s/v1beta/openai/models?key=%s", baseURL, channel.Key)
|
url = fmt.Sprintf("%s/v1beta/openai/models", baseURL) // Remove key in url since we need to use AuthHeader
|
||||||
case constant.ChannelTypeAli:
|
case constant.ChannelTypeAli:
|
||||||
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
|
url = fmt.Sprintf("%s/compatible-mode/v1/models", baseURL)
|
||||||
default:
|
default:
|
||||||
@@ -212,10 +193,11 @@ func FetchUpstreamModels(c *gin.Context) {
|
|||||||
|
|
||||||
// 获取响应体 - 根据渠道类型决定是否添加 AuthHeader
|
// 获取响应体 - 根据渠道类型决定是否添加 AuthHeader
|
||||||
var body []byte
|
var body []byte
|
||||||
|
key := strings.Split(channel.Key, "\n")[0]
|
||||||
if channel.Type == constant.ChannelTypeGemini {
|
if channel.Type == constant.ChannelTypeGemini {
|
||||||
body, err = GetResponseBody("GET", url, channel, nil) // I don't know why, but Gemini requires no AuthHeader
|
body, err = GetResponseBody("GET", url, channel, GetAuthHeader(key)) // Use AuthHeader since Gemini now forces it
|
||||||
} else {
|
} else {
|
||||||
body, err = GetResponseBody("GET", url, channel, GetAuthHeader(channel.Key))
|
body, err = GetResponseBody("GET", url, channel, GetAuthHeader(key))
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.ApiError(c, err)
|
common.ApiError(c, err)
|
||||||
@@ -223,34 +205,12 @@ func FetchUpstreamModels(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var result OpenAIModelsResponse
|
var result OpenAIModelsResponse
|
||||||
var parseSuccess bool
|
if err = json.Unmarshal(body, &result); err != nil {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
// 适配特殊格式
|
"success": false,
|
||||||
switch channel.Type {
|
"message": fmt.Sprintf("解析响应失败: %s", err.Error()),
|
||||||
case constant.ChannelTypeGemini:
|
})
|
||||||
var googleResult GoogleOpenAICompatibleResponse
|
return
|
||||||
if err = json.Unmarshal(body, &googleResult); err == nil {
|
|
||||||
// 转换Google格式到OpenAI格式
|
|
||||||
for _, model := range googleResult.Models {
|
|
||||||
for _, gModel := range model {
|
|
||||||
result.Data = append(result.Data, OpenAIModel{
|
|
||||||
ID: gModel.Name,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
parseSuccess = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 如果解析失败,尝试OpenAI格式
|
|
||||||
if !parseSuccess {
|
|
||||||
if err = json.Unmarshal(body, &result); err != nil {
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"success": false,
|
|
||||||
"message": fmt.Sprintf("解析响应失败: %s", err.Error()),
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var ids []string
|
var ids []string
|
||||||
|
|||||||
@@ -361,7 +361,7 @@ type ClaudeUsage struct {
|
|||||||
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
CacheCreationInputTokens int `json:"cache_creation_input_tokens"`
|
||||||
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
CacheReadInputTokens int `json:"cache_read_input_tokens"`
|
||||||
OutputTokens int `json:"output_tokens"`
|
OutputTokens int `json:"output_tokens"`
|
||||||
ServerToolUse *ClaudeServerToolUse `json:"server_tool_use"`
|
ServerToolUse *ClaudeServerToolUse `json:"server_tool_use,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClaudeServerToolUse struct {
|
type ClaudeServerToolUse struct {
|
||||||
|
|||||||
@@ -99,8 +99,11 @@ type StreamOptions struct {
|
|||||||
IncludeUsage bool `json:"include_usage,omitempty"`
|
IncludeUsage bool `json:"include_usage,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *GeneralOpenAIRequest) GetMaxTokens() int {
|
func (r *GeneralOpenAIRequest) GetMaxTokens() uint {
|
||||||
return int(r.MaxTokens)
|
if r.MaxCompletionTokens != 0 {
|
||||||
|
return r.MaxCompletionTokens
|
||||||
|
}
|
||||||
|
return r.MaxTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *GeneralOpenAIRequest) ParseInput() []string {
|
func (r *GeneralOpenAIRequest) ParseInput() []string {
|
||||||
|
|||||||
@@ -3,16 +3,17 @@ package ali
|
|||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
|
"one-api/relay/channel/claude"
|
||||||
"one-api/relay/channel/openai"
|
"one-api/relay/channel/openai"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/constant"
|
"one-api/relay/constant"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
|
"strings"
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Adaptor struct {
|
type Adaptor struct {
|
||||||
@@ -23,10 +24,8 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||||
//TODO implement me
|
return req, nil
|
||||||
panic("implement me")
|
|
||||||
return nil, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||||
@@ -34,18 +33,24 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
var fullRequestURL string
|
var fullRequestURL string
|
||||||
switch info.RelayMode {
|
switch info.RelayFormat {
|
||||||
case constant.RelayModeEmbeddings:
|
case relaycommon.RelayFormatClaude:
|
||||||
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.BaseUrl)
|
fullRequestURL = fmt.Sprintf("%s/api/v2/apps/claude-code-proxy/v1/messages", info.BaseUrl)
|
||||||
case constant.RelayModeRerank:
|
|
||||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl)
|
|
||||||
case constant.RelayModeImagesGenerations:
|
|
||||||
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl)
|
|
||||||
case constant.RelayModeCompletions:
|
|
||||||
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.BaseUrl)
|
|
||||||
default:
|
default:
|
||||||
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl)
|
switch info.RelayMode {
|
||||||
|
case constant.RelayModeEmbeddings:
|
||||||
|
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.BaseUrl)
|
||||||
|
case constant.RelayModeRerank:
|
||||||
|
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl)
|
||||||
|
case constant.RelayModeImagesGenerations:
|
||||||
|
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl)
|
||||||
|
case constant.RelayModeCompletions:
|
||||||
|
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.BaseUrl)
|
||||||
|
default:
|
||||||
|
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return fullRequestURL, nil
|
return fullRequestURL, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -65,7 +70,13 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
|||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
// docs: https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2712216
|
||||||
|
// fix: InternalError.Algo.InvalidParameter: The value of the enable_thinking parameter is restricted to True.
|
||||||
|
if strings.Contains(request.Model, "thinking") {
|
||||||
|
request.EnableThinking = true
|
||||||
|
request.Stream = true
|
||||||
|
info.IsStream = true
|
||||||
|
}
|
||||||
// fix: ali parameter.enable_thinking must be set to false for non-streaming calls
|
// fix: ali parameter.enable_thinking must be set to false for non-streaming calls
|
||||||
if !info.IsStream {
|
if !info.IsStream {
|
||||||
request.EnableThinking = false
|
request.EnableThinking = false
|
||||||
@@ -106,18 +117,27 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
switch info.RelayMode {
|
switch info.RelayFormat {
|
||||||
case constant.RelayModeImagesGenerations:
|
case relaycommon.RelayFormatClaude:
|
||||||
err, usage = aliImageHandler(c, resp, info)
|
|
||||||
case constant.RelayModeEmbeddings:
|
|
||||||
err, usage = aliEmbeddingHandler(c, resp)
|
|
||||||
case constant.RelayModeRerank:
|
|
||||||
err, usage = RerankHandler(c, resp, info)
|
|
||||||
default:
|
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||||
} else {
|
} else {
|
||||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
err, usage = claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
switch info.RelayMode {
|
||||||
|
case constant.RelayModeImagesGenerations:
|
||||||
|
err, usage = aliImageHandler(c, resp, info)
|
||||||
|
case constant.RelayModeEmbeddings:
|
||||||
|
err, usage = aliEmbeddingHandler(c, resp)
|
||||||
|
case constant.RelayModeRerank:
|
||||||
|
err, usage = RerankHandler(c, resp, info)
|
||||||
|
default:
|
||||||
|
if info.IsStream {
|
||||||
|
usage, err = openai.OaiStreamHandler(c, info, resp)
|
||||||
|
} else {
|
||||||
|
usage, err = openai.OpenaiHandler(c, info, resp)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -34,9 +34,9 @@ func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
|
|||||||
EnableCitation: false,
|
EnableCitation: false,
|
||||||
UserId: request.User,
|
UserId: request.User,
|
||||||
}
|
}
|
||||||
if request.MaxTokens != 0 {
|
if request.GetMaxTokens() != 0 {
|
||||||
maxTokens := int(request.MaxTokens)
|
maxTokens := int(request.GetMaxTokens())
|
||||||
if request.MaxTokens == 1 {
|
if request.GetMaxTokens() == 1 {
|
||||||
maxTokens = 2
|
maxTokens = 2
|
||||||
}
|
}
|
||||||
baiduRequest.MaxOutputTokens = &maxTokens
|
baiduRequest.MaxOutputTokens = &maxTokens
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
"one-api/relay/channel/openai"
|
"one-api/relay/channel/openai"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/relay/constant"
|
||||||
"one-api/types"
|
"one-api/types"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -23,10 +24,9 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||||
//TODO implement me
|
adaptor := openai.Adaptor{}
|
||||||
panic("implement me")
|
return adaptor.ConvertClaudeRequest(c, info, req)
|
||||||
return nil, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
@@ -43,7 +43,20 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||||
return fmt.Sprintf("%s/v2/chat/completions", info.BaseUrl), nil
|
switch info.RelayMode {
|
||||||
|
case constant.RelayModeChatCompletions:
|
||||||
|
return fmt.Sprintf("%s/v2/chat/completions", info.BaseUrl), nil
|
||||||
|
case constant.RelayModeEmbeddings:
|
||||||
|
return fmt.Sprintf("%s/v2/embeddings", info.BaseUrl), nil
|
||||||
|
case constant.RelayModeImagesGenerations:
|
||||||
|
return fmt.Sprintf("%s/v2/images/generations", info.BaseUrl), nil
|
||||||
|
case constant.RelayModeImagesEdits:
|
||||||
|
return fmt.Sprintf("%s/v2/images/edits", info.BaseUrl), nil
|
||||||
|
case constant.RelayModeRerank:
|
||||||
|
return fmt.Sprintf("%s/v2/rerank", info.BaseUrl), nil
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||||
@@ -99,11 +112,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
if info.IsStream {
|
adaptor := openai.Adaptor{}
|
||||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
usage, err = adaptor.DoResponse(c, resp, info)
|
||||||
} else {
|
|
||||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode)
|
err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode)
|
||||||
} else {
|
} else {
|
||||||
err, usage = ClaudeHandler(c, resp, a.RequestMode, info)
|
err, usage = ClaudeHandler(c, resp, info, a.RequestMode)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -149,7 +149,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
|
|||||||
|
|
||||||
claudeRequest := dto.ClaudeRequest{
|
claudeRequest := dto.ClaudeRequest{
|
||||||
Model: textRequest.Model,
|
Model: textRequest.Model,
|
||||||
MaxTokens: textRequest.MaxTokens,
|
MaxTokens: textRequest.GetMaxTokens(),
|
||||||
StopSequences: nil,
|
StopSequences: nil,
|
||||||
Temperature: textRequest.Temperature,
|
Temperature: textRequest.Temperature,
|
||||||
TopP: textRequest.TopP,
|
TopP: textRequest.TopP,
|
||||||
@@ -740,7 +740,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
|
func ClaudeHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
|
||||||
defer common.CloseResponseBodyGracefully(resp)
|
defer common.CloseResponseBodyGracefully(resp)
|
||||||
|
|
||||||
claudeInfo := &ClaudeResponseInfo{
|
claudeInfo := &ClaudeResponseInfo{
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import "one-api/dto"
|
|||||||
type CfRequest struct {
|
type CfRequest struct {
|
||||||
Messages []dto.Message `json:"messages,omitempty"`
|
Messages []dto.Message `json:"messages,omitempty"`
|
||||||
Lora string `json:"lora,omitempty"`
|
Lora string `json:"lora,omitempty"`
|
||||||
MaxTokens int `json:"max_tokens,omitempty"`
|
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||||
Prompt string `json:"prompt,omitempty"`
|
Prompt string `json:"prompt,omitempty"`
|
||||||
Raw bool `json:"raw,omitempty"`
|
Raw bool `json:"raw,omitempty"`
|
||||||
Stream bool `json:"stream,omitempty"`
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ type CohereRequest struct {
|
|||||||
ChatHistory []ChatHistory `json:"chat_history"`
|
ChatHistory []ChatHistory `json:"chat_history"`
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
Stream bool `json:"stream"`
|
Stream bool `json:"stream"`
|
||||||
MaxTokens int `json:"max_tokens"`
|
MaxTokens uint `json:"max_tokens"`
|
||||||
SafetyMode string `json:"safety_mode,omitempty"`
|
SafetyMode string `json:"safety_mode,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -24,10 +24,9 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||||
//TODO implement me
|
adaptor := openai.Adaptor{}
|
||||||
panic("implement me")
|
return adaptor.ConvertClaudeRequest(c, info, req)
|
||||||
return nil, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
|
|||||||
@@ -120,6 +120,9 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
action := "generateContent"
|
action := "generateContent"
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
action = "streamGenerateContent?alt=sse"
|
action = "streamGenerateContent?alt=sse"
|
||||||
|
if info.RelayMode == constant.RelayModeGemini {
|
||||||
|
info.DisablePing = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
|
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
|
||||||
}
|
}
|
||||||
@@ -193,7 +196,6 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
if info.RelayMode == constant.RelayModeGemini {
|
if info.RelayMode == constant.RelayModeGemini {
|
||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
info.DisablePing = true
|
|
||||||
return GeminiTextGenerationStreamHandler(c, info, resp)
|
return GeminiTextGenerationStreamHandler(c, info, resp)
|
||||||
} else {
|
} else {
|
||||||
return GeminiTextGenerationHandler(c, info, resp)
|
return GeminiTextGenerationHandler(c, info, resp)
|
||||||
|
|||||||
@@ -184,7 +184,7 @@ func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon
|
|||||||
GenerationConfig: dto.GeminiChatGenerationConfig{
|
GenerationConfig: dto.GeminiChatGenerationConfig{
|
||||||
Temperature: textRequest.Temperature,
|
Temperature: textRequest.Temperature,
|
||||||
TopP: textRequest.TopP,
|
TopP: textRequest.TopP,
|
||||||
MaxOutputTokens: textRequest.MaxTokens,
|
MaxOutputTokens: textRequest.GetMaxTokens(),
|
||||||
Seed: int64(textRequest.Seed),
|
Seed: int64(textRequest.Seed),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAI
|
|||||||
Messages: messages,
|
Messages: messages,
|
||||||
Temperature: request.Temperature,
|
Temperature: request.Temperature,
|
||||||
TopP: request.TopP,
|
TopP: request.TopP,
|
||||||
MaxTokens: request.MaxTokens,
|
MaxTokens: request.GetMaxTokens(),
|
||||||
Tools: request.Tools,
|
Tools: request.Tools,
|
||||||
ToolChoice: request.ToolChoice,
|
ToolChoice: request.ToolChoice,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ func requestOpenAI2Ollama(request *dto.GeneralOpenAIRequest) (*OllamaRequest, er
|
|||||||
TopK: request.TopK,
|
TopK: request.TopK,
|
||||||
Stop: Stop,
|
Stop: Stop,
|
||||||
Tools: request.Tools,
|
Tools: request.Tools,
|
||||||
MaxTokens: request.MaxTokens,
|
MaxTokens: request.GetMaxTokens(),
|
||||||
ResponseFormat: request.ResponseFormat,
|
ResponseFormat: request.ResponseFormat,
|
||||||
FrequencyPenalty: request.FrequencyPenalty,
|
FrequencyPenalty: request.FrequencyPenalty,
|
||||||
PresencePenalty: request.PresencePenalty,
|
PresencePenalty: request.PresencePenalty,
|
||||||
|
|||||||
@@ -18,30 +18,6 @@ import (
|
|||||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
|
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
|
||||||
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
|
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
|
||||||
|
|
||||||
func requestOpenAI2PaLM(textRequest dto.GeneralOpenAIRequest) *PaLMChatRequest {
|
|
||||||
palmRequest := PaLMChatRequest{
|
|
||||||
Prompt: PaLMPrompt{
|
|
||||||
Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
|
|
||||||
},
|
|
||||||
Temperature: textRequest.Temperature,
|
|
||||||
CandidateCount: textRequest.N,
|
|
||||||
TopP: textRequest.TopP,
|
|
||||||
TopK: textRequest.MaxTokens,
|
|
||||||
}
|
|
||||||
for _, message := range textRequest.Messages {
|
|
||||||
palmMessage := PaLMChatMessage{
|
|
||||||
Content: message.StringContent(),
|
|
||||||
}
|
|
||||||
if message.Role == "user" {
|
|
||||||
palmMessage.Author = "0"
|
|
||||||
} else {
|
|
||||||
palmMessage.Author = "1"
|
|
||||||
}
|
|
||||||
palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage)
|
|
||||||
}
|
|
||||||
return &palmRequest
|
|
||||||
}
|
|
||||||
|
|
||||||
func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse {
|
func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse {
|
||||||
fullTextResponse := dto.OpenAITextResponse{
|
fullTextResponse := dto.OpenAITextResponse{
|
||||||
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
|
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
|
||||||
|
|||||||
@@ -16,6 +16,6 @@ func requestOpenAI2Perplexity(request dto.GeneralOpenAIRequest) *dto.GeneralOpen
|
|||||||
Messages: messages,
|
Messages: messages,
|
||||||
Temperature: request.Temperature,
|
Temperature: request.Temperature,
|
||||||
TopP: request.TopP,
|
TopP: request.TopP,
|
||||||
MaxTokens: request.MaxTokens,
|
MaxTokens: request.GetMaxTokens(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -238,7 +238,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
} else {
|
} else {
|
||||||
switch a.RequestMode {
|
switch a.RequestMode {
|
||||||
case RequestModeClaude:
|
case RequestModeClaude:
|
||||||
err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info)
|
err, usage = claude.ClaudeHandler(c, resp, info, claude.RequestModeMessage)
|
||||||
case RequestModeGemini:
|
case RequestModeGemini:
|
||||||
if info.RelayMode == constant.RelayModeGemini {
|
if info.RelayMode == constant.RelayModeGemini {
|
||||||
usage, err = gemini.GeminiTextGenerationHandler(c, info, resp)
|
usage, err = gemini.GeminiTextGenerationHandler(c, info, resp)
|
||||||
|
|||||||
@@ -28,10 +28,9 @@ func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dt
|
|||||||
return nil, errors.New("not implemented")
|
return nil, errors.New("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
||||||
//TODO implement me
|
adaptor := openai.Adaptor{}
|
||||||
panic("implement me")
|
return adaptor.ConvertClaudeRequest(c, info, req)
|
||||||
return nil, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||||
@@ -196,6 +195,10 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|||||||
return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil
|
return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil
|
||||||
case constant.RelayModeImagesGenerations:
|
case constant.RelayModeImagesGenerations:
|
||||||
return fmt.Sprintf("%s/api/v3/images/generations", info.BaseUrl), nil
|
return fmt.Sprintf("%s/api/v3/images/generations", info.BaseUrl), nil
|
||||||
|
case constant.RelayModeImagesEdits:
|
||||||
|
return fmt.Sprintf("%s/api/v3/images/edits", info.BaseUrl), nil
|
||||||
|
case constant.RelayModeRerank:
|
||||||
|
return fmt.Sprintf("%s/api/v3/rerank", info.BaseUrl), nil
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
|
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
|
||||||
@@ -232,18 +235,8 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
||||||
switch info.RelayMode {
|
adaptor := openai.Adaptor{}
|
||||||
case constant.RelayModeChatCompletions:
|
usage, err = adaptor.DoResponse(c, resp, info)
|
||||||
if info.IsStream {
|
|
||||||
usage, err = openai.OaiStreamHandler(c, info, resp)
|
|
||||||
} else {
|
|
||||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
|
||||||
}
|
|
||||||
case constant.RelayModeEmbeddings:
|
|
||||||
usage, err = openai.OpenaiHandler(c, info, resp)
|
|
||||||
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
|
|
||||||
usage, err = openai.OpenaiHandlerWithUsage(c, info, resp)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ func requestOpenAI2Xunfei(request dto.GeneralOpenAIRequest, xunfeiAppId string,
|
|||||||
xunfeiRequest.Parameter.Chat.Domain = domain
|
xunfeiRequest.Parameter.Chat.Domain = domain
|
||||||
xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
|
xunfeiRequest.Parameter.Chat.Temperature = request.Temperature
|
||||||
xunfeiRequest.Parameter.Chat.TopK = request.N
|
xunfeiRequest.Parameter.Chat.TopK = request.N
|
||||||
xunfeiRequest.Parameter.Chat.MaxTokens = request.MaxTokens
|
xunfeiRequest.Parameter.Chat.MaxTokens = request.GetMaxTokens()
|
||||||
xunfeiRequest.Payload.Message.Text = messages
|
xunfeiRequest.Payload.Message.Text = messages
|
||||||
return &xunfeiRequest
|
return &xunfeiRequest
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -105,7 +105,7 @@ func requestOpenAI2Zhipu(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIReq
|
|||||||
Messages: messages,
|
Messages: messages,
|
||||||
Temperature: request.Temperature,
|
Temperature: request.Temperature,
|
||||||
TopP: request.TopP,
|
TopP: request.TopP,
|
||||||
MaxTokens: request.MaxTokens,
|
MaxTokens: request.GetMaxTokens(),
|
||||||
Stop: Stop,
|
Stop: Stop,
|
||||||
Tools: request.Tools,
|
Tools: request.Tools,
|
||||||
ToolChoice: request.ToolChoice,
|
ToolChoice: request.ToolChoice,
|
||||||
|
|||||||
@@ -225,6 +225,9 @@ func GenRelayInfo(c *gin.Context) *RelayInfo {
|
|||||||
userId := common.GetContextKeyInt(c, constant.ContextKeyUserId)
|
userId := common.GetContextKeyInt(c, constant.ContextKeyUserId)
|
||||||
tokenUnlimited := common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited)
|
tokenUnlimited := common.GetContextKeyBool(c, constant.ContextKeyTokenUnlimited)
|
||||||
startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime)
|
startTime := common.GetContextKeyTime(c, constant.ContextKeyRequestStartTime)
|
||||||
|
if startTime.IsZero() {
|
||||||
|
startTime = time.Now()
|
||||||
|
}
|
||||||
// firstResponseTime = time.Now() - 1 second
|
// firstResponseTime = time.Now() - 1 second
|
||||||
|
|
||||||
apiType, _ := common.ChannelType2APIType(channelType)
|
apiType, _ := common.ChannelType2APIType(channelType)
|
||||||
|
|||||||
@@ -283,7 +283,9 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
|
|||||||
if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" {
|
if chosenChoice.FinishReason != nil && *chosenChoice.FinishReason != "" {
|
||||||
// should be done
|
// should be done
|
||||||
info.FinishReason = *chosenChoice.FinishReason
|
info.FinishReason = *chosenChoice.FinishReason
|
||||||
return claudeResponses
|
if !info.Done {
|
||||||
|
return claudeResponses
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if info.Done {
|
if info.Done {
|
||||||
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
|
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
|
||||||
@@ -432,6 +434,8 @@ func stopReasonOpenAI2Claude(reason string) string {
|
|||||||
return "end_turn"
|
return "end_turn"
|
||||||
case "stop_sequence":
|
case "stop_sequence":
|
||||||
return "stop_sequence"
|
return "stop_sequence"
|
||||||
|
case "length":
|
||||||
|
fallthrough
|
||||||
case "max_tokens":
|
case "max_tokens":
|
||||||
return "max_tokens"
|
return "max_tokens"
|
||||||
case "tool_calls":
|
case "tool_calls":
|
||||||
|
|||||||
@@ -93,6 +93,9 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *t
|
|||||||
if showBodyWhenFail {
|
if showBodyWhenFail {
|
||||||
newApiErr.Err = fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))
|
newApiErr.Err = fmt.Errorf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody))
|
||||||
} else {
|
} else {
|
||||||
|
if common.DebugEnabled {
|
||||||
|
println(fmt.Sprintf("bad response status code %d, body: %s", resp.StatusCode, string(responseBody)))
|
||||||
|
}
|
||||||
newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode)
|
newApiErr.Err = fmt.Errorf("bad response status code %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user