refactor: centralize logging and update resource initialization

This commit refactors the logging mechanism across the application by replacing direct logger calls with a centralized logging approach using the `common` package. Key changes include:

- Replaced instances of `logger.SysLog` and `logger.FatalLog` with `common.SysLog` and `common.FatalLog` for consistent logging practices.
- Updated resource initialization error handling to utilize the new logging structure, enhancing maintainability and readability.
- Minor adjustments to improve code clarity and organization throughout various modules.

This change aims to streamline logging and improve the overall architecture of the codebase.
This commit is contained in:
CaIon
2025-08-14 21:10:04 +08:00
parent e2037ad756
commit 6748b006b7
101 changed files with 537 additions and 568 deletions

View File

@@ -34,20 +34,20 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
var fullRequestURL string
switch info.RelayFormat {
case relaycommon.RelayFormatClaude:
fullRequestURL = fmt.Sprintf("%s/api/v2/apps/claude-code-proxy/v1/messages", info.BaseUrl)
case types.RelayFormatClaude:
fullRequestURL = fmt.Sprintf("%s/api/v2/apps/claude-code-proxy/v1/messages", info.ChannelBaseUrl)
default:
switch info.RelayMode {
case constant.RelayModeEmbeddings:
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.BaseUrl)
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.ChannelBaseUrl)
case constant.RelayModeRerank:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl)
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.ChannelBaseUrl)
case constant.RelayModeImagesGenerations:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl)
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.ChannelBaseUrl)
case constant.RelayModeCompletions:
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.BaseUrl)
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.ChannelBaseUrl)
default:
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl)
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.ChannelBaseUrl)
}
}
@@ -118,7 +118,7 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayFormat {
case relaycommon.RelayFormatClaude:
case types.RelayFormatClaude:
if info.IsStream {
err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
} else {

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/logger"
relaycommon "one-api/relay/common"
@@ -22,14 +23,14 @@ func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest {
imageRequest.Input.Prompt = request.Prompt
imageRequest.Model = request.Model
imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1)
imageRequest.Parameters.N = request.N
imageRequest.Parameters.N = int(request.N)
imageRequest.ResponseFormat = request.ResponseFormat
return &imageRequest
}
func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) {
url := fmt.Sprintf("%s/api/v1/tasks/%s", info.BaseUrl, taskID)
url := fmt.Sprintf("%s/api/v1/tasks/%s", info.ChannelBaseUrl, taskID)
var aliResponse AliResponse
@@ -43,7 +44,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
logger.SysError("updateTask client.Do err: " + err.Error())
common.SysLog("updateTask client.Do err: " + err.Error())
return &aliResponse, err, nil
}
defer resp.Body.Close()
@@ -53,7 +54,7 @@ func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error
var response AliResponse
err = json.Unmarshal(responseBody, &response)
if err != nil {
logger.SysError("updateTask NewDecoder err: " + err.Error())
common.SysLog("updateTask NewDecoder err: " + err.Error())
return &aliResponse, err, nil
}

View File

@@ -7,7 +7,6 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
"one-api/logger"
"one-api/relay/helper"
"one-api/service"
"strings"
@@ -150,7 +149,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
var aliResponse AliResponse
err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
common.SysLog("error unmarshalling stream response: " + err.Error())
return true
}
if aliResponse.Usage.OutputTokens != 0 {
@@ -163,7 +162,7 @@ func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
common.SysLog("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})

View File

@@ -101,7 +101,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
default:
suffix += strings.ToLower(info.UpstreamModelName)
}
fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", info.BaseUrl, suffix)
fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", info.ChannelBaseUrl, suffix)
var accessToken string
var err error
if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil {

View File

@@ -9,7 +9,6 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -119,7 +118,7 @@ func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
var baiduResponse BaiduChatStreamResponse
err := common.Unmarshal([]byte(data), &baiduResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
common.SysLog("error unmarshalling stream response: " + err.Error())
return true
}
if baiduResponse.Usage.TotalTokens != 0 {
@@ -130,7 +129,7 @@ func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
response := streamResponseBaidu2OpenAI(&baiduResponse)
err = helper.ObjectData(c, response)
if err != nil {
logger.SysError("error sending stream response: " + err.Error())
common.SysLog("error sending stream response: " + err.Error())
}
return true
})

View File

@@ -45,15 +45,15 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
switch info.RelayMode {
case constant.RelayModeChatCompletions:
return fmt.Sprintf("%s/v2/chat/completions", info.BaseUrl), nil
return fmt.Sprintf("%s/v2/chat/completions", info.ChannelBaseUrl), nil
case constant.RelayModeEmbeddings:
return fmt.Sprintf("%s/v2/embeddings", info.BaseUrl), nil
return fmt.Sprintf("%s/v2/embeddings", info.ChannelBaseUrl), nil
case constant.RelayModeImagesGenerations:
return fmt.Sprintf("%s/v2/images/generations", info.BaseUrl), nil
return fmt.Sprintf("%s/v2/images/generations", info.ChannelBaseUrl), nil
case constant.RelayModeImagesEdits:
return fmt.Sprintf("%s/v2/images/edits", info.BaseUrl), nil
return fmt.Sprintf("%s/v2/images/edits", info.ChannelBaseUrl), nil
case constant.RelayModeRerank:
return fmt.Sprintf("%s/v2/rerank", info.BaseUrl), nil
return fmt.Sprintf("%s/v2/rerank", info.ChannelBaseUrl), nil
default:
}
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)

View File

@@ -53,9 +53,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if a.RequestMode == RequestModeMessage {
return fmt.Sprintf("%s/v1/messages", info.BaseUrl), nil
return fmt.Sprintf("%s/v1/messages", info.ChannelBaseUrl), nil
} else {
return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil
return fmt.Sprintf("%s/v1/complete", info.ChannelBaseUrl), nil
}
}

View File

@@ -376,7 +376,7 @@ func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.Cla
for _, toolCall := range message.ParseToolCalls() {
inputObj := make(map[string]any)
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
logger.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
common.SysLog("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
continue
}
claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{
@@ -610,13 +610,13 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
var claudeResponse dto.ClaudeResponse
err := common.UnmarshalJsonStr(data, &claudeResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
common.SysLog("error unmarshalling stream response: " + err.Error())
return types.NewError(err, types.ErrorCodeBadResponseBody)
}
if claudeError := claudeResponse.GetClaudeError(); claudeError != nil && claudeError.Type != "" {
return types.WithClaudeError(*claudeError, http.StatusInternalServerError)
}
if info.RelayFormat == relaycommon.RelayFormatClaude {
if info.RelayFormat == types.RelayFormatClaude {
FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
if requestMode == RequestModeCompletion {
@@ -629,7 +629,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
}
}
helper.ClaudeChunkData(c, claudeResponse, data)
} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
} else if info.RelayFormat == types.RelayFormatOpenAI {
response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
if !FormatClaudeResponseInfo(requestMode, &claudeResponse, response, claudeInfo) {
@@ -654,21 +654,20 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
}
if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done {
if common.DebugEnabled {
logger.SysError("claude response usage is not complete, maybe upstream error")
common.SysLog("claude response usage is not complete, maybe upstream error")
}
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
}
}
if info.RelayFormat == relaycommon.RelayFormatClaude {
if info.RelayFormat == types.RelayFormatClaude {
//
} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
} else if info.RelayFormat == types.RelayFormatOpenAI {
if info.ShouldIncludeUsage {
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
err := helper.ObjectData(c, response)
if err != nil {
logger.SysError("send final response failed: " + err.Error())
common.SysLog("send final response failed: " + err.Error())
}
}
helper.Done(c)
@@ -722,14 +721,14 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
}
var responseData []byte
switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI:
case types.RelayFormatOpenAI:
openaiResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
openaiResponse.Usage = *claudeInfo.Usage
responseData, err = json.Marshal(openaiResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody)
}
case relaycommon.RelayFormatClaude:
case types.RelayFormatClaude:
responseData = data
}

View File

@@ -36,13 +36,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
switch info.RelayMode {
case constant.RelayModeChatCompletions:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.BaseUrl, info.ApiVersion), nil
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.ChannelBaseUrl, info.ApiVersion), nil
case constant.RelayModeEmbeddings:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.BaseUrl, info.ApiVersion), nil
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.ChannelBaseUrl, info.ApiVersion), nil
case constant.RelayModeResponses:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/responses", info.BaseUrl, info.ApiVersion), nil
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/responses", info.ChannelBaseUrl, info.ApiVersion), nil
default:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.BaseUrl, info.ApiVersion, info.UpstreamModelName), nil
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.ChannelBaseUrl, info.ApiVersion, info.UpstreamModelName), nil
}
}

View File

@@ -43,9 +43,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayMode == constant.RelayModeRerank {
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil
} else {
return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil
return fmt.Sprintf("%s/v1/chat", info.ChannelBaseUrl), nil
}
}

View File

@@ -7,7 +7,6 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
"one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -119,7 +118,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
var cohereResp CohereResponse
err := json.Unmarshal([]byte(data), &cohereResp)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
common.SysLog("error unmarshalling stream response: " + err.Error())
return true
}
var openaiResp dto.ChatCompletionsStreamResponse
@@ -154,7 +153,7 @@ func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
}
jsonStr, err := json.Marshal(openaiResp)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
common.SysLog("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})

View File

@@ -122,7 +122,7 @@ func (a *Adaptor) GetModelList() []string {
// GetRequestURL implements channel.Adaptor.
func (a *Adaptor) GetRequestURL(info *common.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v3/chat", info.BaseUrl), nil
return fmt.Sprintf("%s/v3/chat", info.ChannelBaseUrl), nil
}
// Init implements channel.Adaptor.

View File

@@ -9,7 +9,6 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
"one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -155,7 +154,7 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st
var chatData CozeChatResponseData
err := json.Unmarshal([]byte(data), &chatData)
if err != nil {
logger.SysError("error_unmarshalling_stream_response: " + err.Error())
common.SysLog("error_unmarshalling_stream_response: " + err.Error())
return
}
@@ -172,14 +171,14 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st
var messageData CozeChatV3MessageDetail
err := json.Unmarshal([]byte(data), &messageData)
if err != nil {
logger.SysError("error_unmarshalling_stream_response: " + err.Error())
common.SysLog("error_unmarshalling_stream_response: " + err.Error())
return
}
var content string
err = json.Unmarshal(messageData.Content, &content)
if err != nil {
logger.SysError("error_unmarshalling_stream_response: " + err.Error())
common.SysLog("error_unmarshalling_stream_response: " + err.Error())
return
}
@@ -204,16 +203,16 @@ func handleCozeEvent(c *gin.Context, event string, data string, responseText *st
var errorData CozeError
err := json.Unmarshal([]byte(data), &errorData)
if err != nil {
logger.SysError("error_unmarshalling_stream_response: " + err.Error())
common.SysLog("error_unmarshalling_stream_response: " + err.Error())
return
}
logger.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message))
common.SysLog(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message))
}
}
func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) {
requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.BaseUrl)
requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.ChannelBaseUrl)
requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id")
// 将 conversationId和chatId作为参数发送get请求
@@ -259,7 +258,7 @@ func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo
}
func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*http.Response, error) {
requestURL := fmt.Sprintf("%s/v3/chat/message/list", info.BaseUrl)
requestURL := fmt.Sprintf("%s/v3/chat/message/list", info.ChannelBaseUrl)
requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id")
req, err := http.NewRequest("GET", requestURL, nil)

View File

@@ -43,15 +43,15 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
fimBaseUrl := info.BaseUrl
if !strings.HasSuffix(info.BaseUrl, "/beta") {
fimBaseUrl := info.ChannelBaseUrl
if !strings.HasSuffix(info.ChannelBaseUrl, "/beta") {
fimBaseUrl += "/beta"
}
switch info.RelayMode {
case constant.RelayModeCompletions:
return fmt.Sprintf("%s/completions", fimBaseUrl), nil
default:
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
}
}

View File

@@ -61,13 +61,13 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
switch a.BotType {
case BotTypeWorkFlow:
return fmt.Sprintf("%s/v1/workflows/run", info.BaseUrl), nil
return fmt.Sprintf("%s/v1/workflows/run", info.ChannelBaseUrl), nil
case BotTypeCompletion:
return fmt.Sprintf("%s/v1/completion-messages", info.BaseUrl), nil
return fmt.Sprintf("%s/v1/completion-messages", info.ChannelBaseUrl), nil
case BotTypeAgent:
fallthrough
default:
return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
return fmt.Sprintf("%s/v1/chat-messages", info.ChannelBaseUrl), nil
}
}

View File

@@ -11,7 +11,6 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -23,7 +22,7 @@ import (
)
func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, media dto.MediaContent) *DifyFile {
uploadUrl := fmt.Sprintf("%s/v1/files/upload", info.BaseUrl)
uploadUrl := fmt.Sprintf("%s/v1/files/upload", info.ChannelBaseUrl)
switch media.Type {
case dto.ContentTypeImageURL:
// Decode base64 data
@@ -37,14 +36,14 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
// Decode base64 string
decodedData, err := base64.StdEncoding.DecodeString(base64Data)
if err != nil {
logger.SysError("failed to decode base64: " + err.Error())
common.SysLog("failed to decode base64: " + err.Error())
return nil
}
// Create temporary file
tempFile, err := os.CreateTemp("", "dify-upload-*")
if err != nil {
logger.SysError("failed to create temp file: " + err.Error())
common.SysLog("failed to create temp file: " + err.Error())
return nil
}
defer tempFile.Close()
@@ -52,7 +51,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
// Write decoded data to temp file
if _, err := tempFile.Write(decodedData); err != nil {
logger.SysError("failed to write to temp file: " + err.Error())
common.SysLog("failed to write to temp file: " + err.Error())
return nil
}
@@ -62,7 +61,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
// Add user field
if err := writer.WriteField("user", user); err != nil {
logger.SysError("failed to add user field: " + err.Error())
common.SysLog("failed to add user field: " + err.Error())
return nil
}
@@ -75,13 +74,13 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
// Create form file
part, err := writer.CreateFormFile("file", fmt.Sprintf("image.%s", strings.TrimPrefix(mimeType, "image/")))
if err != nil {
logger.SysError("failed to create form file: " + err.Error())
common.SysLog("failed to create form file: " + err.Error())
return nil
}
// Copy file content to form
if _, err = io.Copy(part, bytes.NewReader(decodedData)); err != nil {
logger.SysError("failed to copy file content: " + err.Error())
common.SysLog("failed to copy file content: " + err.Error())
return nil
}
writer.Close()
@@ -89,7 +88,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
// Create HTTP request
req, err := http.NewRequest("POST", uploadUrl, body)
if err != nil {
logger.SysError("failed to create request: " + err.Error())
common.SysLog("failed to create request: " + err.Error())
return nil
}
@@ -100,7 +99,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
client := service.GetHttpClient()
resp, err := client.Do(req)
if err != nil {
logger.SysError("failed to send request: " + err.Error())
common.SysLog("failed to send request: " + err.Error())
return nil
}
defer resp.Body.Close()
@@ -110,7 +109,7 @@ func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, me
Id string `json:"id"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
logger.SysError("failed to decode response: " + err.Error())
common.SysLog("failed to decode response: " + err.Error())
return nil
}
@@ -220,7 +219,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
var difyResponse DifyChunkChatCompletionResponse
err := json.Unmarshal([]byte(data), &difyResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
common.SysLog("error unmarshalling stream response: " + err.Error())
return true
}
var openaiResponse dto.ChatCompletionsStreamResponse
@@ -240,7 +239,7 @@ func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
}
err = helper.ObjectData(c, openaiResponse)
if err != nil {
logger.SysError(err.Error())
common.SysLog(err.Error())
}
return true
})

View File

@@ -108,7 +108,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
version := model_setting.GetGeminiVersionSetting(info.UpstreamModelName)
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil
return fmt.Sprintf("%s/%s/models/%s:predict", info.ChannelBaseUrl, version, info.UpstreamModelName), nil
}
if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
@@ -118,7 +118,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.IsGeminiBatchEmbedding {
action = "batchEmbedContents"
}
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
return fmt.Sprintf("%s/%s/models/%s:%s", info.ChannelBaseUrl, version, info.UpstreamModelName, action), nil
}
action := "generateContent"
@@ -128,7 +128,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
info.DisablePing = true
}
}
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
return fmt.Sprintf("%s/%s/models/%s:%s", info.ChannelBaseUrl, version, info.UpstreamModelName, action), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {

View File

@@ -994,7 +994,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
response := helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
err := handleFinalStream(c, info, response)
if err != nil {
logger.SysError("send final response failed: " + err.Error())
common.SysLog("send final response failed: " + err.Error())
}
//if info.RelayFormat == relaycommon.RelayFormatOpenAI {
// helper.Done(c)
@@ -1042,19 +1042,19 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
fullTextResponse.Usage = usage
switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI:
case types.RelayFormatOpenAI:
responseBody, err = common.Marshal(fullTextResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
case relaycommon.RelayFormatClaude:
case types.RelayFormatClaude:
claudeResp := service.ResponseOpenAI2Claude(fullTextResponse, info)
claudeRespStr, err := common.Marshal(claudeResp)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
responseBody = claudeRespStr
case relaycommon.RelayFormatGemini:
case types.RelayFormatGemini:
break
}

View File

@@ -32,7 +32,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/?Action=CVProcess&Version=2022-08-31", info.BaseUrl), nil
return fmt.Sprintf("%s/?Action=CVProcess&Version=2022-08-31", info.ChannelBaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {

View File

@@ -45,9 +45,9 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayMode == constant.RelayModeRerank {
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil
} else if info.RelayMode == constant.RelayModeEmbeddings {
return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil
return fmt.Sprintf("%s/v1/embeddings", info.ChannelBaseUrl), nil
}
return "", errors.New("invalid relay mode")
}

View File

@@ -6,5 +6,5 @@ import (
)
func GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/text/chatcompletion_v2", info.BaseUrl), nil
return fmt.Sprintf("%s/v1/text/chatcompletion_v2", info.ChannelBaseUrl), nil
}

View File

@@ -41,7 +41,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {

View File

@@ -54,7 +54,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if strings.HasPrefix(info.UpstreamModelName, "m3e") {
suffix = "embeddings"
}
fullRequestURL := fmt.Sprintf("%s/%s", info.BaseUrl, suffix)
fullRequestURL := fmt.Sprintf("%s/%s", info.ChannelBaseUrl, suffix)
return fullRequestURL, nil
}

View File

@@ -44,19 +44,19 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
switch info.RelayFormat {
case relaycommon.RelayFormatClaude:
return fmt.Sprintf("%s/anthropic/v1/messages", info.BaseUrl), nil
case types.RelayFormatClaude:
return fmt.Sprintf("%s/anthropic/v1/messages", info.ChannelBaseUrl), nil
default:
if info.RelayMode == constant.RelayModeRerank {
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil
} else if info.RelayMode == constant.RelayModeEmbeddings {
return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil
return fmt.Sprintf("%s/v1/embeddings", info.ChannelBaseUrl), nil
} else if info.RelayMode == constant.RelayModeChatCompletions {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
} else if info.RelayMode == constant.RelayModeCompletions {
return fmt.Sprintf("%s/v1/completions", info.BaseUrl), nil
return fmt.Sprintf("%s/v1/completions", info.ChannelBaseUrl), nil
}
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
}
}
@@ -89,10 +89,10 @@ func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.Rela
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI:
case types.RelayFormatOpenAI:
adaptor := openai.Adaptor{}
return adaptor.DoResponse(c, resp, info)
case relaycommon.RelayFormatClaude:
case types.RelayFormatClaude:
if info.IsStream {
err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
} else {

View File

@@ -48,14 +48,14 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayFormat == relaycommon.RelayFormatClaude {
return info.BaseUrl + "/v1/chat/completions", nil
if info.RelayFormat == types.RelayFormatClaude {
return info.ChannelBaseUrl + "/v1/chat/completions", nil
}
switch info.RelayMode {
case relayconstant.RelayModeEmbeddings:
return info.BaseUrl + "/api/embed", nil
return info.ChannelBaseUrl + "/api/embed", nil
default:
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
}
}

View File

@@ -105,14 +105,14 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayMode == relayconstant.RelayModeRealtime {
if strings.HasPrefix(info.BaseUrl, "https://") {
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
if strings.HasPrefix(info.ChannelBaseUrl, "https://") {
baseUrl := strings.TrimPrefix(info.ChannelBaseUrl, "https://")
baseUrl = "wss://" + baseUrl
info.BaseUrl = baseUrl
} else if strings.HasPrefix(info.BaseUrl, "http://") {
baseUrl := strings.TrimPrefix(info.BaseUrl, "http://")
info.ChannelBaseUrl = baseUrl
} else if strings.HasPrefix(info.ChannelBaseUrl, "http://") {
baseUrl := strings.TrimPrefix(info.ChannelBaseUrl, "http://")
baseUrl = "ws://" + baseUrl
info.BaseUrl = baseUrl
info.ChannelBaseUrl = baseUrl
}
}
switch info.ChannelType {
@@ -126,7 +126,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
task := strings.TrimPrefix(requestURL, "/v1/")
if info.RelayFormat == relaycommon.RelayFormatClaude {
if info.RelayFormat == types.RelayFormatClaude {
task = strings.TrimPrefix(task, "messages")
task = "chat/completions" + task
}
@@ -136,7 +136,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
responsesApiVersion := "preview"
subUrl := "/openai/v1/responses"
if strings.Contains(info.BaseUrl, "cognitiveservices.azure.com") {
if strings.Contains(info.ChannelBaseUrl, "cognitiveservices.azure.com") {
subUrl = "/openai/responses"
responsesApiVersion = apiVersion
}
@@ -146,7 +146,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
}
requestURL = fmt.Sprintf("%s?api-version=%s", subUrl, responsesApiVersion)
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, requestURL, info.ChannelType), nil
}
model_ := info.UpstreamModelName
@@ -159,18 +159,18 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayMode == relayconstant.RelayModeRealtime {
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, apiVersion)
}
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, requestURL, info.ChannelType), nil
case constant.ChannelTypeMiniMax:
return minimax.GetRequestURL(info)
case constant.ChannelTypeCustom:
url := info.BaseUrl
url := info.ChannelBaseUrl
url = strings.Replace(url, "{model}", info.UpstreamModelName, -1)
return url, nil
default:
if info.RelayFormat == relaycommon.RelayFormatClaude || info.RelayFormat == relaycommon.RelayFormatGemini {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
if info.RelayFormat == types.RelayFormatClaude || info.RelayFormat == types.RelayFormatGemini {
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
}
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
}
}

View File

@@ -12,6 +12,7 @@ import (
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
@@ -22,11 +23,11 @@ func HandleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string
info.SendResponseCount++
switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI:
case types.RelayFormatOpenAI:
return sendStreamData(c, info, data, forceFormat, thinkToContent)
case relaycommon.RelayFormatClaude:
case types.RelayFormatClaude:
return handleClaudeFormat(c, data, info)
case relaycommon.RelayFormatGemini:
case types.RelayFormatGemini:
return handleGeminiFormat(c, data, info)
}
return nil
@@ -111,14 +112,14 @@ func processChatCompletions(streamResp string, streamItems []string, responseTex
var streamResponses []dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
// 一次性解析失败,逐个解析
logger.SysError("error unmarshalling stream response: " + err.Error())
common.SysLog("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
return err
}
if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil {
logger.SysError("error processing stream response: " + err.Error())
common.SysLog("error processing stream response: " + err.Error())
}
}
return nil
@@ -147,7 +148,7 @@ func processCompletions(streamResp string, streamItems []string, responseTextBui
var streamResponses []dto.CompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
// 一次性解析失败,逐个解析
logger.SysError("error unmarshalling stream response: " + err.Error())
common.SysLog("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.CompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
@@ -202,7 +203,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
usage *dto.Usage, containStreamUsage bool) {
switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI:
case types.RelayFormatOpenAI:
if info.ShouldIncludeUsage && !containStreamUsage {
response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
response.SetSystemFingerprint(systemFingerprint)
@@ -210,11 +211,11 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
}
helper.Done(c)
case relaycommon.RelayFormatClaude:
case types.RelayFormatClaude:
info.ClaudeConvertInfo.Done = true
var streamResponse dto.ChatCompletionsStreamResponse
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
common.SysLog("error unmarshalling stream response: " + err.Error())
return
}
@@ -225,10 +226,10 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
_ = helper.ClaudeData(c, *resp)
}
case relaycommon.RelayFormatGemini:
case types.RelayFormatGemini:
var streamResponse dto.ChatCompletionsStreamResponse
if err := common.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
common.SysLog("error unmarshalling stream response: " + err.Error())
return
}
@@ -246,7 +247,7 @@ func HandleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStream
geminiResponseStr, err := common.Marshal(geminiResponse)
if err != nil {
logger.SysError("error marshalling gemini response: " + err.Error())
common.SysLog("error marshalling gemini response: " + err.Error())
return
}

View File

@@ -130,7 +130,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
if lastStreamData != "" {
err := HandleStreamFormat(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
if err != nil {
logger.SysError("error handling stream format: " + err.Error())
common.SysLog("error handling stream format: " + err.Error())
}
}
if len(data) > 0 {
@@ -147,7 +147,7 @@ func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
logger.LogError(c, fmt.Sprintf("error handling last response: %s, lastStreamData: [%s]", err.Error(), lastStreamData))
}
if info.RelayFormat == relaycommon.RelayFormatOpenAI {
if info.RelayFormat == types.RelayFormatOpenAI {
if shouldSendLastResp {
_ = sendStreamData(c, info, lastStreamData, info.ChannelSetting.ForceFormat, info.ChannelSetting.ThinkingToContent)
}
@@ -211,7 +211,7 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
}
switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI:
case types.RelayFormatOpenAI:
if forceFormat {
responseBody, err = common.Marshal(simpleResponse)
if err != nil {
@@ -220,14 +220,14 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
} else {
break
}
case relaycommon.RelayFormatClaude:
case types.RelayFormatClaude:
claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info)
claudeRespStr, err := common.Marshal(claudeResp)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
responseBody = claudeRespStr
case relaycommon.RelayFormatGemini:
case types.RelayFormatGemini:
geminiResp := service.ResponseOpenAI2Gemini(&simpleResponse, info)
geminiRespStr, err := common.Marshal(geminiResp)
if err != nil {

View File

@@ -42,7 +42,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.BaseUrl), nil
return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.ChannelBaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {

View File

@@ -7,7 +7,6 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -59,7 +58,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
go func() {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
logger.SysError("error reading stream response: " + err.Error())
common.SysLog("error reading stream response: " + err.Error())
stopChan <- true
return
}
@@ -67,7 +66,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
var palmResponse PaLMChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
common.SysLog("error unmarshalling stream response: " + err.Error())
stopChan <- true
return
}
@@ -79,7 +78,7 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
}
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
common.SysLog("error marshalling stream response: " + err.Error())
stopChan <- true
return
}

View File

@@ -42,7 +42,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil
return fmt.Sprintf("%s/chat/completions", info.ChannelBaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {

View File

@@ -43,15 +43,15 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayMode == constant.RelayModeRerank {
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
return fmt.Sprintf("%s/v1/rerank", info.ChannelBaseUrl), nil
} else if info.RelayMode == constant.RelayModeEmbeddings {
return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil
return fmt.Sprintf("%s/v1/embeddings", info.ChannelBaseUrl), nil
} else if info.RelayMode == constant.RelayModeChatCompletions {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
} else if info.RelayMode == constant.RelayModeCompletions {
return fmt.Sprintf("%s/v1/completions", info.BaseUrl), nil
return fmt.Sprintf("%s/v1/completions", info.ChannelBaseUrl), nil
}
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
return fmt.Sprintf("%s/v1/chat/completions", info.ChannelBaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {

View File

@@ -76,7 +76,7 @@ type TaskAdaptor struct {
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
a.ChannelType = info.ChannelType
a.baseURL = info.BaseUrl
a.baseURL = info.ChannelBaseUrl
// apiKey format: "access_key|secret_key"
keyParts := strings.Split(info.ApiKey, "|")

View File

@@ -81,7 +81,7 @@ type TaskAdaptor struct {
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
a.ChannelType = info.ChannelType
a.baseURL = info.BaseUrl
a.baseURL = info.ChannelBaseUrl
a.apiKey = info.ApiKey
// apiKey format: "access_key|secret_key"

View File

@@ -11,7 +11,6 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/logger"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
@@ -60,7 +59,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
}
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
baseURL := info.BaseUrl
baseURL := info.ChannelBaseUrl
fullRequestURL := fmt.Sprintf("%s%s", baseURL, "/suno/submit/"+info.Action)
return fullRequestURL, nil
}
@@ -140,7 +139,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(byteBody))
if err != nil {
logger.SysError(fmt.Sprintf("Get Task error: %v", err))
common.SysLog(fmt.Sprintf("Get Task error: %v", err))
return nil, err
}
defer req.Body.Close()

View File

@@ -86,7 +86,7 @@ type TaskAdaptor struct {
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
a.ChannelType = info.ChannelType
a.baseURL = info.BaseUrl
a.baseURL = info.ChannelBaseUrl
}
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError {

View File

@@ -53,7 +53,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/", info.BaseUrl), nil
return fmt.Sprintf("%s/", info.ChannelBaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {

View File

@@ -13,7 +13,6 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -107,7 +106,7 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
var tencentResponse TencentChatResponse
err := json.Unmarshal([]byte(data), &tencentResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
common.SysLog("error unmarshalling stream response: " + err.Error())
continue
}
@@ -118,12 +117,12 @@ func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *htt
err = helper.ObjectData(c, response)
if err != nil {
logger.SysError(err.Error())
common.SysLog(err.Error())
}
}
if err := scanner.Err(); err != nil {
logger.SysError("error reading stream: " + err.Error())
common.SysLog("error reading stream: " + err.Error())
}
helper.Done(c)

View File

@@ -188,17 +188,17 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
switch info.RelayMode {
case constant.RelayModeChatCompletions:
if strings.HasPrefix(info.UpstreamModelName, "bot") {
return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.BaseUrl), nil
return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.ChannelBaseUrl), nil
}
return fmt.Sprintf("%s/api/v3/chat/completions", info.BaseUrl), nil
return fmt.Sprintf("%s/api/v3/chat/completions", info.ChannelBaseUrl), nil
case constant.RelayModeEmbeddings:
return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil
return fmt.Sprintf("%s/api/v3/embeddings", info.ChannelBaseUrl), nil
case constant.RelayModeImagesGenerations:
return fmt.Sprintf("%s/api/v3/images/generations", info.BaseUrl), nil
return fmt.Sprintf("%s/api/v3/images/generations", info.ChannelBaseUrl), nil
case constant.RelayModeImagesEdits:
return fmt.Sprintf("%s/api/v3/images/edits", info.BaseUrl), nil
return fmt.Sprintf("%s/api/v3/images/edits", info.ChannelBaseUrl), nil
case constant.RelayModeRerank:
return fmt.Sprintf("%s/api/v3/rerank", info.BaseUrl), nil
return fmt.Sprintf("%s/api/v3/rerank", info.ChannelBaseUrl), nil
default:
}
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)

View File

@@ -39,7 +39,7 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
xaiRequest := ImageRequest{
Model: request.Model,
Prompt: request.Prompt,
N: request.N,
N: int(request.N),
ResponseFormat: request.ResponseFormat,
}
return xaiRequest, nil
@@ -49,7 +49,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
return relaycommon.GetFullRequestURL(info.ChannelBaseUrl, info.RequestURLPath, info.ChannelType), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {

View File

@@ -6,7 +6,6 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
"one-api/logger"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
@@ -48,7 +47,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
var xAIResp *dto.ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &xAIResp)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
common.SysLog("error unmarshalling stream response: " + err.Error())
return true
}
@@ -64,7 +63,7 @@ func xAIStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Re
_ = openai.ProcessStreamResponse(*openaiResponse, &responseTextBuilder, &toolCount)
err = helper.ObjectData(c, openaiResponse)
if err != nil {
logger.SysError(err.Error())
common.SysLog(err.Error())
}
return true
})

View File

@@ -11,7 +11,6 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/logger"
"one-api/relay/helper"
"one-api/types"
"strings"
@@ -144,7 +143,7 @@ func xunfeiStreamHandler(c *gin.Context, textRequest dto.GeneralOpenAIRequest, a
response := streamResponseXunfei2OpenAI(&xunfeiResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
common.SysLog("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -219,20 +218,20 @@ func xunfeiMakeRequest(textRequest dto.GeneralOpenAIRequest, domain, authUrl, ap
for {
_, msg, err := conn.ReadMessage()
if err != nil {
logger.SysError("error reading stream response: " + err.Error())
common.SysLog("error reading stream response: " + err.Error())
break
}
var response XunfeiChatResponse
err = json.Unmarshal(msg, &response)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
common.SysLog("error unmarshalling stream response: " + err.Error())
break
}
dataChan <- response
if response.Payload.Choices.Status == 2 {
err := conn.Close()
if err != nil {
logger.SysError("error closing websocket connection: " + err.Error())
common.SysLog("error closing websocket connection: " + err.Error())
}
break
}
@@ -283,6 +282,6 @@ func getAPIVersion(c *gin.Context, modelName string) string {
return apiVersion
}
apiVersion = "v1.1"
logger.SysLog("api_version not found, using default: " + apiVersion)
common.SysLog("api_version not found, using default: " + apiVersion)
return apiVersion
}

View File

@@ -45,7 +45,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.IsStream {
method = "sse-invoke"
}
return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.BaseUrl, info.UpstreamModelName, method), nil
return fmt.Sprintf("%s/api/paas/v3/model-api/%s/%s", info.ChannelBaseUrl, info.UpstreamModelName, method), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {

View File

@@ -8,7 +8,6 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/logger"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -40,7 +39,7 @@ func getZhipuToken(apikey string) string {
split := strings.Split(apikey, ".")
if len(split) != 2 {
logger.SysError("invalid zhipu key: " + apikey)
common.SysLog("invalid zhipu key: " + apikey)
return ""
}
@@ -188,7 +187,7 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
response := streamResponseZhipu2OpenAI(data)
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
common.SysLog("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
@@ -197,13 +196,13 @@ func zhipuStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.
var zhipuResponse ZhipuStreamMetaResponse
err := json.Unmarshal([]byte(data), &zhipuResponse)
if err != nil {
logger.SysError("error unmarshalling stream response: " + err.Error())
common.SysLog("error unmarshalling stream response: " + err.Error())
return true
}
response, zhipuUsage := streamMetaResponseZhipu2OpenAI(&zhipuResponse)
jsonResponse, err := json.Marshal(response)
if err != nil {
logger.SysError("error marshalling stream response: " + err.Error())
common.SysLog("error marshalling stream response: " + err.Error())
return true
}
usage = zhipuUsage

View File

@@ -43,7 +43,7 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
baseUrl := fmt.Sprintf("%s/api/paas/v4", info.BaseUrl)
baseUrl := fmt.Sprintf("%s/api/paas/v4", info.ChannelBaseUrl)
switch info.RelayMode {
case relayconstant.RelayModeEmbeddings:
return fmt.Sprintf("%s/embeddings", baseUrl), nil

View File

@@ -66,6 +66,7 @@ type ChannelMeta struct {
ChannelOtherSettings dto.ChannelOtherSettings
UpstreamModelName string
IsModelMapped bool
SupportStreamOptions bool // 是否支持流式选项
}
type RelayInfo struct {
@@ -86,9 +87,9 @@ type RelayInfo struct {
RelayMode int
OriginModelName string
//RecodeModelName string
RequestURLPath string
PromptTokens int
SupportStreamOptions bool
RequestURLPath string
PromptTokens int
//SupportStreamOptions bool
ShouldIncludeUsage bool
DisablePing bool // 是否禁止向下游发送自定义 Ping
ClientWs *websocket.Conn
@@ -135,6 +136,7 @@ func (info *RelayInfo) InitChannelMeta(c *gin.Context) {
ParamOverride: paramOverride,
UpstreamModelName: common.GetContextKeyString(c, constant.ContextKeyOriginalModel),
IsModelMapped: false,
SupportStreamOptions: false,
}
channelSetting, ok := common.GetContextKeyType[dto.ChannelSettings](c, constant.ContextKeyChannelSetting)
@@ -146,6 +148,10 @@ func (info *RelayInfo) InitChannelMeta(c *gin.Context) {
if ok {
channelMeta.ChannelOtherSettings = channelOtherSettings
}
if streamSupportedChannels[channelMeta.ChannelType] {
channelMeta.SupportStreamOptions = true
}
info.ChannelMeta = channelMeta
}
@@ -268,6 +274,12 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
startTime = time.Now()
}
isStream := false
if request != nil {
isStream = request.IsStream(c)
}
// firstResponseTime = time.Now() - 1 second
info := &RelayInfo{
@@ -289,7 +301,7 @@ func genBaseRelayInfo(c *gin.Context, request dto.Request) *RelayInfo {
isFirstResponse: true,
RelayMode: relayconstant.Path2RelayMode(c.Request.URL.Path),
RequestURLPath: c.Request.URL.String(),
IsStream: request.IsStream(c),
IsStream: isStream,
StartTime: startTime,
FirstResponseTime: startTime.Add(-time.Second),
@@ -339,6 +351,10 @@ func GenRelayInfo(c *gin.Context, relayFormat types.RelayFormat, request dto.Req
return GenRelayInfoResponses(c, request), nil
}
return nil, errors.New("request is not a OpenAIResponsesRequest")
case types.RelayFormatTask:
return genBaseRelayInfo(c, nil), nil
case types.RelayFormatMjProxy:
return genBaseRelayInfo(c, nil), nil
default:
return nil, errors.New("invalid relay format")
}
@@ -367,11 +383,15 @@ type TaskRelayInfo struct {
ConsumeQuota bool
}
func GenTaskRelayInfo(c *gin.Context) *TaskRelayInfo {
info := &TaskRelayInfo{
RelayInfo: GenRelayInfo(c),
func GenTaskRelayInfo(c *gin.Context) (*TaskRelayInfo, error) {
relayInfo, err := GenRelayInfo(c, types.RelayFormatTask, nil, nil)
if err != nil {
return nil, err
}
return info
info := &TaskRelayInfo{
RelayInfo: relayInfo,
}
return info, nil
}
type TaskSubmitReq struct {

View File

@@ -53,9 +53,9 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
var imageRatio float64
var cacheCreationRatio float64
if !usePrice {
preConsumedTokens := common.PreConsumedQuota
preConsumedTokens := common.Max(promptTokens, common.PreConsumedQuota)
if meta.MaxTokens != 0 {
preConsumedTokens = promptTokens + meta.MaxTokens
preConsumedTokens += meta.MaxTokens
}
var success bool
var matchName string
@@ -102,27 +102,27 @@ func ModelPriceHelper(c *gin.Context, info *relaycommon.RelayInfo, promptTokens
}
// ModelPriceHelperPerCall 按次计费的 PriceHelper (MJ、Task)
//func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PerCallPriceData {
// groupRatioInfo := HandleGroupRatio(c, info)
//
// modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true)
// // 如果没有配置价格,则使用默认价格
// if !success {
// defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[info.OriginModelName]
// if !ok {
// modelPrice = 0.1
// } else {
// modelPrice = defaultPrice
// }
// }
// quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
// priceData := types.PerCallPriceData{
// ModelPrice: modelPrice,
// Quota: quota,
// GroupRatioInfo: groupRatioInfo,
// }
// return priceData
//}
func ModelPriceHelperPerCall(c *gin.Context, info *relaycommon.RelayInfo) types.PerCallPriceData {
groupRatioInfo := HandleGroupRatio(c, info)
modelPrice, success := ratio_setting.GetModelPrice(info.OriginModelName, true)
// 如果没有配置价格,则使用默认价格
if !success {
defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[info.OriginModelName]
if !ok {
modelPrice = 0.1
} else {
modelPrice = defaultPrice
}
}
quota := int(modelPrice * common.QuotaPerUnit * groupRatioInfo.GroupRatio)
priceData := types.PerCallPriceData{
ModelPrice: modelPrice,
Quota: quota,
GroupRatioInfo: groupRatioInfo,
}
return priceData
}
func ContainPriceOrRatio(modelName string) bool {
_, ok := ratio_setting.GetModelPrice(modelName, false)

View File

@@ -36,7 +36,7 @@ func GetAndValidateRequest(c *gin.Context, format types.RelayFormat) (request dt
case types.RelayFormatOpenAIAudio:
request, err = GetAndValidAudioRequest(c, relayMode)
case types.RelayFormatOpenAIRealtime:
// nothing to do, no request body
request = &dto.BaseRequest{}
default:
return nil, fmt.Errorf("unsupported relay format: %s", format)
}

View File

@@ -10,7 +10,6 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/logger"
"one-api/model"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
@@ -171,13 +170,7 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
return
}
func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
startTime := time.Now().UnixNano() / int64(time.Millisecond)
tokenId := c.GetInt("token_id")
userId := c.GetInt("id")
//group := c.GetString("group")
channelId := c.GetInt("channel_id")
relayInfo := relaycommon.GenRelayInfo(c)
func RelaySwapFace(c *gin.Context, info *relaycommon.RelayInfo) *dto.MidjourneyResponse {
var swapFaceRequest dto.SwapFaceRequest
err := common.UnmarshalBodyReusable(c, &swapFaceRequest)
if err != nil {
@@ -188,9 +181,9 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
}
modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
priceData := helper.ModelPriceHelperPerCall(c, relayInfo)
priceData := helper.ModelPriceHelperPerCall(c, info)
userQuota, err := model.GetUserQuota(userId, false)
userQuota, err := model.GetUserQuota(info.UserId, false)
if err != nil {
return &dto.MidjourneyResponse{
Code: 4,
@@ -213,32 +206,31 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
}
defer func() {
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true)
err := service.PostConsumeQuota(info, priceData.Quota, 0, true)
if err != nil {
logger.SysError("error consuming token remain quota: " + err.Error())
common.SysLog("error consuming token remain quota: " + err.Error())
}
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, constant.MjActionSwapFace)
other := service.GenerateMjOtherInfo(priceData)
model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: channelId,
model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
ChannelId: info.ChannelId,
ModelName: modelName,
TokenName: tokenName,
Quota: priceData.Quota,
Content: logContent,
TokenId: tokenId,
UserQuota: userQuota,
Group: relayInfo.UsingGroup,
TokenId: info.TokenId,
Group: info.UsingGroup,
Other: other,
})
model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota)
model.UpdateChannelUsedQuota(channelId, priceData.Quota)
model.UpdateUserUsedQuotaAndRequestCount(info.UserId, priceData.Quota)
model.UpdateChannelUsedQuota(info.ChannelId, priceData.Quota)
}
}()
midjResponse := &mjResp.Response
midjourneyTask := &model.Midjourney{
UserId: userId,
UserId: info.UserId,
Code: midjResponse.Code,
Action: constant.MjActionSwapFace,
MjId: midjResponse.Result,
@@ -246,7 +238,7 @@ func RelaySwapFace(c *gin.Context) *dto.MidjourneyResponse {
PromptEn: "",
Description: midjResponse.Description,
State: "",
SubmitTime: startTime,
SubmitTime: info.StartTime.UnixNano() / int64(time.Millisecond),
StartTime: time.Now().UnixNano() / int64(time.Millisecond),
FinishTime: 0,
ImageUrl: "",
@@ -370,14 +362,7 @@ func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse
return nil
}
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
//tokenId := c.GetInt("token_id")
//channelType := c.GetInt("channel")
userId := c.GetInt("id")
group := c.GetString("group")
channelId := c.GetInt("channel_id")
relayInfo := relaycommon.GenRelayInfo(c)
func RelayMidjourneySubmit(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dto.MidjourneyResponse {
consumeQuota := true
var midjRequest dto.MidjourneyRequest
err := common.UnmarshalBodyReusable(c, &midjRequest)
@@ -385,35 +370,35 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
}
if relayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus需要从customId中获取任务信息
if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus需要从customId中获取任务信息
mjErr := service.CoverPlusActionToNormalAction(&midjRequest)
if mjErr != nil {
return mjErr
}
relayMode = relayconstant.RelayModeMidjourneyChange
relayInfo.RelayMode = relayconstant.RelayModeMidjourneyChange
}
if relayMode == relayconstant.RelayModeMidjourneyVideo {
if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyVideo {
midjRequest.Action = constant.MjActionVideo
}
if relayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
if midjRequest.Prompt == "" {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "prompt_is_required")
}
midjRequest.Action = constant.MjActionImagine
} else if relayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
midjRequest.Action = constant.MjActionDescribe
} else if relayMode == relayconstant.RelayModeMidjourneyEdits { //编辑任务,此类任务可重复
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyEdits { //编辑任务,此类任务可重复
midjRequest.Action = constant.MjActionEdits
} else if relayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务此类任务可重复plus only
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务此类任务可重复plus only
midjRequest.Action = constant.MjActionShorten
} else if relayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
midjRequest.Action = constant.MjActionBlend
} else if relayMode == relayconstant.RelayModeMidjourneyUpload { //绘画任务,此类任务可重复
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyUpload { //绘画任务,此类任务可重复
midjRequest.Action = constant.MjActionUpload
} else if midjRequest.TaskId != "" { //放大、变换任务此类任务如果重复且已有结果远端api会直接返回最终结果
mjId := ""
if relayMode == relayconstant.RelayModeMidjourneyChange {
if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyChange {
if midjRequest.TaskId == "" {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required")
} else if midjRequest.Action == "" {
@@ -423,7 +408,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
//action = midjRequest.Action
mjId = midjRequest.TaskId
} else if relayMode == relayconstant.RelayModeMidjourneySimpleChange {
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneySimpleChange {
if midjRequest.Content == "" {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required")
}
@@ -433,13 +418,13 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
}
mjId = params.TaskId
midjRequest.Action = params.Action
} else if relayMode == relayconstant.RelayModeMidjourneyModal {
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyModal {
//if midjRequest.MaskBase64 == "" {
// return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required")
//}
mjId = midjRequest.TaskId
midjRequest.Action = constant.MjActionModal
} else if relayMode == relayconstant.RelayModeMidjourneyVideo {
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyVideo {
midjRequest.Action = constant.MjActionVideo
if midjRequest.TaskId == "" {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required")
@@ -449,12 +434,12 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
mjId = midjRequest.TaskId
}
originTask := model.GetByMJId(userId, mjId)
originTask := model.GetByMJId(relayInfo.UserId, mjId)
if originTask == nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found")
} else { //原任务的Status=SUCCESS则可以做放大UPSCALE、变换VARIATION等动作此时必须使用原来的请求地址才能正确处理
if setting.MjActionCheckSuccessEnabled {
if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
if originTask.Status != "SUCCESS" && relayInfo.RelayMode != relayconstant.RelayModeMidjourneyModal {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
}
}
@@ -497,7 +482,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
priceData := helper.ModelPriceHelperPerCall(c, relayInfo)
userQuota, err := model.GetUserQuota(userId, false)
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
return &dto.MidjourneyResponse{
Code: 4,
@@ -522,24 +507,23 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true)
if err != nil {
logger.SysError("error consuming token remain quota: " + err.Error())
common.SysLog("error consuming token remain quota: " + err.Error())
}
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %sID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result)
other := service.GenerateMjOtherInfo(priceData)
model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: channelId,
ChannelId: relayInfo.ChannelId,
ModelName: modelName,
TokenName: tokenName,
Quota: priceData.Quota,
Content: logContent,
TokenId: relayInfo.TokenId,
UserQuota: userQuota,
Group: group,
Group: relayInfo.UsingGroup,
Other: other,
})
model.UpdateUserUsedQuotaAndRequestCount(userId, priceData.Quota)
model.UpdateChannelUsedQuota(channelId, priceData.Quota)
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, priceData.Quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, priceData.Quota)
}
}()
@@ -551,7 +535,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
// 24-prompt包含敏感词 {"code":24,"description":"可能包含敏感词","properties":{"promptEn":"nude body","bannedWord":"nude"}}
// other: 提交错误description为错误描述
midjourneyTask := &model.Midjourney{
UserId: userId,
UserId: relayInfo.UserId,
Code: midjResponse.Code,
Action: midjRequest.Action,
MjId: midjResponse.Result,
@@ -573,7 +557,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
//无实例账号自动禁用渠道No available account instance
channel, err := model.GetChannelById(midjourneyTask.ChannelId, true)
if err != nil {
logger.SysError("get_channel_null: " + err.Error())
common.SysLog("get_channel_null: " + err.Error())
}
if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled {
model.UpdateChannelStatus(midjourneyTask.ChannelId, "", 2, "No available account instance")

View File

@@ -44,6 +44,26 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
includeUsage := true
// 判断用户是否需要返回使用情况
if textRequest.StreamOptions != nil {
includeUsage = textRequest.StreamOptions.IncludeUsage
}
// 如果不支持StreamOptions将StreamOptions设置为nil
if !info.SupportStreamOptions || !textRequest.Stream {
textRequest.StreamOptions = nil
} else {
// 如果支持StreamOptions且请求中没有设置StreamOptions根据配置文件设置StreamOptions
if constant.ForceStreamOption {
textRequest.StreamOptions = &dto.StreamOptions{
IncludeUsage: true,
}
}
}
info.ShouldIncludeUsage = includeUsage
adaptor := GetAdaptor(info.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())

View File

@@ -10,7 +10,6 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/logger"
"one-api/model"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
@@ -28,7 +27,11 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
if platform == "" {
platform = GetTaskPlatform(c)
}
relayInfo := relaycommon.GenTaskRelayInfo(c)
relayInfo, err := relaycommon.GenTaskRelayInfo(c)
if err != nil {
return service.TaskErrorWrapper(err, "gen_relay_info_failed", http.StatusInternalServerError)
}
adaptor := GetTaskAdaptor(platform)
if adaptor == nil {
@@ -98,7 +101,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
c.Set("channel_id", originTask.ChannelId)
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
relayInfo.BaseUrl = channel.GetBaseURL()
relayInfo.ChannelBaseUrl = channel.GetBaseURL()
relayInfo.ChannelId = originTask.ChannelId
}
}
@@ -128,7 +131,7 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true)
if err != nil {
logger.SysError("error consuming token remain quota: " + err.Error())
common.SysLog("error consuming token remain quota: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
@@ -150,7 +153,6 @@ func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
Quota: quota,
Content: logContent,
TokenId: relayInfo.TokenId,
UserQuota: userQuota,
Group: relayInfo.UsingGroup,
Other: other,
})

View File

@@ -4,7 +4,6 @@ import (
"fmt"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
@@ -12,58 +11,35 @@ import (
"github.com/gorilla/websocket"
)
func WssHelper(c *gin.Context, ws *websocket.Conn) (newAPIError *types.NewAPIError) {
relayInfo := relaycommon.GenRelayInfoWs(c, ws)
func WssHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types.NewAPIError) {
info.InitChannelMeta(c)
err := helper.ModelMappedHelper(c, relayInfo, nil)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError, types.ErrOptionWithSkipRetry())
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, 0, 0)
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError, types.ErrOptionWithSkipRetry())
}
// pre-consume quota 预消耗配额
preConsumedQuota, userQuota, newAPIError := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if newAPIError != nil {
return newAPIError
}
defer func() {
if newAPIError != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
adaptor := GetAdaptor(relayInfo.ApiType)
adaptor := GetAdaptor(info.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
return types.NewError(fmt.Errorf("invalid api type: %d", info.ApiType), types.ErrorCodeInvalidApiType, types.ErrOptionWithSkipRetry())
}
adaptor.Init(relayInfo)
adaptor.Init(info)
//var requestBody io.Reader
//firstWssRequest, _ := c.Get("first_wss_request")
//requestBody = bytes.NewBuffer(firstWssRequest.([]byte))
statusCodeMappingStr := c.GetString("status_code_mapping")
resp, err := adaptor.DoRequest(c, relayInfo, nil)
resp, err := adaptor.DoRequest(c, info, nil)
if err != nil {
return types.NewError(err, types.ErrorCodeDoRequestFailed)
}
if resp != nil {
relayInfo.TargetWs = resp.(*websocket.Conn)
defer relayInfo.TargetWs.Close()
info.TargetWs = resp.(*websocket.Conn)
defer info.TargetWs.Close()
}
usage, newAPIError := adaptor.DoResponse(c, nil, relayInfo)
usage, newAPIError := adaptor.DoResponse(c, nil, info)
if newAPIError != nil {
// reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
service.PostWssConsumeQuota(c, relayInfo, relayInfo.UpstreamModelName, usage.(*dto.RealtimeUsage), preConsumedQuota,
userQuota, priceData, "")
service.PostWssConsumeQuota(c, info, info.UpstreamModelName, usage.(*dto.RealtimeUsage), "")
return nil
}