feat: support dify upload image file

This commit is contained in:
1808837298@qq.com
2025-03-15 19:10:12 +08:00
parent 19bfa158cc
commit dd393cd0d9
7 changed files with 239 additions and 72 deletions

9
controller/image.go Normal file
View File

@@ -0,0 +1,9 @@
package controller
import (
"github.com/gin-gonic/gin"
)
func GetImage(c *gin.Context) {
}

View File

@@ -113,9 +113,21 @@ type MediaContent struct {
InputAudio any `json:"input_audio,omitempty"` InputAudio any `json:"input_audio,omitempty"`
} }
func (m *MediaContent) GetImageMedia() *MessageImageUrl {
if m.ImageUrl != nil {
return m.ImageUrl.(*MessageImageUrl)
}
return nil
}
type MessageImageUrl struct { type MessageImageUrl struct {
Url string `json:"url"` Url string `json:"url"`
Detail string `json:"detail"` Detail string `json:"detail"`
MimeType string
}
func (m *MessageImageUrl) IsRemoteImage() bool {
return strings.HasPrefix(m.Url, "http")
} }
type MessageInputAudio struct { type MessageInputAudio struct {
@@ -244,43 +256,39 @@ func (m *Message) ParseContent() []MediaContent {
case ContentTypeImageURL: case ContentTypeImageURL:
imageUrl := contentItem["image_url"] imageUrl := contentItem["image_url"]
temp := &MessageImageUrl{
Detail: "high",
}
switch v := imageUrl.(type) { switch v := imageUrl.(type) {
case string: case string:
contentList = append(contentList, MediaContent{ temp.Url = v
Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{
Url: v,
Detail: "high",
},
})
case map[string]interface{}: case map[string]interface{}:
url, ok1 := v["url"].(string) url, ok1 := v["url"].(string)
detail, ok2 := v["detail"].(string) detail, ok2 := v["detail"].(string)
if !ok2 { if ok2 {
detail = "high" temp.Detail = detail
} }
if ok1 { if ok1 {
temp.Url = url
}
}
contentList = append(contentList, MediaContent{ contentList = append(contentList, MediaContent{
Type: ContentTypeImageURL, Type: ContentTypeImageURL,
ImageUrl: MessageImageUrl{ ImageUrl: temp,
Url: url,
Detail: detail,
},
}) })
}
}
case ContentTypeInputAudio: case ContentTypeInputAudio:
if audioData, ok := contentItem["input_audio"].(map[string]interface{}); ok { if audioData, ok := contentItem["input_audio"].(map[string]interface{}); ok {
data, ok1 := audioData["data"].(string) data, ok1 := audioData["data"].(string)
format, ok2 := audioData["format"].(string) format, ok2 := audioData["format"].(string)
if ok1 && ok2 { if ok1 && ok2 {
contentList = append(contentList, MediaContent{ temp := &MessageInputAudio{
Type: ContentTypeInputAudio,
InputAudio: MessageInputAudio{
Data: data, Data: data,
Format: format, Format: format,
}, }
contentList = append(contentList, MediaContent{
Type: ContentTypeInputAudio,
InputAudio: temp,
}) })
} }
} }

View File

@@ -414,7 +414,7 @@ func ResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto
// 加密的不管, 只输出明文的推理过程 // 加密的不管, 只输出明文的推理过程
thinkingContent = message.Thinking thinkingContent = message.Thinking
case "text": case "text":
responseText = *message.Text responseText = message.GetText()
} }
} }
} }

View File

@@ -74,7 +74,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
if request == nil { if request == nil {
return nil, errors.New("request is nil") return nil, errors.New("request is nil")
} }
return requestOpenAI2Dify(*request), nil return requestOpenAI2Dify(c, info, *request), nil
} }
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {

View File

@@ -8,6 +8,14 @@ type DifyChatRequest struct {
ResponseMode string `json:"response_mode"` ResponseMode string `json:"response_mode"`
User string `json:"user"` User string `json:"user"`
AutoGenerateName bool `json:"auto_generate_name"` AutoGenerateName bool `json:"auto_generate_name"`
Files []DifyFile `json:"files"`
}
type DifyFile struct {
Type string `json:"type"`
TransferMode string `json:"transfer_mode"`
URL string `json:"url,omitempty"`
UploadFileId string `json:"upload_file_id,omitempty"`
} }
type DifyMetaData struct { type DifyMetaData struct {
@@ -17,6 +25,8 @@ type DifyMetaData struct {
type DifyData struct { type DifyData struct {
WorkflowId string `json:"workflow_id"` WorkflowId string `json:"workflow_id"`
NodeId string `json:"node_id"` NodeId string `json:"node_id"`
NodeType string `json:"node_type"`
Status string `json:"status"`
} }
type DifyChatCompletionResponse struct { type DifyChatCompletionResponse struct {

View File

@@ -2,9 +2,12 @@ package dify
import ( import (
"bufio" "bufio"
"bytes"
"encoding/base64"
"encoding/json" "encoding/json"
"github.com/gin-gonic/gin" "fmt"
"io" "io"
"mime/multipart"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/constant" "one-api/constant"
@@ -12,35 +15,163 @@ import (
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/relay/helper" "one-api/relay/helper"
"one-api/service" "one-api/service"
"os"
"strings" "strings"
"github.com/gin-gonic/gin"
) )
func requestOpenAI2Dify(request dto.GeneralOpenAIRequest) *DifyChatRequest { func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, media dto.MediaContent) *DifyFile {
content := "" uploadUrl := fmt.Sprintf("%s/v1/files/upload", info.BaseUrl)
switch media.Type {
case dto.ContentTypeImageURL:
// Decode base64 data
imageMedia := media.GetImageMedia()
base64Data := imageMedia.Url
// Remove base64 prefix if exists (e.g., "data:image/jpeg;base64,")
if idx := strings.Index(base64Data, ","); idx != -1 {
base64Data = base64Data[idx+1:]
}
// Decode base64 string
decodedData, err := base64.StdEncoding.DecodeString(base64Data)
if err != nil {
common.SysError("failed to decode base64: " + err.Error())
return nil
}
// Create temporary file
tempFile, err := os.CreateTemp("", "dify-upload-*")
if err != nil {
common.SysError("failed to create temp file: " + err.Error())
return nil
}
defer tempFile.Close()
defer os.Remove(tempFile.Name())
// Write decoded data to temp file
if _, err := tempFile.Write(decodedData); err != nil {
common.SysError("failed to write to temp file: " + err.Error())
return nil
}
// Create multipart form
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
// Add user field
if err := writer.WriteField("user", user); err != nil {
common.SysError("failed to add user field: " + err.Error())
return nil
}
// Create form file with proper mime type
mimeType := imageMedia.MimeType
if mimeType == "" {
mimeType = "image/jpeg" // default mime type
}
// Create form file
part, err := writer.CreateFormFile("file", fmt.Sprintf("image.%s", strings.TrimPrefix(mimeType, "image/")))
if err != nil {
common.SysError("failed to create form file: " + err.Error())
return nil
}
// Copy file content to form
if _, err = io.Copy(part, bytes.NewReader(decodedData)); err != nil {
common.SysError("failed to copy file content: " + err.Error())
return nil
}
writer.Close()
// Create HTTP request
req, err := http.NewRequest("POST", uploadUrl, body)
if err != nil {
common.SysError("failed to create request: " + err.Error())
return nil
}
req.Header.Set("Content-Type", writer.FormDataContentType())
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
// Send request
client := service.GetImpatientHttpClient()
resp, err := client.Do(req)
if err != nil {
common.SysError("failed to send request: " + err.Error())
return nil
}
defer resp.Body.Close()
// Parse response
var result struct {
Id string `json:"id"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
common.SysError("failed to decode response: " + err.Error())
return nil
}
return &DifyFile{
UploadFileId: result.Id,
Type: "image",
TransferMode: "local_file",
}
}
return nil
}
func requestOpenAI2Dify(c *gin.Context, info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) *DifyChatRequest {
difyReq := DifyChatRequest{
Inputs: make(map[string]interface{}),
AutoGenerateName: false,
}
user := request.User
if user == "" {
user = helper.GetResponseID(c)
}
difyReq.User = user
files := make([]DifyFile, 0)
var content strings.Builder
for _, message := range request.Messages { for _, message := range request.Messages {
if message.Role == "system" { if message.Role == "system" {
content += "SYSTEM: \n" + message.StringContent() + "\n" content.WriteString("SYSTEM: \n" + message.StringContent() + "\n")
} else if message.Role == "assistant" { } else if message.Role == "assistant" {
content += "ASSISTANT: \n" + message.StringContent() + "\n" content.WriteString("ASSISTANT: \n" + message.StringContent() + "\n")
} else { } else {
content += "USER: \n" + message.StringContent() + "\n" parseContent := message.ParseContent()
for _, mediaContent := range parseContent {
switch mediaContent.Type {
case dto.ContentTypeText:
content.WriteString("USER: \n" + mediaContent.Text + "\n")
case dto.ContentTypeImageURL:
media := mediaContent.GetImageMedia()
var file *DifyFile
if media.IsRemoteImage() {
file.Type = media.MimeType
file.TransferMode = "remote_url"
file.URL = media.Url
} else {
file = uploadDifyFile(c, info, difyReq.User, mediaContent)
}
if file != nil {
files = append(files, *file)
} }
} }
}
}
}
difyReq.Query = content.String()
difyReq.Files = files
mode := "blocking" mode := "blocking"
if request.Stream { if request.Stream {
mode = "streaming" mode = "streaming"
} }
user := request.User difyReq.ResponseMode = mode
if user == "" { return &difyReq
user = "api-user"
}
return &DifyChatRequest{
Inputs: make(map[string]interface{}),
Query: content,
ResponseMode: mode,
User: user,
AutoGenerateName: false,
}
} }
func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dto.ChatCompletionsStreamResponse { func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dto.ChatCompletionsStreamResponse {
@@ -50,10 +181,22 @@ func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dt
Model: "dify", Model: "dify",
} }
var choice dto.ChatCompletionsStreamResponseChoice var choice dto.ChatCompletionsStreamResponseChoice
if constant.DifyDebug && difyResponse.Event == "workflow_started" { if strings.HasPrefix(difyResponse.Event, "workflow_") {
choice.Delta.SetContentString("Workflow: " + difyResponse.Data.WorkflowId + "\n") if constant.DifyDebug {
} else if constant.DifyDebug && difyResponse.Event == "node_started" { text := "Workflow: " + difyResponse.Data.WorkflowId
choice.Delta.SetContentString("Node: " + difyResponse.Data.NodeId + "\n") if difyResponse.Event == "workflow_finished" {
text += " " + difyResponse.Data.Status
}
choice.Delta.SetReasoningContent(text + "\n")
}
} else if strings.HasPrefix(difyResponse.Event, "node_") {
if constant.DifyDebug {
text := "Node: " + difyResponse.Data.NodeType
if difyResponse.Event == "node_finished" {
text += " " + difyResponse.Data.Status
}
choice.Delta.SetReasoningContent(text + "\n")
}
} else if difyResponse.Event == "message" || difyResponse.Event == "agent_message" { } else if difyResponse.Event == "message" || difyResponse.Event == "agent_message" {
choice.Delta.SetContentString(difyResponse.Answer) choice.Delta.SetContentString(difyResponse.Answer)
} }
@@ -66,38 +209,38 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
usage := &dto.Usage{} usage := &dto.Usage{}
scanner := bufio.NewScanner(resp.Body) scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines) scanner.Split(bufio.ScanLines)
var nodeToken int
helper.SetEventStreamHeaders(c) helper.SetEventStreamHeaders(c)
for scanner.Scan() { helper.StreamScannerHandler(c, resp, info, func(data string) bool {
data := scanner.Text()
if len(data) < 5 || !strings.HasPrefix(data, "data:") {
continue
}
data = strings.TrimPrefix(data, "data:")
var difyResponse DifyChunkChatCompletionResponse var difyResponse DifyChunkChatCompletionResponse
err := json.Unmarshal([]byte(data), &difyResponse) err := json.Unmarshal([]byte(data), &difyResponse)
if err != nil { if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error()) common.SysError("error unmarshalling stream response: " + err.Error())
continue return true
} }
var openaiResponse dto.ChatCompletionsStreamResponse var openaiResponse dto.ChatCompletionsStreamResponse
if difyResponse.Event == "message_end" { if difyResponse.Event == "message_end" {
usage = &difyResponse.MetaData.Usage usage = &difyResponse.MetaData.Usage
break return false
} else if difyResponse.Event == "error" { } else if difyResponse.Event == "error" {
break return false
} else { } else {
openaiResponse = *streamResponseDify2OpenAI(difyResponse) openaiResponse = *streamResponseDify2OpenAI(difyResponse)
if len(openaiResponse.Choices) != 0 { if len(openaiResponse.Choices) != 0 {
responseText += openaiResponse.Choices[0].Delta.GetContentString() responseText += openaiResponse.Choices[0].Delta.GetContentString()
if openaiResponse.Choices[0].Delta.ReasoningContent != nil {
nodeToken += 1
}
} }
} }
err = helper.ObjectData(c, openaiResponse) err = helper.ObjectData(c, openaiResponse)
if err != nil { if err != nil {
common.SysError(err.Error()) common.SysError(err.Error())
} }
} return true
})
if err := scanner.Err(); err != nil { if err := scanner.Err(); err != nil {
common.SysError("error reading stream: " + err.Error()) common.SysError("error reading stream: " + err.Error())
} }
@@ -112,6 +255,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText) usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
} }
usage.CompletionTokens += nodeToken
return nil, usage return nil, usage
} }

View File

@@ -86,6 +86,9 @@ func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int {
} }
func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) { func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
if imageUrl == nil {
return 0, fmt.Errorf("image_url_is_nil")
}
baseTokens := 85 baseTokens := 85
if model == "glm-4v" { if model == "glm-4v" {
return 1047, nil return 1047, nil
@@ -93,10 +96,10 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m
if imageUrl.Detail == "low" { if imageUrl.Detail == "low" {
return baseTokens, nil return baseTokens, nil
} }
// TODO: 非流模式下不计算图片token数量
if !constant.GetMediaTokenNotStream && !stream { if !constant.GetMediaTokenNotStream && !stream {
return 256, nil return 3 * baseTokens, nil
} }
// 同步One API的图片计费逻辑 // 同步One API的图片计费逻辑
if imageUrl.Detail == "auto" || imageUrl.Detail == "" { if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
imageUrl.Detail = "high" imageUrl.Detail = "high"
@@ -126,18 +129,11 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m
if err != nil { if err != nil {
return 0, err return 0, err
} }
imageUrl.MimeType = format
if config.Width == 0 || config.Height == 0 { if config.Width == 0 || config.Height == 0 {
return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", imageUrl.Url)) return 0, errors.New(fmt.Sprintf("fail to decode image config: %s", imageUrl.Url))
} }
//// TODO: 适配官方auto计费
//if config.Width < 512 && config.Height < 512 {
// if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
// // 如果图片尺寸小于512强制使用low
// imageUrl.Detail = "low"
// return 85, nil
// }
//}
shortSide := config.Width shortSide := config.Width
otherSide := config.Height otherSide := config.Height
@@ -392,8 +388,8 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod
arrayContent := message.ParseContent() arrayContent := message.ParseContent()
for _, m := range arrayContent { for _, m := range arrayContent {
if m.Type == dto.ContentTypeImageURL { if m.Type == dto.ContentTypeImageURL {
imageUrl := m.ImageUrl.(dto.MessageImageUrl) imageUrl := m.GetImageMedia()
imageTokenNum, err := getImageToken(info, &imageUrl, model, stream) imageTokenNum, err := getImageToken(info, imageUrl, model, stream)
if err != nil { if err != nil {
return 0, err return 0, err
} }