Merge branch 'alpha' into refactor/model-pricing

This commit is contained in:
t0ng7u
2025-07-24 01:35:59 +08:00
13 changed files with 149 additions and 85 deletions

View File

@@ -56,7 +56,7 @@ func relayHandler(c *gin.Context, relayMode int) *types.NewAPIError {
userGroup := c.GetString("group")
channelId := c.GetInt("channel_id")
other := make(map[string]interface{})
other["error_type"] = err.ErrorType
other["error_type"] = err.GetErrorType()
other["error_code"] = err.GetErrorCode()
other["status_code"] = err.StatusCode
other["channel_id"] = channelId
@@ -259,10 +259,10 @@ func getChannel(c *gin.Context, group, originalModel string, retryCount int) (*m
}
channel, selectGroup, err := model.CacheGetRandomSatisfiedChannel(c, group, originalModel, retryCount)
if err != nil {
if group == "auto" {
return nil, types.NewError(errors.New(fmt.Sprintf("获取自动分组下模型 %s 的可用渠道失败: %s", originalModel, err.Error())), types.ErrorCodeGetChannelFailed)
}
return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败: %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed)
return nil, types.NewError(errors.New(fmt.Sprintf("获取分组 %s 下模型 %s 的可用渠道失败retry: %s", selectGroup, originalModel, err.Error())), types.ErrorCodeGetChannelFailed)
}
if channel == nil {
return nil, types.NewError(errors.New(fmt.Sprintf("分组 %s 下模型 %s 的可用渠道不存在数据库一致性已被破坏retry", selectGroup, originalModel)), types.ErrorCodeGetChannelFailed)
}
newAPIError := middleware.SetupContextForSelectedChannel(c, channel, originalModel)
if newAPIError != nil {

View File

@@ -100,6 +100,10 @@ func Distribute() func(c *gin.Context) {
}
if shouldSelectChannel {
if modelRequest.Model == "" {
abortWithOpenAiMessage(c, http.StatusBadRequest, "未指定模型名称,模型名称不能为空")
return
}
var selectGroup string
channel, selectGroup, err = model.CacheGetRandomSatisfiedChannel(c, userGroup, modelRequest.Model, 0)
if err != nil {
@@ -107,7 +111,7 @@ func Distribute() func(c *gin.Context) {
if userGroup == "auto" {
showGroup = fmt.Sprintf("auto(%s)", selectGroup)
}
message := fmt.Sprintf("当前分组 %s 下对于模型 %s 可用渠道", showGroup, modelRequest.Model)
message := fmt.Sprintf("获取分组 %s 下模型 %s 可用渠道失败distributor: %s", showGroup, modelRequest.Model, err.Error())
// 如果错误,但是渠道不为空,说明是数据库一致性问题
if channel != nil {
common.SysError(fmt.Sprintf("渠道不存在:%d", channel.Id))
@@ -118,7 +122,7 @@ func Distribute() func(c *gin.Context) {
return
}
if channel == nil {
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("当前分组 %s 下对于模型 %s 可用渠道(数据库一致性已被破坏)", userGroup, modelRequest.Model))
abortWithOpenAiMessage(c, http.StatusServiceUnavailable, fmt.Sprintf("分组 %s 下模型 %s 可用渠道不存在(数据库一致性已被破坏distributor", userGroup, modelRequest.Model))
return
}
}

View File

@@ -109,9 +109,6 @@ func CacheGetRandomSatisfiedChannel(c *gin.Context, group string, model string,
return nil, group, err
}
}
if channel == nil {
return nil, group, errors.New("channel not found")
}
return channel, selectGroup, nil
}

View File

@@ -17,10 +17,16 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
openaiAdaptor := openai.Adaptor{}
openaiRequest, err := openaiAdaptor.ConvertClaudeRequest(c, info, request)
if err != nil {
return nil, err
}
openaiRequest.(*dto.GeneralOpenAIRequest).StreamOptions = &dto.StreamOptions{
IncludeUsage: true,
}
return requestOpenAI2Ollama(openaiRequest.(*dto.GeneralOpenAIRequest))
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
@@ -37,6 +43,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayFormat == relaycommon.RelayFormatClaude {
return info.BaseUrl + "/v1/chat/completions", nil
}
switch info.RelayMode {
case relayconstant.RelayModeEmbeddings:
return info.BaseUrl + "/api/embed", nil
@@ -55,7 +64,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if request == nil {
return nil, errors.New("request is nil")
}
return requestOpenAI2Ollama(*request)
return requestOpenAI2Ollama(request)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
@@ -76,11 +85,12 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
if info.RelayMode == relayconstant.RelayModeEmbeddings {
usage, err = ollamaEmbeddingHandler(c, info, resp)
switch info.RelayMode {
case relayconstant.RelayModeEmbeddings:
usage, err = ollamaEmbeddingHandler(c, info, resp)
default:
if info.IsStream {
usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
usage, err = openai.OpenaiHandler(c, info, resp)
}

View File

@@ -14,7 +14,7 @@ import (
"github.com/gin-gonic/gin"
)
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
func requestOpenAI2Ollama(request *dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
messages := make([]dto.Message, 0, len(request.Messages))
for _, message := range request.Messages {
if !message.IsStringContent() {
@@ -92,15 +92,15 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
var ollamaEmbeddingResponse OllamaEmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
common.CloseResponseBodyGracefully(resp)
err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if ollamaEmbeddingResponse.Error != "" {
return nil, types.NewError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding)
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
@@ -121,7 +121,7 @@ func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
}
doResponseBody, err := common.Marshal(embeddingResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
common.IOCopyBytesGracefully(c, resp, doResponseBody)
return usage, nil

View File

@@ -27,7 +27,7 @@ func handleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string
func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
if err := common.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
return err
}
@@ -174,7 +174,7 @@ func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
case relaycommon.RelayFormatClaude:
info.ClaudeConvertInfo.Done = true
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return
}
@@ -183,7 +183,7 @@ func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
for _, resp := range claudeResponses {
helper.ClaudeData(c, *resp)
_ = helper.ClaudeData(c, *resp)
}
}
}

View File

@@ -145,8 +145,10 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
common.SysError("error handling stream format: " + err.Error())
}
}
lastStreamData = data
streamItems = append(streamItems, data)
if len(data) > 0 {
lastStreamData = data
streamItems = append(streamItems, data)
}
return true
})
@@ -154,16 +156,18 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
shouldSendLastResp := true
if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage,
&containStreamUsage, info, &shouldSendLastResp); err != nil {
common.SysError("error handling last response: " + err.Error())
common.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData))
}
if shouldSendLastResp && info.RelayFormat == relaycommon.RelayFormatOpenAI {
_ = sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
if info.RelayFormat == relaycommon.RelayFormatOpenAI {
if shouldSendLastResp {
_ = sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
}
}
// 处理token计算
if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
common.SysError("error processing tokens: " + err.Error())
common.LogError(c, "error processing tokens: "+err.Error())
}
if !containStreamUsage {
@@ -176,7 +180,6 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
}
}
}
handleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
return usage, nil

View File

@@ -18,20 +18,19 @@ import (
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
adaptor := openai.Adaptor{}
return adaptor.ConvertClaudeRequest(c, info, req)
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
return nil, errors.New("not supported")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
adaptor := openai.Adaptor{}
return adaptor.ConvertImageRequest(c, info, request)
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
@@ -47,7 +46,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
} else if info.RelayMode == constant.RelayModeCompletions {
return fmt.Sprintf("%s/v1/completions", info.BaseUrl), nil
}
return "", errors.New("invalid relay mode")
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
@@ -81,16 +80,19 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
switch info.RelayMode {
case constant.RelayModeRerank:
usage, err = siliconflowRerankHandler(c, info, resp)
case constant.RelayModeEmbeddings:
usage, err = openai.OpenaiHandler(c, info, resp)
case constant.RelayModeCompletions:
fallthrough
case constant.RelayModeChatCompletions:
fallthrough
default:
if info.IsStream {
usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
usage, err = openai.OpenaiHandler(c, info, resp)
}
case constant.RelayModeEmbeddings:
usage, err = openai.OpenaiHandler(c, info, resp)
}
return
}

View File

@@ -234,6 +234,12 @@ func StreamScannerHandler(c *gin.Context, resp *http.Response, info *relaycommon
case <-stopChan:
return
}
} else {
// done, 处理完成标志,直接退出停止读取剩余数据防止出错
if common.DebugEnabled {
println("received [DONE], stopping scanner")
}
return
}
}

View File

@@ -251,22 +251,54 @@ func StreamResponseOpenAI2Claude(openAIResponse *dto.ChatCompletionsStreamRespon
resp.SetIndex(0)
claudeResponses = append(claudeResponses, resp)
} else {
//resp := &dto.ClaudeResponse{
// Type: "content_block_start",
// ContentBlock: &dto.ClaudeMediaMessage{
// Type: "text",
// Text: common.GetPointer[string](""),
// },
//}
//resp.SetIndex(0)
//claudeResponses = append(claudeResponses, resp)
}
// 判断首个响应是否存在内容(非标准的 OpenAI 响应)
if len(openAIResponse.Choices) > 0 && len(openAIResponse.Choices[0].Delta.GetContentString()) > 0 {
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Index: &info.ClaudeConvertInfo.Index,
Type: "content_block_start",
ContentBlock: &dto.ClaudeMediaMessage{
Type: "text",
Text: common.GetPointer[string](""),
},
})
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Type: "content_block_delta",
Delta: &dto.ClaudeMediaMessage{
Type: "text",
Text: common.GetPointer[string](openAIResponse.Choices[0].Delta.GetContentString()),
},
})
info.ClaudeConvertInfo.LastMessagesType = relaycommon.LastMessageTypeText
}
return claudeResponses
}
if len(openAIResponse.Choices) == 0 {
// no choices
// TODO: handle this case
// 可能为非标准的 OpenAI 响应,判断是否已经完成
if info.Done {
claudeResponses = append(claudeResponses, generateStopBlock(info.ClaudeConvertInfo.Index))
oaiUsage := info.ClaudeConvertInfo.Usage
if oaiUsage != nil {
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Type: "message_delta",
Usage: &dto.ClaudeUsage{
InputTokens: oaiUsage.PromptTokens,
OutputTokens: oaiUsage.CompletionTokens,
CacheCreationInputTokens: oaiUsage.PromptTokensDetails.CachedCreationTokens,
CacheReadInputTokens: oaiUsage.PromptTokensDetails.CachedTokens,
},
Delta: &dto.ClaudeMediaMessage{
StopReason: common.GetPointer[string](stopReasonOpenAI2Claude(info.FinishReason)),
},
})
}
claudeResponses = append(claudeResponses, &dto.ClaudeResponse{
Type: "message_stop",
})
}
return claudeResponses
} else {
chosenChoice := openAIResponse.Choices[0]

View File

@@ -80,10 +80,7 @@ func ClaudeErrorWrapperLocal(err error, code string, statusCode int) *dto.Claude
}
func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *types.NewAPIError) {
newApiErr = &types.NewAPIError{
StatusCode: resp.StatusCode,
ErrorType: types.ErrorTypeOpenAIError,
}
newApiErr = types.InitOpenAIError(types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
@@ -105,8 +102,7 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *t
// General format error (OpenAI, Anthropic, Gemini, etc.)
newApiErr = types.WithOpenAIError(errResponse.Error, resp.StatusCode)
} else {
newApiErr = types.NewErrorWithStatusCode(errors.New(errResponse.ToMessage()), types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
newApiErr.ErrorType = types.ErrorTypeOpenAIError
newApiErr = types.NewOpenAIError(errors.New(errResponse.ToMessage()), types.ErrorCodeBadResponseStatusCode, resp.StatusCode)
}
return
}

View File

@@ -75,7 +75,7 @@ const (
type NewAPIError struct {
Err error
RelayError any
ErrorType ErrorType
errorType ErrorType
errorCode ErrorCode
StatusCode int
}
@@ -87,6 +87,13 @@ func (e *NewAPIError) GetErrorCode() ErrorCode {
return e.errorCode
}
func (e *NewAPIError) GetErrorType() ErrorType {
if e == nil {
return ""
}
return e.errorType
}
func (e *NewAPIError) Error() string {
if e == nil {
return ""
@@ -103,7 +110,7 @@ func (e *NewAPIError) SetMessage(message string) {
}
func (e *NewAPIError) ToOpenAIError() OpenAIError {
switch e.ErrorType {
switch e.errorType {
case ErrorTypeOpenAIError:
if openAIError, ok := e.RelayError.(OpenAIError); ok {
return openAIError
@@ -120,14 +127,14 @@ func (e *NewAPIError) ToOpenAIError() OpenAIError {
}
return OpenAIError{
Message: e.Error(),
Type: string(e.ErrorType),
Type: string(e.errorType),
Param: "",
Code: e.errorCode,
}
}
func (e *NewAPIError) ToClaudeError() ClaudeError {
switch e.ErrorType {
switch e.errorType {
case ErrorTypeOpenAIError:
openAIError := e.RelayError.(OpenAIError)
return ClaudeError{
@@ -139,7 +146,7 @@ func (e *NewAPIError) ToClaudeError() ClaudeError {
default:
return ClaudeError{
Message: e.Error(),
Type: string(e.ErrorType),
Type: string(e.errorType),
}
}
}
@@ -148,7 +155,7 @@ func NewError(err error, errorCode ErrorCode) *NewAPIError {
return &NewAPIError{
Err: err,
RelayError: nil,
ErrorType: ErrorTypeNewAPIError,
errorType: ErrorTypeNewAPIError,
StatusCode: http.StatusInternalServerError,
errorCode: errorCode,
}
@@ -162,6 +169,13 @@ func NewOpenAIError(err error, errorCode ErrorCode, statusCode int) *NewAPIError
return WithOpenAIError(openaiError, statusCode)
}
func InitOpenAIError(errorCode ErrorCode, statusCode int) *NewAPIError {
openaiError := OpenAIError{
Type: string(errorCode),
}
return WithOpenAIError(openaiError, statusCode)
}
func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int) *NewAPIError {
return &NewAPIError{
Err: err,
@@ -169,7 +183,7 @@ func NewErrorWithStatusCode(err error, errorCode ErrorCode, statusCode int) *New
Message: err.Error(),
Type: string(errorCode),
},
ErrorType: ErrorTypeNewAPIError,
errorType: ErrorTypeNewAPIError,
StatusCode: statusCode,
errorCode: errorCode,
}
@@ -182,7 +196,7 @@ func WithOpenAIError(openAIError OpenAIError, statusCode int) *NewAPIError {
}
return &NewAPIError{
RelayError: openAIError,
ErrorType: ErrorTypeOpenAIError,
errorType: ErrorTypeOpenAIError,
StatusCode: statusCode,
Err: errors.New(openAIError.Message),
errorCode: ErrorCode(code),
@@ -192,7 +206,7 @@ func WithOpenAIError(openAIError OpenAIError, statusCode int) *NewAPIError {
func WithClaudeError(claudeError ClaudeError, statusCode int) *NewAPIError {
return &NewAPIError{
RelayError: claudeError,
ErrorType: ErrorTypeClaudeError,
errorType: ErrorTypeClaudeError,
StatusCode: statusCode,
Err: errors.New(claudeError.Message),
errorCode: ErrorCode(claudeError.Type),
@@ -211,5 +225,5 @@ func IsLocalError(err *NewAPIError) bool {
return false
}
return err.ErrorType == ErrorTypeNewAPIError
return err.errorType == ErrorTypeNewAPIError
}

View File

@@ -704,20 +704,20 @@ const EditChannelModal = (props) => {
}
}}
>{t('批量创建')}</Checkbox>
{/*{batch && (*/}
{/* <Checkbox disabled={isEdit} checked={multiToSingle} onChange={() => {*/}
{/* setMultiToSingle(prev => !prev);*/}
{/* setInputs(prev => {*/}
{/* const newInputs = { ...prev };*/}
{/* if (!multiToSingle) {*/}
{/* newInputs.multi_key_mode = multiKeyMode;*/}
{/* } else {*/}
{/* delete newInputs.multi_key_mode;*/}
{/* }*/}
{/* return newInputs;*/}
{/* });*/}
{/* }}>{t('密钥聚合模式')}</Checkbox>*/}
{/*)}*/}
{batch && (
<Checkbox disabled={isEdit} checked={multiToSingle} onChange={() => {
setMultiToSingle(prev => !prev);
setInputs(prev => {
const newInputs = { ...prev };
if (!multiToSingle) {
newInputs.multi_key_mode = multiKeyMode;
} else {
delete newInputs.multi_key_mode;
}
return newInputs;
});
}}>{t('密钥聚合模式')}</Checkbox>
)}
</Space>
) : null;