This commit refactors the logging mechanism across the application by replacing direct logger calls with a centralized logging approach using the `common` package. Key changes include: - Replaced instances of `logger.SysLog` and `logger.FatalLog` with `common.SysLog` and `common.FatalLog` for consistent logging practices. - Updated resource initialization error handling to utilize the new logging structure, enhancing maintainability and readability. - Minor adjustments to improve code clarity and organization throughout various modules. This change aims to streamline logging and improve the overall architecture of the codebase.
250 lines
7.7 KiB
Go
250 lines
7.7 KiB
Go
package volcengine
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"mime/multipart"
|
|
"net/http"
|
|
"net/textproto"
|
|
"one-api/dto"
|
|
"one-api/relay/channel"
|
|
"one-api/relay/channel/openai"
|
|
relaycommon "one-api/relay/common"
|
|
"one-api/relay/constant"
|
|
"one-api/types"
|
|
"path/filepath"
|
|
"strings"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
type Adaptor struct {
|
|
}
|
|
|
|
func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) {
|
|
//TODO implement me
|
|
return nil, errors.New("not implemented")
|
|
}
|
|
|
|
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, req *dto.ClaudeRequest) (any, error) {
|
|
adaptor := openai.Adaptor{}
|
|
return adaptor.ConvertClaudeRequest(c, info, req)
|
|
}
|
|
|
|
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
|
//TODO implement me
|
|
return nil, errors.New("not implemented")
|
|
}
|
|
|
|
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
|
switch info.RelayMode {
|
|
case constant.RelayModeImagesEdits:
|
|
|
|
var requestBody bytes.Buffer
|
|
writer := multipart.NewWriter(&requestBody)
|
|
|
|
writer.WriteField("model", request.Model)
|
|
// 获取所有表单字段
|
|
formData := c.Request.PostForm
|
|
// 遍历表单字段并打印输出
|
|
for key, values := range formData {
|
|
if key == "model" {
|
|
continue
|
|
}
|
|
for _, value := range values {
|
|
writer.WriteField(key, value)
|
|
}
|
|
}
|
|
|
|
// Parse the multipart form to handle both single image and multiple images
|
|
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory
|
|
return nil, errors.New("failed to parse multipart form")
|
|
}
|
|
|
|
if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
|
|
// Check if "image" field exists in any form, including array notation
|
|
var imageFiles []*multipart.FileHeader
|
|
var exists bool
|
|
|
|
// First check for standard "image" field
|
|
if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
|
|
// If not found, check for "image[]" field
|
|
if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
|
|
// If still not found, iterate through all fields to find any that start with "image["
|
|
foundArrayImages := false
|
|
for fieldName, files := range c.Request.MultipartForm.File {
|
|
if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
|
|
foundArrayImages = true
|
|
for _, file := range files {
|
|
imageFiles = append(imageFiles, file)
|
|
}
|
|
}
|
|
}
|
|
|
|
// If no image fields found at all
|
|
if !foundArrayImages && (len(imageFiles) == 0) {
|
|
return nil, errors.New("image is required")
|
|
}
|
|
}
|
|
}
|
|
|
|
// Process all image files
|
|
for i, fileHeader := range imageFiles {
|
|
file, err := fileHeader.Open()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to open image file %d: %w", i, err)
|
|
}
|
|
defer file.Close()
|
|
|
|
// If multiple images, use image[] as the field name
|
|
fieldName := "image"
|
|
if len(imageFiles) > 1 {
|
|
fieldName = "image[]"
|
|
}
|
|
|
|
// Determine MIME type based on file extension
|
|
mimeType := detectImageMimeType(fileHeader.Filename)
|
|
|
|
// Create a form file with the appropriate content type
|
|
h := make(textproto.MIMEHeader)
|
|
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename))
|
|
h.Set("Content-Type", mimeType)
|
|
|
|
part, err := writer.CreatePart(h)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create form part failed for image %d: %w", i, err)
|
|
}
|
|
|
|
if _, err := io.Copy(part, file); err != nil {
|
|
return nil, fmt.Errorf("copy file failed for image %d: %w", i, err)
|
|
}
|
|
}
|
|
|
|
// Handle mask file if present
|
|
if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
|
|
maskFile, err := maskFiles[0].Open()
|
|
if err != nil {
|
|
return nil, errors.New("failed to open mask file")
|
|
}
|
|
defer maskFile.Close()
|
|
|
|
// Determine MIME type for mask file
|
|
mimeType := detectImageMimeType(maskFiles[0].Filename)
|
|
|
|
// Create a form file with the appropriate content type
|
|
h := make(textproto.MIMEHeader)
|
|
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename))
|
|
h.Set("Content-Type", mimeType)
|
|
|
|
maskPart, err := writer.CreatePart(h)
|
|
if err != nil {
|
|
return nil, errors.New("create form file failed for mask")
|
|
}
|
|
|
|
if _, err := io.Copy(maskPart, maskFile); err != nil {
|
|
return nil, errors.New("copy mask file failed")
|
|
}
|
|
}
|
|
} else {
|
|
return nil, errors.New("no multipart form data found")
|
|
}
|
|
|
|
// 关闭 multipart 编写器以设置分界线
|
|
writer.Close()
|
|
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
|
|
return bytes.NewReader(requestBody.Bytes()), nil
|
|
|
|
default:
|
|
return request, nil
|
|
}
|
|
}
|
|
|
|
// detectImageMimeType determines the MIME type based on the file extension
|
|
func detectImageMimeType(filename string) string {
|
|
ext := strings.ToLower(filepath.Ext(filename))
|
|
switch ext {
|
|
case ".jpg", ".jpeg":
|
|
return "image/jpeg"
|
|
case ".png":
|
|
return "image/png"
|
|
case ".webp":
|
|
return "image/webp"
|
|
default:
|
|
// Try to detect from extension if possible
|
|
if strings.HasPrefix(ext, ".jp") {
|
|
return "image/jpeg"
|
|
}
|
|
// Default to png as a fallback
|
|
return "image/png"
|
|
}
|
|
}
|
|
|
|
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
|
}
|
|
|
|
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
|
switch info.RelayMode {
|
|
case constant.RelayModeChatCompletions:
|
|
if strings.HasPrefix(info.UpstreamModelName, "bot") {
|
|
return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.ChannelBaseUrl), nil
|
|
}
|
|
return fmt.Sprintf("%s/api/v3/chat/completions", info.ChannelBaseUrl), nil
|
|
case constant.RelayModeEmbeddings:
|
|
return fmt.Sprintf("%s/api/v3/embeddings", info.ChannelBaseUrl), nil
|
|
case constant.RelayModeImagesGenerations:
|
|
return fmt.Sprintf("%s/api/v3/images/generations", info.ChannelBaseUrl), nil
|
|
case constant.RelayModeImagesEdits:
|
|
return fmt.Sprintf("%s/api/v3/images/edits", info.ChannelBaseUrl), nil
|
|
case constant.RelayModeRerank:
|
|
return fmt.Sprintf("%s/api/v3/rerank", info.ChannelBaseUrl), nil
|
|
default:
|
|
}
|
|
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
|
|
}
|
|
|
|
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
|
channel.SetupApiRequestHeader(info, c, req)
|
|
req.Set("Authorization", "Bearer "+info.ApiKey)
|
|
return nil
|
|
}
|
|
|
|
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
|
|
if request == nil {
|
|
return nil, errors.New("request is nil")
|
|
}
|
|
return request, nil
|
|
}
|
|
|
|
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
|
return nil, nil
|
|
}
|
|
|
|
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
|
return request, nil
|
|
}
|
|
|
|
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
|
// TODO implement me
|
|
return nil, errors.New("not implemented")
|
|
}
|
|
|
|
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
|
return channel.DoApiRequest(a, c, info, requestBody)
|
|
}
|
|
|
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
|
|
adaptor := openai.Adaptor{}
|
|
usage, err = adaptor.DoResponse(c, resp, info)
|
|
return
|
|
}
|
|
|
|
func (a *Adaptor) GetModelList() []string {
|
|
return ModelList
|
|
}
|
|
|
|
func (a *Adaptor) GetChannelName() string {
|
|
return ChannelName
|
|
}
|