diff --git a/controller/image.go b/controller/image.go new file mode 100644 index 00000000..d6e8806a --- /dev/null +++ b/controller/image.go @@ -0,0 +1,9 @@ +package controller + +import ( + "github.com/gin-gonic/gin" +) + +func GetImage(c *gin.Context) { + +} diff --git a/dto/openai_request.go b/dto/openai_request.go index 812e14a5..a0a05b34 100644 --- a/dto/openai_request.go +++ b/dto/openai_request.go @@ -113,9 +113,21 @@ type MediaContent struct { InputAudio any `json:"input_audio,omitempty"` } +func (m *MediaContent) GetImageMedia() *MessageImageUrl { + if m.ImageUrl != nil { + return m.ImageUrl.(*MessageImageUrl) + } + return nil +} + type MessageImageUrl struct { - Url string `json:"url"` - Detail string `json:"detail"` + Url string `json:"url"` + Detail string `json:"detail"` + MimeType string +} + +func (m *MessageImageUrl) IsRemoteImage() bool { + return strings.HasPrefix(m.Url, "http") } type MessageInputAudio struct { @@ -244,43 +256,39 @@ func (m *Message) ParseContent() []MediaContent { case ContentTypeImageURL: imageUrl := contentItem["image_url"] + temp := &MessageImageUrl{ + Detail: "high", + } switch v := imageUrl.(type) { case string: - contentList = append(contentList, MediaContent{ - Type: ContentTypeImageURL, - ImageUrl: MessageImageUrl{ - Url: v, - Detail: "high", - }, - }) + temp.Url = v case map[string]interface{}: url, ok1 := v["url"].(string) detail, ok2 := v["detail"].(string) - if !ok2 { - detail = "high" + if ok2 { + temp.Detail = detail } if ok1 { - contentList = append(contentList, MediaContent{ - Type: ContentTypeImageURL, - ImageUrl: MessageImageUrl{ - Url: url, - Detail: detail, - }, - }) + temp.Url = url } } + contentList = append(contentList, MediaContent{ + Type: ContentTypeImageURL, + ImageUrl: temp, + }) case ContentTypeInputAudio: if audioData, ok := contentItem["input_audio"].(map[string]interface{}); ok { data, ok1 := audioData["data"].(string) format, ok2 := audioData["format"].(string) if ok1 && ok2 { + temp := &MessageInputAudio{ + Data: data, + Format: format, + } contentList = append(contentList, MediaContent{ - Type: ContentTypeInputAudio, - InputAudio: MessageInputAudio{ - Data: data, - Format: format, - }, + Type: ContentTypeInputAudio, + InputAudio: temp, }) } } diff --git a/relay/channel/claude/relay-claude.go b/relay/channel/claude/relay-claude.go index 3dbca4a9..92f8bbbe 100644 --- a/relay/channel/claude/relay-claude.go +++ b/relay/channel/claude/relay-claude.go @@ -414,7 +414,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto // 加密的不管, 只输出明文的推理过程 thinkingContent = message.Thinking case "text": - responseText = *message.Text + responseText = message.GetText() } } } diff --git a/relay/channel/dify/adaptor.go b/relay/channel/dify/adaptor.go index 54d10c97..dddcb994 100644 --- a/relay/channel/dify/adaptor.go +++ b/relay/channel/dify/adaptor.go @@ -74,7 +74,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn if request == nil { return nil, errors.New("request is nil") } - return requestOpenAI2Dify(*request), nil + return requestOpenAI2Dify(c, info, *request), nil } func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { diff --git a/relay/channel/dify/dto.go b/relay/channel/dify/dto.go index bc641a22..7c6f39b6 100644 --- a/relay/channel/dify/dto.go +++ b/relay/channel/dify/dto.go @@ -8,6 +8,14 @@ type DifyChatRequest struct { ResponseMode string `json:"response_mode"` User string `json:"user"` AutoGenerateName bool `json:"auto_generate_name"` + Files []DifyFile `json:"files"` +} + +type DifyFile struct { + Type string `json:"type"` + TransferMode string `json:"transfer_mode"` + URL string `json:"url,omitempty"` + UploadFileId string `json:"upload_file_id,omitempty"` } type DifyMetaData struct { @@ -17,6 +25,8 @@ type DifyMetaData struct { type DifyData struct { WorkflowId string `json:"workflow_id"` NodeId string `json:"node_id"` + NodeType string `json:"node_type"` + Status string `json:"status"` } type DifyChatCompletionResponse struct { diff --git a/relay/channel/dify/relay-dify.go b/relay/channel/dify/relay-dify.go index 3e62d41c..a0fd1f07 100644 --- a/relay/channel/dify/relay-dify.go +++ b/relay/channel/dify/relay-dify.go @@ -2,9 +2,12 @@ package dify import ( "bufio" + "bytes" + "encoding/base64" "encoding/json" - "github.com/gin-gonic/gin" + "fmt" "io" + "mime/multipart" "net/http" "one-api/common" "one-api/constant" @@ -12,35 +15,163 @@ import ( relaycommon "one-api/relay/common" "one-api/relay/helper" "one-api/service" + "os" "strings" + + "github.com/gin-gonic/gin" ) -func requestOpenAI2Dify(request dto.GeneralOpenAIRequest) *DifyChatRequest { - content := "" - for _, message := range request.Messages { - if message.Role == "system" { - content += "SYSTEM: \n" + message.StringContent() + "\n" - } else if message.Role == "assistant" { - content += "ASSISTANT: \n" + message.StringContent() + "\n" - } else { - content += "USER: \n" + message.StringContent() + "\n" +func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, media dto.MediaContent) *DifyFile { + uploadUrl := fmt.Sprintf("%s/v1/files/upload", info.BaseUrl) + switch media.Type { + case dto.ContentTypeImageURL: + // Decode base64 data + imageMedia := media.GetImageMedia() + base64Data := imageMedia.Url + // Remove base64 prefix if exists (e.g., "data:image/jpeg;base64,") + if idx := strings.Index(base64Data, ","); idx != -1 { + base64Data = base64Data[idx+1:] + } + + // Decode base64 string + decodedData, err := base64.StdEncoding.DecodeString(base64Data) + if err != nil { + common.SysError("failed to decode base64: " + err.Error()) + return nil + } + + // Create temporary file + tempFile, err := os.CreateTemp("", "dify-upload-*") + if err != nil { + common.SysError("failed to create temp file: " + err.Error()) + return nil + } + defer tempFile.Close() + defer os.Remove(tempFile.Name()) + + // Write decoded data to temp file + if _, err := tempFile.Write(decodedData); err != nil { + common.SysError("failed to write to temp file: " + err.Error()) + return nil + } + + // Create multipart form + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + + // Add user field + if err := writer.WriteField("user", user); err != nil { + common.SysError("failed to add user field: " + err.Error()) + return nil + } + + // Create form file with proper mime type + mimeType := imageMedia.MimeType + if mimeType == "" { + mimeType = "image/jpeg" // default mime type + } + + // Create form file + part, err := writer.CreateFormFile("file", fmt.Sprintf("image.%s", strings.TrimPrefix(mimeType, "image/"))) + if err != nil { + common.SysError("failed to create form file: " + err.Error()) + return nil + } + + // Copy file content to form + if _, err = io.Copy(part, bytes.NewReader(decodedData)); err != nil { + common.SysError("failed to copy file content: " + err.Error()) + return nil + } + writer.Close() + + // Create HTTP request + req, err := http.NewRequest("POST", uploadUrl, body) + if err != nil { + common.SysError("failed to create request: " + err.Error()) + return nil + } + + req.Header.Set("Content-Type", writer.FormDataContentType()) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey)) + + // Send request + client := service.GetImpatientHttpClient() + resp, err := client.Do(req) + if err != nil { + common.SysError("failed to send request: " + err.Error()) + return nil + } + defer resp.Body.Close() + + // Parse response + var result struct { + Id string `json:"id"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + common.SysError("failed to decode response: " + err.Error()) + return nil + } + + return &DifyFile{ + UploadFileId: result.Id, + Type: "image", + TransferMode: "local_file", } } + return nil +} + +func requestOpenAI2Dify(c *gin.Context, info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) *DifyChatRequest { + difyReq := DifyChatRequest{ + Inputs: make(map[string]interface{}), + AutoGenerateName: false, + } + + user := request.User + if user == "" { + user = helper.GetResponseID(c) + } + difyReq.User = user + + files := make([]DifyFile, 0) + var content strings.Builder + for _, message := range request.Messages { + if message.Role == "system" { + content.WriteString("SYSTEM: \n" + message.StringContent() + "\n") + } else if message.Role == "assistant" { + content.WriteString("ASSISTANT: \n" + message.StringContent() + "\n") + } else { + parseContent := message.ParseContent() + for _, mediaContent := range parseContent { + switch mediaContent.Type { + case dto.ContentTypeText: + content.WriteString("USER: \n" + mediaContent.Text + "\n") + case dto.ContentTypeImageURL: + media := mediaContent.GetImageMedia() + var file *DifyFile + if media.IsRemoteImage() { + file.Type = media.MimeType + file.TransferMode = "remote_url" + file.URL = media.Url + } else { + file = uploadDifyFile(c, info, difyReq.User, mediaContent) + } + if file != nil { + files = append(files, *file) + } + } + } + } + } + difyReq.Query = content.String() + difyReq.Files = files mode := "blocking" if request.Stream { mode = "streaming" } - user := request.User - if user == "" { - user = "api-user" - } - return &DifyChatRequest{ - Inputs: make(map[string]interface{}), - Query: content, - ResponseMode: mode, - User: user, - AutoGenerateName: false, - } + difyReq.ResponseMode = mode + return &difyReq } func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dto.ChatCompletionsStreamResponse { @@ -50,10 +181,22 @@ func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dt Model: "dify", } var choice dto.ChatCompletionsStreamResponseChoice - if constant.DifyDebug && difyResponse.Event == "workflow_started" { - choice.Delta.SetContentString("Workflow: " + difyResponse.Data.WorkflowId + "\n") - } else if constant.DifyDebug && difyResponse.Event == "node_started" { - choice.Delta.SetContentString("Node: " + difyResponse.Data.NodeId + "\n") + if strings.HasPrefix(difyResponse.Event, "workflow_") { + if constant.DifyDebug { + text := "Workflow: " + difyResponse.Data.WorkflowId + if difyResponse.Event == "workflow_finished" { + text += " " + difyResponse.Data.Status + } + choice.Delta.SetReasoningContent(text + "\n") + } + } else if strings.HasPrefix(difyResponse.Event, "node_") { + if constant.DifyDebug { + text := "Node: " + difyResponse.Data.NodeType + if difyResponse.Event == "node_finished" { + text += " " + difyResponse.Data.Status + } + choice.Delta.SetReasoningContent(text + "\n") + } } else if difyResponse.Event == "message" || difyResponse.Event == "agent_message" { choice.Delta.SetContentString(difyResponse.Answer) } @@ -66,38 +209,38 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re usage := &dto.Usage{} scanner := bufio.NewScanner(resp.Body) scanner.Split(bufio.ScanLines) + var nodeToken int helper.SetEventStreamHeaders(c) - for scanner.Scan() { - data := scanner.Text() - if len(data) < 5 || !strings.HasPrefix(data, "data:") { - continue - } - data = strings.TrimPrefix(data, "data:") + helper.StreamScannerHandler(c, resp, info, func(data string) bool { var difyResponse DifyChunkChatCompletionResponse err := json.Unmarshal([]byte(data), &difyResponse) if err != nil { common.SysError("error unmarshalling stream response: " + err.Error()) - continue + return true } var openaiResponse dto.ChatCompletionsStreamResponse if difyResponse.Event == "message_end" { usage = &difyResponse.MetaData.Usage - break + return false } else if difyResponse.Event == "error" { - break + return false } else { openaiResponse = *streamResponseDify2OpenAI(difyResponse) if len(openaiResponse.Choices) != 0 { responseText += openaiResponse.Choices[0].Delta.GetContentString() + if openaiResponse.Choices[0].Delta.ReasoningContent != nil { + nodeToken += 1 + } } } err = helper.ObjectData(c, openaiResponse) if err != nil { common.SysError(err.Error()) } - } + return true + }) if err := scanner.Err(); err != nil { common.SysError("error reading stream: " + err.Error()) } @@ -112,6 +255,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText) usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens } + usage.CompletionTokens += nodeToken return nil, usage } diff --git a/service/token_counter.go b/service/token_counter.go index 98386f85..99723c97 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -86,6 +86,9 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { } func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) { + if imageUrl == nil { + return 0, fmt.Errorf("image_url_is_nil") + } baseTokens := 85 if model == "glm-4v" { return 1047, nil @@ -93,10 +96,10 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m if imageUrl.Detail == "low" { return baseTokens, nil } - // TODO: 非流模式下不计算图片token数量 if !constant.GetMediaTokenNotStream && !stream { - return 256, nil + return 3 * baseTokens, nil } + // 同步One API的图片计费逻辑 if imageUrl.Detail == "auto" || imageUrl.Detail == "" { imageUrl.Detail = "high" @@ -126,18 +129,11 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m if err != nil { return 0, err } + imageUrl.MimeType = format if config.Width == 0 || config.Height == 0 { return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", imageUrl.Url)) } - //// TODO: 适配官方auto计费 - //if config.Width < 512 && config.Height < 512 { - // if imageUrl.Detail == "auto" || imageUrl.Detail == "" { - // // 如果图片尺寸小于512,强制使用low - // imageUrl.Detail = "low" - // return 85, nil - // } - //} shortSide := config.Width otherSide := config.Height @@ -392,8 +388,8 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod arrayContent := message.ParseContent() for _, m := range arrayContent { if m.Type == dto.ContentTypeImageURL { - imageUrl := m.ImageUrl.(dto.MessageImageUrl) - imageTokenNum, err := getImageToken(info, &imageUrl, model, stream) + imageUrl := m.GetImageMedia() + imageTokenNum, err := getImageToken(info, imageUrl, model, stream) if err != nil { return 0, err }