添加完整项目文件

包含Go API项目的所有源代码、配置文件、Docker配置、文档和前端资源

🤖 Generated with [Claude Code](https://claude.ai/code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
nosqli
2025-08-21 23:33:10 +08:00
parent 69af9723d9
commit fb44c8bf03
513 changed files with 90387 additions and 0 deletions

134
relay/audio_handler.go Normal file
View File

@@ -0,0 +1,134 @@
package relay
import (
"errors"
"fmt"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"one-api/setting"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) {
audioRequest := &dto.AudioRequest{}
err := common.UnmarshalBodyReusable(c, audioRequest)
if err != nil {
return nil, err
}
switch info.RelayMode {
case relayconstant.RelayModeAudioSpeech:
if audioRequest.Model == "" {
return nil, errors.New("model is required")
}
if setting.ShouldCheckPromptSensitive() {
words, err := service.CheckSensitiveInput(audioRequest.Input)
if err != nil {
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(words, ",")))
return nil, err
}
}
default:
err = c.Request.ParseForm()
if err != nil {
return nil, err
}
formData := c.Request.PostForm
if audioRequest.Model == "" {
audioRequest.Model = formData.Get("model")
}
if audioRequest.Model == "" {
return nil, errors.New("model is required")
}
audioRequest.ResponseFormat = formData.Get("response_format")
if audioRequest.ResponseFormat == "" {
audioRequest.ResponseFormat = "json"
}
}
return audioRequest, nil
}
func AudioHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
relayInfo := relaycommon.GenRelayInfoOpenAIAudio(c)
audioRequest, err := getAndValidAudioRequest(c, relayInfo)
if err != nil {
common.LogError(c, fmt.Sprintf("getAndValidAudioRequest failed: %s", err.Error()))
return types.NewError(err, types.ErrorCodeInvalidRequest)
}
promptTokens := 0
preConsumedTokens := common.PreConsumedQuota
if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
promptTokens = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
preConsumedTokens = promptTokens
relayInfo.PromptTokens = promptTokens
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, preConsumedTokens, 0)
if err != nil {
return types.NewError(err, types.ErrorCodeModelPriceError)
}
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
if openaiErr != nil {
return openaiErr
}
defer func() {
if openaiErr != nil {
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
}
}()
err = helper.ModelMappedHelper(c, relayInfo, audioRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelModelMappedError)
}
adaptor := GetAdaptor(relayInfo.ApiType)
if adaptor == nil {
return types.NewError(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), types.ErrorCodeInvalidApiType)
}
adaptor.Init(relayInfo)
ioReader, err := adaptor.ConvertAudioRequest(c, relayInfo, *audioRequest)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
}
resp, err := adaptor.DoRequest(c, relayInfo, ioReader)
if err != nil {
return types.NewError(err, types.ErrorCodeDoRequestFailed)
}
statusCodeMappingStr := c.GetString("status_code_mapping")
var httpResp *http.Response
if resp != nil {
httpResp = resp.(*http.Response)
if httpResp.StatusCode != http.StatusOK {
newAPIError = service.RelayErrorHandler(httpResp, false)
// reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
}
usage, newAPIError := adaptor.DoResponse(c, httpResp, relayInfo)
if newAPIError != nil {
// reset status code 重置状态码
service.ResetStatusCode(newAPIError, statusCodeMappingStr)
return newAPIError
}
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
return nil
}

50
relay/channel/adapter.go Normal file
View File

@@ -0,0 +1,50 @@
package channel
import (
"io"
"net/http"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/types"
"github.com/gin-gonic/gin"
)
type Adaptor interface {
// Init IsStream bool
Init(info *relaycommon.RelayInfo)
GetRequestURL(info *relaycommon.RelayInfo) (string, error)
SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error
ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error)
ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error)
ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error)
ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error)
ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error)
ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error)
DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error)
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError)
GetModelList() []string
GetChannelName() string
ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error)
}
type TaskAdaptor interface {
Init(info *relaycommon.TaskRelayInfo)
ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) *dto.TaskError
BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error)
BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error
BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error)
DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error)
DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, err *dto.TaskError)
GetModelList() []string
GetChannelName() string
// FetchTask
FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error)
ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error)
}

View File

@@ -0,0 +1,14 @@
package ai360
var ModelList = []string{
"360gpt-turbo",
"360gpt-turbo-responsibility-8k",
"360gpt-pro",
"360gpt2-pro",
"360GPT_S2_V9",
"embedding-bert-512-v1",
"embedding_s1_v1",
"semantic_similarity_s1_v1",
}
var ChannelName = "ai360"

View File

@@ -0,0 +1,127 @@
package ali
import (
"errors"
"fmt"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/types"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
var fullRequestURL string
switch info.RelayMode {
case constant.RelayModeEmbeddings:
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/embeddings", info.BaseUrl)
case constant.RelayModeRerank:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/rerank/text-rerank/text-rerank", info.BaseUrl)
case constant.RelayModeImagesGenerations:
fullRequestURL = fmt.Sprintf("%s/api/v1/services/aigc/text2image/image-synthesis", info.BaseUrl)
case constant.RelayModeCompletions:
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/completions", info.BaseUrl)
default:
fullRequestURL = fmt.Sprintf("%s/compatible-mode/v1/chat/completions", info.BaseUrl)
}
return fullRequestURL, nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", "Bearer "+info.ApiKey)
if info.IsStream {
req.Set("X-DashScope-SSE", "enable")
}
if c.GetString("plugin") != "" {
req.Set("X-DashScope-Plugin", c.GetString("plugin"))
}
return nil
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
// fix: ali parameter.enable_thinking must be set to false for non-streaming calls
if !info.IsStream {
request.EnableThinking = false
}
switch info.RelayMode {
default:
aliReq := requestOpenAI2Ali(*request)
return aliReq, nil
}
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
aliRequest := oaiImage2Ali(request)
return aliRequest, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return ConvertRerankRequest(request), nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return request, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayMode {
case constant.RelayModeImagesGenerations:
err, usage = aliImageHandler(c, resp, info)
case constant.RelayModeEmbeddings:
err, usage = aliEmbeddingHandler(c, resp)
case constant.RelayModeRerank:
err, usage = RerankHandler(c, resp, info)
default:
if info.IsStream {
usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
usage, err = openai.OpenaiHandler(c, info, resp)
}
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,14 @@
package ali
var ModelList = []string{
"qwen-turbo",
"qwen-plus",
"qwen-max",
"qwen-max-longcontext",
"qwq-32b",
"qwen3-235b-a22b",
"text-embedding-v1",
"gte-rerank-v2",
}
var ChannelName = "ali"

126
relay/channel/ali/dto.go Normal file
View File

@@ -0,0 +1,126 @@
package ali
import "one-api/dto"
type AliMessage struct {
Content string `json:"content"`
Role string `json:"role"`
}
type AliInput struct {
Prompt string `json:"prompt,omitempty"`
//History []AliMessage `json:"history,omitempty"`
Messages []AliMessage `json:"messages"`
}
type AliParameters struct {
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"`
}
type AliChatRequest struct {
Model string `json:"model"`
Input AliInput `json:"input,omitempty"`
Parameters AliParameters `json:"parameters,omitempty"`
}
type AliEmbeddingRequest struct {
Model string `json:"model"`
Input struct {
Texts []string `json:"texts"`
} `json:"input"`
Parameters *struct {
TextType string `json:"text_type,omitempty"`
} `json:"parameters,omitempty"`
}
type AliEmbedding struct {
Embedding []float64 `json:"embedding"`
TextIndex int `json:"text_index"`
}
type AliEmbeddingResponse struct {
Output struct {
Embeddings []AliEmbedding `json:"embeddings"`
} `json:"output"`
Usage AliUsage `json:"usage"`
AliError
}
type AliError struct {
Code string `json:"code"`
Message string `json:"message"`
RequestId string `json:"request_id"`
}
type AliUsage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
}
type TaskResult struct {
B64Image string `json:"b64_image,omitempty"`
Url string `json:"url,omitempty"`
Code string `json:"code,omitempty"`
Message string `json:"message,omitempty"`
}
type AliOutput struct {
TaskId string `json:"task_id,omitempty"`
TaskStatus string `json:"task_status,omitempty"`
Text string `json:"text"`
FinishReason string `json:"finish_reason"`
Message string `json:"message,omitempty"`
Code string `json:"code,omitempty"`
Results []TaskResult `json:"results,omitempty"`
}
type AliResponse struct {
Output AliOutput `json:"output"`
Usage AliUsage `json:"usage"`
AliError
}
type AliImageRequest struct {
Model string `json:"model"`
Input struct {
Prompt string `json:"prompt"`
NegativePrompt string `json:"negative_prompt,omitempty"`
} `json:"input"`
Parameters struct {
Size string `json:"size,omitempty"`
N int `json:"n,omitempty"`
Steps string `json:"steps,omitempty"`
Scale string `json:"scale,omitempty"`
} `json:"parameters,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
}
type AliRerankParameters struct {
TopN *int `json:"top_n,omitempty"`
ReturnDocuments *bool `json:"return_documents,omitempty"`
}
type AliRerankInput struct {
Query string `json:"query"`
Documents []any `json:"documents"`
}
type AliRerankRequest struct {
Model string `json:"model"`
Input AliRerankInput `json:"input"`
Parameters AliRerankParameters `json:"parameters,omitempty"`
}
type AliRerankResponse struct {
Output struct {
Results []dto.RerankResponseResult `json:"results"`
} `json:"output"`
Usage AliUsage `json:"usage"`
RequestId string `json:"request_id"`
AliError
}

171
relay/channel/ali/image.go Normal file
View File

@@ -0,0 +1,171 @@
package ali
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/types"
"strings"
"time"
"github.com/gin-gonic/gin"
)
func oaiImage2Ali(request dto.ImageRequest) *AliImageRequest {
var imageRequest AliImageRequest
imageRequest.Input.Prompt = request.Prompt
imageRequest.Model = request.Model
imageRequest.Parameters.Size = strings.Replace(request.Size, "x", "*", -1)
imageRequest.Parameters.N = request.N
imageRequest.ResponseFormat = request.ResponseFormat
return &imageRequest
}
func updateTask(info *relaycommon.RelayInfo, taskID string) (*AliResponse, error, []byte) {
url := fmt.Sprintf("%s/api/v1/tasks/%s", info.BaseUrl, taskID)
var aliResponse AliResponse
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return &aliResponse, err, nil
}
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
common.SysError("updateTask client.Do err: " + err.Error())
return &aliResponse, err, nil
}
defer resp.Body.Close()
responseBody, err := io.ReadAll(resp.Body)
var response AliResponse
err = json.Unmarshal(responseBody, &response)
if err != nil {
common.SysError("updateTask NewDecoder err: " + err.Error())
return &aliResponse, err, nil
}
return &response, nil, responseBody
}
func asyncTaskWait(info *relaycommon.RelayInfo, taskID string) (*AliResponse, []byte, error) {
waitSeconds := 3
step := 0
maxStep := 20
var taskResponse AliResponse
var responseBody []byte
for {
step++
rsp, err, body := updateTask(info, taskID)
responseBody = body
if err != nil {
return &taskResponse, responseBody, err
}
if rsp.Output.TaskStatus == "" {
return &taskResponse, responseBody, nil
}
switch rsp.Output.TaskStatus {
case "FAILED":
fallthrough
case "CANCELED":
fallthrough
case "SUCCEEDED":
fallthrough
case "UNKNOWN":
return rsp, responseBody, nil
}
if step >= maxStep {
break
}
time.Sleep(time.Duration(waitSeconds) * time.Second)
}
return nil, nil, fmt.Errorf("aliAsyncTaskWait timeout")
}
func responseAli2OpenAIImage(c *gin.Context, response *AliResponse, info *relaycommon.RelayInfo, responseFormat string) *dto.ImageResponse {
imageResponse := dto.ImageResponse{
Created: info.StartTime.Unix(),
}
for _, data := range response.Output.Results {
var b64Json string
if responseFormat == "b64_json" {
_, b64, err := service.GetImageFromUrl(data.Url)
if err != nil {
common.LogError(c, "get_image_data_failed: "+err.Error())
continue
}
b64Json = b64
} else {
b64Json = data.B64Image
}
imageResponse.Data = append(imageResponse.Data, dto.ImageData{
Url: data.Url,
B64Json: b64Json,
RevisedPrompt: "",
})
}
return &imageResponse
}
func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
responseFormat := c.GetString("response_format")
var aliTaskResponse AliResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &aliTaskResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
if aliTaskResponse.Message != "" {
common.LogError(c, "ali_async_task_failed: "+aliTaskResponse.Message)
return types.NewError(errors.New(aliTaskResponse.Message), types.ErrorCodeBadResponse), nil
}
aliResponse, _, err := asyncTaskWait(info, aliTaskResponse.Output.TaskId)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponse), nil
}
if aliResponse.Output.TaskStatus != "SUCCEEDED" {
return types.WithOpenAIError(types.OpenAIError{
Message: aliResponse.Output.Message,
Type: "ali_error",
Param: "",
Code: aliResponse.Output.Code,
}, resp.StatusCode), nil
}
fullTextResponse := responseAli2OpenAIImage(c, aliResponse, info, responseFormat)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
c.Writer.Write(jsonResponse)
return nil, &dto.Usage{}
}

View File

@@ -0,0 +1,74 @@
package ali
import (
"encoding/json"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/types"
"github.com/gin-gonic/gin"
)
func ConvertRerankRequest(request dto.RerankRequest) *AliRerankRequest {
returnDocuments := request.ReturnDocuments
if returnDocuments == nil {
t := true
returnDocuments = &t
}
return &AliRerankRequest{
Model: request.Model,
Input: AliRerankInput{
Query: request.Query,
Documents: request.Documents,
},
Parameters: AliRerankParameters{
TopN: &request.TopN,
ReturnDocuments: returnDocuments,
},
}
}
func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
}
common.CloseResponseBodyGracefully(resp)
var aliResponse AliRerankResponse
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
if aliResponse.Code != "" {
return types.WithOpenAIError(types.OpenAIError{
Message: aliResponse.Message,
Type: aliResponse.Code,
Param: aliResponse.RequestId,
Code: aliResponse.Code,
}, resp.StatusCode), nil
}
usage := dto.Usage{
PromptTokens: aliResponse.Usage.TotalTokens,
CompletionTokens: 0,
TotalTokens: aliResponse.Usage.TotalTokens,
}
rerankResponse := dto.RerankResponse{
Results: aliResponse.Output.Results,
Usage: usage,
}
jsonResponse, err := json.Marshal(rerankResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
c.Writer.Write(jsonResponse)
return nil, &usage
}

206
relay/channel/ali/text.go Normal file
View File

@@ -0,0 +1,206 @@
package ali
import (
"bufio"
"encoding/json"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/relay/helper"
"strings"
"one-api/types"
"github.com/gin-gonic/gin"
)
// https://help.aliyun.com/document_detail/613695.html?spm=a2c4g.2399480.0.0.1adb778fAdzP9w#341800c0f8w0r
const EnableSearchModelSuffix = "-internet"
func requestOpenAI2Ali(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
if request.TopP >= 1 {
request.TopP = 0.999
} else if request.TopP <= 0 {
request.TopP = 0.001
}
return &request
}
func embeddingRequestOpenAI2Ali(request dto.EmbeddingRequest) *AliEmbeddingRequest {
return &AliEmbeddingRequest{
Model: request.Model,
Input: struct {
Texts []string `json:"texts"`
}{
Texts: request.ParseInput(),
},
}
}
func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
var fullTextResponse dto.FlexibleEmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&fullTextResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
common.CloseResponseBodyGracefully(resp)
model := c.GetString("model")
if model == "" {
model = "text-embedding-v4"
}
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func embeddingResponseAli2OpenAI(response *AliEmbeddingResponse, model string) *dto.OpenAIEmbeddingResponse {
openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
Object: "list",
Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Output.Embeddings)),
Model: model,
Usage: dto.Usage{TotalTokens: response.Usage.TotalTokens},
}
for _, item := range response.Output.Embeddings {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
Object: `embedding`,
Index: item.TextIndex,
Embedding: item.Embedding,
})
}
return &openAIEmbeddingResponse
}
func responseAli2OpenAI(response *AliResponse) *dto.OpenAITextResponse {
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Role: "assistant",
Content: response.Output.Text,
},
FinishReason: response.Output.FinishReason,
}
fullTextResponse := dto.OpenAITextResponse{
Id: response.RequestId,
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: []dto.OpenAITextResponseChoice{choice},
Usage: dto.Usage{
PromptTokens: response.Usage.InputTokens,
CompletionTokens: response.Usage.OutputTokens,
TotalTokens: response.Usage.InputTokens + response.Usage.OutputTokens,
},
}
return &fullTextResponse
}
func streamResponseAli2OpenAI(aliResponse *AliResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.SetContentString(aliResponse.Output.Text)
if aliResponse.Output.FinishReason != "null" {
finishReason := aliResponse.Output.FinishReason
choice.FinishReason = &finishReason
}
response := dto.ChatCompletionsStreamResponse{
Id: aliResponse.RequestId,
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "ernie-bot",
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func aliStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
var usage dto.Usage
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 { // ignore blank line or wrong format
continue
}
if data[:5] != "data:" {
continue
}
data = data[5:]
dataChan <- data
}
stopChan <- true
}()
helper.SetEventStreamHeaders(c)
lastResponseText := ""
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
var aliResponse AliResponse
err := json.Unmarshal([]byte(data), &aliResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if aliResponse.Usage.OutputTokens != 0 {
usage.PromptTokens = aliResponse.Usage.InputTokens
usage.CompletionTokens = aliResponse.Usage.OutputTokens
usage.TotalTokens = aliResponse.Usage.InputTokens + aliResponse.Usage.OutputTokens
}
response := streamResponseAli2OpenAI(&aliResponse)
response.Choices[0].Delta.SetContentString(strings.TrimPrefix(response.Choices[0].Delta.GetContentString(), lastResponseText))
lastResponseText = aliResponse.Output.Text
jsonResponse, err := json.Marshal(response)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
common.CloseResponseBodyGracefully(resp)
return nil, &usage
}
func aliHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
var aliResponse AliResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
if aliResponse.Code != "" {
return types.WithOpenAIError(types.OpenAIError{
Message: aliResponse.Message,
Type: "ali_error",
Param: aliResponse.RequestId,
Code: aliResponse.Code,
}, resp.StatusCode), nil
}
fullTextResponse := responseAli2OpenAI(&aliResponse)
jsonResponse, err := common.Marshal(fullTextResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}

View File

@@ -0,0 +1,277 @@
package channel
import (
"context"
"errors"
"fmt"
"io"
"net/http"
common2 "one-api/common"
"one-api/relay/common"
"one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"one-api/setting/operation_setting"
"sync"
"time"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
)
func SetupApiRequestHeader(info *common.RelayInfo, c *gin.Context, req *http.Header) {
if info.RelayMode == constant.RelayModeAudioTranscription || info.RelayMode == constant.RelayModeAudioTranslation {
// multipart/form-data
} else if info.RelayMode == constant.RelayModeRealtime {
// websocket
} else {
req.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Set("Accept", c.Request.Header.Get("Accept"))
if info.IsStream && c.Request.Header.Get("Accept") == "" {
req.Set("Accept", "text/event-stream")
}
}
}
func DoApiRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) {
fullRequestURL, err := a.GetRequestURL(info)
if err != nil {
return nil, fmt.Errorf("get request url failed: %w", err)
}
if common2.DebugEnabled {
println("fullRequestURL:", fullRequestURL)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}
err = a.SetupRequestHeader(c, &req.Header, info)
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
resp, err := doRequest(c, req, info)
if err != nil {
return nil, fmt.Errorf("do request failed: %w", err)
}
return resp, nil
}
func DoFormRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*http.Response, error) {
fullRequestURL, err := a.GetRequestURL(info)
if err != nil {
return nil, fmt.Errorf("get request url failed: %w", err)
}
if common2.DebugEnabled {
println("fullRequestURL:", fullRequestURL)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}
// set form data
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
err = a.SetupRequestHeader(c, &req.Header, info)
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
resp, err := doRequest(c, req, info)
if err != nil {
return nil, fmt.Errorf("do request failed: %w", err)
}
return resp, nil
}
func DoWssRequest(a Adaptor, c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (*websocket.Conn, error) {
fullRequestURL, err := a.GetRequestURL(info)
if err != nil {
return nil, fmt.Errorf("get request url failed: %w", err)
}
targetHeader := http.Header{}
err = a.SetupRequestHeader(c, &targetHeader, info)
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
targetHeader.Set("Content-Type", c.Request.Header.Get("Content-Type"))
targetConn, _, err := websocket.DefaultDialer.Dial(fullRequestURL, targetHeader)
if err != nil {
return nil, fmt.Errorf("dial failed to %s: %w", fullRequestURL, err)
}
// send request body
//all, err := io.ReadAll(requestBody)
//err = service.WssString(c, targetConn, string(all))
return targetConn, nil
}
func startPingKeepAlive(c *gin.Context, pingInterval time.Duration) context.CancelFunc {
pingerCtx, stopPinger := context.WithCancel(context.Background())
gopool.Go(func() {
defer func() {
// 增加panic恢复处理
if r := recover(); r != nil {
if common2.DebugEnabled {
println("SSE ping goroutine panic recovered:", fmt.Sprintf("%v", r))
}
}
if common2.DebugEnabled {
println("SSE ping goroutine stopped.")
}
}()
if pingInterval <= 0 {
pingInterval = helper.DefaultPingInterval
}
ticker := time.NewTicker(pingInterval)
// 确保在任何情况下都清理ticker
defer func() {
ticker.Stop()
if common2.DebugEnabled {
println("SSE ping ticker stopped")
}
}()
var pingMutex sync.Mutex
if common2.DebugEnabled {
println("SSE ping goroutine started")
}
// 增加超时控制防止goroutine长时间运行
maxPingDuration := 120 * time.Minute // 最大ping持续时间
pingTimeout := time.NewTimer(maxPingDuration)
defer pingTimeout.Stop()
for {
select {
// 发送 ping 数据
case <-ticker.C:
if err := sendPingData(c, &pingMutex); err != nil {
if common2.DebugEnabled {
println("SSE ping error, stopping goroutine:", err.Error())
}
return
}
// 收到退出信号
case <-pingerCtx.Done():
return
// request 结束
case <-c.Request.Context().Done():
return
// 超时保护防止goroutine无限运行
case <-pingTimeout.C:
if common2.DebugEnabled {
println("SSE ping goroutine timeout, stopping")
}
return
}
}
})
return stopPinger
}
func sendPingData(c *gin.Context, mutex *sync.Mutex) error {
// 增加超时控制,防止锁死等待
done := make(chan error, 1)
go func() {
mutex.Lock()
defer mutex.Unlock()
err := helper.PingData(c)
if err != nil {
common2.LogError(c, "SSE ping error: "+err.Error())
done <- err
return
}
if common2.DebugEnabled {
println("SSE ping data sent.")
}
done <- nil
}()
// 设置发送ping数据的超时时间
select {
case err := <-done:
return err
case <-time.After(10 * time.Second):
return errors.New("SSE ping data send timeout")
case <-c.Request.Context().Done():
return errors.New("request context cancelled during ping")
}
}
func DoRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) {
return doRequest(c, req, info)
}
func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http.Response, error) {
var client *http.Client
var err error
if info.ChannelSetting.Proxy != "" {
client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
if err != nil {
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}
} else {
client = service.GetHttpClient()
}
var stopPinger context.CancelFunc
if info.IsStream {
helper.SetEventStreamHeaders(c)
// 处理流式请求的 ping 保活
generalSettings := operation_setting.GetGeneralSetting()
if generalSettings.PingIntervalEnabled {
pingInterval := time.Duration(generalSettings.PingIntervalSeconds) * time.Second
stopPinger = startPingKeepAlive(c, pingInterval)
// 使用defer确保在任何情况下都能停止ping goroutine
defer func() {
if stopPinger != nil {
stopPinger()
if common2.DebugEnabled {
println("SSE ping goroutine stopped by defer")
}
}
}()
}
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
if resp == nil {
return nil, errors.New("resp is nil")
}
_ = req.Body.Close()
_ = c.Request.Body.Close()
return resp, nil
}
func DoTaskApiRequest(a TaskAdaptor, c *gin.Context, info *common.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
fullRequestURL, err := a.BuildRequestURL(info)
if err != nil {
return nil, err
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}
req.GetBody = func() (io.ReadCloser, error) {
return io.NopCloser(requestBody), nil
}
err = a.BuildRequestHeader(c, req, info)
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
resp, err := doRequest(c, req, info.RelayInfo)
if err != nil {
return nil, fmt.Errorf("do request failed: %w", err)
}
return resp, nil
}

View File

@@ -0,0 +1,107 @@
package aws
import (
"errors"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel/claude"
relaycommon "one-api/relay/common"
"one-api/setting/model_setting"
"one-api/types"
"github.com/gin-gonic/gin"
)
const (
RequestModeCompletion = 1
RequestModeMessage = 2
)
type Adaptor struct {
RequestMode int
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
c.Set("request_model", request.Model)
c.Set("converted_request", request)
return request, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
a.RequestMode = RequestModeMessage
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return "", nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
model_setting.GetClaudeSettings().WriteHeaders(info.OriginModelName, req)
return nil
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
var claudeReq *dto.ClaudeRequest
var err error
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(*request)
if err != nil {
return nil, err
}
c.Set("request_model", claudeReq.Model)
c.Set("converted_request", claudeReq)
return claudeReq, err
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return nil, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
err, usage = awsStreamHandler(c, resp, info, a.RequestMode)
} else {
err, usage = awsHandler(c, info, a.RequestMode)
}
return
}
func (a *Adaptor) GetModelList() (models []string) {
for n := range awsModelIDMap {
models = append(models, n)
}
return
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,65 @@
package aws
var awsModelIDMap = map[string]string{
"claude-instant-1.2": "anthropic.claude-instant-v1",
"claude-2.0": "anthropic.claude-v2",
"claude-2.1": "anthropic.claude-v2:1",
"claude-3-sonnet-20240229": "anthropic.claude-3-sonnet-20240229-v1:0",
"claude-3-opus-20240229": "anthropic.claude-3-opus-20240229-v1:0",
"claude-3-haiku-20240307": "anthropic.claude-3-haiku-20240307-v1:0",
"claude-3-5-sonnet-20240620": "anthropic.claude-3-5-sonnet-20240620-v1:0",
"claude-3-5-sonnet-20241022": "anthropic.claude-3-5-sonnet-20241022-v2:0",
"claude-3-5-haiku-20241022": "anthropic.claude-3-5-haiku-20241022-v1:0",
"claude-3-7-sonnet-20250219": "anthropic.claude-3-7-sonnet-20250219-v1:0",
"claude-sonnet-4-20250514": "anthropic.claude-sonnet-4-20250514-v1:0",
"claude-opus-4-20250514": "anthropic.claude-opus-4-20250514-v1:0",
}
var awsModelCanCrossRegionMap = map[string]map[string]bool{
"anthropic.claude-3-sonnet-20240229-v1:0": {
"us": true,
"eu": true,
"ap": true,
},
"anthropic.claude-3-opus-20240229-v1:0": {
"us": true,
},
"anthropic.claude-3-haiku-20240307-v1:0": {
"us": true,
"eu": true,
"ap": true,
},
"anthropic.claude-3-5-sonnet-20240620-v1:0": {
"us": true,
"eu": true,
"ap": true,
},
"anthropic.claude-3-5-sonnet-20241022-v2:0": {
"us": true,
"ap": true,
},
"anthropic.claude-3-5-haiku-20241022-v1:0": {
"us": true,
},
"anthropic.claude-3-7-sonnet-20250219-v1:0": {
"us": true,
"ap": true,
"eu": true,
},
"anthropic.claude-sonnet-4-20250514-v1:0": {
"us": true,
"ap": true,
"eu": true,
},
"anthropic.claude-opus-4-20250514-v1:0": {
"us": true,
},
}
var awsRegionCrossModelPrefixMap = map[string]string{
"us": "us",
"eu": "eu",
"ap": "apac",
}
var ChannelName = "aws"

36
relay/channel/aws/dto.go Normal file
View File

@@ -0,0 +1,36 @@
package aws
import (
"one-api/dto"
)
type AwsClaudeRequest struct {
// AnthropicVersion should be "bedrock-2023-05-31"
AnthropicVersion string `json:"anthropic_version"`
System any `json:"system,omitempty"`
Messages []dto.ClaudeMessage `json:"messages"`
MaxTokens uint `json:"max_tokens,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Thinking *dto.Thinking `json:"thinking,omitempty"`
}
func copyRequest(req *dto.ClaudeRequest) *AwsClaudeRequest {
return &AwsClaudeRequest{
AnthropicVersion: "bedrock-2023-05-31",
System: req.System,
Messages: req.Messages,
MaxTokens: req.MaxTokens,
Temperature: req.Temperature,
TopP: req.TopP,
TopK: req.TopK,
StopSequences: req.StopSequences,
Tools: req.Tools,
ToolChoice: req.ToolChoice,
Thinking: req.Thinking,
}
}

View File

@@ -0,0 +1,196 @@
package aws
import (
"encoding/json"
"fmt"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/relay/channel/claude"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
bedrockruntimeTypes "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
)
func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
awsSecret := strings.Split(info.ApiKey, "|")
if len(awsSecret) != 3 {
return nil, errors.New("invalid aws secret key")
}
ak := awsSecret[0]
sk := awsSecret[1]
region := awsSecret[2]
client := bedrockruntime.New(bedrockruntime.Options{
Region: region,
Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")),
})
return client, nil
}
func wrapErr(err error) *dto.OpenAIErrorWithStatusCode {
return &dto.OpenAIErrorWithStatusCode{
StatusCode: http.StatusInternalServerError,
Error: dto.OpenAIError{
Message: fmt.Sprintf("%s", err.Error()),
},
}
}
func awsRegionPrefix(awsRegionId string) string {
parts := strings.Split(awsRegionId, "-")
regionPrefix := ""
if len(parts) > 0 {
regionPrefix = parts[0]
}
return regionPrefix
}
func awsModelCanCrossRegion(awsModelId, awsRegionPrefix string) bool {
regionSet, exists := awsModelCanCrossRegionMap[awsModelId]
return exists && regionSet[awsRegionPrefix]
}
func awsModelCrossRegion(awsModelId, awsRegionPrefix string) string {
modelPrefix, find := awsRegionCrossModelPrefixMap[awsRegionPrefix]
if !find {
return awsModelId
}
return modelPrefix + "." + awsModelId
}
func awsModelID(requestModel string) string {
if awsModelID, ok := awsModelIDMap[requestModel]; ok {
return awsModelID
}
return requestModel
}
func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
awsCli, err := newAwsClient(c, info)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
}
awsModelId := awsModelID(c.GetString("request_model"))
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
if canCrossRegion {
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
}
awsReq := &bedrockruntime.InvokeModelInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
claudeReq_, ok := c.Get("converted_request")
if !ok {
return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
}
claudeReq := claudeReq_.(*dto.ClaudeRequest)
awsClaudeReq := copyRequest(claudeReq)
awsReq.Body, err = json.Marshal(awsClaudeReq)
if err != nil {
return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
}
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
if err != nil {
return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil
}
claudeInfo := &claude.ClaudeResponseInfo{
ResponseId: helper.GetResponseID(c),
Created: common.GetTimestamp(),
Model: info.UpstreamModelName,
ResponseText: strings.Builder{},
Usage: &dto.Usage{},
}
handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, awsResp.Body, RequestModeMessage)
if handlerErr != nil {
return handlerErr, nil
}
return nil, claudeInfo.Usage
}
func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
awsCli, err := newAwsClient(c, info)
if err != nil {
return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
}
awsModelId := awsModelID(c.GetString("request_model"))
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
if canCrossRegion {
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
}
awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
ModelId: aws.String(awsModelId),
Accept: aws.String("application/json"),
ContentType: aws.String("application/json"),
}
claudeReq_, ok := c.Get("converted_request")
if !ok {
return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
}
claudeReq := claudeReq_.(*dto.ClaudeRequest)
awsClaudeReq := copyRequest(claudeReq)
awsReq.Body, err = json.Marshal(awsClaudeReq)
if err != nil {
return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
}
awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
if err != nil {
return types.NewError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeChannelAwsClientError), nil
}
stream := awsResp.GetStream()
defer stream.Close()
claudeInfo := &claude.ClaudeResponseInfo{
ResponseId: helper.GetResponseID(c),
Created: common.GetTimestamp(),
Model: info.UpstreamModelName,
ResponseText: strings.Builder{},
Usage: &dto.Usage{},
}
for event := range stream.Events() {
switch v := event.(type) {
case *bedrockruntimeTypes.ResponseStreamMemberChunk:
info.SetFirstResponseTime()
respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage)
if respErr != nil {
return respErr, nil
}
case *bedrockruntimeTypes.UnknownUnionMember:
fmt.Println("unknown tag:", v.Tag)
return types.NewError(errors.New("unknown response type"), types.ErrorCodeInvalidRequest), nil
default:
fmt.Println("union is nil or unknown type")
return types.NewError(errors.New("nil or unknown response type"), types.ErrorCodeInvalidRequest), nil
}
}
claude.HandleStreamFinalResponse(c, info, claudeInfo, RequestModeMessage)
return nil, claudeInfo.Usage
}

View File

@@ -0,0 +1,164 @@
package baidu
import (
"errors"
"fmt"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
suffix := "chat/"
if strings.HasPrefix(info.UpstreamModelName, "Embedding") {
suffix = "embeddings/"
}
if strings.HasPrefix(info.UpstreamModelName, "bge-large") {
suffix = "embeddings/"
}
if strings.HasPrefix(info.UpstreamModelName, "tao-8k") {
suffix = "embeddings/"
}
switch info.UpstreamModelName {
case "ERNIE-4.0":
suffix += "completions_pro"
case "ERNIE-Bot-4":
suffix += "completions_pro"
case "ERNIE-Bot":
suffix += "completions"
case "ERNIE-Bot-turbo":
suffix += "eb-instant"
case "ERNIE-Speed":
suffix += "ernie_speed"
case "ERNIE-4.0-8K":
suffix += "completions_pro"
case "ERNIE-3.5-8K":
suffix += "completions"
case "ERNIE-3.5-8K-0205":
suffix += "ernie-3.5-8k-0205"
case "ERNIE-3.5-8K-1222":
suffix += "ernie-3.5-8k-1222"
case "ERNIE-Bot-8K":
suffix += "ernie_bot_8k"
case "ERNIE-3.5-4K-0205":
suffix += "ernie-3.5-4k-0205"
case "ERNIE-Speed-8K":
suffix += "ernie_speed"
case "ERNIE-Speed-128K":
suffix += "ernie-speed-128k"
case "ERNIE-Lite-8K-0922":
suffix += "eb-instant"
case "ERNIE-Lite-8K-0308":
suffix += "ernie-lite-8k"
case "ERNIE-Tiny-8K":
suffix += "ernie-tiny-8k"
case "BLOOMZ-7B":
suffix += "bloomz_7b1"
case "Embedding-V1":
suffix += "embedding-v1"
case "bge-large-zh":
suffix += "bge_large_zh"
case "bge-large-en":
suffix += "bge_large_en"
case "tao-8k":
suffix += "tao_8k"
default:
suffix += strings.ToLower(info.UpstreamModelName)
}
fullRequestURL := fmt.Sprintf("%s/rpc/2.0/ai_custom/v1/wenxinworkshop/%s", info.BaseUrl, suffix)
var accessToken string
var err error
if accessToken, err = getBaiduAccessToken(info.ApiKey); err != nil {
return "", err
}
fullRequestURL += "?access_token=" + accessToken
return fullRequestURL, nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch info.RelayMode {
default:
baiduRequest := requestOpenAI2Baidu(*request)
return baiduRequest, nil
}
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
baiduEmbeddingRequest := embeddingRequestOpenAI2Baidu(request)
return baiduEmbeddingRequest, nil
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
err, usage = baiduStreamHandler(c, info, resp)
} else {
switch info.RelayMode {
case constant.RelayModeEmbeddings:
err, usage = baiduEmbeddingHandler(c, info, resp)
default:
err, usage = baiduHandler(c, info, resp)
}
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,22 @@
package baidu
var ModelList = []string{
"ERNIE-4.0-8K",
"ERNIE-3.5-8K",
"ERNIE-3.5-8K-0205",
"ERNIE-3.5-8K-1222",
"ERNIE-Bot-8K",
"ERNIE-3.5-4K-0205",
"ERNIE-Speed-8K",
"ERNIE-Speed-128K",
"ERNIE-Lite-8K-0922",
"ERNIE-Lite-8K-0308",
"ERNIE-Tiny-8K",
"BLOOMZ-7B",
"Embedding-V1",
"bge-large-zh",
"bge-large-en",
"tao-8k",
}
var ChannelName = "baidu"

View File

@@ -0,0 +1,78 @@
package baidu
import (
"one-api/dto"
"time"
)
type BaiduMessage struct {
Role string `json:"role"`
Content string `json:"content"`
}
type BaiduChatRequest struct {
Messages []BaiduMessage `json:"messages"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
PenaltyScore float64 `json:"penalty_score,omitempty"`
Stream bool `json:"stream,omitempty"`
System string `json:"system,omitempty"`
DisableSearch bool `json:"disable_search,omitempty"`
EnableCitation bool `json:"enable_citation,omitempty"`
MaxOutputTokens *int `json:"max_output_tokens,omitempty"`
UserId string `json:"user_id,omitempty"`
}
type Error struct {
ErrorCode int `json:"error_code"`
ErrorMsg string `json:"error_msg"`
}
type BaiduChatResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Result string `json:"result"`
IsTruncated bool `json:"is_truncated"`
NeedClearHistory bool `json:"need_clear_history"`
Usage dto.Usage `json:"usage"`
Error
}
type BaiduChatStreamResponse struct {
BaiduChatResponse
SentenceId int `json:"sentence_id"`
IsEnd bool `json:"is_end"`
}
type BaiduEmbeddingRequest struct {
Input []string `json:"input"`
}
type BaiduEmbeddingData struct {
Object string `json:"object"`
Embedding []float64 `json:"embedding"`
Index int `json:"index"`
}
type BaiduEmbeddingResponse struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Data []BaiduEmbeddingData `json:"data"`
Usage dto.Usage `json:"usage"`
Error
}
type BaiduAccessToken struct {
AccessToken string `json:"access_token"`
Error string `json:"error,omitempty"`
ErrorDescription string `json:"error_description,omitempty"`
ExpiresIn int64 `json:"expires_in,omitempty"`
ExpiresAt time.Time `json:"-"`
}
type BaiduTokenResponse struct {
ExpiresIn int `json:"expires_in"`
AccessToken string `json:"access_token"`
}

View File

@@ -0,0 +1,245 @@
package baidu
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
)
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/flfmc9do2
var baiduTokenStore sync.Map
func requestOpenAI2Baidu(request dto.GeneralOpenAIRequest) *BaiduChatRequest {
baiduRequest := BaiduChatRequest{
Temperature: request.Temperature,
TopP: request.TopP,
PenaltyScore: request.FrequencyPenalty,
Stream: request.Stream,
DisableSearch: false,
EnableCitation: false,
UserId: request.User,
}
if request.MaxTokens != 0 {
maxTokens := int(request.MaxTokens)
if request.MaxTokens == 1 {
maxTokens = 2
}
baiduRequest.MaxOutputTokens = &maxTokens
}
for _, message := range request.Messages {
if message.Role == "system" {
baiduRequest.System = message.StringContent()
} else {
baiduRequest.Messages = append(baiduRequest.Messages, BaiduMessage{
Role: message.Role,
Content: message.StringContent(),
})
}
}
return &baiduRequest
}
func responseBaidu2OpenAI(response *BaiduChatResponse) *dto.OpenAITextResponse {
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Role: "assistant",
Content: response.Result,
},
FinishReason: "stop",
}
fullTextResponse := dto.OpenAITextResponse{
Id: response.Id,
Object: "chat.completion",
Created: response.Created,
Choices: []dto.OpenAITextResponseChoice{choice},
Usage: response.Usage,
}
return &fullTextResponse
}
func streamResponseBaidu2OpenAI(baiduResponse *BaiduChatStreamResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.SetContentString(baiduResponse.Result)
if baiduResponse.IsEnd {
choice.FinishReason = &constant.FinishReasonStop
}
response := dto.ChatCompletionsStreamResponse{
Id: baiduResponse.Id,
Object: "chat.completion.chunk",
Created: baiduResponse.Created,
Model: "ernie-bot",
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
}
return &response
}
func embeddingRequestOpenAI2Baidu(request dto.EmbeddingRequest) *BaiduEmbeddingRequest {
return &BaiduEmbeddingRequest{
Input: request.ParseInput(),
}
}
func embeddingResponseBaidu2OpenAI(response *BaiduEmbeddingResponse) *dto.OpenAIEmbeddingResponse {
openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
Object: "list",
Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Data)),
Model: "baidu-embedding",
Usage: response.Usage,
}
for _, item := range response.Data {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
Object: item.Object,
Index: item.Index,
Embedding: item.Embedding,
})
}
return &openAIEmbeddingResponse
}
func baiduStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
usage := &dto.Usage{}
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var baiduResponse BaiduChatStreamResponse
err := common.Unmarshal([]byte(data), &baiduResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
if baiduResponse.Usage.TotalTokens != 0 {
usage.TotalTokens = baiduResponse.Usage.TotalTokens
usage.PromptTokens = baiduResponse.Usage.PromptTokens
usage.CompletionTokens = baiduResponse.Usage.TotalTokens - baiduResponse.Usage.PromptTokens
}
response := streamResponseBaidu2OpenAI(&baiduResponse)
err = helper.ObjectData(c, response)
if err != nil {
common.SysError("error sending stream response: " + err.Error())
}
return true
})
common.CloseResponseBodyGracefully(resp)
return nil, usage
}
func baiduHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
var baiduResponse BaiduChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
if baiduResponse.ErrorMsg != "" {
return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
}
fullTextResponse := responseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func baiduEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
var baiduResponse BaiduEmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
if baiduResponse.ErrorMsg != "" {
return types.NewError(fmt.Errorf(baiduResponse.ErrorMsg), types.ErrorCodeBadResponseBody), nil
}
fullTextResponse := embeddingResponseBaidu2OpenAI(&baiduResponse)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return nil, &fullTextResponse.Usage
}
func getBaiduAccessToken(apiKey string) (string, error) {
if val, ok := baiduTokenStore.Load(apiKey); ok {
var accessToken BaiduAccessToken
if accessToken, ok = val.(BaiduAccessToken); ok {
// soon this will expire
if time.Now().Add(time.Hour).After(accessToken.ExpiresAt) {
go func() {
_, _ = getBaiduAccessTokenHelper(apiKey)
}()
}
return accessToken.AccessToken, nil
}
}
accessToken, err := getBaiduAccessTokenHelper(apiKey)
if err != nil {
return "", err
}
if accessToken == nil {
return "", errors.New("getBaiduAccessToken return a nil token")
}
return (*accessToken).AccessToken, nil
}
func getBaiduAccessTokenHelper(apiKey string) (*BaiduAccessToken, error) {
parts := strings.Split(apiKey, "|")
if len(parts) != 2 {
return nil, errors.New("invalid baidu apikey")
}
req, err := http.NewRequest("POST", fmt.Sprintf("https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id=%s&client_secret=%s",
parts[0], parts[1]), nil)
if err != nil {
return nil, err
}
req.Header.Add("Content-Type", "application/json")
req.Header.Add("Accept", "application/json")
res, err := service.GetHttpClient().Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
var accessToken BaiduAccessToken
err = json.NewDecoder(res.Body).Decode(&accessToken)
if err != nil {
return nil, err
}
if accessToken.Error != "" {
return nil, errors.New(accessToken.Error + ": " + accessToken.ErrorDescription)
}
if accessToken.AccessToken == "" {
return nil, errors.New("getBaiduAccessTokenHelper get empty access token")
}
accessToken.ExpiresAt = time.Now().Add(time.Duration(accessToken.ExpiresIn) * time.Second)
baiduTokenStore.Store(apiKey, accessToken)
return &accessToken, nil
}

View File

@@ -0,0 +1,111 @@
package baidu_v2
import (
"errors"
"fmt"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v2/chat/completions", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
keyParts := strings.Split(info.ApiKey, "|")
if len(keyParts) == 0 || keyParts[0] == "" {
return errors.New("invalid API key: authorization token is required")
}
if len(keyParts) > 1 {
if keyParts[1] != "" {
req.Set("appid", keyParts[1])
}
}
req.Set("Authorization", "Bearer "+keyParts[0])
return nil
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
if strings.HasSuffix(info.UpstreamModelName, "-search") {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-search")
request.Model = info.UpstreamModelName
toMap := request.ToMap()
toMap["web_search"] = map[string]any{
"enable": true,
"enable_citation": true,
"enable_trace": true,
"enable_status": false,
}
return toMap, nil
}
return request, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
usage, err = openai.OpenaiHandler(c, info, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,29 @@
package baidu_v2
var ModelList = []string{
"ernie-4.0-8k-latest",
"ernie-4.0-8k-preview",
"ernie-4.0-8k",
"ernie-4.0-turbo-8k-latest",
"ernie-4.0-turbo-8k-preview",
"ernie-4.0-turbo-8k",
"ernie-4.0-turbo-128k",
"ernie-3.5-8k-preview",
"ernie-3.5-8k",
"ernie-3.5-128k",
"ernie-speed-8k",
"ernie-speed-128k",
"ernie-speed-pro-128k",
"ernie-lite-8k",
"ernie-lite-pro-128k",
"ernie-tiny-8k",
"ernie-char-8k",
"ernie-char-fiction-8k",
"ernie-novel-8k",
"deepseek-v3",
"deepseek-r1",
"deepseek-r1-distill-qwen-32b",
"deepseek-r1-distill-qwen-14b",
}
var ChannelName = "volcengine"

View File

@@ -0,0 +1,113 @@
package claude
import (
"errors"
"fmt"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/setting/model_setting"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
const (
RequestModeCompletion = 1
RequestModeMessage = 2
)
type Adaptor struct {
RequestMode int
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
return request, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
if strings.HasPrefix(info.UpstreamModelName, "claude-2") || strings.HasPrefix(info.UpstreamModelName, "claude-instant") {
a.RequestMode = RequestModeCompletion
} else {
a.RequestMode = RequestModeMessage
}
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if a.RequestMode == RequestModeMessage {
return fmt.Sprintf("%s/v1/messages", info.BaseUrl), nil
} else {
return fmt.Sprintf("%s/v1/complete", info.BaseUrl), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("x-api-key", info.ApiKey)
anthropicVersion := c.Request.Header.Get("anthropic-version")
if anthropicVersion == "" {
anthropicVersion = "2023-06-01"
}
req.Set("anthropic-version", anthropicVersion)
model_setting.GetClaudeSettings().WriteHeaders(info.OriginModelName, req)
return nil
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
if a.RequestMode == RequestModeCompletion {
return RequestOpenAI2ClaudeComplete(*request), nil
} else {
return RequestOpenAI2ClaudeMessage(*request)
}
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
err, usage = ClaudeStreamHandler(c, resp, info, a.RequestMode)
} else {
err, usage = ClaudeHandler(c, resp, a.RequestMode, info)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,22 @@
package claude
var ModelList = []string{
"claude-instant-1.2",
"claude-2",
"claude-2.0",
"claude-2.1",
"claude-3-sonnet-20240229",
"claude-3-opus-20240229",
"claude-3-haiku-20240307",
"claude-3-5-haiku-20241022",
"claude-3-5-sonnet-20240620",
"claude-3-5-sonnet-20241022",
"claude-3-7-sonnet-20250219",
"claude-3-7-sonnet-20250219-thinking",
"claude-sonnet-4-20250514",
"claude-sonnet-4-20250514-thinking",
"claude-opus-4-20250514",
"claude-opus-4-20250514-thinking",
}
var ChannelName = "claude"

View File

@@ -0,0 +1,95 @@
package claude
//
//type ClaudeMetadata struct {
// UserId string `json:"user_id"`
//}
//
//type ClaudeMediaMessage struct {
// Type string `json:"type"`
// Text string `json:"text,omitempty"`
// Source *ClaudeMessageSource `json:"source,omitempty"`
// Usage *ClaudeUsage `json:"usage,omitempty"`
// StopReason *string `json:"stop_reason,omitempty"`
// PartialJson string `json:"partial_json,omitempty"`
// Thinking string `json:"thinking,omitempty"`
// Signature string `json:"signature,omitempty"`
// Delta string `json:"delta,omitempty"`
// // tool_calls
// Id string `json:"id,omitempty"`
// Name string `json:"name,omitempty"`
// Input any `json:"input,omitempty"`
// Content string `json:"content,omitempty"`
// ToolUseId string `json:"tool_use_id,omitempty"`
//}
//
//type ClaudeMessageSource struct {
// Type string `json:"type"`
// MediaType string `json:"media_type"`
// Data string `json:"data"`
//}
//
//type ClaudeMessage struct {
// Role string `json:"role"`
// Content any `json:"content"`
//}
//
//type Tool struct {
// Name string `json:"name"`
// Description string `json:"description,omitempty"`
// InputSchema map[string]interface{} `json:"input_schema"`
//}
//
//type InputSchema struct {
// Type string `json:"type"`
// Properties any `json:"properties,omitempty"`
// Required any `json:"required,omitempty"`
//}
//
//type ClaudeRequest struct {
// Model string `json:"model"`
// Prompt string `json:"prompt,omitempty"`
// System string `json:"system,omitempty"`
// Messages []ClaudeMessage `json:"messages,omitempty"`
// MaxTokens uint `json:"max_tokens,omitempty"`
// MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"`
// StopSequences []string `json:"stop_sequences,omitempty"`
// Temperature *float64 `json:"temperature,omitempty"`
// TopP float64 `json:"top_p,omitempty"`
// TopK int `json:"top_k,omitempty"`
// //ClaudeMetadata `json:"metadata,omitempty"`
// Stream bool `json:"stream,omitempty"`
// Tools any `json:"tools,omitempty"`
// ToolChoice any `json:"tool_choice,omitempty"`
// Thinking *Thinking `json:"thinking,omitempty"`
//}
//
//type Thinking struct {
// Type string `json:"type"`
// BudgetTokens int `json:"budget_tokens"`
//}
//
//type ClaudeError struct {
// Type string `json:"type"`
// Message string `json:"message"`
//}
//
//type ClaudeResponse struct {
// Id string `json:"id"`
// Type string `json:"type"`
// Content []ClaudeMediaMessage `json:"content"`
// Completion string `json:"completion"`
// StopReason string `json:"stop_reason"`
// Model string `json:"model"`
// Error ClaudeError `json:"error"`
// Usage ClaudeUsage `json:"usage"`
// Index int `json:"index"` // stream only
// ContentBlock *ClaudeMediaMessage `json:"content_block"`
// Delta *ClaudeMediaMessage `json:"delta"` // stream only
// Message *ClaudeResponse `json:"message"` // stream only: message_start
//}
//
//type ClaudeUsage struct {
// InputTokens int `json:"input_tokens"`
// OutputTokens int `json:"output_tokens"`
//}

View File

@@ -0,0 +1,813 @@
package claude
import (
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/relay/channel/openrouter"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/setting/model_setting"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
const (
WebSearchMaxUsesLow = 1
WebSearchMaxUsesMedium = 5
WebSearchMaxUsesHigh = 10
)
func stopReasonClaude2OpenAI(reason string) string {
switch reason {
case "stop_sequence":
return "stop"
case "end_turn":
return "stop"
case "max_tokens":
return "max_tokens"
case "tool_use":
return "tool_calls"
default:
return reason
}
}
func RequestOpenAI2ClaudeComplete(textRequest dto.GeneralOpenAIRequest) *dto.ClaudeRequest {
claudeRequest := dto.ClaudeRequest{
Model: textRequest.Model,
Prompt: "",
StopSequences: nil,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
TopK: textRequest.TopK,
Stream: textRequest.Stream,
}
if claudeRequest.MaxTokensToSample == 0 {
claudeRequest.MaxTokensToSample = 4096
}
prompt := ""
for _, message := range textRequest.Messages {
if message.Role == "user" {
prompt += fmt.Sprintf("\n\nHuman: %s", message.StringContent())
} else if message.Role == "assistant" {
prompt += fmt.Sprintf("\n\nAssistant: %s", message.StringContent())
} else if message.Role == "system" {
if prompt == "" {
prompt = message.StringContent()
}
}
}
prompt += "\n\nAssistant:"
claudeRequest.Prompt = prompt
return &claudeRequest
}
func RequestOpenAI2ClaudeMessage(textRequest dto.GeneralOpenAIRequest) (*dto.ClaudeRequest, error) {
claudeTools := make([]any, 0, len(textRequest.Tools))
for _, tool := range textRequest.Tools {
if params, ok := tool.Function.Parameters.(map[string]any); ok {
claudeTool := dto.Tool{
Name: tool.Function.Name,
Description: tool.Function.Description,
}
claudeTool.InputSchema = make(map[string]interface{})
if params["type"] != nil {
claudeTool.InputSchema["type"] = params["type"].(string)
}
claudeTool.InputSchema["properties"] = params["properties"]
claudeTool.InputSchema["required"] = params["required"]
for s, a := range params {
if s == "type" || s == "properties" || s == "required" {
continue
}
claudeTool.InputSchema[s] = a
}
claudeTools = append(claudeTools, &claudeTool)
}
}
// Web search tool
// https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/web-search-tool
if textRequest.WebSearchOptions != nil {
webSearchTool := dto.ClaudeWebSearchTool{
Type: "web_search_20250305",
Name: "web_search",
}
// 处理 user_location
if textRequest.WebSearchOptions.UserLocation != nil {
anthropicUserLocation := &dto.ClaudeWebSearchUserLocation{
Type: "approximate", // 固定为 "approximate"
}
// 解析 UserLocation JSON
var userLocationMap map[string]interface{}
if err := json.Unmarshal(textRequest.WebSearchOptions.UserLocation, &userLocationMap); err == nil {
// 检查是否有 approximate 字段
if approximateData, ok := userLocationMap["approximate"].(map[string]interface{}); ok {
if timezone, ok := approximateData["timezone"].(string); ok && timezone != "" {
anthropicUserLocation.Timezone = timezone
}
if country, ok := approximateData["country"].(string); ok && country != "" {
anthropicUserLocation.Country = country
}
if region, ok := approximateData["region"].(string); ok && region != "" {
anthropicUserLocation.Region = region
}
if city, ok := approximateData["city"].(string); ok && city != "" {
anthropicUserLocation.City = city
}
}
}
webSearchTool.UserLocation = anthropicUserLocation
}
// 处理 search_context_size 转换为 max_uses
if textRequest.WebSearchOptions.SearchContextSize != "" {
switch textRequest.WebSearchOptions.SearchContextSize {
case "low":
webSearchTool.MaxUses = WebSearchMaxUsesLow
case "medium":
webSearchTool.MaxUses = WebSearchMaxUsesMedium
case "high":
webSearchTool.MaxUses = WebSearchMaxUsesHigh
}
}
claudeTools = append(claudeTools, &webSearchTool)
}
claudeRequest := dto.ClaudeRequest{
Model: textRequest.Model,
MaxTokens: textRequest.MaxTokens,
StopSequences: nil,
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
TopK: textRequest.TopK,
Stream: textRequest.Stream,
Tools: claudeTools,
}
// 处理 tool_choice 和 parallel_tool_calls
if textRequest.ToolChoice != nil || textRequest.ParallelTooCalls != nil {
claudeToolChoice := mapToolChoice(textRequest.ToolChoice, textRequest.ParallelTooCalls)
if claudeToolChoice != nil {
claudeRequest.ToolChoice = claudeToolChoice
}
}
if claudeRequest.MaxTokens == 0 {
claudeRequest.MaxTokens = uint(model_setting.GetClaudeSettings().GetDefaultMaxTokens(textRequest.Model))
}
if model_setting.GetClaudeSettings().ThinkingAdapterEnabled &&
strings.HasSuffix(textRequest.Model, "-thinking") {
// 因为BudgetTokens 必须大于1024
if claudeRequest.MaxTokens < 1280 {
claudeRequest.MaxTokens = 1280
}
// BudgetTokens 为 max_tokens 的 80%
claudeRequest.Thinking = &dto.Thinking{
Type: "enabled",
BudgetTokens: common.GetPointer[int](int(float64(claudeRequest.MaxTokens) * model_setting.GetClaudeSettings().ThinkingAdapterBudgetTokensPercentage)),
}
// TODO: 临时处理
// https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations-when-using-extended-thinking
claudeRequest.TopP = 0
claudeRequest.Temperature = common.GetPointer[float64](1.0)
claudeRequest.Model = strings.TrimSuffix(textRequest.Model, "-thinking")
}
if textRequest.ReasoningEffort != "" {
switch textRequest.ReasoningEffort {
case "low":
claudeRequest.Thinking = &dto.Thinking{
Type: "enabled",
BudgetTokens: common.GetPointer[int](1280),
}
case "medium":
claudeRequest.Thinking = &dto.Thinking{
Type: "enabled",
BudgetTokens: common.GetPointer[int](2048),
}
case "high":
claudeRequest.Thinking = &dto.Thinking{
Type: "enabled",
BudgetTokens: common.GetPointer[int](4096),
}
}
}
// 指定了 reasoning 参数,覆盖 budgetTokens
if textRequest.Reasoning != nil {
var reasoning openrouter.RequestReasoning
if err := common.Unmarshal(textRequest.Reasoning, &reasoning); err != nil {
return nil, err
}
budgetTokens := reasoning.MaxTokens
if budgetTokens > 0 {
claudeRequest.Thinking = &dto.Thinking{
Type: "enabled",
BudgetTokens: &budgetTokens,
}
}
}
if textRequest.Stop != nil {
// stop maybe string/array string, convert to array string
switch textRequest.Stop.(type) {
case string:
claudeRequest.StopSequences = []string{textRequest.Stop.(string)}
case []interface{}:
stopSequences := make([]string, 0)
for _, stop := range textRequest.Stop.([]interface{}) {
stopSequences = append(stopSequences, stop.(string))
}
claudeRequest.StopSequences = stopSequences
}
}
formatMessages := make([]dto.Message, 0)
lastMessage := dto.Message{
Role: "tool",
}
for i, message := range textRequest.Messages {
if message.Role == "" {
textRequest.Messages[i].Role = "user"
}
fmtMessage := dto.Message{
Role: message.Role,
Content: message.Content,
}
if message.Role == "tool" {
fmtMessage.ToolCallId = message.ToolCallId
}
if message.Role == "assistant" && message.ToolCalls != nil {
fmtMessage.ToolCalls = message.ToolCalls
}
if lastMessage.Role == message.Role && lastMessage.Role != "tool" {
if lastMessage.IsStringContent() && message.IsStringContent() {
fmtMessage.SetStringContent(strings.Trim(fmt.Sprintf("%s %s", lastMessage.StringContent(), message.StringContent()), "\""))
// delete last message
formatMessages = formatMessages[:len(formatMessages)-1]
}
}
if fmtMessage.Content == nil {
fmtMessage.SetStringContent("...")
}
formatMessages = append(formatMessages, fmtMessage)
lastMessage = fmtMessage
}
claudeMessages := make([]dto.ClaudeMessage, 0)
isFirstMessage := true
for _, message := range formatMessages {
if message.Role == "system" {
if message.IsStringContent() {
claudeRequest.System = message.StringContent()
} else {
contents := message.ParseContent()
content := ""
for _, ctx := range contents {
if ctx.Type == "text" {
content += ctx.Text
}
}
claudeRequest.System = content
}
} else {
if isFirstMessage {
isFirstMessage = false
if message.Role != "user" {
// fix: first message is assistant, add user message
claudeMessage := dto.ClaudeMessage{
Role: "user",
Content: []dto.ClaudeMediaMessage{
{
Type: "text",
Text: common.GetPointer[string]("..."),
},
},
}
claudeMessages = append(claudeMessages, claudeMessage)
}
}
claudeMessage := dto.ClaudeMessage{
Role: message.Role,
}
if message.Role == "tool" {
if len(claudeMessages) > 0 && claudeMessages[len(claudeMessages)-1].Role == "user" {
lastMessage := claudeMessages[len(claudeMessages)-1]
if content, ok := lastMessage.Content.(string); ok {
lastMessage.Content = []dto.ClaudeMediaMessage{
{
Type: "text",
Text: common.GetPointer[string](content),
},
}
}
lastMessage.Content = append(lastMessage.Content.([]dto.ClaudeMediaMessage), dto.ClaudeMediaMessage{
Type: "tool_result",
ToolUseId: message.ToolCallId,
Content: message.Content,
})
claudeMessages[len(claudeMessages)-1] = lastMessage
continue
} else {
claudeMessage.Role = "user"
claudeMessage.Content = []dto.ClaudeMediaMessage{
{
Type: "tool_result",
ToolUseId: message.ToolCallId,
Content: message.Content,
},
}
}
} else if message.IsStringContent() && message.ToolCalls == nil {
claudeMessage.Content = message.StringContent()
} else {
claudeMediaMessages := make([]dto.ClaudeMediaMessage, 0)
for _, mediaMessage := range message.ParseContent() {
claudeMediaMessage := dto.ClaudeMediaMessage{
Type: mediaMessage.Type,
}
if mediaMessage.Type == "text" {
claudeMediaMessage.Text = common.GetPointer[string](mediaMessage.Text)
} else {
imageUrl := mediaMessage.GetImageMedia()
claudeMediaMessage.Type = "image"
claudeMediaMessage.Source = &dto.ClaudeMessageSource{
Type: "base64",
}
// 判断是否是url
if strings.HasPrefix(imageUrl.Url, "http") {
// 是url获取图片的类型和base64编码的数据
fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)
if err != nil {
return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error())
}
claudeMediaMessage.Source.MediaType = fileData.MimeType
claudeMediaMessage.Source.Data = fileData.Base64Data
} else {
_, format, base64String, err := service.DecodeBase64ImageData(imageUrl.Url)
if err != nil {
return nil, err
}
claudeMediaMessage.Source.MediaType = "image/" + format
claudeMediaMessage.Source.Data = base64String
}
}
claudeMediaMessages = append(claudeMediaMessages, claudeMediaMessage)
}
if message.ToolCalls != nil {
for _, toolCall := range message.ParseToolCalls() {
inputObj := make(map[string]any)
if err := json.Unmarshal([]byte(toolCall.Function.Arguments), &inputObj); err != nil {
common.SysError("tool call function arguments is not a map[string]any: " + fmt.Sprintf("%v", toolCall.Function.Arguments))
continue
}
claudeMediaMessages = append(claudeMediaMessages, dto.ClaudeMediaMessage{
Type: "tool_use",
Id: toolCall.ID,
Name: toolCall.Function.Name,
Input: inputObj,
})
}
}
claudeMessage.Content = claudeMediaMessages
}
claudeMessages = append(claudeMessages, claudeMessage)
}
}
claudeRequest.Prompt = ""
claudeRequest.Messages = claudeMessages
return &claudeRequest, nil
}
func StreamResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto.ChatCompletionsStreamResponse {
var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = claudeResponse.Model
response.Choices = make([]dto.ChatCompletionsStreamResponseChoice, 0)
tools := make([]dto.ToolCallResponse, 0)
fcIdx := 0
if claudeResponse.Index != nil {
fcIdx = *claudeResponse.Index - 1
if fcIdx < 0 {
fcIdx = 0
}
}
var choice dto.ChatCompletionsStreamResponseChoice
if reqMode == RequestModeCompletion {
choice.Delta.SetContentString(claudeResponse.Completion)
finishReason := stopReasonClaude2OpenAI(claudeResponse.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
} else {
if claudeResponse.Type == "message_start" {
response.Id = claudeResponse.Message.Id
response.Model = claudeResponse.Message.Model
//claudeUsage = &claudeResponse.Message.Usage
choice.Delta.SetContentString("")
choice.Delta.Role = "assistant"
} else if claudeResponse.Type == "content_block_start" {
if claudeResponse.ContentBlock != nil {
//choice.Delta.SetContentString(claudeResponse.ContentBlock.Text)
if claudeResponse.ContentBlock.Type == "tool_use" {
tools = append(tools, dto.ToolCallResponse{
Index: common.GetPointer(fcIdx),
ID: claudeResponse.ContentBlock.Id,
Type: "function",
Function: dto.FunctionResponse{
Name: claudeResponse.ContentBlock.Name,
Arguments: "",
},
})
}
} else {
return nil
}
} else if claudeResponse.Type == "content_block_delta" {
if claudeResponse.Delta != nil {
choice.Delta.Content = claudeResponse.Delta.Text
switch claudeResponse.Delta.Type {
case "input_json_delta":
tools = append(tools, dto.ToolCallResponse{
Type: "function",
Index: common.GetPointer(fcIdx),
Function: dto.FunctionResponse{
Arguments: *claudeResponse.Delta.PartialJson,
},
})
case "signature_delta":
// 加密的不处理
signatureContent := "\n"
choice.Delta.ReasoningContent = &signatureContent
case "thinking_delta":
thinkingContent := claudeResponse.Delta.Thinking
choice.Delta.ReasoningContent = &thinkingContent
}
}
} else if claudeResponse.Type == "message_delta" {
finishReason := stopReasonClaude2OpenAI(*claudeResponse.Delta.StopReason)
if finishReason != "null" {
choice.FinishReason = &finishReason
}
//claudeUsage = &claudeResponse.Usage
} else if claudeResponse.Type == "message_stop" {
return nil
} else {
return nil
}
}
if len(tools) > 0 {
choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ...
choice.Delta.ToolCalls = tools
}
response.Choices = append(response.Choices, choice)
return &response
}
func ResponseClaude2OpenAI(reqMode int, claudeResponse *dto.ClaudeResponse) *dto.OpenAITextResponse {
choices := make([]dto.OpenAITextResponseChoice, 0)
fullTextResponse := dto.OpenAITextResponse{
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
Object: "chat.completion",
Created: common.GetTimestamp(),
}
var responseText string
var responseThinking string
if len(claudeResponse.Content) > 0 {
responseText = claudeResponse.Content[0].GetText()
responseThinking = claudeResponse.Content[0].Thinking
}
tools := make([]dto.ToolCallResponse, 0)
thinkingContent := ""
if reqMode == RequestModeCompletion {
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Role: "assistant",
Content: strings.TrimPrefix(claudeResponse.Completion, " "),
Name: nil,
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
}
choices = append(choices, choice)
} else {
fullTextResponse.Id = claudeResponse.Id
for _, message := range claudeResponse.Content {
switch message.Type {
case "tool_use":
args, _ := json.Marshal(message.Input)
tools = append(tools, dto.ToolCallResponse{
ID: message.Id,
Type: "function", // compatible with other OpenAI derivative applications
Function: dto.FunctionResponse{
Name: message.Name,
Arguments: string(args),
},
})
case "thinking":
// 加密的不管, 只输出明文的推理过程
thinkingContent = message.Thinking
case "text":
responseText = message.GetText()
}
}
}
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Role: "assistant",
},
FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason),
}
choice.SetStringContent(responseText)
if len(responseThinking) > 0 {
choice.ReasoningContent = responseThinking
}
if len(tools) > 0 {
choice.Message.SetToolCalls(tools)
}
choice.Message.ReasoningContent = thinkingContent
fullTextResponse.Model = claudeResponse.Model
choices = append(choices, choice)
fullTextResponse.Choices = choices
return &fullTextResponse
}
type ClaudeResponseInfo struct {
ResponseId string
Created int64
Model string
ResponseText strings.Builder
Usage *dto.Usage
Done bool
}
func FormatClaudeResponseInfo(requestMode int, claudeResponse *dto.ClaudeResponse, oaiResponse *dto.ChatCompletionsStreamResponse, claudeInfo *ClaudeResponseInfo) bool {
if requestMode == RequestModeCompletion {
claudeInfo.ResponseText.WriteString(claudeResponse.Completion)
} else {
if claudeResponse.Type == "message_start" {
claudeInfo.ResponseId = claudeResponse.Message.Id
claudeInfo.Model = claudeResponse.Message.Model
// message_start, 获取usage
claudeInfo.Usage.PromptTokens = claudeResponse.Message.Usage.InputTokens
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Message.Usage.CacheReadInputTokens
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Message.Usage.CacheCreationInputTokens
claudeInfo.Usage.CompletionTokens = claudeResponse.Message.Usage.OutputTokens
} else if claudeResponse.Type == "content_block_delta" {
if claudeResponse.Delta.Text != nil {
claudeInfo.ResponseText.WriteString(*claudeResponse.Delta.Text)
}
if claudeResponse.Delta.Thinking != "" {
claudeInfo.ResponseText.WriteString(claudeResponse.Delta.Thinking)
}
} else if claudeResponse.Type == "message_delta" {
// 最终的usage获取
if claudeResponse.Usage.InputTokens > 0 {
// 不叠加,只取最新的
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
}
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
claudeInfo.Usage.TotalTokens = claudeInfo.Usage.PromptTokens + claudeInfo.Usage.CompletionTokens
// 判断是否完整
claudeInfo.Done = true
} else if claudeResponse.Type == "content_block_start" {
} else {
return false
}
}
if oaiResponse != nil {
oaiResponse.Id = claudeInfo.ResponseId
oaiResponse.Created = claudeInfo.Created
oaiResponse.Model = claudeInfo.Model
}
return true
}
func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data string, requestMode int) *types.NewAPIError {
var claudeResponse dto.ClaudeResponse
err := common.UnmarshalJsonStr(data, &claudeResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return types.NewError(err, types.ErrorCodeBadResponseBody)
}
if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
return types.WithClaudeError(*claudeResponse.Error, http.StatusInternalServerError)
}
if info.RelayFormat == relaycommon.RelayFormatClaude {
FormatClaudeResponseInfo(requestMode, &claudeResponse, nil, claudeInfo)
if requestMode == RequestModeCompletion {
} else {
if claudeResponse.Type == "message_start" {
// message_start, 获取usage
info.UpstreamModelName = claudeResponse.Message.Model
} else if claudeResponse.Type == "content_block_delta" {
} else if claudeResponse.Type == "message_delta" {
}
}
helper.ClaudeChunkData(c, claudeResponse, data)
} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
response := StreamResponseClaude2OpenAI(requestMode, &claudeResponse)
if !FormatClaudeResponseInfo(requestMode, &claudeResponse, response, claudeInfo) {
return nil
}
err = helper.ObjectData(c, response)
if err != nil {
common.LogError(c, "send_stream_response_failed: "+err.Error())
}
}
return nil
}
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
if requestMode == RequestModeCompletion {
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
} else {
if claudeInfo.Usage.PromptTokens == 0 {
//上游出错
}
if claudeInfo.Usage.CompletionTokens == 0 || !claudeInfo.Done {
if common.DebugEnabled {
common.SysError("claude response usage is not complete, maybe upstream error")
}
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
}
}
if info.RelayFormat == relaycommon.RelayFormatClaude {
//
} else if info.RelayFormat == relaycommon.RelayFormatOpenAI {
if info.ShouldIncludeUsage {
response := helper.GenerateFinalUsageResponse(claudeInfo.ResponseId, claudeInfo.Created, info.UpstreamModelName, *claudeInfo.Usage)
err := helper.ObjectData(c, response)
if err != nil {
common.SysError("send final response failed: " + err.Error())
}
}
helper.Done(c)
}
}
func ClaudeStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
claudeInfo := &ClaudeResponseInfo{
ResponseId: helper.GetResponseID(c),
Created: common.GetTimestamp(),
Model: info.UpstreamModelName,
ResponseText: strings.Builder{},
Usage: &dto.Usage{},
}
var err *types.NewAPIError
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
err = HandleStreamResponseData(c, info, claudeInfo, data, requestMode)
if err != nil {
return false
}
return true
})
if err != nil {
return err, nil
}
HandleStreamFinalResponse(c, info, claudeInfo, requestMode)
return nil, claudeInfo.Usage
}
func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, data []byte, requestMode int) *types.NewAPIError {
var claudeResponse dto.ClaudeResponse
err := common.Unmarshal(data, &claudeResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody)
}
if claudeResponse.Error != nil && claudeResponse.Error.Type != "" {
return types.WithClaudeError(*claudeResponse.Error, http.StatusInternalServerError)
}
if requestMode == RequestModeCompletion {
completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
claudeInfo.Usage.PromptTokens = info.PromptTokens
claudeInfo.Usage.CompletionTokens = completionTokens
claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens
} else {
claudeInfo.Usage.PromptTokens = claudeResponse.Usage.InputTokens
claudeInfo.Usage.CompletionTokens = claudeResponse.Usage.OutputTokens
claudeInfo.Usage.TotalTokens = claudeResponse.Usage.InputTokens + claudeResponse.Usage.OutputTokens
claudeInfo.Usage.PromptTokensDetails.CachedTokens = claudeResponse.Usage.CacheReadInputTokens
claudeInfo.Usage.PromptTokensDetails.CachedCreationTokens = claudeResponse.Usage.CacheCreationInputTokens
}
var responseData []byte
switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI:
openaiResponse := ResponseClaude2OpenAI(requestMode, &claudeResponse)
openaiResponse.Usage = *claudeInfo.Usage
responseData, err = json.Marshal(openaiResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody)
}
case relaycommon.RelayFormatClaude:
responseData = data
}
if claudeResponse.Usage.ServerToolUse != nil && claudeResponse.Usage.ServerToolUse.WebSearchRequests > 0 {
c.Set("claude_web_search_requests", claudeResponse.Usage.ServerToolUse.WebSearchRequests)
}
common.IOCopyBytesGracefully(c, nil, responseData)
return nil
}
func ClaudeHandler(c *gin.Context, resp *http.Response, requestMode int, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
defer common.CloseResponseBodyGracefully(resp)
claudeInfo := &ClaudeResponseInfo{
ResponseId: helper.GetResponseID(c),
Created: common.GetTimestamp(),
Model: info.UpstreamModelName,
ResponseText: strings.Builder{},
Usage: &dto.Usage{},
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
if common.DebugEnabled {
println("responseBody: ", string(responseBody))
}
handleErr := HandleClaudeResponseData(c, info, claudeInfo, responseBody, requestMode)
if handleErr != nil {
return handleErr, nil
}
return nil, claudeInfo.Usage
}
func mapToolChoice(toolChoice any, parallelToolCalls *bool) *dto.ClaudeToolChoice {
var claudeToolChoice *dto.ClaudeToolChoice
// 处理 tool_choice 字符串值
if toolChoiceStr, ok := toolChoice.(string); ok {
switch toolChoiceStr {
case "auto":
claudeToolChoice = &dto.ClaudeToolChoice{
Type: "auto",
}
case "required":
claudeToolChoice = &dto.ClaudeToolChoice{
Type: "any",
}
case "none":
claudeToolChoice = &dto.ClaudeToolChoice{
Type: "none",
}
}
} else if toolChoiceMap, ok := toolChoice.(map[string]interface{}); ok {
// 处理 tool_choice 对象值
if function, ok := toolChoiceMap["function"].(map[string]interface{}); ok {
if toolName, ok := function["name"].(string); ok {
claudeToolChoice = &dto.ClaudeToolChoice{
Type: "tool",
Name: toolName,
}
}
}
}
// 处理 parallel_tool_calls
if parallelToolCalls != nil {
if claudeToolChoice == nil {
// 如果没有 tool_choice但有 parallel_tool_calls创建默认的 auto 类型
claudeToolChoice = &dto.ClaudeToolChoice{
Type: "auto",
}
}
// 设置 disable_parallel_tool_use
// 如果 parallel_tool_calls 为 true则 disable_parallel_tool_use 为 false
claudeToolChoice.DisableParallelToolUse = !*parallelToolCalls
}
return claudeToolChoice
}

View File

@@ -0,0 +1,122 @@
package cloudflare
import (
"bytes"
"errors"
"fmt"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/types"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
switch info.RelayMode {
case constant.RelayModeChatCompletions:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/chat/completions", info.BaseUrl, info.ApiVersion), nil
case constant.RelayModeEmbeddings:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/v1/embeddings", info.BaseUrl, info.ApiVersion), nil
default:
return fmt.Sprintf("%s/client/v4/accounts/%s/ai/run/%s", info.BaseUrl, info.ApiVersion, info.UpstreamModelName), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
return nil
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch info.RelayMode {
case constant.RelayModeCompletions:
return convertCf2CompletionsRequest(*request), nil
default:
return request, nil
}
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return request, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return request, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
// 添加文件字段
file, _, err := c.Request.FormFile("file")
if err != nil {
return nil, errors.New("file is required")
}
defer file.Close()
// 打开临时文件用于保存上传的文件内容
requestBody := &bytes.Buffer{}
// 将上传的文件内容复制到临时文件
if _, err := io.Copy(requestBody, file); err != nil {
return nil, err
}
return requestBody, nil
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayMode {
case constant.RelayModeEmbeddings:
fallthrough
case constant.RelayModeChatCompletions:
if info.IsStream {
err, usage = cfStreamHandler(c, info, resp)
} else {
err, usage = cfHandler(c, info, resp)
}
case constant.RelayModeAudioTranslation:
fallthrough
case constant.RelayModeAudioTranscription:
err, usage = cfSTTHandler(c, info, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,39 @@
package cloudflare
var ModelList = []string{
"@cf/meta/llama-3.1-8b-instruct",
"@cf/meta/llama-2-7b-chat-fp16",
"@cf/meta/llama-2-7b-chat-int8",
"@cf/mistral/mistral-7b-instruct-v0.1",
"@hf/thebloke/deepseek-coder-6.7b-base-awq",
"@hf/thebloke/deepseek-coder-6.7b-instruct-awq",
"@cf/deepseek-ai/deepseek-math-7b-base",
"@cf/deepseek-ai/deepseek-math-7b-instruct",
"@cf/thebloke/discolm-german-7b-v1-awq",
"@cf/tiiuae/falcon-7b-instruct",
"@cf/google/gemma-2b-it-lora",
"@hf/google/gemma-7b-it",
"@cf/google/gemma-7b-it-lora",
"@hf/nousresearch/hermes-2-pro-mistral-7b",
"@hf/thebloke/llama-2-13b-chat-awq",
"@cf/meta-llama/llama-2-7b-chat-hf-lora",
"@cf/meta/llama-3-8b-instruct",
"@hf/thebloke/llamaguard-7b-awq",
"@hf/thebloke/mistral-7b-instruct-v0.1-awq",
"@hf/mistralai/mistral-7b-instruct-v0.2",
"@cf/mistral/mistral-7b-instruct-v0.2-lora",
"@hf/thebloke/neural-chat-7b-v3-1-awq",
"@cf/openchat/openchat-3.5-0106",
"@hf/thebloke/openhermes-2.5-mistral-7b-awq",
"@cf/microsoft/phi-2",
"@cf/qwen/qwen1.5-0.5b-chat",
"@cf/qwen/qwen1.5-1.8b-chat",
"@cf/qwen/qwen1.5-14b-chat-awq",
"@cf/qwen/qwen1.5-7b-chat-awq",
"@cf/defog/sqlcoder-7b-2",
"@hf/nexusflow/starling-lm-7b-beta",
"@cf/tinyllama/tinyllama-1.1b-chat-v1.0",
"@hf/thebloke/zephyr-7b-beta-awq",
}
var ChannelName = "cloudflare"

View File

@@ -0,0 +1,21 @@
package cloudflare
import "one-api/dto"
type CfRequest struct {
Messages []dto.Message `json:"messages,omitempty"`
Lora string `json:"lora,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Prompt string `json:"prompt,omitempty"`
Raw bool `json:"raw,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
}
type CfAudioResponse struct {
Result CfSTTResult `json:"result"`
}
type CfSTTResult struct {
Text string `json:"text"`
}

View File

@@ -0,0 +1,150 @@
package cloudflare
import (
"bufio"
"encoding/json"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
"strings"
"time"
"github.com/gin-gonic/gin"
)
func convertCf2CompletionsRequest(textRequest dto.GeneralOpenAIRequest) *CfRequest {
p, _ := textRequest.Prompt.(string)
return &CfRequest{
Prompt: p,
MaxTokens: textRequest.GetMaxTokens(),
Stream: textRequest.Stream,
Temperature: textRequest.Temperature,
}
}
func cfStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
helper.SetEventStreamHeaders(c)
id := helper.GetResponseID(c)
var responseText string
isFirst := true
for scanner.Scan() {
data := scanner.Text()
if len(data) < len("data: ") {
continue
}
data = strings.TrimPrefix(data, "data: ")
data = strings.TrimSuffix(data, "\r")
if data == "[DONE]" {
break
}
var response dto.ChatCompletionsStreamResponse
err := json.Unmarshal([]byte(data), &response)
if err != nil {
common.LogError(c, "error_unmarshalling_stream_response: "+err.Error())
continue
}
for _, choice := range response.Choices {
choice.Delta.Role = "assistant"
responseText += choice.Delta.GetContentString()
}
response.Id = id
response.Model = info.UpstreamModelName
err = helper.ObjectData(c, response)
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
}
if err != nil {
common.LogError(c, "error_rendering_stream_response: "+err.Error())
}
}
if err := scanner.Err(); err != nil {
common.LogError(c, "error_scanning_stream_response: "+err.Error())
}
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
if info.ShouldIncludeUsage {
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
err := helper.ObjectData(c, response)
if err != nil {
common.LogError(c, "error_rendering_final_usage_response: "+err.Error())
}
}
helper.Done(c)
common.CloseResponseBodyGracefully(resp)
return nil, usage
}
func cfHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
common.CloseResponseBodyGracefully(resp)
var response dto.TextResponse
err = json.Unmarshal(responseBody, &response)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
response.Model = info.UpstreamModelName
var responseText string
for _, choice := range response.Choices {
responseText += choice.Message.StringContent()
}
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
response.Usage = *usage
response.Id = helper.GetResponseID(c)
jsonResponse, err := json.Marshal(response)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
return nil, usage
}
func cfSTTHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*types.NewAPIError, *dto.Usage) {
var cfResp CfAudioResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &cfResp)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
audioResp := &dto.AudioResponse{
Text: cfResp.Result.Text,
}
jsonResponse, err := json.Marshal(audioResp)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return nil, usage
}

View File

@@ -0,0 +1,94 @@
package cohere
import (
"errors"
"fmt"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/types"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayMode == constant.RelayModeRerank {
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
} else {
return fmt.Sprintf("%s/v1/chat", info.BaseUrl), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
return nil
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return requestOpenAI2Cohere(*request), nil
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return requestConvertRerank2Cohere(request), nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == constant.RelayModeRerank {
usage, err = cohereRerankHandler(c, resp, info)
} else {
if info.IsStream {
usage, err = cohereStreamHandler(c, info, resp) // TODO: fix this
} else {
usage, err = cohereHandler(c, info, resp)
}
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,12 @@
package cohere
var ModelList = []string{
"command-a-03-2025",
"command-r", "command-r-plus",
"command-r-08-2024", "command-r-plus-08-2024",
"c4ai-aya-23-35b", "c4ai-aya-23-8b",
"command-light", "command-light-nightly", "command", "command-nightly",
"rerank-english-v3.0", "rerank-multilingual-v3.0", "rerank-english-v2.0", "rerank-multilingual-v2.0",
}
var ChannelName = "cohere"

View File

@@ -0,0 +1,60 @@
package cohere
import "one-api/dto"
type CohereRequest struct {
Model string `json:"model"`
ChatHistory []ChatHistory `json:"chat_history"`
Message string `json:"message"`
Stream bool `json:"stream"`
MaxTokens int `json:"max_tokens"`
SafetyMode string `json:"safety_mode,omitempty"`
}
type ChatHistory struct {
Role string `json:"role"`
Message string `json:"message"`
}
type CohereResponse struct {
IsFinished bool `json:"is_finished"`
EventType string `json:"event_type"`
Text string `json:"text,omitempty"`
FinishReason string `json:"finish_reason,omitempty"`
Response *CohereResponseResult `json:"response"`
}
type CohereResponseResult struct {
ResponseId string `json:"response_id"`
FinishReason string `json:"finish_reason,omitempty"`
Text string `json:"text"`
Meta CohereMeta `json:"meta"`
}
type CohereRerankRequest struct {
Documents []any `json:"documents"`
Query string `json:"query"`
Model string `json:"model"`
TopN int `json:"top_n"`
ReturnDocuments bool `json:"return_documents"`
}
type CohereRerankResponseResult struct {
Results []dto.RerankResponseResult `json:"results"`
Meta CohereMeta `json:"meta"`
}
type CohereMeta struct {
//Tokens CohereTokens `json:"tokens"`
BilledUnits CohereBilledUnits `json:"billed_units"`
}
type CohereBilledUnits struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
type CohereTokens struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}

View File

@@ -0,0 +1,248 @@
package cohere
import (
"bufio"
"encoding/json"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
"strings"
"time"
"github.com/gin-gonic/gin"
)
func requestOpenAI2Cohere(textRequest dto.GeneralOpenAIRequest) *CohereRequest {
cohereReq := CohereRequest{
Model: textRequest.Model,
ChatHistory: []ChatHistory{},
Message: "",
Stream: textRequest.Stream,
MaxTokens: textRequest.GetMaxTokens(),
}
if common.CohereSafetySetting != "NONE" {
cohereReq.SafetyMode = common.CohereSafetySetting
}
if cohereReq.MaxTokens == 0 {
cohereReq.MaxTokens = 4000
}
for _, msg := range textRequest.Messages {
if msg.Role == "user" {
cohereReq.Message = msg.StringContent()
} else {
var role string
if msg.Role == "assistant" {
role = "CHATBOT"
} else if msg.Role == "system" {
role = "SYSTEM"
} else {
role = "USER"
}
cohereReq.ChatHistory = append(cohereReq.ChatHistory, ChatHistory{
Role: role,
Message: msg.StringContent(),
})
}
}
return &cohereReq
}
func requestConvertRerank2Cohere(rerankRequest dto.RerankRequest) *CohereRerankRequest {
if rerankRequest.TopN == 0 {
rerankRequest.TopN = 1
}
cohereReq := CohereRerankRequest{
Query: rerankRequest.Query,
Documents: rerankRequest.Documents,
Model: rerankRequest.Model,
TopN: rerankRequest.TopN,
ReturnDocuments: true,
}
return &cohereReq
}
func stopReasonCohere2OpenAI(reason string) string {
switch reason {
case "COMPLETE":
return "stop"
case "MAX_TOKENS":
return "max_tokens"
default:
return reason
}
}
func cohereStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseId := helper.GetResponseID(c)
createdTime := common.GetTimestamp()
usage := &dto.Usage{}
responseText := ""
scanner := bufio.NewScanner(resp.Body)
scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) {
if atEOF && len(data) == 0 {
return 0, nil, nil
}
if i := strings.Index(string(data), "\n"); i >= 0 {
return i + 1, data[0:i], nil
}
if atEOF {
return len(data), data, nil
}
return 0, nil, nil
})
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
for scanner.Scan() {
data := scanner.Text()
dataChan <- data
}
stopChan <- true
}()
helper.SetEventStreamHeaders(c)
isFirst := true
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
if isFirst {
isFirst = false
info.FirstResponseTime = time.Now()
}
data = strings.TrimSuffix(data, "\r")
var cohereResp CohereResponse
err := json.Unmarshal([]byte(data), &cohereResp)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
var openaiResp dto.ChatCompletionsStreamResponse
openaiResp.Id = responseId
openaiResp.Created = createdTime
openaiResp.Object = "chat.completion.chunk"
openaiResp.Model = info.UpstreamModelName
if cohereResp.IsFinished {
finishReason := stopReasonCohere2OpenAI(cohereResp.FinishReason)
openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
{
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{},
Index: 0,
FinishReason: &finishReason,
},
}
if cohereResp.Response != nil {
usage.PromptTokens = cohereResp.Response.Meta.BilledUnits.InputTokens
usage.CompletionTokens = cohereResp.Response.Meta.BilledUnits.OutputTokens
}
} else {
openaiResp.Choices = []dto.ChatCompletionsStreamResponseChoice{
{
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
Role: "assistant",
Content: &cohereResp.Text,
},
Index: 0,
},
}
responseText += cohereResp.Text
}
jsonStr, err := json.Marshal(openaiResp)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
return true
}
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonStr)})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
if usage.PromptTokens == 0 {
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
}
return usage, nil
}
func cohereHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
createdTime := common.GetTimestamp()
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
common.CloseResponseBodyGracefully(resp)
var cohereResp CohereResponseResult
err = json.Unmarshal(responseBody, &cohereResp)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
usage := dto.Usage{}
usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
var openaiResp dto.TextResponse
openaiResp.Id = cohereResp.ResponseId
openaiResp.Created = createdTime
openaiResp.Object = "chat.completion"
openaiResp.Model = info.UpstreamModelName
openaiResp.Usage = usage
openaiResp.Choices = []dto.OpenAITextResponseChoice{
{
Index: 0,
Message: dto.Message{Content: cohereResp.Text, Role: "assistant"},
FinishReason: stopReasonCohere2OpenAI(cohereResp.FinishReason),
},
}
jsonResponse, err := json.Marshal(openaiResp)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
return &usage, nil
}
func cohereRerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
common.CloseResponseBodyGracefully(resp)
var cohereResp CohereRerankResponseResult
err = json.Unmarshal(responseBody, &cohereResp)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
usage := dto.Usage{}
if cohereResp.Meta.BilledUnits.InputTokens == 0 {
usage.PromptTokens = info.PromptTokens
usage.CompletionTokens = 0
usage.TotalTokens = info.PromptTokens
} else {
usage.PromptTokens = cohereResp.Meta.BilledUnits.InputTokens
usage.CompletionTokens = cohereResp.Meta.BilledUnits.OutputTokens
usage.TotalTokens = cohereResp.Meta.BilledUnits.InputTokens + cohereResp.Meta.BilledUnits.OutputTokens
}
var rerankResp dto.RerankResponse
rerankResp.Results = cohereResp.Results
rerankResp.Usage = usage
jsonResponse, err := json.Marshal(rerankResp)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
return &usage, nil
}

View File

@@ -0,0 +1,133 @@
package coze
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/common"
"one-api/types"
"time"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
}
// ConvertAudioRequest implements channel.Adaptor.
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *common.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
return nil, errors.New("not implemented")
}
// ConvertClaudeRequest implements channel.Adaptor.
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *common.RelayInfo, request *dto.ClaudeRequest) (any, error) {
return nil, errors.New("not implemented")
}
// ConvertEmbeddingRequest implements channel.Adaptor.
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *common.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return nil, errors.New("not implemented")
}
// ConvertImageRequest implements channel.Adaptor.
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *common.RelayInfo, request dto.ImageRequest) (any, error) {
return nil, errors.New("not implemented")
}
// ConvertOpenAIRequest implements channel.Adaptor.
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *common.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return convertCozeChatRequest(c, *request), nil
}
// ConvertOpenAIResponsesRequest implements channel.Adaptor.
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *common.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
return nil, errors.New("not implemented")
}
// ConvertRerankRequest implements channel.Adaptor.
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, errors.New("not implemented")
}
// DoRequest implements channel.Adaptor.
func (a *Adaptor) DoRequest(c *gin.Context, info *common.RelayInfo, requestBody io.Reader) (any, error) {
if info.IsStream {
return channel.DoApiRequest(a, c, info, requestBody)
}
// 首先发送创建消息请求,成功后再发送获取消息请求
// 发送创建消息请求
resp, err := channel.DoApiRequest(a, c, info, requestBody)
if err != nil {
return nil, err
}
// 解析 resp
var cozeResponse CozeChatResponse
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
err = json.Unmarshal(respBody, &cozeResponse)
if cozeResponse.Code != 0 {
return nil, errors.New(cozeResponse.Msg)
}
c.Set("coze_conversation_id", cozeResponse.Data.ConversationId)
c.Set("coze_chat_id", cozeResponse.Data.Id)
// 轮询检查消息是否完成
for {
err, isComplete := checkIfChatComplete(a, c, info)
if err != nil {
return nil, err
} else {
if isComplete {
break
}
}
time.Sleep(time.Second * 1)
}
// 发送获取消息请求
return getChatDetail(a, c, info)
}
// DoResponse implements channel.Adaptor.
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *common.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
usage, err = cozeChatStreamHandler(c, info, resp)
} else {
usage, err = cozeChatHandler(c, info, resp)
}
return
}
// GetChannelName implements channel.Adaptor.
func (a *Adaptor) GetChannelName() string {
return ChannelName
}
// GetModelList implements channel.Adaptor.
func (a *Adaptor) GetModelList() []string {
return ModelList
}
// GetRequestURL implements channel.Adaptor.
func (a *Adaptor) GetRequestURL(info *common.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v3/chat", info.BaseUrl), nil
}
// Init implements channel.Adaptor.
func (a *Adaptor) Init(info *common.RelayInfo) {
}
// SetupRequestHeader implements channel.Adaptor.
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *common.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}

View File

@@ -0,0 +1,30 @@
package coze
var ModelList = []string{
"moonshot-v1-8k",
"moonshot-v1-32k",
"moonshot-v1-128k",
"Baichuan4",
"abab6.5s-chat-pro",
"glm-4-0520",
"qwen-max",
"deepseek-r1",
"deepseek-v3",
"deepseek-r1-distill-qwen-32b",
"deepseek-r1-distill-qwen-7b",
"step-1v-8k",
"step-1.5v-mini",
"Doubao-pro-32k",
"Doubao-pro-256k",
"Doubao-lite-128k",
"Doubao-lite-32k",
"Doubao-vision-lite-32k",
"Doubao-vision-pro-32k",
"Doubao-1.5-pro-vision-32k",
"Doubao-1.5-lite-32k",
"Doubao-1.5-pro-32k",
"Doubao-1.5-thinking-pro",
"Doubao-1.5-pro-256k",
}
var ChannelName = "coze"

78
relay/channel/coze/dto.go Normal file
View File

@@ -0,0 +1,78 @@
package coze
import "encoding/json"
type CozeError struct {
Code int `json:"code"`
Message string `json:"message"`
}
type CozeEnterMessage struct {
Role string `json:"role"`
Type string `json:"type,omitempty"`
Content any `json:"content,omitempty"`
MetaData json.RawMessage `json:"meta_data,omitempty"`
ContentType string `json:"content_type,omitempty"`
}
type CozeChatRequest struct {
BotId string `json:"bot_id"`
UserId string `json:"user_id"`
AdditionalMessages []CozeEnterMessage `json:"additional_messages,omitempty"`
Stream bool `json:"stream,omitempty"`
CustomVariables json.RawMessage `json:"custom_variables,omitempty"`
AutoSaveHistory bool `json:"auto_save_history,omitempty"`
MetaData json.RawMessage `json:"meta_data,omitempty"`
ExtraParams json.RawMessage `json:"extra_params,omitempty"`
ShortcutCommand json.RawMessage `json:"shortcut_command,omitempty"`
Parameters json.RawMessage `json:"parameters,omitempty"`
}
type CozeChatResponse struct {
Code int `json:"code"`
Msg string `json:"msg"`
Data CozeChatResponseData `json:"data"`
}
type CozeChatResponseData struct {
Id string `json:"id"`
ConversationId string `json:"conversation_id"`
BotId string `json:"bot_id"`
CreatedAt int64 `json:"created_at"`
LastError CozeError `json:"last_error"`
Status string `json:"status"`
Usage CozeChatUsage `json:"usage"`
}
type CozeChatUsage struct {
TokenCount int `json:"token_count"`
OutputCount int `json:"output_count"`
InputCount int `json:"input_count"`
}
type CozeChatDetailResponse struct {
Data []CozeChatV3MessageDetail `json:"data"`
Code int `json:"code"`
Msg string `json:"msg"`
Detail CozeResponseDetail `json:"detail"`
}
type CozeChatV3MessageDetail struct {
Id string `json:"id"`
Role string `json:"role"`
Type string `json:"type"`
BotId string `json:"bot_id"`
ChatId string `json:"chat_id"`
Content json.RawMessage `json:"content"`
MetaData json.RawMessage `json:"meta_data"`
CreatedAt int64 `json:"created_at"`
SectionId string `json:"section_id"`
UpdatedAt int64 `json:"updated_at"`
ContentType string `json:"content_type"`
ConversationId string `json:"conversation_id"`
ReasoningContent string `json:"reasoning_content"`
}
type CozeResponseDetail struct {
Logid string `json:"logid"`
}

View File

@@ -0,0 +1,296 @@
package coze
import (
"bufio"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
func convertCozeChatRequest(c *gin.Context, request dto.GeneralOpenAIRequest) *CozeChatRequest {
var messages []CozeEnterMessage
// 将 request的messages的role为user的content转换为CozeMessage
for _, message := range request.Messages {
if message.Role == "user" {
messages = append(messages, CozeEnterMessage{
Role: "user",
Content: message.Content,
// TODO: support more content type
ContentType: "text",
})
}
}
user := request.User
if user == "" {
user = helper.GetResponseID(c)
}
cozeRequest := &CozeChatRequest{
BotId: c.GetString("bot_id"),
UserId: user,
AdditionalMessages: messages,
Stream: request.Stream,
}
return cozeRequest
}
func cozeChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
common.CloseResponseBodyGracefully(resp)
// convert coze response to openai response
var response dto.TextResponse
var cozeResponse CozeChatDetailResponse
response.Model = info.UpstreamModelName
err = json.Unmarshal(responseBody, &cozeResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
if cozeResponse.Code != 0 {
return nil, types.NewError(errors.New(cozeResponse.Msg), types.ErrorCodeBadResponseBody)
}
// 从上下文获取 usage
var usage dto.Usage
usage.PromptTokens = c.GetInt("coze_input_count")
usage.CompletionTokens = c.GetInt("coze_output_count")
usage.TotalTokens = c.GetInt("coze_token_count")
response.Usage = usage
response.Id = helper.GetResponseID(c)
var responseContent json.RawMessage
for _, data := range cozeResponse.Data {
if data.Type == "answer" {
responseContent = data.Content
response.Created = data.CreatedAt
}
}
// 添加 response.Choices
response.Choices = []dto.OpenAITextResponseChoice{
{
Index: 0,
Message: dto.Message{Role: "assistant", Content: responseContent},
FinishReason: "stop",
},
}
jsonResponse, err := json.Marshal(response)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
return &usage, nil
}
func cozeChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
helper.SetEventStreamHeaders(c)
id := helper.GetResponseID(c)
var responseText string
var currentEvent string
var currentData string
var usage = &dto.Usage{}
for scanner.Scan() {
line := scanner.Text()
if line == "" {
if currentEvent != "" && currentData != "" {
// handle last event
handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
currentEvent = ""
currentData = ""
}
continue
}
if strings.HasPrefix(line, "event:") {
currentEvent = strings.TrimSpace(line[6:])
continue
}
if strings.HasPrefix(line, "data:") {
currentData = strings.TrimSpace(line[5:])
continue
}
}
// Last event
if currentEvent != "" && currentData != "" {
handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
}
if err := scanner.Err(); err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
helper.Done(c)
if usage.TotalTokens == 0 {
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
}
return usage, nil
}
func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {
switch event {
case "conversation.chat.completed":
// 将 data 解析为 CozeChatResponseData
var chatData CozeChatResponseData
err := json.Unmarshal([]byte(data), &chatData)
if err != nil {
common.SysError("error_unmarshalling_stream_response: " + err.Error())
return
}
usage.PromptTokens = chatData.Usage.InputCount
usage.CompletionTokens = chatData.Usage.OutputCount
usage.TotalTokens = chatData.Usage.TokenCount
finishReason := "stop"
stopResponse := helper.GenerateStopResponse(id, common.GetTimestamp(), info.UpstreamModelName, finishReason)
helper.ObjectData(c, stopResponse)
case "conversation.message.delta":
// 将 data 解析为 CozeChatV3MessageDetail
var messageData CozeChatV3MessageDetail
err := json.Unmarshal([]byte(data), &messageData)
if err != nil {
common.SysError("error_unmarshalling_stream_response: " + err.Error())
return
}
var content string
err = json.Unmarshal(messageData.Content, &content)
if err != nil {
common.SysError("error_unmarshalling_stream_response: " + err.Error())
return
}
*responseText += content
openaiResponse := dto.ChatCompletionsStreamResponse{
Id: id,
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: info.UpstreamModelName,
}
choice := dto.ChatCompletionsStreamResponseChoice{
Index: 0,
}
choice.Delta.SetContentString(content)
openaiResponse.Choices = append(openaiResponse.Choices, choice)
helper.ObjectData(c, openaiResponse)
case "error":
var errorData CozeError
err := json.Unmarshal([]byte(data), &errorData)
if err != nil {
common.SysError("error_unmarshalling_stream_response: " + err.Error())
return
}
common.SysError(fmt.Sprintf("stream event error: ", errorData.Code, errorData.Message))
}
}
func checkIfChatComplete(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (error, bool) {
requestURL := fmt.Sprintf("%s/v3/chat/retrieve", info.BaseUrl)
requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id")
// 将 conversationId和chatId作为参数发送get请求
req, err := http.NewRequest("GET", requestURL, nil)
if err != nil {
return err, false
}
err = a.SetupRequestHeader(c, &req.Header, info)
if err != nil {
return err, false
}
resp, err := doRequest(req, info) // 调用 doRequest
if err != nil {
return err, false
}
if resp == nil { // 确保在 doRequest 失败时 resp 不为 nil 导致 panic
return fmt.Errorf("resp is nil"), false
}
defer resp.Body.Close() // 确保响应体被关闭
// 解析 resp 到 CozeChatResponse
var cozeResponse CozeChatResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("read response body failed: %w", err), false
}
err = json.Unmarshal(responseBody, &cozeResponse)
if err != nil {
return fmt.Errorf("unmarshal response body failed: %w", err), false
}
if cozeResponse.Data.Status == "completed" {
// 在上下文设置 usage
c.Set("coze_token_count", cozeResponse.Data.Usage.TokenCount)
c.Set("coze_output_count", cozeResponse.Data.Usage.OutputCount)
c.Set("coze_input_count", cozeResponse.Data.Usage.InputCount)
return nil, true
} else if cozeResponse.Data.Status == "failed" || cozeResponse.Data.Status == "canceled" || cozeResponse.Data.Status == "requires_action" {
return fmt.Errorf("chat status: %s", cozeResponse.Data.Status), false
} else {
return nil, false
}
}
func getChatDetail(a *Adaptor, c *gin.Context, info *relaycommon.RelayInfo) (*http.Response, error) {
requestURL := fmt.Sprintf("%s/v3/chat/message/list", info.BaseUrl)
requestURL = requestURL + "?conversation_id=" + c.GetString("coze_conversation_id") + "&chat_id=" + c.GetString("coze_chat_id")
req, err := http.NewRequest("GET", requestURL, nil)
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}
err = a.SetupRequestHeader(c, &req.Header, info)
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
resp, err := doRequest(req, info)
if err != nil {
return nil, fmt.Errorf("do request failed: %w", err)
}
return resp, nil
}
func doRequest(req *http.Request, info *relaycommon.RelayInfo) (*http.Response, error) {
var client *http.Client
var err error // 声明 err 变量
if info.ChannelSetting.Proxy != "" {
client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
if err != nil {
return nil, fmt.Errorf("new proxy http client failed: %w", err)
}
} else {
client = service.GetHttpClient()
}
resp, err := client.Do(req)
if err != nil { // 增加对 client.Do(req) 返回错误的检查
return nil, fmt.Errorf("client.Do failed: %w", err)
}
// _ = resp.Body.Close()
return resp, nil
}

View File

@@ -0,0 +1,100 @@
package deepseek
import (
"errors"
"fmt"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
fimBaseUrl := info.BaseUrl
if !strings.HasSuffix(info.BaseUrl, "/beta") {
fimBaseUrl += "/beta"
}
switch info.RelayMode {
case constant.RelayModeCompletions:
return fmt.Sprintf("%s/completions", fimBaseUrl), nil
default:
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
usage, err = openai.OpenaiHandler(c, info, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,7 @@
package deepseek
var ModelList = []string{
"deepseek-chat", "deepseek-reasoner",
}
var ChannelName = "deepseek"

View File

@@ -0,0 +1,115 @@
package dify
import (
"errors"
"fmt"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/types"
"github.com/gin-gonic/gin"
)
const (
BotTypeChatFlow = 1 // chatflow default
BotTypeAgent = 2
BotTypeWorkFlow = 3
BotTypeCompletion = 4
)
type Adaptor struct {
BotType int
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
//if strings.HasPrefix(info.UpstreamModelName, "agent") {
// a.BotType = BotTypeAgent
//} else if strings.HasPrefix(info.UpstreamModelName, "workflow") {
// a.BotType = BotTypeWorkFlow
//} else if strings.HasPrefix(info.UpstreamModelName, "chat") {
// a.BotType = BotTypeCompletion
//} else {
//}
a.BotType = BotTypeChatFlow
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
switch a.BotType {
case BotTypeWorkFlow:
return fmt.Sprintf("%s/v1/workflows/run", info.BaseUrl), nil
case BotTypeCompletion:
return fmt.Sprintf("%s/v1/completion-messages", info.BaseUrl), nil
case BotTypeAgent:
fallthrough
default:
return fmt.Sprintf("%s/v1/chat-messages", info.BaseUrl), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return requestOpenAI2Dify(c, info, *request), nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
return difyStreamHandler(c, info, resp)
} else {
return difyHandler(c, info, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,5 @@
package dify
var ModelList []string
var ChannelName = "dify"

45
relay/channel/dify/dto.go Normal file
View File

@@ -0,0 +1,45 @@
package dify
import "one-api/dto"
type DifyChatRequest struct {
Inputs map[string]interface{} `json:"inputs"`
Query string `json:"query"`
ResponseMode string `json:"response_mode"`
User string `json:"user"`
AutoGenerateName bool `json:"auto_generate_name"`
Files []DifyFile `json:"files"`
}
type DifyFile struct {
Type string `json:"type"`
TransferMode string `json:"transfer_mode"`
URL string `json:"url,omitempty"`
UploadFileId string `json:"upload_file_id,omitempty"`
}
type DifyMetaData struct {
Usage dto.Usage `json:"usage"`
}
type DifyData struct {
WorkflowId string `json:"workflow_id"`
NodeId string `json:"node_id"`
NodeType string `json:"node_type"`
Status string `json:"status"`
}
type DifyChatCompletionResponse struct {
ConversationId string `json:"conversation_id"`
Answer string `json:"answer"`
CreateAt int64 `json:"create_at"`
MetaData DifyMetaData `json:"metadata"`
}
type DifyChunkChatCompletionResponse struct {
Event string `json:"event"`
ConversationId string `json:"conversation_id"`
Answer string `json:"answer"`
Data DifyData `json:"data"`
MetaData DifyMetaData `json:"metadata"`
}

View File

@@ -0,0 +1,289 @@
package dify
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
"os"
"strings"
"github.com/gin-gonic/gin"
)
func uploadDifyFile(c *gin.Context, info *relaycommon.RelayInfo, user string, media dto.MediaContent) *DifyFile {
uploadUrl := fmt.Sprintf("%s/v1/files/upload", info.BaseUrl)
switch media.Type {
case dto.ContentTypeImageURL:
// Decode base64 data
imageMedia := media.GetImageMedia()
base64Data := imageMedia.Url
// Remove base64 prefix if exists (e.g., "data:image/jpeg;base64,")
if idx := strings.Index(base64Data, ","); idx != -1 {
base64Data = base64Data[idx+1:]
}
// Decode base64 string
decodedData, err := base64.StdEncoding.DecodeString(base64Data)
if err != nil {
common.SysError("failed to decode base64: " + err.Error())
return nil
}
// Create temporary file
tempFile, err := os.CreateTemp("", "dify-upload-*")
if err != nil {
common.SysError("failed to create temp file: " + err.Error())
return nil
}
defer tempFile.Close()
defer os.Remove(tempFile.Name())
// Write decoded data to temp file
if _, err := tempFile.Write(decodedData); err != nil {
common.SysError("failed to write to temp file: " + err.Error())
return nil
}
// Create multipart form
body := &bytes.Buffer{}
writer := multipart.NewWriter(body)
// Add user field
if err := writer.WriteField("user", user); err != nil {
common.SysError("failed to add user field: " + err.Error())
return nil
}
// Create form file with proper mime type
mimeType := imageMedia.MimeType
if mimeType == "" {
mimeType = "image/jpeg" // default mime type
}
// Create form file
part, err := writer.CreateFormFile("file", fmt.Sprintf("image.%s", strings.TrimPrefix(mimeType, "image/")))
if err != nil {
common.SysError("failed to create form file: " + err.Error())
return nil
}
// Copy file content to form
if _, err = io.Copy(part, bytes.NewReader(decodedData)); err != nil {
common.SysError("failed to copy file content: " + err.Error())
return nil
}
writer.Close()
// Create HTTP request
req, err := http.NewRequest("POST", uploadUrl, body)
if err != nil {
common.SysError("failed to create request: " + err.Error())
return nil
}
req.Header.Set("Content-Type", writer.FormDataContentType())
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
// Send request
client := service.GetHttpClient()
resp, err := client.Do(req)
if err != nil {
common.SysError("failed to send request: " + err.Error())
return nil
}
defer resp.Body.Close()
// Parse response
var result struct {
Id string `json:"id"`
}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
common.SysError("failed to decode response: " + err.Error())
return nil
}
return &DifyFile{
UploadFileId: result.Id,
Type: "image",
TransferMode: "local_file",
}
}
return nil
}
func requestOpenAI2Dify(c *gin.Context, info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) *DifyChatRequest {
difyReq := DifyChatRequest{
Inputs: make(map[string]interface{}),
AutoGenerateName: false,
}
user := request.User
if user == "" {
user = helper.GetResponseID(c)
}
difyReq.User = user
files := make([]DifyFile, 0)
var content strings.Builder
for _, message := range request.Messages {
if message.Role == "system" {
content.WriteString("SYSTEM: \n" + message.StringContent() + "\n")
} else if message.Role == "assistant" {
content.WriteString("ASSISTANT: \n" + message.StringContent() + "\n")
} else {
parseContent := message.ParseContent()
for _, mediaContent := range parseContent {
switch mediaContent.Type {
case dto.ContentTypeText:
content.WriteString("USER: \n" + mediaContent.Text + "\n")
case dto.ContentTypeImageURL:
media := mediaContent.GetImageMedia()
var file *DifyFile
if media.IsRemoteImage() {
file.Type = media.MimeType
file.TransferMode = "remote_url"
file.URL = media.Url
} else {
file = uploadDifyFile(c, info, difyReq.User, mediaContent)
}
if file != nil {
files = append(files, *file)
}
}
}
}
}
difyReq.Query = content.String()
difyReq.Files = files
mode := "blocking"
if request.Stream {
mode = "streaming"
}
difyReq.ResponseMode = mode
return &difyReq
}
func streamResponseDify2OpenAI(difyResponse DifyChunkChatCompletionResponse) *dto.ChatCompletionsStreamResponse {
response := dto.ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "dify",
}
var choice dto.ChatCompletionsStreamResponseChoice
if strings.HasPrefix(difyResponse.Event, "workflow_") {
if constant.DifyDebug {
text := "Workflow: " + difyResponse.Data.WorkflowId
if difyResponse.Event == "workflow_finished" {
text += " " + difyResponse.Data.Status
}
choice.Delta.SetReasoningContent(text + "\n")
}
} else if strings.HasPrefix(difyResponse.Event, "node_") {
if constant.DifyDebug {
text := "Node: " + difyResponse.Data.NodeType
if difyResponse.Event == "node_finished" {
text += " " + difyResponse.Data.Status
}
choice.Delta.SetReasoningContent(text + "\n")
}
} else if difyResponse.Event == "message" || difyResponse.Event == "agent_message" {
if difyResponse.Answer == "<details style=\"color:gray;background-color: #f8f8f8;padding: 8px;border-radius: 4px;\" open> <summary> Thinking... </summary>\n" {
difyResponse.Answer = "<think>"
} else if difyResponse.Answer == "</details>" {
difyResponse.Answer = "</think>"
}
choice.Delta.SetContentString(difyResponse.Answer)
}
response.Choices = append(response.Choices, choice)
return &response
}
func difyStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var responseText string
usage := &dto.Usage{}
var nodeToken int
helper.SetEventStreamHeaders(c)
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var difyResponse DifyChunkChatCompletionResponse
err := json.Unmarshal([]byte(data), &difyResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return true
}
var openaiResponse dto.ChatCompletionsStreamResponse
if difyResponse.Event == "message_end" {
usage = &difyResponse.MetaData.Usage
return false
} else if difyResponse.Event == "error" {
return false
} else {
openaiResponse = *streamResponseDify2OpenAI(difyResponse)
if len(openaiResponse.Choices) != 0 {
responseText += openaiResponse.Choices[0].Delta.GetContentString()
if openaiResponse.Choices[0].Delta.ReasoningContent != nil {
nodeToken += 1
}
}
}
err = helper.ObjectData(c, openaiResponse)
if err != nil {
common.SysError(err.Error())
}
return true
})
helper.Done(c)
if usage.TotalTokens == 0 {
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
}
usage.CompletionTokens += nodeToken
return usage, nil
}
func difyHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var difyResponse DifyChatCompletionResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &difyResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
fullTextResponse := dto.OpenAITextResponse{
Id: difyResponse.ConversationId,
Object: "chat.completion",
Created: common.GetTimestamp(),
Usage: difyResponse.MetaData.Usage,
}
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Role: "assistant",
Content: difyResponse.Answer,
},
FinishReason: "stop",
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
c.Writer.Write(jsonResponse)
return &difyResponse.MetaData.Usage, nil
}

View File

@@ -0,0 +1,271 @@
package gemini
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/setting/model_setting"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
if !strings.HasPrefix(info.UpstreamModelName, "imagen") {
return nil, errors.New("not supported model for image generation")
}
// convert size to aspect ratio
aspectRatio := "1:1" // default aspect ratio
switch request.Size {
case "1024x1024":
aspectRatio = "1:1"
case "1024x1792":
aspectRatio = "9:16"
case "1792x1024":
aspectRatio = "16:9"
}
// build gemini imagen request
geminiRequest := GeminiImageRequest{
Instances: []GeminiImageInstance{
{
Prompt: request.Prompt,
},
},
Parameters: GeminiImageParameters{
SampleCount: request.N,
AspectRatio: aspectRatio,
PersonGeneration: "allow_adult", // default allow adult
},
}
return geminiRequest, nil
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
// 新增逻辑:处理 -thinking-<budget> 格式
if strings.Contains(info.UpstreamModelName, "-thinking-") {
parts := strings.Split(info.UpstreamModelName, "-thinking-")
info.UpstreamModelName = parts[0]
} else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
} else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
}
}
version := model_setting.GetGeminiVersionSetting(info.UpstreamModelName)
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
return fmt.Sprintf("%s/%s/models/%s:predict", info.BaseUrl, version, info.UpstreamModelName), nil
}
if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
strings.HasPrefix(info.UpstreamModelName, "embedding") ||
strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
return fmt.Sprintf("%s/%s/models/%s:embedContent", info.BaseUrl, version, info.UpstreamModelName), nil
}
action := "generateContent"
if info.IsStream {
action = "streamGenerateContent?alt=sse"
}
return fmt.Sprintf("%s/%s/models/%s:%s", info.BaseUrl, version, info.UpstreamModelName, action), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("x-goog-api-key", info.ApiKey)
return nil
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
geminiRequest, err := CovertGemini2OpenAI(*request, info)
if err != nil {
return nil, err
}
return geminiRequest, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
if request.Input == nil {
return nil, errors.New("input is required")
}
inputs := request.ParseInput()
if len(inputs) == 0 {
return nil, errors.New("input is empty")
}
// only process the first input
geminiRequest := GeminiEmbeddingRequest{
Content: GeminiChatContent{
Parts: []GeminiPart{
{
Text: inputs[0],
},
},
},
}
// set specific parameters for different models
// https://ai.google.dev/api/embeddings?hl=zh-cn#method:-models.embedcontent
switch info.UpstreamModelName {
case "text-embedding-004":
// except embedding-001 supports setting `OutputDimensionality`
if request.Dimensions > 0 {
geminiRequest.OutputDimensionality = request.Dimensions
}
}
return geminiRequest, nil
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == constant.RelayModeGemini {
if info.IsStream {
return GeminiTextGenerationStreamHandler(c, info, resp)
} else {
return GeminiTextGenerationHandler(c, info, resp)
}
}
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
return GeminiImageHandler(c, info, resp)
}
// check if the model is an embedding model
if strings.HasPrefix(info.UpstreamModelName, "text-embedding") ||
strings.HasPrefix(info.UpstreamModelName, "embedding") ||
strings.HasPrefix(info.UpstreamModelName, "gemini-embedding") {
return GeminiEmbeddingHandler(c, info, resp)
}
if info.IsStream {
return GeminiChatStreamHandler(c, info, resp)
} else {
return GeminiChatHandler(c, info, resp)
}
//if usage.(*dto.Usage).CompletionTokenDetails.ReasoningTokens > 100 {
// // 没有请求-thinking的情况下产生思考token则按照思考模型计费
// if !strings.HasSuffix(info.OriginModelName, "-thinking") &&
// !strings.HasSuffix(info.OriginModelName, "-nothinking") {
// thinkingModelName := info.OriginModelName + "-thinking"
// if operation_setting.SelfUseModeEnabled || helper.ContainPriceOrRatio(thinkingModelName) {
// info.OriginModelName = thinkingModelName
// }
// }
//}
return nil, types.NewError(errors.New("not implemented"), types.ErrorCodeBadResponseBody)
}
func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody)
}
_ = resp.Body.Close()
var geminiResponse GeminiImageResponse
if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
}
if len(geminiResponse.Predictions) == 0 {
return nil, types.NewError(errors.New("no images generated"), types.ErrorCodeBadResponseBody)
}
// convert to openai format response
openAIResponse := dto.ImageResponse{
Created: common.GetTimestamp(),
Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
}
for _, prediction := range geminiResponse.Predictions {
if prediction.RaiFilteredReason != "" {
continue // skip filtered image
}
openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
B64Json: prediction.BytesBase64Encoded,
})
}
jsonResponse, jsonErr := json.Marshal(openAIResponse)
if jsonErr != nil {
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
// https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
// each image has fixed 258 tokens
const imageTokens = 258
generatedImages := len(openAIResponse.Data)
usage := &dto.Usage{
PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
CompletionTokens: 0, // image generation does not calculate completion tokens
TotalTokens: imageTokens * generatedImages,
}
return usage, nil
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,37 @@
package gemini
var ModelList = []string{
// stable version
"gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.5-flash-8b",
"gemini-2.0-flash",
// latest version
"gemini-1.5-pro-latest", "gemini-1.5-flash-latest",
// preview version
"gemini-2.0-flash-lite-preview",
// gemini exp
"gemini-exp-1206",
// flash exp
"gemini-2.0-flash-exp",
// pro exp
"gemini-2.0-pro-exp",
// thinking exp
"gemini-2.0-flash-thinking-exp",
"gemini-2.5-pro-exp-03-25",
"gemini-2.5-pro-preview-03-25",
// imagen models
"imagen-3.0-generate-002",
// embedding models
"gemini-embedding-exp-03-07",
"text-embedding-004",
"embedding-001",
}
var SafetySettingList = []string{
"HARM_CATEGORY_HARASSMENT",
"HARM_CATEGORY_HATE_SPEECH",
"HARM_CATEGORY_SEXUALLY_EXPLICIT",
"HARM_CATEGORY_DANGEROUS_CONTENT",
"HARM_CATEGORY_CIVIC_INTEGRITY",
}
var ChannelName = "google gemini"

222
relay/channel/gemini/dto.go Normal file
View File

@@ -0,0 +1,222 @@
package gemini
import "encoding/json"
type GeminiChatRequest struct {
Contents []GeminiChatContent `json:"contents"`
SafetySettings []GeminiChatSafetySettings `json:"safetySettings,omitempty"`
GenerationConfig GeminiChatGenerationConfig `json:"generationConfig,omitempty"`
Tools []GeminiChatTool `json:"tools,omitempty"`
SystemInstructions *GeminiChatContent `json:"systemInstruction,omitempty"`
}
type GeminiThinkingConfig struct {
IncludeThoughts bool `json:"includeThoughts,omitempty"`
ThinkingBudget *int `json:"thinkingBudget,omitempty"`
}
func (c *GeminiThinkingConfig) SetThinkingBudget(budget int) {
c.ThinkingBudget = &budget
}
type GeminiInlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"`
}
// UnmarshalJSON custom unmarshaler for GeminiInlineData to support snake_case and camelCase for MimeType
func (g *GeminiInlineData) UnmarshalJSON(data []byte) error {
type Alias GeminiInlineData // Use type alias to avoid recursion
var aux struct {
Alias
MimeTypeSnake string `json:"mime_type"`
}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}
*g = GeminiInlineData(aux.Alias) // Copy other fields if any in future
// Prioritize snake_case if present
if aux.MimeTypeSnake != "" {
g.MimeType = aux.MimeTypeSnake
} else if aux.MimeType != "" { // Fallback to camelCase from Alias
g.MimeType = aux.MimeType
}
// g.Data would be populated by aux.Alias.Data
return nil
}
type FunctionCall struct {
FunctionName string `json:"name"`
Arguments any `json:"args"`
}
type FunctionResponse struct {
Name string `json:"name"`
Response map[string]interface{} `json:"response"`
}
type GeminiPartExecutableCode struct {
Language string `json:"language,omitempty"`
Code string `json:"code,omitempty"`
}
type GeminiPartCodeExecutionResult struct {
Outcome string `json:"outcome,omitempty"`
Output string `json:"output,omitempty"`
}
type GeminiFileData struct {
MimeType string `json:"mimeType,omitempty"`
FileUri string `json:"fileUri,omitempty"`
}
type GeminiPart struct {
Text string `json:"text,omitempty"`
Thought bool `json:"thought,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
FileData *GeminiFileData `json:"fileData,omitempty"`
ExecutableCode *GeminiPartExecutableCode `json:"executableCode,omitempty"`
CodeExecutionResult *GeminiPartCodeExecutionResult `json:"codeExecutionResult,omitempty"`
}
// UnmarshalJSON custom unmarshaler for GeminiPart to support snake_case and camelCase for InlineData
func (p *GeminiPart) UnmarshalJSON(data []byte) error {
// Alias to avoid recursion during unmarshalling
type Alias GeminiPart
var aux struct {
Alias
InlineDataSnake *GeminiInlineData `json:"inline_data,omitempty"` // snake_case variant
}
if err := json.Unmarshal(data, &aux); err != nil {
return err
}
// Assign fields from alias
*p = GeminiPart(aux.Alias)
// Prioritize snake_case for InlineData if present
if aux.InlineDataSnake != nil {
p.InlineData = aux.InlineDataSnake
} else if aux.InlineData != nil { // Fallback to camelCase from Alias
p.InlineData = aux.InlineData
}
// Other fields like Text, FunctionCall etc. are already populated via aux.Alias
return nil
}
type GeminiChatContent struct {
Role string `json:"role,omitempty"`
Parts []GeminiPart `json:"parts"`
}
type GeminiChatSafetySettings struct {
Category string `json:"category"`
Threshold string `json:"threshold"`
}
type GeminiChatTool struct {
GoogleSearch any `json:"googleSearch,omitempty"`
GoogleSearchRetrieval any `json:"googleSearchRetrieval,omitempty"`
CodeExecution any `json:"codeExecution,omitempty"`
FunctionDeclarations any `json:"functionDeclarations,omitempty"`
}
type GeminiChatGenerationConfig struct {
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"`
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
StopSequences []string `json:"stopSequences,omitempty"`
ResponseMimeType string `json:"responseMimeType,omitempty"`
ResponseSchema any `json:"responseSchema,omitempty"`
Seed int64 `json:"seed,omitempty"`
ResponseModalities []string `json:"responseModalities,omitempty"`
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config
}
type GeminiChatCandidate struct {
Content GeminiChatContent `json:"content"`
FinishReason *string `json:"finishReason"`
Index int64 `json:"index"`
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
}
type GeminiChatSafetyRating struct {
Category string `json:"category"`
Probability string `json:"probability"`
}
type GeminiChatPromptFeedback struct {
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"`
}
type GeminiChatResponse struct {
Candidates []GeminiChatCandidate `json:"candidates"`
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"`
UsageMetadata GeminiUsageMetadata `json:"usageMetadata"`
}
type GeminiUsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
ThoughtsTokenCount int `json:"thoughtsTokenCount"`
PromptTokensDetails []GeminiPromptTokensDetails `json:"promptTokensDetails"`
}
type GeminiPromptTokensDetails struct {
Modality string `json:"modality"`
TokenCount int `json:"tokenCount"`
}
// Imagen related structs
type GeminiImageRequest struct {
Instances []GeminiImageInstance `json:"instances"`
Parameters GeminiImageParameters `json:"parameters"`
}
type GeminiImageInstance struct {
Prompt string `json:"prompt"`
}
type GeminiImageParameters struct {
SampleCount int `json:"sampleCount,omitempty"`
AspectRatio string `json:"aspectRatio,omitempty"`
PersonGeneration string `json:"personGeneration,omitempty"`
}
type GeminiImageResponse struct {
Predictions []GeminiImagePrediction `json:"predictions"`
}
type GeminiImagePrediction struct {
MimeType string `json:"mimeType"`
BytesBase64Encoded string `json:"bytesBase64Encoded"`
RaiFilteredReason string `json:"raiFilteredReason,omitempty"`
SafetyAttributes any `json:"safetyAttributes,omitempty"`
}
// Embedding related structs
type GeminiEmbeddingRequest struct {
Content GeminiChatContent `json:"content"`
TaskType string `json:"taskType,omitempty"`
Title string `json:"title,omitempty"`
OutputDimensionality int `json:"outputDimensionality,omitempty"`
}
type GeminiEmbeddingResponse struct {
Embedding ContentEmbedding `json:"embedding"`
}
type ContentEmbedding struct {
Values []float64 `json:"values"`
}

View File

@@ -0,0 +1,138 @@
package gemini
import (
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
defer common.CloseResponseBodyGracefully(resp)
// 读取响应体
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
if common.DebugEnabled {
println(string(responseBody))
}
// 解析为 Gemini 原生响应格式
var geminiResponse GeminiChatResponse
err = common.Unmarshal(responseBody, &geminiResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
// 计算使用量(基于 UsageMetadata
usage := dto.Usage{
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount,
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
}
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
if detail.Modality == "AUDIO" {
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
} else if detail.Modality == "TEXT" {
usage.PromptTokensDetails.TextTokens = detail.TokenCount
}
}
// 直接返回 Gemini 原生格式的 JSON 响应
jsonResponse, err := common.Marshal(geminiResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
common.IOCopyBytesGracefully(c, resp, jsonResponse)
return &usage, nil
}
func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var usage = &dto.Usage{}
var imageCount int
helper.SetEventStreamHeaders(c)
responseText := strings.Builder{}
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var geminiResponse GeminiChatResponse
err := common.UnmarshalJsonStr(data, &geminiResponse)
if err != nil {
common.LogError(c, "error unmarshalling stream response: "+err.Error())
return false
}
// 统计图片数量
for _, candidate := range geminiResponse.Candidates {
for _, part := range candidate.Content.Parts {
if part.InlineData != nil && part.InlineData.MimeType != "" {
imageCount++
}
if part.Text != "" {
responseText.WriteString(part.Text)
}
}
}
// 更新使用量统计
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
if detail.Modality == "AUDIO" {
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
} else if detail.Modality == "TEXT" {
usage.PromptTokensDetails.TextTokens = detail.TokenCount
}
}
}
// 直接发送 GeminiChatResponse 响应
err = helper.StringData(c, data)
if err != nil {
common.LogError(c, err.Error())
}
return true
})
if imageCount != 0 {
if usage.CompletionTokens == 0 {
usage.CompletionTokens = imageCount * 258
}
}
// 如果usage.CompletionTokens为0则使用本地统计的completion tokens
if usage.CompletionTokens == 0 {
str := responseText.String()
if len(str) > 0 {
usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
} else {
// 空补全,不需要使用量
usage = &dto.Usage{}
}
}
// 移除流式响应结尾的[Done]因为Gemini API没有发送Done的行为
//helper.Done(c)
return usage, nil
}

View File

@@ -0,0 +1,958 @@
package gemini
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/setting/model_setting"
"one-api/types"
"strconv"
"strings"
"unicode/utf8"
"github.com/gin-gonic/gin"
)
var geminiSupportedMimeTypes = map[string]bool{
"application/pdf": true,
"audio/mpeg": true,
"audio/mp3": true,
"audio/wav": true,
"image/png": true,
"image/jpeg": true,
"text/plain": true,
"video/mov": true,
"video/mpeg": true,
"video/mp4": true,
"video/mpg": true,
"video/avi": true,
"video/wmv": true,
"video/mpegps": true,
"video/flv": true,
}
// Gemini 允许的思考预算范围
const (
pro25MinBudget = 128
pro25MaxBudget = 32768
flash25MaxBudget = 24576
flash25LiteMinBudget = 512
flash25LiteMaxBudget = 24576
)
// clampThinkingBudget 根据模型名称将预算限制在允许的范围内
func clampThinkingBudget(modelName string, budget int) int {
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
is25FlashLite := strings.HasPrefix(modelName, "gemini-2.5-flash-lite")
if is25FlashLite {
if budget < flash25LiteMinBudget {
return flash25LiteMinBudget
}
if budget > flash25LiteMaxBudget {
return flash25LiteMaxBudget
}
} else if isNew25Pro {
if budget < pro25MinBudget {
return pro25MinBudget
}
if budget > pro25MaxBudget {
return pro25MaxBudget
}
} else { // 其他模型
if budget < 0 {
return 0
}
if budget > flash25MaxBudget {
return flash25MaxBudget
}
}
return budget
}
func ThinkingAdaptor(geminiRequest *GeminiChatRequest, info *relaycommon.RelayInfo) {
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
modelName := info.UpstreamModelName
isNew25Pro := strings.HasPrefix(modelName, "gemini-2.5-pro") &&
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-05-06") &&
!strings.HasPrefix(modelName, "gemini-2.5-pro-preview-03-25")
if strings.Contains(modelName, "-thinking-") {
parts := strings.SplitN(modelName, "-thinking-", 2)
if len(parts) == 2 && parts[1] != "" {
if budgetTokens, err := strconv.Atoi(parts[1]); err == nil {
clampedBudget := clampThinkingBudget(modelName, budgetTokens)
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
ThinkingBudget: common.GetPointer(clampedBudget),
IncludeThoughts: true,
}
}
}
} else if strings.HasSuffix(modelName, "-thinking") {
unsupportedModels := []string{
"gemini-2.5-pro-preview-05-06",
"gemini-2.5-pro-preview-03-25",
}
isUnsupported := false
for _, unsupportedModel := range unsupportedModels {
if strings.HasPrefix(modelName, unsupportedModel) {
isUnsupported = true
break
}
}
if isUnsupported {
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
IncludeThoughts: true,
}
} else {
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
IncludeThoughts: true,
}
if geminiRequest.GenerationConfig.MaxOutputTokens > 0 {
budgetTokens := model_setting.GetGeminiSettings().ThinkingAdapterBudgetTokensPercentage * float64(geminiRequest.GenerationConfig.MaxOutputTokens)
clampedBudget := clampThinkingBudget(modelName, int(budgetTokens))
geminiRequest.GenerationConfig.ThinkingConfig.ThinkingBudget = common.GetPointer(clampedBudget)
}
}
} else if strings.HasSuffix(modelName, "-nothinking") {
if !isNew25Pro {
geminiRequest.GenerationConfig.ThinkingConfig = &GeminiThinkingConfig{
ThinkingBudget: common.GetPointer(0),
}
}
}
}
}
// Setting safety to the lowest possible values since Gemini is already powerless enough
func CovertGemini2OpenAI(textRequest dto.GeneralOpenAIRequest, info *relaycommon.RelayInfo) (*GeminiChatRequest, error) {
geminiRequest := GeminiChatRequest{
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)),
GenerationConfig: GeminiChatGenerationConfig{
Temperature: textRequest.Temperature,
TopP: textRequest.TopP,
MaxOutputTokens: textRequest.MaxTokens,
Seed: int64(textRequest.Seed),
},
}
if model_setting.IsGeminiModelSupportImagine(info.UpstreamModelName) {
geminiRequest.GenerationConfig.ResponseModalities = []string{
"TEXT",
"IMAGE",
}
}
ThinkingAdaptor(&geminiRequest, info)
safetySettings := make([]GeminiChatSafetySettings, 0, len(SafetySettingList))
for _, category := range SafetySettingList {
safetySettings = append(safetySettings, GeminiChatSafetySettings{
Category: category,
Threshold: model_setting.GetGeminiSafetySetting(category),
})
}
geminiRequest.SafetySettings = safetySettings
// openaiContent.FuncToToolCalls()
if textRequest.Tools != nil {
functions := make([]dto.FunctionRequest, 0, len(textRequest.Tools))
googleSearch := false
codeExecution := false
for _, tool := range textRequest.Tools {
if tool.Function.Name == "googleSearch" {
googleSearch = true
continue
}
if tool.Function.Name == "codeExecution" {
codeExecution = true
continue
}
if tool.Function.Parameters != nil {
params, ok := tool.Function.Parameters.(map[string]interface{})
if ok {
if props, hasProps := params["properties"].(map[string]interface{}); hasProps {
if len(props) == 0 {
tool.Function.Parameters = nil
}
}
}
}
// Clean the parameters before appending
cleanedParams := cleanFunctionParameters(tool.Function.Parameters)
tool.Function.Parameters = cleanedParams
functions = append(functions, tool.Function)
}
if codeExecution {
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
CodeExecution: make(map[string]string),
})
}
if googleSearch {
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
GoogleSearch: make(map[string]string),
})
}
if len(functions) > 0 {
geminiRequest.Tools = append(geminiRequest.Tools, GeminiChatTool{
FunctionDeclarations: functions,
})
}
// common.SysLog("tools: " + fmt.Sprintf("%+v", geminiRequest.Tools))
// json_data, _ := json.Marshal(geminiRequest.Tools)
// common.SysLog("tools_json: " + string(json_data))
}
if textRequest.ResponseFormat != nil && (textRequest.ResponseFormat.Type == "json_schema" || textRequest.ResponseFormat.Type == "json_object") {
geminiRequest.GenerationConfig.ResponseMimeType = "application/json"
if textRequest.ResponseFormat.JsonSchema != nil && textRequest.ResponseFormat.JsonSchema.Schema != nil {
cleanedSchema := removeAdditionalPropertiesWithDepth(textRequest.ResponseFormat.JsonSchema.Schema, 0)
geminiRequest.GenerationConfig.ResponseSchema = cleanedSchema
}
}
tool_call_ids := make(map[string]string)
var system_content []string
//shouldAddDummyModelMessage := false
for _, message := range textRequest.Messages {
if message.Role == "system" {
system_content = append(system_content, message.StringContent())
continue
} else if message.Role == "tool" || message.Role == "function" {
if len(geminiRequest.Contents) == 0 || geminiRequest.Contents[len(geminiRequest.Contents)-1].Role == "model" {
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{
Role: "user",
})
}
var parts = &geminiRequest.Contents[len(geminiRequest.Contents)-1].Parts
name := ""
if message.Name != nil {
name = *message.Name
} else if val, exists := tool_call_ids[message.ToolCallId]; exists {
name = val
}
var contentMap map[string]interface{}
contentStr := message.StringContent()
// 1. 尝试解析为 JSON 对象
if err := json.Unmarshal([]byte(contentStr), &contentMap); err != nil {
// 2. 如果失败,尝试解析为 JSON 数组
var contentSlice []interface{}
if err := json.Unmarshal([]byte(contentStr), &contentSlice); err == nil {
// 如果是数组,包装成对象
contentMap = map[string]interface{}{"result": contentSlice}
} else {
// 3. 如果再次失败,作为纯文本处理
contentMap = map[string]interface{}{"content": contentStr}
}
}
functionResp := &FunctionResponse{
Name: name,
Response: contentMap,
}
*parts = append(*parts, GeminiPart{
FunctionResponse: functionResp,
})
continue
}
var parts []GeminiPart
content := GeminiChatContent{
Role: message.Role,
}
// isToolCall := false
if message.ToolCalls != nil {
// message.Role = "model"
// isToolCall = true
for _, call := range message.ParseToolCalls() {
args := map[string]interface{}{}
if call.Function.Arguments != "" {
if json.Unmarshal([]byte(call.Function.Arguments), &args) != nil {
return nil, fmt.Errorf("invalid arguments for function %s, args: %s", call.Function.Name, call.Function.Arguments)
}
}
toolCall := GeminiPart{
FunctionCall: &FunctionCall{
FunctionName: call.Function.Name,
Arguments: args,
},
}
parts = append(parts, toolCall)
tool_call_ids[call.ID] = call.Function.Name
}
}
openaiContent := message.ParseContent()
imageNum := 0
for _, part := range openaiContent {
if part.Type == dto.ContentTypeText {
if part.Text == "" {
continue
}
parts = append(parts, GeminiPart{
Text: part.Text,
})
} else if part.Type == dto.ContentTypeImageURL {
imageNum += 1
if constant.GeminiVisionMaxImageNum != -1 && imageNum > constant.GeminiVisionMaxImageNum {
return nil, fmt.Errorf("too many images in the message, max allowed is %d", constant.GeminiVisionMaxImageNum)
}
// 判断是否是url
if strings.HasPrefix(part.GetImageMedia().Url, "http") {
// 是url获取文件的类型和base64编码的数据
fileData, err := service.GetFileBase64FromUrl(part.GetImageMedia().Url)
if err != nil {
return nil, fmt.Errorf("get file base64 from url '%s' failed: %w", part.GetImageMedia().Url, err)
}
// 校验 MimeType 是否在 Gemini 支持的白名单中
if _, ok := geminiSupportedMimeTypes[strings.ToLower(fileData.MimeType)]; !ok {
url := part.GetImageMedia().Url
return nil, fmt.Errorf("mime type is not supported by Gemini: '%s', url: '%s', supported types are: %v", fileData.MimeType, url, getSupportedMimeTypesList())
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: fileData.MimeType, // 使用原始的 MimeType因为大小写可能对API有意义
Data: fileData.Base64Data,
},
})
} else {
format, base64String, err := service.DecodeBase64FileData(part.GetImageMedia().Url)
if err != nil {
return nil, fmt.Errorf("decode base64 image data failed: %s", err.Error())
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: format,
Data: base64String,
},
})
}
} else if part.Type == dto.ContentTypeFile {
if part.GetFile().FileId != "" {
return nil, fmt.Errorf("only base64 file is supported in gemini")
}
format, base64String, err := service.DecodeBase64FileData(part.GetFile().FileData)
if err != nil {
return nil, fmt.Errorf("decode base64 file data failed: %s", err.Error())
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: format,
Data: base64String,
},
})
} else if part.Type == dto.ContentTypeInputAudio {
if part.GetInputAudio().Data == "" {
return nil, fmt.Errorf("only base64 audio is supported in gemini")
}
base64String, err := service.DecodeBase64AudioData(part.GetInputAudio().Data)
if err != nil {
return nil, fmt.Errorf("decode base64 audio data failed: %s", err.Error())
}
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: "audio/" + part.GetInputAudio().Format,
Data: base64String,
},
})
}
}
content.Parts = parts
// there's no assistant role in gemini and API shall vomit if Role is not user or model
if content.Role == "assistant" {
content.Role = "model"
}
if len(content.Parts) > 0 {
geminiRequest.Contents = append(geminiRequest.Contents, content)
}
}
if len(system_content) > 0 {
geminiRequest.SystemInstructions = &GeminiChatContent{
Parts: []GeminiPart{
{
Text: strings.Join(system_content, "\n"),
},
},
}
}
return &geminiRequest, nil
}
// Helper function to get a list of supported MIME types for error messages
func getSupportedMimeTypesList() []string {
keys := make([]string, 0, len(geminiSupportedMimeTypes))
for k := range geminiSupportedMimeTypes {
keys = append(keys, k)
}
return keys
}
// cleanFunctionParameters recursively removes unsupported fields from Gemini function parameters.
func cleanFunctionParameters(params interface{}) interface{} {
if params == nil {
return nil
}
switch v := params.(type) {
case map[string]interface{}:
// Create a copy to avoid modifying the original
cleanedMap := make(map[string]interface{})
for k, val := range v {
cleanedMap[k] = val
}
// Remove unsupported root-level fields
delete(cleanedMap, "default")
delete(cleanedMap, "exclusiveMaximum")
delete(cleanedMap, "exclusiveMinimum")
delete(cleanedMap, "$schema")
delete(cleanedMap, "additionalProperties")
// Check and clean 'format' for string types
if propType, typeExists := cleanedMap["type"].(string); typeExists && propType == "string" {
if formatValue, formatExists := cleanedMap["format"].(string); formatExists {
if formatValue != "enum" && formatValue != "date-time" {
delete(cleanedMap, "format")
}
}
}
// Clean properties
if props, ok := cleanedMap["properties"].(map[string]interface{}); ok && props != nil {
cleanedProps := make(map[string]interface{})
for propName, propValue := range props {
cleanedProps[propName] = cleanFunctionParameters(propValue)
}
cleanedMap["properties"] = cleanedProps
}
// Recursively clean items in arrays
if items, ok := cleanedMap["items"].(map[string]interface{}); ok && items != nil {
cleanedMap["items"] = cleanFunctionParameters(items)
}
// Also handle items if it's an array of schemas
if itemsArray, ok := cleanedMap["items"].([]interface{}); ok {
cleanedItemsArray := make([]interface{}, len(itemsArray))
for i, item := range itemsArray {
cleanedItemsArray[i] = cleanFunctionParameters(item)
}
cleanedMap["items"] = cleanedItemsArray
}
// Recursively clean other schema composition keywords
for _, field := range []string{"allOf", "anyOf", "oneOf"} {
if nested, ok := cleanedMap[field].([]interface{}); ok {
cleanedNested := make([]interface{}, len(nested))
for i, item := range nested {
cleanedNested[i] = cleanFunctionParameters(item)
}
cleanedMap[field] = cleanedNested
}
}
// Recursively clean patternProperties
if patternProps, ok := cleanedMap["patternProperties"].(map[string]interface{}); ok {
cleanedPatternProps := make(map[string]interface{})
for pattern, schema := range patternProps {
cleanedPatternProps[pattern] = cleanFunctionParameters(schema)
}
cleanedMap["patternProperties"] = cleanedPatternProps
}
// Recursively clean definitions
if definitions, ok := cleanedMap["definitions"].(map[string]interface{}); ok {
cleanedDefinitions := make(map[string]interface{})
for defName, defSchema := range definitions {
cleanedDefinitions[defName] = cleanFunctionParameters(defSchema)
}
cleanedMap["definitions"] = cleanedDefinitions
}
// Recursively clean $defs (newer JSON Schema draft)
if defs, ok := cleanedMap["$defs"].(map[string]interface{}); ok {
cleanedDefs := make(map[string]interface{})
for defName, defSchema := range defs {
cleanedDefs[defName] = cleanFunctionParameters(defSchema)
}
cleanedMap["$defs"] = cleanedDefs
}
// Clean conditional keywords
for _, field := range []string{"if", "then", "else", "not"} {
if nested, ok := cleanedMap[field]; ok {
cleanedMap[field] = cleanFunctionParameters(nested)
}
}
return cleanedMap
case []interface{}:
// Handle arrays of schemas
cleanedArray := make([]interface{}, len(v))
for i, item := range v {
cleanedArray[i] = cleanFunctionParameters(item)
}
return cleanedArray
default:
// Not a map or array, return as is (e.g., could be a primitive)
return params
}
}
func removeAdditionalPropertiesWithDepth(schema interface{}, depth int) interface{} {
if depth >= 5 {
return schema
}
v, ok := schema.(map[string]interface{})
if !ok || len(v) == 0 {
return schema
}
// 删除所有的title字段
delete(v, "title")
delete(v, "$schema")
// 如果type不为object和array则直接返回
if typeVal, exists := v["type"]; !exists || (typeVal != "object" && typeVal != "array") {
return schema
}
switch v["type"] {
case "object":
delete(v, "additionalProperties")
// 处理 properties
if properties, ok := v["properties"].(map[string]interface{}); ok {
for key, value := range properties {
properties[key] = removeAdditionalPropertiesWithDepth(value, depth+1)
}
}
for _, field := range []string{"allOf", "anyOf", "oneOf"} {
if nested, ok := v[field].([]interface{}); ok {
for i, item := range nested {
nested[i] = removeAdditionalPropertiesWithDepth(item, depth+1)
}
}
}
case "array":
if items, ok := v["items"].(map[string]interface{}); ok {
v["items"] = removeAdditionalPropertiesWithDepth(items, depth+1)
}
}
return v
}
func unescapeString(s string) (string, error) {
var result []rune
escaped := false
i := 0
for i < len(s) {
r, size := utf8.DecodeRuneInString(s[i:]) // 正确解码UTF-8字符
if r == utf8.RuneError {
return "", fmt.Errorf("invalid UTF-8 encoding")
}
if escaped {
// 如果是转义符后的字符,检查其类型
switch r {
case '"':
result = append(result, '"')
case '\\':
result = append(result, '\\')
case '/':
result = append(result, '/')
case 'b':
result = append(result, '\b')
case 'f':
result = append(result, '\f')
case 'n':
result = append(result, '\n')
case 'r':
result = append(result, '\r')
case 't':
result = append(result, '\t')
case '\'':
result = append(result, '\'')
default:
// 如果遇到一个非法的转义字符,直接按原样输出
result = append(result, '\\', r)
}
escaped = false
} else {
if r == '\\' {
escaped = true // 记录反斜杠作为转义符
} else {
result = append(result, r)
}
}
i += size // 移动到下一个字符
}
return string(result), nil
}
func unescapeMapOrSlice(data interface{}) interface{} {
switch v := data.(type) {
case map[string]interface{}:
for k, val := range v {
v[k] = unescapeMapOrSlice(val)
}
case []interface{}:
for i, val := range v {
v[i] = unescapeMapOrSlice(val)
}
case string:
if unescaped, err := unescapeString(v); err != nil {
return v
} else {
return unescaped
}
}
return data
}
func getResponseToolCall(item *GeminiPart) *dto.ToolCallResponse {
var argsBytes []byte
var err error
if result, ok := item.FunctionCall.Arguments.(map[string]interface{}); ok {
argsBytes, err = json.Marshal(unescapeMapOrSlice(result))
} else {
argsBytes, err = json.Marshal(item.FunctionCall.Arguments)
}
if err != nil {
return nil
}
return &dto.ToolCallResponse{
ID: fmt.Sprintf("call_%s", common.GetUUID()),
Type: "function",
Function: dto.FunctionResponse{
Arguments: string(argsBytes),
Name: item.FunctionCall.FunctionName,
},
}
}
func responseGeminiChat2OpenAI(c *gin.Context, response *GeminiChatResponse) *dto.OpenAITextResponse {
fullTextResponse := dto.OpenAITextResponse{
Id: helper.GetResponseID(c),
Object: "chat.completion",
Created: common.GetTimestamp(),
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
}
isToolCall := false
for _, candidate := range response.Candidates {
choice := dto.OpenAITextResponseChoice{
Index: int(candidate.Index),
Message: dto.Message{
Role: "assistant",
Content: "",
},
FinishReason: constant.FinishReasonStop,
}
if len(candidate.Content.Parts) > 0 {
var texts []string
var toolCalls []dto.ToolCallResponse
for _, part := range candidate.Content.Parts {
if part.FunctionCall != nil {
choice.FinishReason = constant.FinishReasonToolCalls
if call := getResponseToolCall(&part); call != nil {
toolCalls = append(toolCalls, *call)
}
} else if part.Thought {
choice.Message.ReasoningContent = part.Text
} else {
if part.ExecutableCode != nil {
texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```")
} else if part.CodeExecutionResult != nil {
texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```")
} else {
// 过滤掉空行
if part.Text != "\n" {
texts = append(texts, part.Text)
}
}
}
}
if len(toolCalls) > 0 {
choice.Message.SetToolCalls(toolCalls)
isToolCall = true
}
choice.Message.SetStringContent(strings.Join(texts, "\n"))
}
if candidate.FinishReason != nil {
switch *candidate.FinishReason {
case "STOP":
choice.FinishReason = constant.FinishReasonStop
case "MAX_TOKENS":
choice.FinishReason = constant.FinishReasonLength
default:
choice.FinishReason = constant.FinishReasonContentFilter
}
}
if isToolCall {
choice.FinishReason = constant.FinishReasonToolCalls
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
return &fullTextResponse
}
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) (*dto.ChatCompletionsStreamResponse, bool, bool) {
choices := make([]dto.ChatCompletionsStreamResponseChoice, 0, len(geminiResponse.Candidates))
isStop := false
hasImage := false
for _, candidate := range geminiResponse.Candidates {
if candidate.FinishReason != nil && *candidate.FinishReason == "STOP" {
isStop = true
candidate.FinishReason = nil
}
choice := dto.ChatCompletionsStreamResponseChoice{
Index: int(candidate.Index),
Delta: dto.ChatCompletionsStreamResponseChoiceDelta{
Role: "assistant",
},
}
var texts []string
isTools := false
isThought := false
if candidate.FinishReason != nil {
// p := GeminiConvertFinishReason(*candidate.FinishReason)
switch *candidate.FinishReason {
case "STOP":
choice.FinishReason = &constant.FinishReasonStop
case "MAX_TOKENS":
choice.FinishReason = &constant.FinishReasonLength
default:
choice.FinishReason = &constant.FinishReasonContentFilter
}
}
for _, part := range candidate.Content.Parts {
if part.InlineData != nil {
if strings.HasPrefix(part.InlineData.MimeType, "image") {
imgText := "![image](data:" + part.InlineData.MimeType + ";base64," + part.InlineData.Data + ")"
texts = append(texts, imgText)
hasImage = true
}
} else if part.FunctionCall != nil {
isTools = true
if call := getResponseToolCall(&part); call != nil {
call.SetIndex(len(choice.Delta.ToolCalls))
choice.Delta.ToolCalls = append(choice.Delta.ToolCalls, *call)
}
} else if part.Thought {
isThought = true
texts = append(texts, part.Text)
} else {
if part.ExecutableCode != nil {
texts = append(texts, "```"+part.ExecutableCode.Language+"\n"+part.ExecutableCode.Code+"\n```\n")
} else if part.CodeExecutionResult != nil {
texts = append(texts, "```output\n"+part.CodeExecutionResult.Output+"\n```\n")
} else {
if part.Text != "\n" {
texts = append(texts, part.Text)
}
}
}
}
if isThought {
choice.Delta.SetReasoningContent(strings.Join(texts, "\n"))
} else {
choice.Delta.SetContentString(strings.Join(texts, "\n"))
}
if isTools {
choice.FinishReason = &constant.FinishReasonToolCalls
}
choices = append(choices, choice)
}
var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Choices = choices
return &response, isStop, hasImage
}
func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
// responseText := ""
id := helper.GetResponseID(c)
createAt := common.GetTimestamp()
var usage = &dto.Usage{}
var imageCount int
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var geminiResponse GeminiChatResponse
err := common.UnmarshalJsonStr(data, &geminiResponse)
if err != nil {
common.LogError(c, "error unmarshalling stream response: "+err.Error())
return false
}
response, isStop, hasImage := streamResponseGeminiChat2OpenAI(&geminiResponse)
if hasImage {
imageCount++
}
response.Id = id
response.Created = createAt
response.Model = info.UpstreamModelName
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
if detail.Modality == "AUDIO" {
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
} else if detail.Modality == "TEXT" {
usage.PromptTokensDetails.TextTokens = detail.TokenCount
}
}
}
err = helper.ObjectData(c, response)
if err != nil {
common.LogError(c, err.Error())
}
if isStop {
response := helper.GenerateStopResponse(id, createAt, info.UpstreamModelName, constant.FinishReasonStop)
helper.ObjectData(c, response)
}
return true
})
var response *dto.ChatCompletionsStreamResponse
if imageCount != 0 {
if usage.CompletionTokens == 0 {
usage.CompletionTokens = imageCount * 258
}
}
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
if info.ShouldIncludeUsage {
response = helper.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
err := helper.ObjectData(c, response)
if err != nil {
common.SysError("send final response failed: " + err.Error())
}
}
helper.Done(c)
//resp.Body.Close()
return usage, nil
}
func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
common.CloseResponseBodyGracefully(resp)
if common.DebugEnabled {
println(string(responseBody))
}
var geminiResponse GeminiChatResponse
err = common.Unmarshal(responseBody, &geminiResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
if len(geminiResponse.Candidates) == 0 {
return nil, types.NewError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody)
}
fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
fullTextResponse.Model = info.UpstreamModelName
usage := dto.Usage{
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
}
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
if detail.Modality == "AUDIO" {
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
} else if detail.Modality == "TEXT" {
usage.PromptTokensDetails.TextTokens = detail.TokenCount
}
}
fullTextResponse.Usage = usage
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
c.Writer.Write(jsonResponse)
return &usage, nil
}
func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
defer common.CloseResponseBodyGracefully(resp)
responseBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody)
}
var geminiResponse GeminiEmbeddingResponse
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
}
// convert to openai format response
openAIResponse := dto.OpenAIEmbeddingResponse{
Object: "list",
Data: []dto.OpenAIEmbeddingResponseItem{
{
Object: "embedding",
Embedding: geminiResponse.Embedding.Values,
Index: 0,
},
},
Model: info.UpstreamModelName,
}
// calculate usage
// https://ai.google.dev/gemini-api/docs/pricing?hl=zh-cn#text-embedding-004
// Google has not yet clarified how embedding models will be billed
// refer to openai billing method to use input tokens billing
// https://platform.openai.com/docs/guides/embeddings#what-are-embeddings
usage := &dto.Usage{
PromptTokens: info.PromptTokens,
CompletionTokens: 0,
TotalTokens: info.PromptTokens,
}
openAIResponse.Usage = *usage
jsonResponse, jsonErr := common.Marshal(openAIResponse)
if jsonErr != nil {
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
}
common.IOCopyBytesGracefully(c, resp, jsonResponse)
return usage, nil
}

View File

@@ -0,0 +1,136 @@
package jimeng
import (
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/types"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/?Action=CVProcess&Version=2022-08-31", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
return errors.New("not implemented")
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
type LogoInfo struct {
AddLogo bool `json:"add_logo,omitempty"`
Position int `json:"position,omitempty"`
Language int `json:"language,omitempty"`
Opacity float64 `json:"opacity,omitempty"`
LogoTextContent string `json:"logo_text_content,omitempty"`
}
type imageRequestPayload struct {
ReqKey string `json:"req_key"` // Service identifier, fixed value: jimeng_high_aes_general_v21_L
Prompt string `json:"prompt"` // Prompt for image generation, supports both Chinese and English
Seed int64 `json:"seed,omitempty"` // Random seed, default -1 (random)
Width int `json:"width,omitempty"` // Image width, default 512, range [256, 768]
Height int `json:"height,omitempty"` // Image height, default 512, range [256, 768]
UsePreLLM bool `json:"use_pre_llm,omitempty"` // Enable text expansion, default true
UseSR bool `json:"use_sr,omitempty"` // Enable super resolution, default true
ReturnURL bool `json:"return_url,omitempty"` // Whether to return image URL (valid for 24 hours)
LogoInfo LogoInfo `json:"logo_info,omitempty"` // Watermark information
ImageUrls []string `json:"image_urls,omitempty"` // Image URLs for input
BinaryData []string `json:"binary_data_base64,omitempty"` // Base64 encoded binary data
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
payload := imageRequestPayload{
ReqKey: request.Model,
Prompt: request.Prompt,
}
if request.ResponseFormat == "" || request.ResponseFormat == "url" {
payload.ReturnURL = true // Default to returning image URLs
}
if len(request.ExtraFields) > 0 {
if err := json.Unmarshal(request.ExtraFields, &payload); err != nil {
return nil, fmt.Errorf("failed to unmarshal extra fields: %w", err)
}
}
return payload, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
fullRequestURL, err := a.GetRequestURL(info)
if err != nil {
return nil, fmt.Errorf("get request url failed: %w", err)
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return nil, fmt.Errorf("new request failed: %w", err)
}
err = Sign(c, req, info.ApiKey)
if err != nil {
return nil, fmt.Errorf("setup request header failed: %w", err)
}
resp, err := channel.DoRequest(c, req, info)
if err != nil {
return nil, fmt.Errorf("do request failed: %w", err)
}
return resp, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == relayconstant.RelayModeImagesGenerations {
usage, err = jimengImageHandler(c, resp, info)
} else if info.IsStream {
usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
usage, err = openai.OpenaiHandler(c, info, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,9 @@
package jimeng
const (
ChannelName = "jimeng"
)
var ModelList = []string{
"jimeng_high_aes_general_v21_L",
}

View File

@@ -0,0 +1,89 @@
package jimeng
import (
"encoding/json"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/types"
"github.com/gin-gonic/gin"
)
type ImageResponse struct {
Code int `json:"code"`
Message string `json:"message"`
Data struct {
BinaryDataBase64 []string `json:"binary_data_base64"`
ImageUrls []string `json:"image_urls"`
RephraseResult string `json:"rephraser_result"`
RequestID string `json:"request_id"`
// Other fields are omitted for brevity
} `json:"data"`
RequestID string `json:"request_id"`
Status int `json:"status"`
TimeElapsed string `json:"time_elapsed"`
}
func responseJimeng2OpenAIImage(_ *gin.Context, response *ImageResponse, info *relaycommon.RelayInfo) *dto.ImageResponse {
imageResponse := dto.ImageResponse{
Created: info.StartTime.Unix(),
}
for _, base64Data := range response.Data.BinaryDataBase64 {
imageResponse.Data = append(imageResponse.Data, dto.ImageData{
B64Json: base64Data,
})
}
for _, imageUrl := range response.Data.ImageUrls {
imageResponse.Data = append(imageResponse.Data, dto.ImageData{
Url: imageUrl,
})
}
return &imageResponse
}
// jimengImageHandler handles the Jimeng image generation response
func jimengImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
var jimengResponse ImageResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &jimengResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
// Check if the response indicates an error
if jimengResponse.Code != 10000 {
return nil, types.WithOpenAIError(types.OpenAIError{
Message: jimengResponse.Message,
Type: "jimeng_error",
Param: "",
Code: fmt.Sprintf("%d", jimengResponse.Code),
}, resp.StatusCode)
}
// Convert Jimeng response to OpenAI format
fullTextResponse := responseJimeng2OpenAIImage(c, &jimengResponse, info)
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
return &dto.Usage{}, nil
}

View File

@@ -0,0 +1,176 @@
package jimeng
import (
"bytes"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"net/url"
"one-api/common"
"sort"
"strings"
"time"
)
// SignRequestForJimeng 对即梦 API 请求进行签名,支持 http.Request 或 header+url+body 方式
//func SignRequestForJimeng(req *http.Request, accessKey, secretKey string) error {
// var bodyBytes []byte
// var err error
//
// if req.Body != nil {
// bodyBytes, err = io.ReadAll(req.Body)
// if err != nil {
// return fmt.Errorf("read request body failed: %w", err)
// }
// _ = req.Body.Close()
// req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // rewind
// } else {
// bodyBytes = []byte{}
// }
//
// return signJimengHeaders(&req.Header, req.Method, req.URL, bodyBytes, accessKey, secretKey)
//}
const HexPayloadHashKey = "HexPayloadHash"
func SetPayloadHash(c *gin.Context, req any) error {
body, err := json.Marshal(req)
if err != nil {
return err
}
common.LogInfo(c, fmt.Sprintf("SetPayloadHash body: %s", body))
payloadHash := sha256.Sum256(body)
hexPayloadHash := hex.EncodeToString(payloadHash[:])
c.Set(HexPayloadHashKey, hexPayloadHash)
return nil
}
func getPayloadHash(c *gin.Context) string {
return c.GetString(HexPayloadHashKey)
}
func Sign(c *gin.Context, req *http.Request, apiKey string) error {
header := req.Header
var bodyBytes []byte
var err error
if req.Body != nil {
bodyBytes, err = io.ReadAll(req.Body)
if err != nil {
return err
}
_ = req.Body.Close()
req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // Rewind
}
payloadHash := sha256.Sum256(bodyBytes)
hexPayloadHash := hex.EncodeToString(payloadHash[:])
method := c.Request.Method
u := req.URL
keyParts := strings.Split(apiKey, "|")
if len(keyParts) != 2 {
return errors.New("invalid api key format for jimeng: expected 'ak|sk'")
}
accessKey := strings.TrimSpace(keyParts[0])
secretKey := strings.TrimSpace(keyParts[1])
t := time.Now().UTC()
xDate := t.Format("20060102T150405Z")
shortDate := t.Format("20060102")
host := u.Host
header.Set("Host", host)
header.Set("X-Date", xDate)
header.Set("X-Content-Sha256", hexPayloadHash)
// Sort and encode query parameters to create canonical query string
queryParams := u.Query()
sortedKeys := make([]string, 0, len(queryParams))
for k := range queryParams {
sortedKeys = append(sortedKeys, k)
}
sort.Strings(sortedKeys)
var queryParts []string
for _, k := range sortedKeys {
values := queryParams[k]
sort.Strings(values)
for _, v := range values {
queryParts = append(queryParts, fmt.Sprintf("%s=%s", url.QueryEscape(k), url.QueryEscape(v)))
}
}
canonicalQueryString := strings.Join(queryParts, "&")
headersToSign := map[string]string{
"host": host,
"x-date": xDate,
"x-content-sha256": hexPayloadHash,
}
if header.Get("Content-Type") == "" {
header.Set("Content-Type", "application/json")
}
headersToSign["content-type"] = header.Get("Content-Type")
var signedHeaderKeys []string
for k := range headersToSign {
signedHeaderKeys = append(signedHeaderKeys, k)
}
sort.Strings(signedHeaderKeys)
var canonicalHeaders strings.Builder
for _, k := range signedHeaderKeys {
canonicalHeaders.WriteString(k)
canonicalHeaders.WriteString(":")
canonicalHeaders.WriteString(strings.TrimSpace(headersToSign[k]))
canonicalHeaders.WriteString("\n")
}
signedHeaders := strings.Join(signedHeaderKeys, ";")
canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
method,
u.Path,
canonicalQueryString,
canonicalHeaders.String(),
signedHeaders,
hexPayloadHash,
)
hashedCanonicalRequest := sha256.Sum256([]byte(canonicalRequest))
hexHashedCanonicalRequest := hex.EncodeToString(hashedCanonicalRequest[:])
region := "cn-north-1"
serviceName := "cv"
credentialScope := fmt.Sprintf("%s/%s/%s/request", shortDate, region, serviceName)
stringToSign := fmt.Sprintf("HMAC-SHA256\n%s\n%s\n%s",
xDate,
credentialScope,
hexHashedCanonicalRequest,
)
kDate := hmacSHA256([]byte(secretKey), []byte(shortDate))
kRegion := hmacSHA256(kDate, []byte(region))
kService := hmacSHA256(kRegion, []byte(serviceName))
kSigning := hmacSHA256(kService, []byte("request"))
signature := hex.EncodeToString(hmacSHA256(kSigning, []byte(stringToSign)))
authorization := fmt.Sprintf("HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
accessKey,
credentialScope,
signedHeaders,
signature,
)
header.Set("Authorization", authorization)
return nil
}
// hmacSHA256 计算 HMAC-SHA256
func hmacSHA256(key []byte, data []byte) []byte {
h := hmac.New(sha256.New, key)
h.Write(data)
return h.Sum(nil)
}

View File

@@ -0,0 +1,92 @@
package jina
import (
"errors"
"fmt"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/common_handler"
"one-api/relay/constant"
"one-api/types"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayMode == constant.RelayModeRerank {
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
} else if info.RelayMode == constant.RelayModeEmbeddings {
return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil
}
return "", errors.New("invalid relay mode")
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
return nil
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return request, nil
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return request, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.RelayMode == constant.RelayModeRerank {
usage, err = common_handler.RerankHandler(c, info, resp)
} else if info.RelayMode == constant.RelayModeEmbeddings {
usage, err = openai.OpenaiHandler(c, info, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,9 @@
package jina
var ModelList = []string{
"jina-clip-v1",
"jina-reranker-v2-base-multilingual",
"jina-reranker-m0",
}
var ChannelName = "jina"

View File

@@ -0,0 +1 @@
package jina

View File

@@ -0,0 +1,9 @@
package lingyiwanwu
// https://platform.lingyiwanwu.com/docs
var ModelList = []string{
"yi-large", "yi-medium", "yi-vision", "yi-medium-200k", "yi-spark", "yi-large-rag", "yi-large-turbo", "yi-large-preview", "yi-large-rag-preview",
}
var ChannelName = "lingyiwanwu"

View File

@@ -0,0 +1,13 @@
package minimax
// https://www.minimaxi.com/document/guides/chat-model/V2?id=65e0736ab2845de20908e2dd
var ModelList = []string{
"abab6.5-chat",
"abab6.5s-chat",
"abab6-chat",
"abab5.5-chat",
"abab5.5s-chat",
}
var ChannelName = "minimax"

View File

@@ -0,0 +1,10 @@
package minimax
import (
"fmt"
relaycommon "one-api/relay/common"
)
func GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1/text/chatcompletion_v2", info.BaseUrl), nil
}

View File

@@ -0,0 +1,88 @@
package mistral
import (
"errors"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/types"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return requestOpenAI2Mistral(request), nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
usage, err = openai.OpenaiHandler(c, info, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,12 @@
package mistral
var ModelList = []string{
"open-mistral-7b",
"open-mixtral-8x7b",
"mistral-small-latest",
"mistral-medium-latest",
"mistral-large-latest",
"mistral-embed",
}
var ChannelName = "mistral"

View File

@@ -0,0 +1,78 @@
package mistral
import (
"one-api/common"
"one-api/dto"
"regexp"
)
var mistralToolCallIdRegexp = regexp.MustCompile("^[a-zA-Z0-9]{9}$")
func requestOpenAI2Mistral(request *dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
messages := make([]dto.Message, 0, len(request.Messages))
idMap := make(map[string]string)
for _, message := range request.Messages {
// 1. tool_calls.id
toolCalls := message.ParseToolCalls()
if toolCalls != nil {
for i := range toolCalls {
if !mistralToolCallIdRegexp.MatchString(toolCalls[i].ID) {
if newId, ok := idMap[toolCalls[i].ID]; ok {
toolCalls[i].ID = newId
} else {
newId, err := common.GenerateRandomCharsKey(9)
if err == nil {
idMap[toolCalls[i].ID] = newId
toolCalls[i].ID = newId
}
}
}
}
message.SetToolCalls(toolCalls)
}
// 2. tool_call_id
if message.ToolCallId != "" {
if newId, ok := idMap[message.ToolCallId]; ok {
message.ToolCallId = newId
} else {
if !mistralToolCallIdRegexp.MatchString(message.ToolCallId) {
newId, err := common.GenerateRandomCharsKey(9)
if err == nil {
idMap[message.ToolCallId] = newId
message.ToolCallId = newId
}
}
}
}
mediaMessages := message.ParseContent()
if message.Role == "assistant" && message.ToolCalls != nil && message.Content == "" {
mediaMessages = []dto.MediaContent{}
}
for j, mediaMessage := range mediaMessages {
if mediaMessage.Type == dto.ContentTypeImageURL {
imageUrl := mediaMessage.GetImageMedia()
mediaMessage.ImageUrl = imageUrl.Url
mediaMessages[j] = mediaMessage
}
}
message.SetMediaContent(mediaMessages)
messages = append(messages, dto.Message{
Role: message.Role,
Content: message.Content,
ToolCalls: message.ToolCalls,
ToolCallId: message.ToolCallId,
})
}
return &dto.GeneralOpenAIRequest{
Model: request.Model,
Stream: request.Stream,
Messages: messages,
Temperature: request.Temperature,
TopP: request.TopP,
MaxTokens: request.MaxTokens,
Tools: request.Tools,
ToolChoice: request.ToolChoice,
}
}

View File

@@ -0,0 +1,106 @@
package mokaai
import (
"errors"
"fmt"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return request, nil
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
// https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
suffix := "chat/"
if strings.HasPrefix(info.UpstreamModelName, "m3e") {
suffix = "embeddings"
}
fullRequestURL := fmt.Sprintf("%s/%s", info.BaseUrl, suffix)
return fullRequestURL, nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
return nil
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
switch info.RelayMode {
case constant.RelayModeEmbeddings:
baiduEmbeddingRequest := embeddingRequestOpenAI2Moka(*request)
return baiduEmbeddingRequest, nil
default:
return nil, errors.New("not implemented")
}
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayMode {
case constant.RelayModeEmbeddings:
return mokaEmbeddingHandler(c, info, resp)
default:
// err, usage = mokaHandler(c, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,9 @@
package mokaai
var ModelList = []string{
"m3e-large",
"m3e-base",
"m3e-small",
}
var ChannelName = "mokaai"

View File

@@ -0,0 +1,82 @@
package mokaai
import (
"encoding/json"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/types"
"github.com/gin-gonic/gin"
)
func embeddingRequestOpenAI2Moka(request dto.GeneralOpenAIRequest) *dto.EmbeddingRequest {
var input []string // Change input to []string
switch v := request.Input.(type) {
case string:
input = []string{v} // Convert string to []string
case []string:
input = v // Already a []string, no conversion needed
case []interface{}:
for _, part := range v {
if str, ok := part.(string); ok {
input = append(input, str) // Append each string to the slice
}
}
}
return &dto.EmbeddingRequest{
Input: input,
Model: request.Model,
}
}
func embeddingResponseMoka2OpenAI(response *dto.EmbeddingResponse) *dto.OpenAIEmbeddingResponse {
openAIEmbeddingResponse := dto.OpenAIEmbeddingResponse{
Object: "list",
Data: make([]dto.OpenAIEmbeddingResponseItem, 0, len(response.Data)),
Model: "baidu-embedding",
Usage: response.Usage,
}
for _, item := range response.Data {
openAIEmbeddingResponse.Data = append(openAIEmbeddingResponse.Data, dto.OpenAIEmbeddingResponseItem{
Object: item.Object,
Index: item.Index,
Embedding: item.Embedding,
})
}
return &openAIEmbeddingResponse
}
func mokaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var baiduResponse dto.EmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &baiduResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
// if baiduResponse.ErrorMsg != "" {
// return &dto.OpenAIErrorWithStatusCode{
// Error: dto.OpenAIError{
// Type: "baidu_error",
// Param: "",
// },
// StatusCode: resp.StatusCode,
// }, nil
// }
fullTextResponse := embeddingResponseMoka2OpenAI(&baiduResponse)
jsonResponse, err := common.Marshal(fullTextResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
common.IOCopyBytesGracefully(c, resp, jsonResponse)
return &fullTextResponse.Usage, nil
}

View File

@@ -0,0 +1,9 @@
package moonshot
var ModelList = []string{
"moonshot-v1-8k",
"moonshot-v1-32k",
"moonshot-v1-128k",
}
var ChannelName = "moonshot"

View File

@@ -0,0 +1,97 @@
package ollama
import (
"errors"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/types"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
switch info.RelayMode {
case relayconstant.RelayModeEmbeddings:
return info.BaseUrl + "/api/embed", nil
default:
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return requestOpenAI2Ollama(*request)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return requestOpenAI2Embeddings(request), nil
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
if info.RelayMode == relayconstant.RelayModeEmbeddings {
usage, err = ollamaEmbeddingHandler(c, info, resp)
} else {
usage, err = openai.OpenaiHandler(c, info, resp)
}
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,7 @@
package ollama
var ModelList = []string{
"llama3-7b",
}
var ChannelName = "ollama"

View File

@@ -0,0 +1,45 @@
package ollama
import "one-api/dto"
type OllamaRequest struct {
Model string `json:"model,omitempty"`
Messages []dto.Message `json:"messages,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
Seed float64 `json:"seed,omitempty"`
Topp float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Stop any `json:"stop,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
Tools []dto.ToolCallRequest `json:"tools,omitempty"`
ResponseFormat any `json:"response_format,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
Suffix any `json:"suffix,omitempty"`
StreamOptions *dto.StreamOptions `json:"stream_options,omitempty"`
Prompt any `json:"prompt,omitempty"`
}
type Options struct {
Seed int `json:"seed,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
TopP float64 `json:"top_p,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
NumCtx int `json:"num_ctx,omitempty"`
}
type OllamaEmbeddingRequest struct {
Model string `json:"model,omitempty"`
Input []string `json:"input"`
Options *Options `json:"options,omitempty"`
}
type OllamaEmbeddingResponse struct {
Error string `json:"error,omitempty"`
Model string `json:"model"`
Embedding [][]float64 `json:"embeddings,omitempty"`
}

View File

@@ -0,0 +1,132 @@
package ollama
import (
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
func requestOpenAI2Ollama(request dto.GeneralOpenAIRequest) (*OllamaRequest, error) {
messages := make([]dto.Message, 0, len(request.Messages))
for _, message := range request.Messages {
if !message.IsStringContent() {
mediaMessages := message.ParseContent()
for j, mediaMessage := range mediaMessages {
if mediaMessage.Type == dto.ContentTypeImageURL {
imageUrl := mediaMessage.GetImageMedia()
// check if not base64
if strings.HasPrefix(imageUrl.Url, "http") {
fileData, err := service.GetFileBase64FromUrl(imageUrl.Url)
if err != nil {
return nil, err
}
imageUrl.Url = fmt.Sprintf("data:%s;base64,%s", fileData.MimeType, fileData.Base64Data)
}
mediaMessage.ImageUrl = imageUrl
mediaMessages[j] = mediaMessage
}
}
message.SetMediaContent(mediaMessages)
}
messages = append(messages, dto.Message{
Role: message.Role,
Content: message.Content,
ToolCalls: message.ToolCalls,
ToolCallId: message.ToolCallId,
})
}
str, ok := request.Stop.(string)
var Stop []string
if ok {
Stop = []string{str}
} else {
Stop, _ = request.Stop.([]string)
}
return &OllamaRequest{
Model: request.Model,
Messages: messages,
Stream: request.Stream,
Temperature: request.Temperature,
Seed: request.Seed,
Topp: request.TopP,
TopK: request.TopK,
Stop: Stop,
Tools: request.Tools,
MaxTokens: request.MaxTokens,
ResponseFormat: request.ResponseFormat,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
Prompt: request.Prompt,
StreamOptions: request.StreamOptions,
Suffix: request.Suffix,
}, nil
}
func requestOpenAI2Embeddings(request dto.EmbeddingRequest) *OllamaEmbeddingRequest {
return &OllamaEmbeddingRequest{
Model: request.Model,
Input: request.ParseInput(),
Options: &Options{
Seed: int(request.Seed),
Temperature: request.Temperature,
TopP: request.TopP,
FrequencyPenalty: request.FrequencyPenalty,
PresencePenalty: request.PresencePenalty,
},
}
}
func ollamaEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var ollamaEmbeddingResponse OllamaEmbeddingResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
common.CloseResponseBodyGracefully(resp)
err = common.Unmarshal(responseBody, &ollamaEmbeddingResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
if ollamaEmbeddingResponse.Error != "" {
return nil, types.NewError(fmt.Errorf("ollama error: %s", ollamaEmbeddingResponse.Error), types.ErrorCodeBadResponseBody)
}
flattenedEmbeddings := flattenEmbeddings(ollamaEmbeddingResponse.Embedding)
data := make([]dto.OpenAIEmbeddingResponseItem, 0, 1)
data = append(data, dto.OpenAIEmbeddingResponseItem{
Embedding: flattenedEmbeddings,
Object: "embedding",
})
usage := &dto.Usage{
TotalTokens: info.PromptTokens,
CompletionTokens: 0,
PromptTokens: info.PromptTokens,
}
embeddingResponse := &dto.OpenAIEmbeddingResponse{
Object: "list",
Data: data,
Model: info.UpstreamModelName,
Usage: *usage,
}
doResponseBody, err := common.Marshal(embeddingResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
common.IOCopyBytesGracefully(c, resp, doResponseBody)
return usage, nil
}
func flattenEmbeddings(embeddings [][]float64) []float64 {
flattened := []float64{}
for _, row := range embeddings {
flattened = append(flattened, row...)
}
return flattened
}

View File

@@ -0,0 +1,491 @@
package openai
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/textproto"
"one-api/constant"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/ai360"
"one-api/relay/channel/lingyiwanwu"
"one-api/relay/channel/minimax"
"one-api/relay/channel/moonshot"
"one-api/relay/channel/openrouter"
"one-api/relay/channel/xinference"
relaycommon "one-api/relay/common"
"one-api/relay/common_handler"
relayconstant "one-api/relay/constant"
"one-api/service"
"one-api/types"
"path/filepath"
"strings"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
ChannelType int
ResponseFormat string
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
//if !strings.Contains(request.Model, "claude") {
// return nil, fmt.Errorf("you are using openai channel type with path /v1/messages, only claude model supported convert, but got %s", request.Model)
//}
aiRequest, err := service.ClaudeToOpenAIRequest(*request, info)
if err != nil {
return nil, err
}
if info.SupportStreamOptions {
aiRequest.StreamOptions = &dto.StreamOptions{
IncludeUsage: true,
}
}
return a.ConvertOpenAIRequest(c, info, aiRequest)
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
a.ChannelType = info.ChannelType
// initialize ThinkingContentInfo when thinking_to_content is enabled
if info.ChannelSetting.ThinkingToContent {
info.ThinkingContentInfo = relaycommon.ThinkingContentInfo{
IsFirstThinkingContent: true,
SendLastThinkingContent: false,
HasSentThinkingContent: false,
}
}
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayFormat == relaycommon.RelayFormatClaude {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
}
if info.RelayMode == relayconstant.RelayModeRealtime {
if strings.HasPrefix(info.BaseUrl, "https://") {
baseUrl := strings.TrimPrefix(info.BaseUrl, "https://")
baseUrl = "wss://" + baseUrl
info.BaseUrl = baseUrl
} else if strings.HasPrefix(info.BaseUrl, "http://") {
baseUrl := strings.TrimPrefix(info.BaseUrl, "http://")
baseUrl = "ws://" + baseUrl
info.BaseUrl = baseUrl
}
}
switch info.ChannelType {
case constant.ChannelTypeAzure:
apiVersion := info.ApiVersion
if apiVersion == "" {
apiVersion = constant.AzureDefaultAPIVersion
}
// https://learn.microsoft.com/en-us/azure/cognitive-services/openai/chatgpt-quickstart?pivots=rest-api&tabs=command-line#rest-api
requestURL := strings.Split(info.RequestURLPath, "?")[0]
requestURL = fmt.Sprintf("%s?api-version=%s", requestURL, apiVersion)
task := strings.TrimPrefix(requestURL, "/v1/")
// 特殊处理 responses API
if info.RelayMode == relayconstant.RelayModeResponses {
requestURL = fmt.Sprintf("/openai/v1/responses?api-version=preview")
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
}
model_ := info.UpstreamModelName
// 2025年5月10日后创建的渠道不移除.
if info.ChannelCreateTime < constant.AzureNoRemoveDotTime {
model_ = strings.Replace(model_, ".", "", -1)
}
// https://github.com/songquanpeng/one-api/issues/67
requestURL = fmt.Sprintf("/openai/deployments/%s/%s", model_, task)
if info.RelayMode == relayconstant.RelayModeRealtime {
requestURL = fmt.Sprintf("/openai/realtime?deployment=%s&api-version=%s", model_, apiVersion)
}
return relaycommon.GetFullRequestURL(info.BaseUrl, requestURL, info.ChannelType), nil
case constant.ChannelTypeMiniMax:
return minimax.GetRequestURL(info)
case constant.ChannelTypeCustom:
url := info.BaseUrl
url = strings.Replace(url, "{model}", info.UpstreamModelName, -1)
return url, nil
default:
return relaycommon.GetFullRequestURL(info.BaseUrl, info.RequestURLPath, info.ChannelType), nil
}
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, header *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, header)
if info.ChannelType == constant.ChannelTypeAzure {
header.Set("api-key", info.ApiKey)
return nil
}
if info.ChannelType == constant.ChannelTypeOpenAI && "" != info.Organization {
header.Set("OpenAI-Organization", info.Organization)
}
if info.RelayMode == relayconstant.RelayModeRealtime {
swp := c.Request.Header.Get("Sec-WebSocket-Protocol")
if swp != "" {
items := []string{
"realtime",
"openai-insecure-api-key." + info.ApiKey,
"openai-beta.realtime-v1",
}
header.Set("Sec-WebSocket-Protocol", strings.Join(items, ","))
//req.Header.Set("Sec-WebSocket-Key", c.Request.Header.Get("Sec-WebSocket-Key"))
//req.Header.Set("Sec-Websocket-Extensions", c.Request.Header.Get("Sec-Websocket-Extensions"))
//req.Header.Set("Sec-Websocket-Version", c.Request.Header.Get("Sec-Websocket-Version"))
} else {
header.Set("openai-beta", "realtime=v1")
header.Set("Authorization", "Bearer "+info.ApiKey)
}
} else {
header.Set("Authorization", "Bearer "+info.ApiKey)
}
if info.ChannelType == constant.ChannelTypeOpenRouter {
header.Set("HTTP-Referer", "https://www.newapi.ai")
header.Set("X-Title", "New API")
}
return nil
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
if info.ChannelType != constant.ChannelTypeOpenAI && info.ChannelType != constant.ChannelTypeAzure {
request.StreamOptions = nil
}
if info.ChannelType == constant.ChannelTypeOpenRouter {
if len(request.Usage) == 0 {
request.Usage = json.RawMessage(`{"include":true}`)
}
}
if strings.HasPrefix(request.Model, "o") {
if request.MaxCompletionTokens == 0 && request.MaxTokens != 0 {
request.MaxCompletionTokens = request.MaxTokens
request.MaxTokens = 0
}
request.Temperature = nil
if strings.HasSuffix(request.Model, "-high") {
request.ReasoningEffort = "high"
request.Model = strings.TrimSuffix(request.Model, "-high")
} else if strings.HasSuffix(request.Model, "-low") {
request.ReasoningEffort = "low"
request.Model = strings.TrimSuffix(request.Model, "-low")
} else if strings.HasSuffix(request.Model, "-medium") {
request.ReasoningEffort = "medium"
request.Model = strings.TrimSuffix(request.Model, "-medium")
}
info.ReasoningEffort = request.ReasoningEffort
info.UpstreamModelName = request.Model
// o系列模型developer适配o1-mini除外
if !strings.HasPrefix(request.Model, "o1-mini") && !strings.HasPrefix(request.Model, "o1-preview") {
//修改第一个Message的内容将system改为developer
if len(request.Messages) > 0 && request.Messages[0].Role == "system" {
request.Messages[0].Role = "developer"
}
}
}
return request, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return request, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return request, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
a.ResponseFormat = request.ResponseFormat
if info.RelayMode == relayconstant.RelayModeAudioSpeech {
jsonData, err := json.Marshal(request)
if err != nil {
return nil, fmt.Errorf("error marshalling object: %w", err)
}
return bytes.NewReader(jsonData), nil
} else {
var requestBody bytes.Buffer
writer := multipart.NewWriter(&requestBody)
writer.WriteField("model", request.Model)
// 获取所有表单字段
formData := c.Request.PostForm
// 遍历表单字段并打印输出
for key, values := range formData {
if key == "model" {
continue
}
for _, value := range values {
writer.WriteField(key, value)
}
}
// 添加文件字段
file, header, err := c.Request.FormFile("file")
if err != nil {
return nil, errors.New("file is required")
}
defer file.Close()
part, err := writer.CreateFormFile("file", header.Filename)
if err != nil {
return nil, errors.New("create form file failed")
}
if _, err := io.Copy(part, file); err != nil {
return nil, errors.New("copy file failed")
}
// 关闭 multipart 编写器以设置分界线
writer.Close()
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
return &requestBody, nil
}
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
switch info.RelayMode {
case relayconstant.RelayModeImagesEdits:
var requestBody bytes.Buffer
writer := multipart.NewWriter(&requestBody)
writer.WriteField("model", request.Model)
// 获取所有表单字段
formData := c.Request.PostForm
// 遍历表单字段并打印输出
for key, values := range formData {
if key == "model" {
continue
}
for _, value := range values {
writer.WriteField(key, value)
}
}
// Parse the multipart form to handle both single image and multiple images
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory
return nil, errors.New("failed to parse multipart form")
}
if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
// Check if "image" field exists in any form, including array notation
var imageFiles []*multipart.FileHeader
var exists bool
// First check for standard "image" field
if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
// If not found, check for "image[]" field
if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
// If still not found, iterate through all fields to find any that start with "image["
foundArrayImages := false
for fieldName, files := range c.Request.MultipartForm.File {
if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
foundArrayImages = true
for _, file := range files {
imageFiles = append(imageFiles, file)
}
}
}
// If no image fields found at all
if !foundArrayImages && (len(imageFiles) == 0) {
return nil, errors.New("image is required")
}
}
}
// Process all image files
for i, fileHeader := range imageFiles {
file, err := fileHeader.Open()
if err != nil {
return nil, fmt.Errorf("failed to open image file %d: %w", i, err)
}
defer file.Close()
// If multiple images, use image[] as the field name
fieldName := "image"
if len(imageFiles) > 1 {
fieldName = "image[]"
}
// Determine MIME type based on file extension
mimeType := detectImageMimeType(fileHeader.Filename)
// Create a form file with the appropriate content type
h := make(textproto.MIMEHeader)
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename))
h.Set("Content-Type", mimeType)
part, err := writer.CreatePart(h)
if err != nil {
return nil, fmt.Errorf("create form part failed for image %d: %w", i, err)
}
if _, err := io.Copy(part, file); err != nil {
return nil, fmt.Errorf("copy file failed for image %d: %w", i, err)
}
}
// Handle mask file if present
if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
maskFile, err := maskFiles[0].Open()
if err != nil {
return nil, errors.New("failed to open mask file")
}
defer maskFile.Close()
// Determine MIME type for mask file
mimeType := detectImageMimeType(maskFiles[0].Filename)
// Create a form file with the appropriate content type
h := make(textproto.MIMEHeader)
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename))
h.Set("Content-Type", mimeType)
maskPart, err := writer.CreatePart(h)
if err != nil {
return nil, errors.New("create form file failed for mask")
}
if _, err := io.Copy(maskPart, maskFile); err != nil {
return nil, errors.New("copy mask file failed")
}
}
} else {
return nil, errors.New("no multipart form data found")
}
// 关闭 multipart 编写器以设置分界线
writer.Close()
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
return bytes.NewReader(requestBody.Bytes()), nil
default:
return request, nil
}
}
// detectImageMimeType determines the MIME type based on the file extension
func detectImageMimeType(filename string) string {
ext := strings.ToLower(filepath.Ext(filename))
switch ext {
case ".jpg", ".jpeg":
return "image/jpeg"
case ".png":
return "image/png"
case ".webp":
return "image/webp"
default:
// Try to detect from extension if possible
if strings.HasPrefix(ext, ".jp") {
return "image/jpeg"
}
// Default to png as a fallback
return "image/png"
}
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// 模型后缀转换 reasoning effort
if strings.HasSuffix(request.Model, "-high") {
request.Reasoning.Effort = "high"
request.Model = strings.TrimSuffix(request.Model, "-high")
} else if strings.HasSuffix(request.Model, "-low") {
request.Reasoning.Effort = "low"
request.Model = strings.TrimSuffix(request.Model, "-low")
} else if strings.HasSuffix(request.Model, "-medium") {
request.Reasoning.Effort = "medium"
request.Model = strings.TrimSuffix(request.Model, "-medium")
}
return request, nil
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
if info.RelayMode == relayconstant.RelayModeAudioTranscription ||
info.RelayMode == relayconstant.RelayModeAudioTranslation ||
info.RelayMode == relayconstant.RelayModeImagesEdits {
return channel.DoFormRequest(a, c, info, requestBody)
} else if info.RelayMode == relayconstant.RelayModeRealtime {
return channel.DoWssRequest(a, c, info, requestBody)
} else {
return channel.DoApiRequest(a, c, info, requestBody)
}
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayMode {
case relayconstant.RelayModeRealtime:
err, usage = OpenaiRealtimeHandler(c, info)
case relayconstant.RelayModeAudioSpeech:
usage = OpenaiTTSHandler(c, resp, info)
case relayconstant.RelayModeAudioTranslation:
fallthrough
case relayconstant.RelayModeAudioTranscription:
err, usage = OpenaiSTTHandler(c, resp, info, a.ResponseFormat)
case relayconstant.RelayModeImagesGenerations, relayconstant.RelayModeImagesEdits:
usage, err = OpenaiHandlerWithUsage(c, info, resp)
case relayconstant.RelayModeRerank:
usage, err = common_handler.RerankHandler(c, info, resp)
case relayconstant.RelayModeResponses:
if info.IsStream {
usage, err = OaiResponsesStreamHandler(c, info, resp)
} else {
usage, err = OaiResponsesHandler(c, info, resp)
}
default:
if info.IsStream {
usage, err = OaiStreamHandler(c, info, resp)
} else {
usage, err = OpenaiHandler(c, info, resp)
}
}
return
}
func (a *Adaptor) GetModelList() []string {
switch a.ChannelType {
case constant.ChannelType360:
return ai360.ModelList
case constant.ChannelTypeMoonshot:
return moonshot.ModelList
case constant.ChannelTypeLingYiWanWu:
return lingyiwanwu.ModelList
case constant.ChannelTypeMiniMax:
return minimax.ModelList
case constant.ChannelTypeXinference:
return xinference.ModelList
case constant.ChannelTypeOpenRouter:
return openrouter.ModelList
default:
return ModelList
}
}
func (a *Adaptor) GetChannelName() string {
switch a.ChannelType {
case constant.ChannelType360:
return ai360.ChannelName
case constant.ChannelTypeMoonshot:
return moonshot.ChannelName
case constant.ChannelTypeLingYiWanWu:
return lingyiwanwu.ChannelName
case constant.ChannelTypeMiniMax:
return minimax.ChannelName
case constant.ChannelTypeXinference:
return xinference.ChannelName
case constant.ChannelTypeOpenRouter:
return openrouter.ChannelName
default:
return ChannelName
}
}

View File

@@ -0,0 +1,35 @@
package openai
var ModelList = []string{
"gpt-3.5-turbo", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-1106", "gpt-3.5-turbo-0125",
"gpt-3.5-turbo-16k", "gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-instruct",
"gpt-4", "gpt-4-0613", "gpt-4-1106-preview", "gpt-4-0125-preview",
"gpt-4-32k", "gpt-4-32k-0613",
"gpt-4-turbo-preview", "gpt-4-turbo", "gpt-4-turbo-2024-04-09",
"gpt-4-vision-preview",
"chatgpt-4o-latest",
"gpt-4o", "gpt-4o-2024-05-13", "gpt-4o-2024-08-06", "gpt-4o-2024-11-20",
"gpt-4o-mini", "gpt-4o-mini-2024-07-18",
"gpt-4.5-preview", "gpt-4.5-preview-2025-02-27",
"o1-preview", "o1-preview-2024-09-12",
"o1-mini", "o1-mini-2024-09-12",
"o3-mini", "o3-mini-2025-01-31",
"o3-mini-high", "o3-mini-2025-01-31-high",
"o3-mini-low", "o3-mini-2025-01-31-low",
"o3-mini-medium", "o3-mini-2025-01-31-medium",
"o1", "o1-2024-12-17",
"gpt-4o-audio-preview", "gpt-4o-audio-preview-2024-10-01",
"gpt-4o-realtime-preview", "gpt-4o-realtime-preview-2024-10-01", "gpt-4o-realtime-preview-2024-12-17",
"gpt-4o-mini-realtime-preview", "gpt-4o-mini-realtime-preview-2024-12-17",
"text-embedding-ada-002", "text-embedding-3-small", "text-embedding-3-large",
"text-curie-001", "text-babbage-001", "text-ada-001",
"text-moderation-latest", "text-moderation-stable",
"text-davinci-edit-001",
"davinci-002", "babbage-002",
"dall-e-3",
"whisper-1",
"tts-1", "tts-1-1106", "tts-1-hd", "tts-1-hd-1106",
}
var ChannelName = "openai"

View File

@@ -0,0 +1,196 @@
package openai
import (
"encoding/json"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant"
"one-api/relay/helper"
"one-api/service"
"strings"
"github.com/gin-gonic/gin"
)
// 辅助函数
func handleStreamFormat(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
info.SendResponseCount++
switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI:
return sendStreamData(c, info, data, forceFormat, thinkToContent)
case relaycommon.RelayFormatClaude:
return handleClaudeFormat(c, data, info)
}
return nil
}
func handleClaudeFormat(c *gin.Context, data string, info *relaycommon.RelayInfo) error {
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(data), &streamResponse); err != nil {
return err
}
if streamResponse.Usage != nil {
info.ClaudeConvertInfo.Usage = streamResponse.Usage
}
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
for _, resp := range claudeResponses {
helper.ClaudeData(c, *resp)
}
return nil
}
func ProcessStreamResponse(streamResponse dto.ChatCompletionsStreamResponse, responseTextBuilder *strings.Builder, toolCount *int) error {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > *toolCount {
*toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
}
}
}
return nil
}
func processTokens(relayMode int, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
streamResp := "[" + strings.Join(streamItems, ",") + "]"
switch relayMode {
case relayconstant.RelayModeChatCompletions:
return processChatCompletions(streamResp, streamItems, responseTextBuilder, toolCount)
case relayconstant.RelayModeCompletions:
return processCompletions(streamResp, streamItems, responseTextBuilder)
}
return nil
}
func processChatCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder, toolCount *int) error {
var streamResponses []dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
return err
}
if err := ProcessStreamResponse(streamResponse, responseTextBuilder, toolCount); err != nil {
common.SysError("error processing stream response: " + err.Error())
}
}
return nil
}
// 批量处理所有响应
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Delta.GetContentString())
responseTextBuilder.WriteString(choice.Delta.GetReasoningContent())
if choice.Delta.ToolCalls != nil {
if len(choice.Delta.ToolCalls) > *toolCount {
*toolCount = len(choice.Delta.ToolCalls)
}
for _, tool := range choice.Delta.ToolCalls {
responseTextBuilder.WriteString(tool.Function.Name)
responseTextBuilder.WriteString(tool.Function.Arguments)
}
}
}
}
return nil
}
func processCompletions(streamResp string, streamItems []string, responseTextBuilder *strings.Builder) error {
var streamResponses []dto.CompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(streamResp), &streamResponses); err != nil {
// 一次性解析失败,逐个解析
common.SysError("error unmarshalling stream response: " + err.Error())
for _, item := range streamItems {
var streamResponse dto.CompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(item), &streamResponse); err != nil {
continue
}
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
return nil
}
// 批量处理所有响应
for _, streamResponse := range streamResponses {
for _, choice := range streamResponse.Choices {
responseTextBuilder.WriteString(choice.Text)
}
}
return nil
}
func handleLastResponse(lastStreamData string, responseId *string, createAt *int64,
systemFingerprint *string, model *string, usage **dto.Usage,
containStreamUsage *bool, info *relaycommon.RelayInfo,
shouldSendLastResp *bool) error {
var lastStreamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &lastStreamResponse); err != nil {
return err
}
*responseId = lastStreamResponse.Id
*createAt = lastStreamResponse.Created
*systemFingerprint = lastStreamResponse.GetSystemFingerprint()
*model = lastStreamResponse.Model
if service.ValidUsage(lastStreamResponse.Usage) {
*containStreamUsage = true
*usage = lastStreamResponse.Usage
if !info.ShouldIncludeUsage {
*shouldSendLastResp = false
}
}
return nil
}
func handleFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, lastStreamData string,
responseId string, createAt int64, model string, systemFingerprint string,
usage *dto.Usage, containStreamUsage bool) {
switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI:
if info.ShouldIncludeUsage && !containStreamUsage {
response := helper.GenerateFinalUsageResponse(responseId, createAt, model, *usage)
response.SetSystemFingerprint(systemFingerprint)
helper.ObjectData(c, response)
}
helper.Done(c)
case relaycommon.RelayFormatClaude:
info.ClaudeConvertInfo.Done = true
var streamResponse dto.ChatCompletionsStreamResponse
if err := json.Unmarshal(common.StringToByteSlice(lastStreamData), &streamResponse); err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
return
}
info.ClaudeConvertInfo.Usage = usage
claudeResponses := service.StreamResponseOpenAI2Claude(&streamResponse, info)
for _, resp := range claudeResponses {
helper.ClaudeData(c, *resp)
}
}
}
func sendResponsesStreamData(c *gin.Context, streamResponse dto.ResponsesStreamResponse, data string) {
if data == "" {
return
}
helper.ResponseChunkData(c, streamResponse, data)
}

View File

@@ -0,0 +1,587 @@
package openai
import (
"bytes"
"fmt"
"io"
"math"
"mime/multipart"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"os"
"path/filepath"
"strings"
"one-api/types"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
)
func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, forceFormat bool, thinkToContent bool) error {
if data == "" {
return nil
}
if !forceFormat && !thinkToContent {
return helper.StringData(c, data)
}
var lastStreamResponse dto.ChatCompletionsStreamResponse
if err := common.UnmarshalJsonStr(data, &lastStreamResponse); err != nil {
return err
}
if !thinkToContent {
return helper.ObjectData(c, lastStreamResponse)
}
hasThinkingContent := false
hasContent := false
var thinkingContent strings.Builder
for _, choice := range lastStreamResponse.Choices {
if len(choice.Delta.GetReasoningContent()) > 0 {
hasThinkingContent = true
thinkingContent.WriteString(choice.Delta.GetReasoningContent())
}
if len(choice.Delta.GetContentString()) > 0 {
hasContent = true
}
}
// Handle think to content conversion
if info.ThinkingContentInfo.IsFirstThinkingContent {
if hasThinkingContent {
response := lastStreamResponse.Copy()
for i := range response.Choices {
// send `think` tag with thinking content
response.Choices[i].Delta.SetContentString("<think>\n" + thinkingContent.String())
response.Choices[i].Delta.ReasoningContent = nil
response.Choices[i].Delta.Reasoning = nil
}
info.ThinkingContentInfo.IsFirstThinkingContent = false
info.ThinkingContentInfo.HasSentThinkingContent = true
return helper.ObjectData(c, response)
}
}
if lastStreamResponse.Choices == nil || len(lastStreamResponse.Choices) == 0 {
return helper.ObjectData(c, lastStreamResponse)
}
// Process each choice
for i, choice := range lastStreamResponse.Choices {
// Handle transition from thinking to content
// only send `</think>` tag when previous thinking content has been sent
if hasContent && !info.ThinkingContentInfo.SendLastThinkingContent && info.ThinkingContentInfo.HasSentThinkingContent {
response := lastStreamResponse.Copy()
for j := range response.Choices {
response.Choices[j].Delta.SetContentString("\n</think>\n")
response.Choices[j].Delta.ReasoningContent = nil
response.Choices[j].Delta.Reasoning = nil
}
info.ThinkingContentInfo.SendLastThinkingContent = true
helper.ObjectData(c, response)
}
// Convert reasoning content to regular content if any
if len(choice.Delta.GetReasoningContent()) > 0 {
lastStreamResponse.Choices[i].Delta.SetContentString(choice.Delta.GetReasoningContent())
lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
lastStreamResponse.Choices[i].Delta.Reasoning = nil
} else if !hasThinkingContent && !hasContent {
// flush thinking content
lastStreamResponse.Choices[i].Delta.ReasoningContent = nil
lastStreamResponse.Choices[i].Delta.Reasoning = nil
}
}
return helper.ObjectData(c, lastStreamResponse)
}
func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
if resp == nil || resp.Body == nil {
common.LogError(c, "invalid response or response body")
return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse)
}
defer common.CloseResponseBodyGracefully(resp)
model := info.UpstreamModelName
var responseId string
var createAt int64 = 0
var systemFingerprint string
var containStreamUsage bool
var responseTextBuilder strings.Builder
var toolCount int
var usage = &dto.Usage{}
var streamItems []string // store stream items
var forceFormat bool
var thinkToContent bool
if info.ChannelSetting.ForceFormat {
forceFormat = true
}
if info.ChannelSetting.ThinkingToContent {
thinkToContent = true
}
var (
lastStreamData string
)
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
if lastStreamData != "" {
err := handleStreamFormat(c, info, lastStreamData, forceFormat, thinkToContent)
if err != nil {
common.SysError("error handling stream format: " + err.Error())
}
}
lastStreamData = data
streamItems = append(streamItems, data)
return true
})
// 处理最后的响应
shouldSendLastResp := true
if err := handleLastResponse(lastStreamData, &responseId, &createAt, &systemFingerprint, &model, &usage,
&containStreamUsage, info, &shouldSendLastResp); err != nil {
common.SysError("error handling last response: " + err.Error())
}
if shouldSendLastResp && info.RelayFormat == relaycommon.RelayFormatOpenAI {
_ = sendStreamData(c, info, lastStreamData, forceFormat, thinkToContent)
}
// 处理token计算
if err := processTokens(info.RelayMode, streamItems, &responseTextBuilder, &toolCount); err != nil {
common.SysError("error processing tokens: " + err.Error())
}
if !containStreamUsage {
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
usage.CompletionTokens += toolCount * 7
} else {
if info.ChannelType == constant.ChannelTypeDeepSeek {
if usage.PromptCacheHitTokens != 0 {
usage.PromptTokensDetails.CachedTokens = usage.PromptCacheHitTokens
}
}
}
handleFinalResponse(c, info, lastStreamData, responseId, createAt, model, systemFingerprint, usage, containStreamUsage)
return usage, nil
}
func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
defer common.CloseResponseBodyGracefully(resp)
var simpleResponse dto.OpenAITextResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
}
err = common.Unmarshal(responseBody, &simpleResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
if simpleResponse.Error != nil && simpleResponse.Error.Type != "" {
return nil, types.WithOpenAIError(*simpleResponse.Error, resp.StatusCode)
}
forceFormat := false
if info.ChannelSetting.ForceFormat {
forceFormat = true
}
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
completionTokens := 0
for _, choice := range simpleResponse.Choices {
ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
completionTokens += ctkm
}
simpleResponse.Usage = dto.Usage{
PromptTokens: info.PromptTokens,
CompletionTokens: completionTokens,
TotalTokens: info.PromptTokens + completionTokens,
}
}
switch info.RelayFormat {
case relaycommon.RelayFormatOpenAI:
if forceFormat {
responseBody, err = common.Marshal(simpleResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
} else {
break
}
case relaycommon.RelayFormatClaude:
claudeResp := service.ResponseOpenAI2Claude(&simpleResponse, info)
claudeRespStr, err := common.Marshal(claudeResp)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
responseBody = claudeRespStr
}
common.IOCopyBytesGracefully(c, resp, responseBody)
return &simpleResponse.Usage, nil
}
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) *dto.Usage {
// the status code has been judged before, if there is a body reading failure,
// it should be regarded as a non-recoverable error, so it should not return err for external retry.
// Analogous to nginx's load balancing, it will only retry if it can't be requested or
// if the upstream returns a specific status code, once the upstream has already written the header,
// the subsequent failure of the response body should be regarded as a non-recoverable error,
// and can be terminated directly.
defer common.CloseResponseBodyGracefully(resp)
usage := &dto.Usage{}
usage.PromptTokens = info.PromptTokens
usage.TotalTokens = info.PromptTokens
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
c.Writer.WriteHeaderNow()
_, err := io.Copy(c.Writer, resp.Body)
if err != nil {
common.LogError(c, err.Error())
}
return usage
}
func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, responseFormat string) (*types.NewAPIError, *dto.Usage) {
defer common.CloseResponseBodyGracefully(resp)
// count tokens by audio file duration
audioTokens, err := countAudioTokens(c)
if err != nil {
return types.NewError(err, types.ErrorCodeCountTokenFailed), nil
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
}
// 写入新的 response body
common.IOCopyBytesGracefully(c, resp, responseBody)
usage := &dto.Usage{}
usage.PromptTokens = audioTokens
usage.CompletionTokens = 0
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
return nil, usage
}
func countAudioTokens(c *gin.Context) (int, error) {
body, err := common.GetRequestBody(c)
if err != nil {
return 0, errors.WithStack(err)
}
var reqBody struct {
File *multipart.FileHeader `form:"file" binding:"required"`
}
c.Request.Body = io.NopCloser(bytes.NewReader(body))
if err = c.ShouldBind(&reqBody); err != nil {
return 0, errors.WithStack(err)
}
ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
reqFp, err := reqBody.File.Open()
if err != nil {
return 0, errors.WithStack(err)
}
defer reqFp.Close()
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
if err != nil {
return 0, errors.WithStack(err)
}
defer os.Remove(tmpFp.Name())
_, err = io.Copy(tmpFp, reqFp)
if err != nil {
return 0, errors.WithStack(err)
}
if err = tmpFp.Close(); err != nil {
return 0, errors.WithStack(err)
}
duration, err := common.GetAudioDuration(c.Request.Context(), tmpFp.Name(), ext)
if err != nil {
return 0, errors.WithStack(err)
}
return int(math.Round(math.Ceil(duration) / 60.0 * 1000)), nil // 1 minute 相当于 1k tokens
}
func OpenaiRealtimeHandler(c *gin.Context, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.RealtimeUsage) {
if info == nil || info.ClientWs == nil || info.TargetWs == nil {
return types.NewError(fmt.Errorf("invalid websocket connection"), types.ErrorCodeBadResponse), nil
}
info.IsStream = true
clientConn := info.ClientWs
targetConn := info.TargetWs
clientClosed := make(chan struct{})
targetClosed := make(chan struct{})
sendChan := make(chan []byte, 100)
receiveChan := make(chan []byte, 100)
errChan := make(chan error, 2)
usage := &dto.RealtimeUsage{}
localUsage := &dto.RealtimeUsage{}
sumUsage := &dto.RealtimeUsage{}
gopool.Go(func() {
defer func() {
if r := recover(); r != nil {
errChan <- fmt.Errorf("panic in client reader: %v", r)
}
}()
for {
select {
case <-c.Done():
return
default:
_, message, err := clientConn.ReadMessage()
if err != nil {
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
errChan <- fmt.Errorf("error reading from client: %v", err)
}
close(clientClosed)
return
}
realtimeEvent := &dto.RealtimeEvent{}
err = common.Unmarshal(message, realtimeEvent)
if err != nil {
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
return
}
if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdate {
if realtimeEvent.Session != nil {
if realtimeEvent.Session.Tools != nil {
info.RealtimeTools = realtimeEvent.Session.Tools
}
}
}
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
if err != nil {
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
localUsage.TotalTokens += textToken + audioToken
localUsage.InputTokens += textToken + audioToken
localUsage.InputTokenDetails.TextTokens += textToken
localUsage.InputTokenDetails.AudioTokens += audioToken
err = helper.WssString(c, targetConn, string(message))
if err != nil {
errChan <- fmt.Errorf("error writing to target: %v", err)
return
}
select {
case sendChan <- message:
default:
}
}
}
})
gopool.Go(func() {
defer func() {
if r := recover(); r != nil {
errChan <- fmt.Errorf("panic in target reader: %v", r)
}
}()
for {
select {
case <-c.Done():
return
default:
_, message, err := targetConn.ReadMessage()
if err != nil {
if !websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
errChan <- fmt.Errorf("error reading from target: %v", err)
}
close(targetClosed)
return
}
info.SetFirstResponseTime()
realtimeEvent := &dto.RealtimeEvent{}
err = common.Unmarshal(message, realtimeEvent)
if err != nil {
errChan <- fmt.Errorf("error unmarshalling message: %v", err)
return
}
if realtimeEvent.Type == dto.RealtimeEventTypeResponseDone {
realtimeUsage := realtimeEvent.Response.Usage
if realtimeUsage != nil {
usage.TotalTokens += realtimeUsage.TotalTokens
usage.InputTokens += realtimeUsage.InputTokens
usage.OutputTokens += realtimeUsage.OutputTokens
usage.InputTokenDetails.AudioTokens += realtimeUsage.InputTokenDetails.AudioTokens
usage.InputTokenDetails.CachedTokens += realtimeUsage.InputTokenDetails.CachedTokens
usage.InputTokenDetails.TextTokens += realtimeUsage.InputTokenDetails.TextTokens
usage.OutputTokenDetails.AudioTokens += realtimeUsage.OutputTokenDetails.AudioTokens
usage.OutputTokenDetails.TextTokens += realtimeUsage.OutputTokenDetails.TextTokens
err := preConsumeUsage(c, info, usage, sumUsage)
if err != nil {
errChan <- fmt.Errorf("error consume usage: %v", err)
return
}
// 本次计费完成,清除
usage = &dto.RealtimeUsage{}
localUsage = &dto.RealtimeUsage{}
} else {
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
if err != nil {
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
localUsage.TotalTokens += textToken + audioToken
info.IsFirstRequest = false
localUsage.InputTokens += textToken + audioToken
localUsage.InputTokenDetails.TextTokens += textToken
localUsage.InputTokenDetails.AudioTokens += audioToken
err = preConsumeUsage(c, info, localUsage, sumUsage)
if err != nil {
errChan <- fmt.Errorf("error consume usage: %v", err)
return
}
// 本次计费完成,清除
localUsage = &dto.RealtimeUsage{}
// print now usage
}
common.LogInfo(c, fmt.Sprintf("realtime streaming sumUsage: %v", sumUsage))
common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
common.LogInfo(c, fmt.Sprintf("realtime streaming localUsage: %v", localUsage))
} else if realtimeEvent.Type == dto.RealtimeEventTypeSessionUpdated || realtimeEvent.Type == dto.RealtimeEventTypeSessionCreated {
realtimeSession := realtimeEvent.Session
if realtimeSession != nil {
// update audio format
info.InputAudioFormat = common.GetStringIfEmpty(realtimeSession.InputAudioFormat, info.InputAudioFormat)
info.OutputAudioFormat = common.GetStringIfEmpty(realtimeSession.OutputAudioFormat, info.OutputAudioFormat)
}
} else {
textToken, audioToken, err := service.CountTokenRealtime(info, *realtimeEvent, info.UpstreamModelName)
if err != nil {
errChan <- fmt.Errorf("error counting text token: %v", err)
return
}
common.LogInfo(c, fmt.Sprintf("type: %s, textToken: %d, audioToken: %d", realtimeEvent.Type, textToken, audioToken))
localUsage.TotalTokens += textToken + audioToken
localUsage.OutputTokens += textToken + audioToken
localUsage.OutputTokenDetails.TextTokens += textToken
localUsage.OutputTokenDetails.AudioTokens += audioToken
}
err = helper.WssString(c, clientConn, string(message))
if err != nil {
errChan <- fmt.Errorf("error writing to client: %v", err)
return
}
select {
case receiveChan <- message:
default:
}
}
}
})
select {
case <-clientClosed:
case <-targetClosed:
case err := <-errChan:
//return service.OpenAIErrorWrapper(err, "realtime_error", http.StatusInternalServerError), nil
common.LogError(c, "realtime error: "+err.Error())
case <-c.Done():
}
if usage.TotalTokens != 0 {
_ = preConsumeUsage(c, info, usage, sumUsage)
}
if localUsage.TotalTokens != 0 {
_ = preConsumeUsage(c, info, localUsage, sumUsage)
}
// check usage total tokens, if 0, use local usage
return nil, sumUsage
}
func preConsumeUsage(ctx *gin.Context, info *relaycommon.RelayInfo, usage *dto.RealtimeUsage, totalUsage *dto.RealtimeUsage) error {
if usage == nil || totalUsage == nil {
return fmt.Errorf("invalid usage pointer")
}
totalUsage.TotalTokens += usage.TotalTokens
totalUsage.InputTokens += usage.InputTokens
totalUsage.OutputTokens += usage.OutputTokens
totalUsage.InputTokenDetails.CachedTokens += usage.InputTokenDetails.CachedTokens
totalUsage.InputTokenDetails.TextTokens += usage.InputTokenDetails.TextTokens
totalUsage.InputTokenDetails.AudioTokens += usage.InputTokenDetails.AudioTokens
totalUsage.OutputTokenDetails.TextTokens += usage.OutputTokenDetails.TextTokens
totalUsage.OutputTokenDetails.AudioTokens += usage.OutputTokenDetails.AudioTokens
// clear usage
err := service.PreWssConsumeQuota(ctx, info, usage)
return err
}
func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
defer common.CloseResponseBodyGracefully(resp)
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
}
var usageResp dto.SimpleResponse
err = common.Unmarshal(responseBody, &usageResp)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
// 写入新的 response body
common.IOCopyBytesGracefully(c, resp, responseBody)
// Once we've written to the client, we should not return errors anymore
// because the upstream has already consumed resources and returned content
// We should still perform billing even if parsing fails
// format
if usageResp.InputTokens > 0 {
usageResp.PromptTokens += usageResp.InputTokens
}
if usageResp.OutputTokens > 0 {
usageResp.CompletionTokens += usageResp.OutputTokens
}
if usageResp.InputTokensDetails != nil {
usageResp.PromptTokensDetails.ImageTokens += usageResp.InputTokensDetails.ImageTokens
usageResp.PromptTokensDetails.TextTokens += usageResp.InputTokensDetails.TextTokens
}
return &usageResp.Usage, nil
}

View File

@@ -0,0 +1,97 @@
package openai
import (
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
defer common.CloseResponseBodyGracefully(resp)
// read response body
var responsesResponse dto.OpenAIResponsesResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
}
err = common.Unmarshal(responseBody, &responsesResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
if responsesResponse.Error != nil {
return nil, types.WithOpenAIError(*responsesResponse.Error, resp.StatusCode)
}
// 写入新的 response body
common.IOCopyBytesGracefully(c, resp, responseBody)
// compute usage
usage := dto.Usage{}
usage.PromptTokens = responsesResponse.Usage.InputTokens
usage.CompletionTokens = responsesResponse.Usage.OutputTokens
usage.TotalTokens = responsesResponse.Usage.TotalTokens
// 解析 Tools 用量
for _, tool := range responsesResponse.Tools {
info.ResponsesUsageInfo.BuiltInTools[common.Interface2String(tool["type"])].CallCount++
}
return &usage, nil
}
func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
if resp == nil || resp.Body == nil {
common.LogError(c, "invalid response or response body")
return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse)
}
var usage = &dto.Usage{}
var responseTextBuilder strings.Builder
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
// 检查当前数据是否包含 completed 状态和 usage 信息
var streamResponse dto.ResponsesStreamResponse
if err := common.UnmarshalJsonStr(data, &streamResponse); err == nil {
sendResponsesStreamData(c, streamResponse, data)
switch streamResponse.Type {
case "response.completed":
usage.PromptTokens = streamResponse.Response.Usage.InputTokens
usage.CompletionTokens = streamResponse.Response.Usage.OutputTokens
usage.TotalTokens = streamResponse.Response.Usage.TotalTokens
case "response.output_text.delta":
// 处理输出文本
responseTextBuilder.WriteString(streamResponse.Delta)
case dto.ResponsesOutputTypeItemDone:
// 函数调用处理
if streamResponse.Item != nil {
switch streamResponse.Item.Type {
case dto.BuildInCallWebSearchCall:
info.ResponsesUsageInfo.BuiltInTools[dto.BuildInToolWebSearchPreview].CallCount++
}
}
}
}
return true
})
if usage.CompletionTokens == 0 {
// 计算输出文本的 token 数量
tempStr := responseTextBuilder.String()
if len(tempStr) > 0 {
// 非正常结束,使用输出文本的 token 数量
completionTokens := service.CountTextToken(tempStr, info.UpstreamModelName)
usage.CompletionTokens = completionTokens
}
}
return usage, nil
}

View File

@@ -0,0 +1,5 @@
package openrouter
var ModelList = []string{}
var ChannelName = "openrouter"

View File

@@ -0,0 +1,9 @@
package openrouter
type RequestReasoning struct {
// One of the following (not both):
Effort string `json:"effort,omitempty"` // Can be "high", "medium", or "low" (OpenAI-style)
MaxTokens int `json:"max_tokens,omitempty"` // Specific token limit (Anthropic-style)
// Optional: Default is false. All models support this.
Exclude bool `json:"exclude,omitempty"` // Set to true to exclude reasoning tokens from response
}

View File

@@ -0,0 +1,91 @@
package palm
import (
"errors"
"fmt"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
"one-api/types"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/v1beta2/models/chat-bison-001:generateMessage", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("x-goog-api-key", info.ApiKey)
return nil
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
var responseText string
err, responseText = palmStreamHandler(c, resp)
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
} else {
usage, err = palmHandler(c, info, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,7 @@
package palm
var ModelList = []string{
"PaLM-2",
}
var ChannelName = "google palm"

38
relay/channel/palm/dto.go Normal file
View File

@@ -0,0 +1,38 @@
package palm
import "one-api/dto"
type PaLMChatMessage struct {
Author string `json:"author"`
Content string `json:"content"`
}
type PaLMFilter struct {
Reason string `json:"reason"`
Message string `json:"message"`
}
type PaLMPrompt struct {
Messages []PaLMChatMessage `json:"messages"`
}
type PaLMChatRequest struct {
Prompt PaLMPrompt `json:"prompt"`
Temperature *float64 `json:"temperature,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK uint `json:"topK,omitempty"`
}
type PaLMError struct {
Code int `json:"code"`
Message string `json:"message"`
Status string `json:"status"`
}
type PaLMChatResponse struct {
Candidates []PaLMChatMessage `json:"candidates"`
Messages []dto.Message `json:"messages"`
Filters []PaLMFilter `json:"filters"`
Error PaLMError `json:"error"`
}

View File

@@ -0,0 +1,162 @@
package palm
import (
"encoding/json"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
"github.com/gin-gonic/gin"
)
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#request-body
// https://developers.generativeai.google/api/rest/generativelanguage/models/generateMessage#response-body
func requestOpenAI2PaLM(textRequest dto.GeneralOpenAIRequest) *PaLMChatRequest {
palmRequest := PaLMChatRequest{
Prompt: PaLMPrompt{
Messages: make([]PaLMChatMessage, 0, len(textRequest.Messages)),
},
Temperature: textRequest.Temperature,
CandidateCount: textRequest.N,
TopP: textRequest.TopP,
TopK: textRequest.MaxTokens,
}
for _, message := range textRequest.Messages {
palmMessage := PaLMChatMessage{
Content: message.StringContent(),
}
if message.Role == "user" {
palmMessage.Author = "0"
} else {
palmMessage.Author = "1"
}
palmRequest.Prompt.Messages = append(palmRequest.Prompt.Messages, palmMessage)
}
return &palmRequest
}
func responsePaLM2OpenAI(response *PaLMChatResponse) *dto.OpenAITextResponse {
fullTextResponse := dto.OpenAITextResponse{
Choices: make([]dto.OpenAITextResponseChoice, 0, len(response.Candidates)),
}
for i, candidate := range response.Candidates {
choice := dto.OpenAITextResponseChoice{
Index: i,
Message: dto.Message{
Role: "assistant",
Content: candidate.Content,
},
FinishReason: "stop",
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
return &fullTextResponse
}
func streamResponsePaLM2OpenAI(palmResponse *PaLMChatResponse) *dto.ChatCompletionsStreamResponse {
var choice dto.ChatCompletionsStreamResponseChoice
if len(palmResponse.Candidates) > 0 {
choice.Delta.SetContentString(palmResponse.Candidates[0].Content)
}
choice.FinishReason = &constant.FinishReasonStop
var response dto.ChatCompletionsStreamResponse
response.Object = "chat.completion.chunk"
response.Model = "palm2"
response.Choices = []dto.ChatCompletionsStreamResponseChoice{choice}
return &response
}
func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, string) {
responseText := ""
responseId := helper.GetResponseID(c)
createdTime := common.GetTimestamp()
dataChan := make(chan string)
stopChan := make(chan bool)
go func() {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
common.SysError("error reading stream response: " + err.Error())
stopChan <- true
return
}
common.CloseResponseBodyGracefully(resp)
var palmResponse PaLMChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
stopChan <- true
return
}
fullTextResponse := streamResponsePaLM2OpenAI(&palmResponse)
fullTextResponse.Id = responseId
fullTextResponse.Created = createdTime
if len(palmResponse.Candidates) > 0 {
responseText = palmResponse.Candidates[0].Content
}
jsonResponse, err := json.Marshal(fullTextResponse)
if err != nil {
common.SysError("error marshalling stream response: " + err.Error())
stopChan <- true
return
}
dataChan <- string(jsonResponse)
stopChan <- true
}()
helper.SetEventStreamHeaders(c)
c.Stream(func(w io.Writer) bool {
select {
case data := <-dataChan:
c.Render(-1, common.CustomEvent{Data: "data: " + data})
return true
case <-stopChan:
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
return false
}
})
common.CloseResponseBodyGracefully(resp)
return nil, responseText
}
func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
}
common.CloseResponseBodyGracefully(resp)
var palmResponse PaLMChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
return nil, types.WithOpenAIError(types.OpenAIError{
Message: palmResponse.Error.Message,
Type: palmResponse.Error.Status,
Param: "",
Code: palmResponse.Error.Code,
}, resp.StatusCode)
}
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, info.UpstreamModelName)
usage := dto.Usage{
PromptTokens: info.PromptTokens,
CompletionTokens: completionTokens,
TotalTokens: info.PromptTokens + completionTokens,
}
fullTextResponse.Usage = usage
jsonResponse, err := common.Marshal(fullTextResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
common.IOCopyBytesGracefully(c, resp, jsonResponse)
return &usage, nil
}

View File

@@ -0,0 +1,92 @@
package perplexity
import (
"errors"
"fmt"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/types"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/chat/completions", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
if request.TopP >= 1 {
request.TopP = 0.99
}
return requestOpenAI2Perplexity(*request), nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
usage, err = openai.OpenaiHandler(c, info, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,7 @@
package perplexity
var ModelList = []string{
"llama-3-sonar-small-32k-chat", "llama-3-sonar-small-32k-online", "llama-3-sonar-large-32k-chat", "llama-3-sonar-large-32k-online", "llama-3-8b-instruct", "llama-3-70b-instruct", "mixtral-8x7b-instruct",
}
var ChannelName = "perplexity"

View File

@@ -0,0 +1,21 @@
package perplexity
import "one-api/dto"
func requestOpenAI2Perplexity(request dto.GeneralOpenAIRequest) *dto.GeneralOpenAIRequest {
messages := make([]dto.Message, 0, len(request.Messages))
for _, message := range request.Messages {
messages = append(messages, dto.Message{
Role: message.Role,
Content: message.Content,
})
}
return &dto.GeneralOpenAIRequest{
Model: request.Model,
Stream: request.Stream,
Messages: messages,
Temperature: request.Temperature,
TopP: request.TopP,
MaxTokens: request.MaxTokens,
}
}

View File

@@ -0,0 +1,104 @@
package siliconflow
import (
"errors"
"fmt"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/types"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
if info.RelayMode == constant.RelayModeRerank {
return fmt.Sprintf("%s/v1/rerank", info.BaseUrl), nil
} else if info.RelayMode == constant.RelayModeEmbeddings {
return fmt.Sprintf("%s/v1/embeddings", info.BaseUrl), nil
} else if info.RelayMode == constant.RelayModeChatCompletions {
return fmt.Sprintf("%s/v1/chat/completions", info.BaseUrl), nil
} else if info.RelayMode == constant.RelayModeCompletions {
return fmt.Sprintf("%s/v1/completions", info.BaseUrl), nil
}
return "", errors.New("invalid relay mode")
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", fmt.Sprintf("Bearer %s", info.ApiKey))
return nil
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
return request, nil
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return request, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return request, nil
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayMode {
case constant.RelayModeRerank:
usage, err = siliconflowRerankHandler(c, info, resp)
case constant.RelayModeCompletions:
fallthrough
case constant.RelayModeChatCompletions:
if info.IsStream {
usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
usage, err = openai.OpenaiHandler(c, info, resp)
}
case constant.RelayModeEmbeddings:
usage, err = openai.OpenaiHandler(c, info, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,51 @@
package siliconflow
var ModelList = []string{
"THUDM/glm-4-9b-chat",
//"stabilityai/stable-diffusion-xl-base-1.0",
//"TencentARC/PhotoMaker",
"InstantX/InstantID",
//"stabilityai/stable-diffusion-2-1",
//"stabilityai/sd-turbo",
//"stabilityai/sdxl-turbo",
"ByteDance/SDXL-Lightning",
"deepseek-ai/deepseek-llm-67b-chat",
"Qwen/Qwen1.5-14B-Chat",
"Qwen/Qwen1.5-7B-Chat",
"Qwen/Qwen1.5-110B-Chat",
"Qwen/Qwen1.5-32B-Chat",
"01-ai/Yi-1.5-6B-Chat",
"01-ai/Yi-1.5-9B-Chat-16K",
"01-ai/Yi-1.5-34B-Chat-16K",
"THUDM/chatglm3-6b",
"deepseek-ai/DeepSeek-V2-Chat",
"Qwen/Qwen2-72B-Instruct",
"Qwen/Qwen2-7B-Instruct",
"Qwen/Qwen2-57B-A14B-Instruct",
//"stabilityai/stable-diffusion-3-medium",
"deepseek-ai/DeepSeek-Coder-V2-Instruct",
"Qwen/Qwen2-1.5B-Instruct",
"internlm/internlm2_5-7b-chat",
"BAAI/bge-large-en-v1.5",
"BAAI/bge-large-zh-v1.5",
"Pro/Qwen/Qwen2-7B-Instruct",
"Pro/Qwen/Qwen2-1.5B-Instruct",
"Pro/Qwen/Qwen1.5-7B-Chat",
"Pro/THUDM/glm-4-9b-chat",
"Pro/THUDM/chatglm3-6b",
"Pro/01-ai/Yi-1.5-9B-Chat-16K",
"Pro/01-ai/Yi-1.5-6B-Chat",
"Pro/google/gemma-2-9b-it",
"Pro/internlm/internlm2_5-7b-chat",
"Pro/meta-llama/Meta-Llama-3-8B-Instruct",
"Pro/mistralai/Mistral-7B-Instruct-v0.2",
"black-forest-labs/FLUX.1-schnell",
"FunAudioLLM/SenseVoiceSmall",
"netease-youdao/bce-embedding-base_v1",
"BAAI/bge-m3",
"internlm/internlm2_5-20b-chat",
"Qwen/Qwen2-Math-72B-Instruct",
"netease-youdao/bce-reranker-base_v1",
"BAAI/bge-reranker-v2-m3",
}
var ChannelName = "siliconflow"

View File

@@ -0,0 +1,17 @@
package siliconflow
import "one-api/dto"
type SFTokens struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
}
type SFMeta struct {
Tokens SFTokens `json:"tokens"`
}
type SFRerankResponse struct {
Results []dto.RerankResponseResult `json:"results"`
Meta SFMeta `json:"meta"`
}

View File

@@ -0,0 +1,44 @@
package siliconflow
import (
"encoding/json"
"io"
"net/http"
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/types"
"github.com/gin-gonic/gin"
)
func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
}
common.CloseResponseBodyGracefully(resp)
var siliconflowResp SFRerankResponse
err = json.Unmarshal(responseBody, &siliconflowResp)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
usage := &dto.Usage{
PromptTokens: siliconflowResp.Meta.Tokens.InputTokens,
CompletionTokens: siliconflowResp.Meta.Tokens.OutputTokens,
TotalTokens: siliconflowResp.Meta.Tokens.InputTokens + siliconflowResp.Meta.Tokens.OutputTokens,
}
rerankResp := &dto.RerankResponse{
Results: siliconflowResp.Results,
Usage: *usage,
}
jsonResponse, err := json.Marshal(rerankResp)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
common.IOCopyBytesGracefully(c, resp, jsonResponse)
return usage, nil
}

View File

@@ -0,0 +1,380 @@
package jimeng
import (
"bytes"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"one-api/model"
"sort"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
)
// ============================
// Request / Response structures
// ============================
type requestPayload struct {
ReqKey string `json:"req_key"`
BinaryDataBase64 []string `json:"binary_data_base64,omitempty"`
ImageUrls []string `json:"image_urls,omitempty"`
Prompt string `json:"prompt,omitempty"`
Seed int64 `json:"seed"`
AspectRatio string `json:"aspect_ratio"`
}
type responsePayload struct {
Code int `json:"code"`
Message string `json:"message"`
RequestId string `json:"request_id"`
Data struct {
TaskID string `json:"task_id"`
} `json:"data"`
}
type responseTask struct {
Code int `json:"code"`
Data struct {
BinaryDataBase64 []interface{} `json:"binary_data_base64"`
ImageUrls interface{} `json:"image_urls"`
RespData string `json:"resp_data"`
Status string `json:"status"`
VideoUrl string `json:"video_url"`
} `json:"data"`
Message string `json:"message"`
RequestId string `json:"request_id"`
Status int `json:"status"`
TimeElapsed string `json:"time_elapsed"`
}
// ============================
// Adaptor implementation
// ============================
type TaskAdaptor struct {
ChannelType int
accessKey string
secretKey string
baseURL string
}
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
a.ChannelType = info.ChannelType
a.baseURL = info.BaseUrl
// apiKey format: "access_key|secret_key"
keyParts := strings.Split(info.ApiKey, "|")
if len(keyParts) == 2 {
a.accessKey = strings.TrimSpace(keyParts[0])
a.secretKey = strings.TrimSpace(keyParts[1])
}
}
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
// Accept only POST /v1/video/generations as "generate" action.
action := constant.TaskActionGenerate
info.Action = action
req := relaycommon.TaskSubmitReq{}
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
return
}
if strings.TrimSpace(req.Prompt) == "" {
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
return
}
// Store into context for later usage
c.Set("task_request", req)
return nil
}
// BuildRequestURL constructs the upstream URL.
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
return fmt.Sprintf("%s/?Action=CVSync2AsyncSubmitTask&Version=2022-08-31", a.baseURL), nil
}
// BuildRequestHeader sets required headers.
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
return a.signRequest(req, a.accessKey, a.secretKey)
}
// BuildRequestBody converts request into Jimeng specific format.
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
v, exists := c.Get("task_request")
if !exists {
return nil, fmt.Errorf("request not found in context")
}
req := v.(relaycommon.TaskSubmitReq)
body, err := a.convertToRequestPayload(&req)
if err != nil {
return nil, errors.Wrap(err, "convert request payload failed")
}
data, err := json.Marshal(body)
if err != nil {
return nil, err
}
return bytes.NewReader(data), nil
}
// DoRequest delegates to common helper.
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoTaskApiRequest(a, c, info, requestBody)
}
// DoResponse handles upstream response, returns taskID etc.
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
return
}
_ = resp.Body.Close()
// Parse Jimeng response
var jResp responsePayload
if err := json.Unmarshal(responseBody, &jResp); err != nil {
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
return
}
if jResp.Code != 10000 {
taskErr = service.TaskErrorWrapper(fmt.Errorf(jResp.Message), fmt.Sprintf("%d", jResp.Code), http.StatusInternalServerError)
return
}
c.JSON(http.StatusOK, gin.H{"task_id": jResp.Data.TaskID})
return jResp.Data.TaskID, responseBody, nil
}
// FetchTask fetch task status
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
taskID, ok := body["task_id"].(string)
if !ok {
return nil, fmt.Errorf("invalid task_id")
}
uri := fmt.Sprintf("%s/?Action=CVSync2AsyncGetResult&Version=2022-08-31", baseUrl)
payload := map[string]string{
"req_key": "jimeng_vgfm_t2v_l20", // This is fixed value from doc: https://www.volcengine.com/docs/85621/1544774
"task_id": taskID,
}
payloadBytes, err := json.Marshal(payload)
if err != nil {
return nil, errors.Wrap(err, "marshal fetch task payload failed")
}
req, err := http.NewRequest(http.MethodPost, uri, bytes.NewBuffer(payloadBytes))
if err != nil {
return nil, err
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Content-Type", "application/json")
keyParts := strings.Split(key, "|")
if len(keyParts) != 2 {
return nil, fmt.Errorf("invalid api key format for jimeng: expected 'ak|sk'")
}
accessKey := strings.TrimSpace(keyParts[0])
secretKey := strings.TrimSpace(keyParts[1])
if err := a.signRequest(req, accessKey, secretKey); err != nil {
return nil, errors.Wrap(err, "sign request failed")
}
return service.GetHttpClient().Do(req)
}
func (a *TaskAdaptor) GetModelList() []string {
return []string{"jimeng_vgfm_t2v_l20"}
}
func (a *TaskAdaptor) GetChannelName() string {
return "jimeng"
}
func (a *TaskAdaptor) signRequest(req *http.Request, accessKey, secretKey string) error {
var bodyBytes []byte
var err error
if req.Body != nil {
bodyBytes, err = io.ReadAll(req.Body)
if err != nil {
return errors.Wrap(err, "read request body failed")
}
_ = req.Body.Close()
req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) // Rewind
} else {
bodyBytes = []byte{}
}
payloadHash := sha256.Sum256(bodyBytes)
hexPayloadHash := hex.EncodeToString(payloadHash[:])
t := time.Now().UTC()
xDate := t.Format("20060102T150405Z")
shortDate := t.Format("20060102")
req.Header.Set("Host", req.URL.Host)
req.Header.Set("X-Date", xDate)
req.Header.Set("X-Content-Sha256", hexPayloadHash)
// Sort and encode query parameters to create canonical query string
queryParams := req.URL.Query()
sortedKeys := make([]string, 0, len(queryParams))
for k := range queryParams {
sortedKeys = append(sortedKeys, k)
}
sort.Strings(sortedKeys)
var queryParts []string
for _, k := range sortedKeys {
values := queryParams[k]
sort.Strings(values)
for _, v := range values {
queryParts = append(queryParts, fmt.Sprintf("%s=%s", url.QueryEscape(k), url.QueryEscape(v)))
}
}
canonicalQueryString := strings.Join(queryParts, "&")
headersToSign := map[string]string{
"host": req.URL.Host,
"x-date": xDate,
"x-content-sha256": hexPayloadHash,
}
if req.Header.Get("Content-Type") != "" {
headersToSign["content-type"] = req.Header.Get("Content-Type")
}
var signedHeaderKeys []string
for k := range headersToSign {
signedHeaderKeys = append(signedHeaderKeys, k)
}
sort.Strings(signedHeaderKeys)
var canonicalHeaders strings.Builder
for _, k := range signedHeaderKeys {
canonicalHeaders.WriteString(k)
canonicalHeaders.WriteString(":")
canonicalHeaders.WriteString(strings.TrimSpace(headersToSign[k]))
canonicalHeaders.WriteString("\n")
}
signedHeaders := strings.Join(signedHeaderKeys, ";")
canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
req.Method,
req.URL.Path,
canonicalQueryString,
canonicalHeaders.String(),
signedHeaders,
hexPayloadHash,
)
hashedCanonicalRequest := sha256.Sum256([]byte(canonicalRequest))
hexHashedCanonicalRequest := hex.EncodeToString(hashedCanonicalRequest[:])
region := "cn-north-1"
serviceName := "cv"
credentialScope := fmt.Sprintf("%s/%s/%s/request", shortDate, region, serviceName)
stringToSign := fmt.Sprintf("HMAC-SHA256\n%s\n%s\n%s",
xDate,
credentialScope,
hexHashedCanonicalRequest,
)
kDate := hmacSHA256([]byte(secretKey), []byte(shortDate))
kRegion := hmacSHA256(kDate, []byte(region))
kService := hmacSHA256(kRegion, []byte(serviceName))
kSigning := hmacSHA256(kService, []byte("request"))
signature := hex.EncodeToString(hmacSHA256(kSigning, []byte(stringToSign)))
authorization := fmt.Sprintf("HMAC-SHA256 Credential=%s/%s, SignedHeaders=%s, Signature=%s",
accessKey,
credentialScope,
signedHeaders,
signature,
)
req.Header.Set("Authorization", authorization)
return nil
}
func hmacSHA256(key []byte, data []byte) []byte {
h := hmac.New(sha256.New, key)
h.Write(data)
return h.Sum(nil)
}
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
r := requestPayload{
ReqKey: "jimeng_vgfm_i2v_l20",
Prompt: req.Prompt,
AspectRatio: "16:9", // Default aspect ratio
Seed: -1, // Default to random
}
// Handle one-of image_urls or binary_data_base64
if req.Image != "" {
if strings.HasPrefix(req.Image, "http") {
r.ImageUrls = []string{req.Image}
} else {
r.BinaryDataBase64 = []string{req.Image}
}
}
metadata := req.Metadata
medaBytes, err := json.Marshal(metadata)
if err != nil {
return nil, errors.Wrap(err, "metadata marshal metadata failed")
}
err = json.Unmarshal(medaBytes, &r)
if err != nil {
return nil, errors.Wrap(err, "unmarshal metadata failed")
}
return &r, nil
}
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
resTask := responseTask{}
if err := json.Unmarshal(respBody, &resTask); err != nil {
return nil, errors.Wrap(err, "unmarshal task result failed")
}
taskResult := relaycommon.TaskInfo{}
if resTask.Code == 10000 {
taskResult.Code = 0
} else {
taskResult.Code = resTask.Code // todo uni code
taskResult.Reason = resTask.Message
taskResult.Status = model.TaskStatusFailure
taskResult.Progress = "100%"
}
switch resTask.Data.Status {
case "in_queue":
taskResult.Status = model.TaskStatusQueued
taskResult.Progress = "10%"
case "done":
taskResult.Status = model.TaskStatusSuccess
taskResult.Progress = "100%"
}
taskResult.Url = resTask.Data.VideoUrl
return &taskResult, nil
}

View File

@@ -0,0 +1,346 @@
package kling
import (
"bytes"
"encoding/json"
"fmt"
"github.com/samber/lo"
"io"
"net/http"
"one-api/model"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt"
"github.com/pkg/errors"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
)
// ============================
// Request / Response structures
// ============================
type SubmitReq struct {
Prompt string `json:"prompt"`
Model string `json:"model,omitempty"`
Mode string `json:"mode,omitempty"`
Image string `json:"image,omitempty"`
Size string `json:"size,omitempty"`
Duration int `json:"duration,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
type requestPayload struct {
Prompt string `json:"prompt,omitempty"`
Image string `json:"image,omitempty"`
Mode string `json:"mode,omitempty"`
Duration string `json:"duration,omitempty"`
AspectRatio string `json:"aspect_ratio,omitempty"`
ModelName string `json:"model_name,omitempty"`
CfgScale float64 `json:"cfg_scale,omitempty"`
}
type responsePayload struct {
Code int `json:"code"`
Message string `json:"message"`
RequestId string `json:"request_id"`
Data struct {
TaskId string `json:"task_id"`
TaskStatus string `json:"task_status"`
TaskStatusMsg string `json:"task_status_msg"`
TaskResult struct {
Videos []struct {
Id string `json:"id"`
Url string `json:"url"`
Duration string `json:"duration"`
} `json:"videos"`
} `json:"task_result"`
CreatedAt int64 `json:"created_at"`
UpdatedAt int64 `json:"updated_at"`
} `json:"data"`
}
// ============================
// Adaptor implementation
// ============================
type TaskAdaptor struct {
ChannelType int
accessKey string
secretKey string
baseURL string
}
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
a.ChannelType = info.ChannelType
a.baseURL = info.BaseUrl
// apiKey format: "access_key|secret_key"
keyParts := strings.Split(info.ApiKey, "|")
if len(keyParts) == 2 {
a.accessKey = strings.TrimSpace(keyParts[0])
a.secretKey = strings.TrimSpace(keyParts[1])
}
}
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
// Accept only POST /v1/video/generations as "generate" action.
action := constant.TaskActionGenerate
info.Action = action
var req SubmitReq
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
return
}
if strings.TrimSpace(req.Prompt) == "" {
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
return
}
// Store into context for later usage
c.Set("task_request", req)
return nil
}
// BuildRequestURL constructs the upstream URL.
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
path := lo.Ternary(info.Action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
return fmt.Sprintf("%s%s", a.baseURL, path), nil
}
// BuildRequestHeader sets required headers.
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
token, err := a.createJWTToken()
if err != nil {
return fmt.Errorf("failed to create JWT token: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("User-Agent", "kling-sdk/1.0")
return nil
}
// BuildRequestBody converts request into Kling specific format.
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
v, exists := c.Get("task_request")
if !exists {
return nil, fmt.Errorf("request not found in context")
}
req := v.(SubmitReq)
body, err := a.convertToRequestPayload(&req)
if err != nil {
return nil, err
}
data, err := json.Marshal(body)
if err != nil {
return nil, err
}
return bytes.NewReader(data), nil
}
// DoRequest delegates to common helper.
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
if action := c.GetString("action"); action != "" {
info.Action = action
}
return channel.DoTaskApiRequest(a, c, info, requestBody)
}
// DoResponse handles upstream response, returns taskID etc.
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
return
}
// Attempt Kling response parse first.
var kResp responsePayload
if err := json.Unmarshal(responseBody, &kResp); err == nil && kResp.Code == 0 {
c.JSON(http.StatusOK, gin.H{"task_id": kResp.Data.TaskId})
return kResp.Data.TaskId, responseBody, nil
}
// Fallback generic task response.
var generic dto.TaskResponse[string]
if err := json.Unmarshal(responseBody, &generic); err != nil {
taskErr = service.TaskErrorWrapper(errors.Wrapf(err, "body: %s", responseBody), "unmarshal_response_body_failed", http.StatusInternalServerError)
return
}
if !generic.IsSuccess() {
taskErr = service.TaskErrorWrapper(fmt.Errorf(generic.Message), generic.Code, http.StatusInternalServerError)
return
}
c.JSON(http.StatusOK, gin.H{"task_id": generic.Data})
return generic.Data, responseBody, nil
}
// FetchTask fetch task status
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
taskID, ok := body["task_id"].(string)
if !ok {
return nil, fmt.Errorf("invalid task_id")
}
action, ok := body["action"].(string)
if !ok {
return nil, fmt.Errorf("invalid action")
}
path := lo.Ternary(action == constant.TaskActionGenerate, "/v1/videos/image2video", "/v1/videos/text2video")
url := fmt.Sprintf("%s%s/%s", baseUrl, path, taskID)
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return nil, err
}
token, err := a.createJWTTokenWithKey(key)
if err != nil {
token = key
}
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("User-Agent", "kling-sdk/1.0")
return service.GetHttpClient().Do(req)
}
func (a *TaskAdaptor) GetModelList() []string {
return []string{"kling-v1", "kling-v1-6", "kling-v2-master"}
}
func (a *TaskAdaptor) GetChannelName() string {
return "kling"
}
// ============================
// helpers
// ============================
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
r := requestPayload{
Prompt: req.Prompt,
Image: req.Image,
Mode: defaultString(req.Mode, "std"),
Duration: fmt.Sprintf("%d", defaultInt(req.Duration, 5)),
AspectRatio: a.getAspectRatio(req.Size),
ModelName: req.Model,
CfgScale: 0.5,
}
if r.ModelName == "" {
r.ModelName = "kling-v1"
}
metadata := req.Metadata
medaBytes, err := json.Marshal(metadata)
if err != nil {
return nil, errors.Wrap(err, "metadata marshal metadata failed")
}
err = json.Unmarshal(medaBytes, &r)
if err != nil {
return nil, errors.Wrap(err, "unmarshal metadata failed")
}
return &r, nil
}
func (a *TaskAdaptor) getAspectRatio(size string) string {
switch size {
case "1024x1024", "512x512":
return "1:1"
case "1280x720", "1920x1080":
return "16:9"
case "720x1280", "1080x1920":
return "9:16"
default:
return "1:1"
}
}
func defaultString(s, def string) string {
if strings.TrimSpace(s) == "" {
return def
}
return s
}
func defaultInt(v int, def int) int {
if v == 0 {
return def
}
return v
}
// ============================
// JWT helpers
// ============================
func (a *TaskAdaptor) createJWTToken() (string, error) {
return a.createJWTTokenWithKeys(a.accessKey, a.secretKey)
}
func (a *TaskAdaptor) createJWTTokenWithKey(apiKey string) (string, error) {
parts := strings.Split(apiKey, "|")
if len(parts) != 2 {
return "", fmt.Errorf("invalid API key format, expected 'access_key,secret_key'")
}
return a.createJWTTokenWithKeys(strings.TrimSpace(parts[0]), strings.TrimSpace(parts[1]))
}
func (a *TaskAdaptor) createJWTTokenWithKeys(accessKey, secretKey string) (string, error) {
if accessKey == "" || secretKey == "" {
return "", fmt.Errorf("access key and secret key are required")
}
now := time.Now().Unix()
claims := jwt.MapClaims{
"iss": accessKey,
"exp": now + 1800, // 30 minutes
"nbf": now - 5,
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
token.Header["typ"] = "JWT"
return token.SignedString([]byte(secretKey))
}
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
resPayload := responsePayload{}
err := json.Unmarshal(respBody, &resPayload)
if err != nil {
return nil, errors.Wrap(err, "failed to unmarshal response body")
}
taskInfo := &relaycommon.TaskInfo{}
taskInfo.Code = resPayload.Code
taskInfo.TaskID = resPayload.Data.TaskId
taskInfo.Reason = resPayload.Message
//任务状态枚举值submitted已提交、processing处理中、succeed成功、failed失败
status := resPayload.Data.TaskStatus
switch status {
case "submitted":
taskInfo.Status = model.TaskStatusSubmitted
case "processing":
taskInfo.Status = model.TaskStatusInProgress
case "succeed":
taskInfo.Status = model.TaskStatusSuccess
case "failed":
taskInfo.Status = model.TaskStatusFailure
default:
return nil, fmt.Errorf("unknown task status: %s", status)
}
if videos := resPayload.Data.TaskResult.Videos; len(videos) > 0 {
video := videos[0]
taskInfo.Url = video.Url
}
return taskInfo, nil
}

View File

@@ -0,0 +1,176 @@
package suno
import (
"bytes"
"context"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"time"
)
type TaskAdaptor struct {
ChannelType int
}
func (a *TaskAdaptor) ParseTaskResult([]byte) (*relaycommon.TaskInfo, error) {
return nil, fmt.Errorf("not implement") // todo implement this method if needed
}
func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {
a.ChannelType = info.ChannelType
}
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) {
action := strings.ToUpper(c.Param("action"))
var sunoRequest *dto.SunoSubmitReq
err := common.UnmarshalBodyReusable(c, &sunoRequest)
if err != nil {
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
return
}
err = actionValidate(c, sunoRequest, action)
if err != nil {
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
return
}
if sunoRequest.ContinueClipId != "" {
if sunoRequest.TaskID == "" {
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("task id is empty"), "invalid_request", http.StatusBadRequest)
return
}
info.OriginTaskID = sunoRequest.TaskID
}
info.Action = action
c.Set("task_request", sunoRequest)
return nil
}
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) {
baseURL := info.BaseUrl
fullRequestURL := fmt.Sprintf("%s%s", baseURL, "/suno/submit/"+info.Action)
return fullRequestURL, nil
}
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error {
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
req.Header.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.TaskRelayInfo) (io.Reader, error) {
sunoRequest, ok := c.Get("task_request")
if !ok {
err := common.UnmarshalBodyReusable(c, &sunoRequest)
if err != nil {
return nil, err
}
}
data, err := json.Marshal(sunoRequest)
if err != nil {
return nil, err
}
return bytes.NewReader(data), nil
}
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) {
return channel.DoTaskApiRequest(a, c, info, requestBody)
}
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
return
}
var sunoResponse dto.TaskResponse[string]
err = json.Unmarshal(responseBody, &sunoResponse)
if err != nil {
taskErr = service.TaskErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
return
}
if !sunoResponse.IsSuccess() {
taskErr = service.TaskErrorWrapper(fmt.Errorf(sunoResponse.Message), sunoResponse.Code, http.StatusInternalServerError)
return
}
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, bytes.NewBuffer(responseBody))
if err != nil {
taskErr = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
return
}
return sunoResponse.Data, nil, nil
}
func (a *TaskAdaptor) GetModelList() []string {
return ModelList
}
func (a *TaskAdaptor) GetChannelName() string {
return ChannelName
}
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
requestUrl := fmt.Sprintf("%s/suno/fetch", baseUrl)
byteBody, err := json.Marshal(body)
if err != nil {
return nil, err
}
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(byteBody))
if err != nil {
common.SysError(fmt.Sprintf("Get Task error: %v", err))
return nil, err
}
defer req.Body.Close()
// 设置超时时间
timeout := time.Second * 15
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
// 使用带有超时的 context 创建新的请求
req = req.WithContext(ctx)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+key)
resp, err := service.GetHttpClient().Do(req)
if err != nil {
return nil, err
}
return resp, nil
}
func actionValidate(c *gin.Context, sunoRequest *dto.SunoSubmitReq, action string) (err error) {
switch action {
case constant.SunoActionMusic:
if sunoRequest.Mv == "" {
sunoRequest.Mv = "chirp-v3-0"
}
case constant.SunoActionLyrics:
if sunoRequest.Prompt == "" {
err = fmt.Errorf("prompt_empty")
return
}
default:
err = fmt.Errorf("invalid_action")
}
return
}

View File

@@ -0,0 +1,7 @@
package suno
var ModelList = []string{
"suno_music", "suno_lyrics",
}
var ChannelName = "suno"

View File

@@ -0,0 +1,113 @@
package tencent
import (
"errors"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/relay/channel"
relaycommon "one-api/relay/common"
"one-api/types"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
Sign string
AppID int64
Action string
Version string
Timestamp int64
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
a.Action = "ChatCompletions"
a.Version = "2023-09-01"
a.Timestamp = common.GetTimestamp()
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
return fmt.Sprintf("%s/", info.BaseUrl), nil
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", a.Sign)
req.Set("X-TC-Action", a.Action)
req.Set("X-TC-Version", a.Version)
req.Set("X-TC-Timestamp", strconv.FormatInt(a.Timestamp, 10))
return nil
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
apiKey := common.GetContextKeyString(c, constant.ContextKeyChannelKey)
apiKey = strings.TrimPrefix(apiKey, "Bearer ")
appId, secretId, secretKey, err := parseTencentConfig(apiKey)
a.AppID = appId
if err != nil {
return nil, err
}
tencentRequest := requestOpenAI2Tencent(a, *request)
// we have to calculate the sign here
a.Sign = getTencentSign(*tencentRequest, a, secretId, secretKey)
return tencentRequest, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
usage, err = tencentStreamHandler(c, info, resp)
} else {
usage, err = tencentHandler(c, info, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,10 @@
package tencent
var ModelList = []string{
"hunyuan-lite",
"hunyuan-standard",
"hunyuan-standard-256K",
"hunyuan-pro",
}
var ChannelName = "tencent"

View File

@@ -0,0 +1,75 @@
package tencent
type TencentMessage struct {
Role string `json:"Role"`
Content string `json:"Content"`
}
type TencentChatRequest struct {
// 模型名称,可选值包括 hunyuan-lite、hunyuan-standard、hunyuan-standard-256K、hunyuan-pro。
// 各模型介绍请阅读 [产品概述](https://cloud.tencent.com/document/product/1729/104753) 中的说明。
//
// 注意:
// 不同的模型计费不同,请根据 [购买指南](https://cloud.tencent.com/document/product/1729/97731) 按需调用。
Model *string `json:"Model"`
// 聊天上下文信息。
// 说明:
// 1. 长度最多为 40按对话时间从旧到新在数组中排列。
// 2. Message.Role 可选值system、user、assistant。
// 其中system 角色可选如存在则必须位于列表的最开始。user 和 assistant 需交替出现(一问一答),以 user 提问开始和结束,且 Content 不能为空。Role 的顺序示例:[system可选 user assistant user assistant user ...]。
// 3. Messages 中 Content 总长度不能超过模型输入长度上限(可参考 [产品概述](https://cloud.tencent.com/document/product/1729/104753) 文档),超过则会截断最前面的内容,只保留尾部内容。
Messages []*TencentMessage `json:"Messages"`
// 流式调用开关。
// 说明:
// 1. 未传值时默认为非流式调用false
// 2. 流式调用时以 SSE 协议增量返回结果(返回值取 Choices[n].Delta 中的值,需要拼接增量数据才能获得完整结果)。
// 3. 非流式调用时:
// 调用方式与普通 HTTP 请求无异。
// 接口响应耗时较长,**如需更低时延建议设置为 true**。
// 只返回一次最终结果(返回值取 Choices[n].Message 中的值)。
//
// 注意:
// 通过 SDK 调用时,流式和非流式调用需用**不同的方式**获取返回值,具体参考 SDK 中的注释或示例(在各语言 SDK 代码仓库的 examples/hunyuan/v20230901/ 目录中)。
Stream *bool `json:"Stream,omitempty"`
// 说明:
// 1. 影响输出文本的多样性,取值越大,生成文本的多样性越强。
// 2. 取值区间为 [0.0, 1.0],未传值时使用各模型推荐值。
// 3. 非必要不建议使用,不合理的取值会影响效果。
TopP *float64 `json:"TopP,omitempty"`
// 说明:
// 1. 较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定。
// 2. 取值区间为 [0.0, 2.0],未传值时使用各模型推荐值。
// 3. 非必要不建议使用,不合理的取值会影响效果。
Temperature *float64 `json:"Temperature,omitempty"`
}
type TencentError struct {
Code int `json:"Code"`
Message string `json:"Message"`
}
type TencentUsage struct {
PromptTokens int `json:"PromptTokens"`
CompletionTokens int `json:"CompletionTokens"`
TotalTokens int `json:"TotalTokens"`
}
type TencentResponseChoices struct {
FinishReason string `json:"FinishReason,omitempty"` // 流式结束标志位,为 stop 则表示尾包
Messages TencentMessage `json:"Message,omitempty"` // 内容,同步模式返回内容,流模式为 null 输出 content 内容总数最多支持 1024token。
Delta TencentMessage `json:"Delta,omitempty"` // 内容,流模式返回内容,同步模式为 null 输出 content 内容总数最多支持 1024token。
}
type TencentChatResponse struct {
Choices []TencentResponseChoices `json:"Choices,omitempty"` // 结果
Created int64 `json:"Created,omitempty"` // unix 时间戳的字符串
Id string `json:"Id,omitempty"` // 会话 id
Usage TencentUsage `json:"Usage,omitempty"` // token 数量
Error TencentError `json:"Error,omitempty"` // 错误信息 注意:此字段可能返回 null表示取不到有效值
Note string `json:"Note,omitempty"` // 注释
ReqID string `json:"Req_id,omitempty"` // 唯一请求 Id每次请求都会返回。用于反馈接口入参
}
type TencentChatResponseSB struct {
Response TencentChatResponse `json:"Response,omitempty"`
}

View File

@@ -0,0 +1,233 @@
package tencent
import (
"bufio"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"one-api/types"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
)
// https://cloud.tencent.com/document/product/1729/97732
func requestOpenAI2Tencent(a *Adaptor, request dto.GeneralOpenAIRequest) *TencentChatRequest {
messages := make([]*TencentMessage, 0, len(request.Messages))
for i := 0; i < len(request.Messages); i++ {
message := request.Messages[i]
messages = append(messages, &TencentMessage{
Content: message.StringContent(),
Role: message.Role,
})
}
var req = TencentChatRequest{
Stream: &request.Stream,
Messages: messages,
Model: &request.Model,
}
if request.TopP != 0 {
req.TopP = &request.TopP
}
req.Temperature = request.Temperature
return &req
}
func responseTencent2OpenAI(response *TencentChatResponse) *dto.OpenAITextResponse {
fullTextResponse := dto.OpenAITextResponse{
Id: response.Id,
Object: "chat.completion",
Created: common.GetTimestamp(),
Usage: dto.Usage{
PromptTokens: response.Usage.PromptTokens,
CompletionTokens: response.Usage.CompletionTokens,
TotalTokens: response.Usage.TotalTokens,
},
}
if len(response.Choices) > 0 {
choice := dto.OpenAITextResponseChoice{
Index: 0,
Message: dto.Message{
Role: "assistant",
Content: response.Choices[0].Messages.Content,
},
FinishReason: response.Choices[0].FinishReason,
}
fullTextResponse.Choices = append(fullTextResponse.Choices, choice)
}
return &fullTextResponse
}
func streamResponseTencent2OpenAI(TencentResponse *TencentChatResponse) *dto.ChatCompletionsStreamResponse {
response := dto.ChatCompletionsStreamResponse{
Object: "chat.completion.chunk",
Created: common.GetTimestamp(),
Model: "tencent-hunyuan",
}
if len(TencentResponse.Choices) > 0 {
var choice dto.ChatCompletionsStreamResponseChoice
choice.Delta.SetContentString(TencentResponse.Choices[0].Delta.Content)
if TencentResponse.Choices[0].FinishReason == "stop" {
choice.FinishReason = &constant.FinishReasonStop
}
response.Choices = append(response.Choices, choice)
}
return &response
}
func tencentStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var responseText string
scanner := bufio.NewScanner(resp.Body)
scanner.Split(bufio.ScanLines)
helper.SetEventStreamHeaders(c)
for scanner.Scan() {
data := scanner.Text()
if len(data) < 5 || !strings.HasPrefix(data, "data:") {
continue
}
data = strings.TrimPrefix(data, "data:")
var tencentResponse TencentChatResponse
err := json.Unmarshal([]byte(data), &tencentResponse)
if err != nil {
common.SysError("error unmarshalling stream response: " + err.Error())
continue
}
response := streamResponseTencent2OpenAI(&tencentResponse)
if len(response.Choices) != 0 {
responseText += response.Choices[0].Delta.GetContentString()
}
err = helper.ObjectData(c, response)
if err != nil {
common.SysError(err.Error())
}
}
if err := scanner.Err(); err != nil {
common.SysError("error reading stream: " + err.Error())
}
helper.Done(c)
common.CloseResponseBodyGracefully(resp)
return service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens), nil
}
func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
var tencentSb TencentChatResponseSB
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &tencentSb)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
if tencentSb.Response.Error.Code != 0 {
return nil, types.WithOpenAIError(types.OpenAIError{
Message: tencentSb.Response.Error.Message,
Code: tencentSb.Response.Error.Code,
}, resp.StatusCode)
}
fullTextResponse := responseTencent2OpenAI(&tencentSb.Response)
jsonResponse, err := common.Marshal(fullTextResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
common.IOCopyBytesGracefully(c, resp, jsonResponse)
return &fullTextResponse.Usage, nil
}
func parseTencentConfig(config string) (appId int64, secretId string, secretKey string, err error) {
parts := strings.Split(config, "|")
if len(parts) != 3 {
err = errors.New("invalid tencent config")
return
}
appId, err = strconv.ParseInt(parts[0], 10, 64)
secretId = parts[1]
secretKey = parts[2]
return
}
func sha256hex(s string) string {
b := sha256.Sum256([]byte(s))
return hex.EncodeToString(b[:])
}
func hmacSha256(s, key string) string {
hashed := hmac.New(sha256.New, []byte(key))
hashed.Write([]byte(s))
return string(hashed.Sum(nil))
}
func getTencentSign(req TencentChatRequest, adaptor *Adaptor, secId, secKey string) string {
// build canonical request string
host := "hunyuan.tencentcloudapi.com"
httpRequestMethod := "POST"
canonicalURI := "/"
canonicalQueryString := ""
canonicalHeaders := fmt.Sprintf("content-type:%s\nhost:%s\nx-tc-action:%s\n",
"application/json", host, strings.ToLower(adaptor.Action))
signedHeaders := "content-type;host;x-tc-action"
payload, _ := json.Marshal(req)
hashedRequestPayload := sha256hex(string(payload))
canonicalRequest := fmt.Sprintf("%s\n%s\n%s\n%s\n%s\n%s",
httpRequestMethod,
canonicalURI,
canonicalQueryString,
canonicalHeaders,
signedHeaders,
hashedRequestPayload)
// build string to sign
algorithm := "TC3-HMAC-SHA256"
requestTimestamp := strconv.FormatInt(adaptor.Timestamp, 10)
timestamp, _ := strconv.ParseInt(requestTimestamp, 10, 64)
t := time.Unix(timestamp, 0).UTC()
// must be the format 2006-01-02, ref to package time for more info
date := t.Format("2006-01-02")
credentialScope := fmt.Sprintf("%s/%s/tc3_request", date, "hunyuan")
hashedCanonicalRequest := sha256hex(canonicalRequest)
string2sign := fmt.Sprintf("%s\n%s\n%s\n%s",
algorithm,
requestTimestamp,
credentialScope,
hashedCanonicalRequest)
// sign string
secretDate := hmacSha256(date, "TC3"+secKey)
secretService := hmacSha256("hunyuan", secretDate)
secretKey := hmacSha256("tc3_request", secretService)
signature := hex.EncodeToString([]byte(hmacSha256(string2sign, secretKey)))
// build authorization
authorization := fmt.Sprintf("%s Credential=%s/%s, SignedHeaders=%s, Signature=%s",
algorithm,
secId,
credentialScope,
signedHeaders,
signature)
return authorization
}

View File

@@ -0,0 +1,262 @@
package vertex
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/claude"
"one-api/relay/channel/gemini"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/setting/model_setting"
"one-api/types"
"strings"
"github.com/gin-gonic/gin"
)
const (
RequestModeClaude = 1
RequestModeGemini = 2
RequestModeLlama = 3
)
var claudeModelMap = map[string]string{
"claude-3-sonnet-20240229": "claude-3-sonnet@20240229",
"claude-3-opus-20240229": "claude-3-opus@20240229",
"claude-3-haiku-20240307": "claude-3-haiku@20240307",
"claude-3-5-sonnet-20240620": "claude-3-5-sonnet@20240620",
"claude-3-5-sonnet-20241022": "claude-3-5-sonnet-v2@20241022",
"claude-3-7-sonnet-20250219": "claude-3-7-sonnet@20250219",
"claude-sonnet-4-20250514": "claude-sonnet-4@20250514",
"claude-opus-4-20250514": "claude-opus-4@20250514",
}
const anthropicVersion = "vertex-2023-10-16"
type Adaptor struct {
RequestMode int
AccountCredentials Credentials
}
func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) {
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
c.Set("request_model", v)
} else {
c.Set("request_model", request.Model)
}
vertexClaudeReq := copyRequest(request, anthropicVersion)
return vertexClaudeReq, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
if strings.HasPrefix(info.UpstreamModelName, "claude") {
a.RequestMode = RequestModeClaude
} else if strings.HasPrefix(info.UpstreamModelName, "gemini") {
a.RequestMode = RequestModeGemini
} else if strings.Contains(info.UpstreamModelName, "llama") {
a.RequestMode = RequestModeLlama
}
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
adc := &Credentials{}
if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil {
return "", fmt.Errorf("failed to decode credentials file: %w", err)
}
region := GetModelRegion(info.ApiVersion, info.OriginModelName)
a.AccountCredentials = *adc
suffix := ""
if a.RequestMode == RequestModeGemini {
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
// 新增逻辑:处理 -thinking-<budget> 格式
if strings.Contains(info.UpstreamModelName, "-thinking-") {
parts := strings.Split(info.UpstreamModelName, "-thinking-")
info.UpstreamModelName = parts[0]
} else if strings.HasSuffix(info.UpstreamModelName, "-thinking") { // 旧的适配
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-thinking")
} else if strings.HasSuffix(info.UpstreamModelName, "-nothinking") {
info.UpstreamModelName = strings.TrimSuffix(info.UpstreamModelName, "-nothinking")
}
}
if info.IsStream {
suffix = "streamGenerateContent?alt=sse"
} else {
suffix = "generateContent"
}
if region == "global" {
return fmt.Sprintf(
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
adc.ProjectID,
info.UpstreamModelName,
suffix,
), nil
} else {
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
region,
adc.ProjectID,
region,
info.UpstreamModelName,
suffix,
), nil
}
} else if a.RequestMode == RequestModeClaude {
if info.IsStream {
suffix = "streamRawPredict?alt=sse"
} else {
suffix = "rawPredict"
}
model := info.UpstreamModelName
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
model = v
}
if region == "global" {
return fmt.Sprintf(
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:%s",
adc.ProjectID,
model,
suffix,
), nil
} else {
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
region,
adc.ProjectID,
region,
model,
suffix,
), nil
}
} else if a.RequestMode == RequestModeLlama {
return fmt.Sprintf(
"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
region,
adc.ProjectID,
region,
), nil
}
return "", errors.New("unsupported request mode")
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
accessToken, err := getAccessToken(a, info)
if err != nil {
return err
}
req.Set("Authorization", "Bearer "+accessToken)
return nil
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
if a.RequestMode == RequestModeClaude {
claudeReq, err := claude.RequestOpenAI2ClaudeMessage(*request)
if err != nil {
return nil, err
}
vertexClaudeReq := copyRequest(claudeReq, anthropicVersion)
c.Set("request_model", claudeReq.Model)
info.UpstreamModelName = claudeReq.Model
return vertexClaudeReq, nil
} else if a.RequestMode == RequestModeGemini {
geminiRequest, err := gemini.CovertGemini2OpenAI(*request, info)
if err != nil {
return nil, err
}
c.Set("request_model", request.Model)
return geminiRequest, nil
} else if a.RequestMode == RequestModeLlama {
return request, nil
}
return nil, errors.New("unsupported request mode")
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
if info.IsStream {
switch a.RequestMode {
case RequestModeClaude:
err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
case RequestModeGemini:
if info.RelayMode == constant.RelayModeGemini {
usage, err = gemini.GeminiTextGenerationStreamHandler(c, info, resp)
} else {
usage, err = gemini.GeminiChatStreamHandler(c, info, resp)
}
case RequestModeLlama:
usage, err = openai.OaiStreamHandler(c, info, resp)
}
} else {
switch a.RequestMode {
case RequestModeClaude:
err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info)
case RequestModeGemini:
if info.RelayMode == constant.RelayModeGemini {
usage, err = gemini.GeminiTextGenerationHandler(c, info, resp)
} else {
usage, err = gemini.GeminiChatHandler(c, info, resp)
}
case RequestModeLlama:
usage, err = openai.OpenaiHandler(c, info, resp)
}
}
return
}
func (a *Adaptor) GetModelList() []string {
var modelList []string
for i, s := range ModelList {
modelList = append(modelList, s)
ModelList[i] = s
}
for i, s := range claude.ModelList {
modelList = append(modelList, s)
claude.ModelList[i] = s
}
for i, s := range gemini.ModelList {
modelList = append(modelList, s)
gemini.ModelList[i] = s
}
return modelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

View File

@@ -0,0 +1,15 @@
package vertex
var ModelList = []string{
//"claude-3-sonnet-20240229",
//"claude-3-opus-20240229",
//"claude-3-haiku-20240307",
//"claude-3-5-sonnet-20240620",
//"gemini-1.5-pro-latest", "gemini-1.5-flash-latest",
//"gemini-1.5-pro-001", "gemini-1.5-flash-001", "gemini-pro", "gemini-pro-vision",
"meta/llama3-405b-instruct-maas",
}
var ChannelName = "vertex-ai"

View File

@@ -0,0 +1,37 @@
package vertex
import (
"one-api/dto"
)
type VertexAIClaudeRequest struct {
AnthropicVersion string `json:"anthropic_version"`
Messages []dto.ClaudeMessage `json:"messages"`
System any `json:"system,omitempty"`
MaxTokens uint `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Tools any `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Thinking *dto.Thinking `json:"thinking,omitempty"`
}
func copyRequest(req *dto.ClaudeRequest, version string) *VertexAIClaudeRequest {
return &VertexAIClaudeRequest{
AnthropicVersion: version,
System: req.System,
Messages: req.Messages,
MaxTokens: req.MaxTokens,
Stream: req.Stream,
Temperature: req.Temperature,
TopP: req.TopP,
TopK: req.TopK,
StopSequences: req.StopSequences,
Tools: req.Tools,
ToolChoice: req.ToolChoice,
Thinking: req.Thinking,
}
}

View File

@@ -0,0 +1,19 @@
package vertex
import "one-api/common"
func GetModelRegion(other string, localModelName string) string {
// if other is json string
if common.IsJsonObject(other) {
m, err := common.StrToMap(other)
if err != nil {
return other // return original if parsing fails
}
if m[localModelName] != nil {
return m[localModelName].(string)
} else {
return m["default"].(string)
}
}
return other
}

View File

@@ -0,0 +1,134 @@
package vertex
import (
"crypto/rsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"errors"
"github.com/bytedance/gopkg/cache/asynccache"
"github.com/golang-jwt/jwt"
"net/http"
"net/url"
relaycommon "one-api/relay/common"
"one-api/service"
"strings"
"fmt"
"time"
)
type Credentials struct {
ProjectID string `json:"project_id"`
PrivateKeyID string `json:"private_key_id"`
PrivateKey string `json:"private_key"`
ClientEmail string `json:"client_email"`
ClientID string `json:"client_id"`
}
var Cache = asynccache.NewAsyncCache(asynccache.Options{
RefreshDuration: time.Minute * 35,
EnableExpire: true,
ExpireDuration: time.Minute * 30,
Fetcher: func(key string) (interface{}, error) {
return nil, errors.New("not found")
},
})
func getAccessToken(a *Adaptor, info *relaycommon.RelayInfo) (string, error) {
cacheKey := fmt.Sprintf("access-token-%d", info.ChannelId)
val, err := Cache.Get(cacheKey)
if err == nil {
return val.(string), nil
}
signedJWT, err := createSignedJWT(a.AccountCredentials.ClientEmail, a.AccountCredentials.PrivateKey)
if err != nil {
return "", fmt.Errorf("failed to create signed JWT: %w", err)
}
newToken, err := exchangeJwtForAccessToken(signedJWT, info)
if err != nil {
return "", fmt.Errorf("failed to exchange JWT for access token: %w", err)
}
if err := Cache.SetDefault(cacheKey, newToken); err {
return newToken, nil
}
return newToken, nil
}
func createSignedJWT(email, privateKeyPEM string) (string, error) {
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "-----BEGIN PRIVATE KEY-----", "")
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "-----END PRIVATE KEY-----", "")
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\r", "")
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\n", "")
privateKeyPEM = strings.ReplaceAll(privateKeyPEM, "\\n", "")
block, _ := pem.Decode([]byte("-----BEGIN PRIVATE KEY-----\n" + privateKeyPEM + "\n-----END PRIVATE KEY-----"))
if block == nil {
return "", fmt.Errorf("failed to parse PEM block containing the private key")
}
privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return "", err
}
rsaPrivateKey, ok := privateKey.(*rsa.PrivateKey)
if !ok {
return "", fmt.Errorf("not an RSA private key")
}
now := time.Now()
claims := jwt.MapClaims{
"iss": email,
"scope": "https://www.googleapis.com/auth/cloud-platform",
"aud": "https://www.googleapis.com/oauth2/v4/token",
"exp": now.Add(time.Minute * 35).Unix(),
"iat": now.Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
signedToken, err := token.SignedString(rsaPrivateKey)
if err != nil {
return "", err
}
return signedToken, nil
}
func exchangeJwtForAccessToken(signedJWT string, info *relaycommon.RelayInfo) (string, error) {
authURL := "https://www.googleapis.com/oauth2/v4/token"
data := url.Values{}
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
data.Set("assertion", signedJWT)
var client *http.Client
var err error
if info.ChannelSetting.Proxy != "" {
client, err = service.NewProxyHttpClient(info.ChannelSetting.Proxy)
if err != nil {
return "", fmt.Errorf("new proxy http client failed: %w", err)
}
} else {
client = service.GetHttpClient()
}
resp, err := client.PostForm(authURL, data)
if err != nil {
return "", err
}
defer resp.Body.Close()
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", err
}
if accessToken, ok := result["access_token"].(string); ok {
return accessToken, nil
}
return "", fmt.Errorf("failed to get access token: %v", result)
}

View File

@@ -0,0 +1,251 @@
package volcengine
import (
"bytes"
"errors"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/textproto"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
relaycommon "one-api/relay/common"
"one-api/relay/constant"
"one-api/types"
"path/filepath"
"strings"
"github.com/gin-gonic/gin"
)
type Adaptor struct {
}
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
//TODO implement me
panic("implement me")
return nil, nil
}
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
//TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
switch info.RelayMode {
case constant.RelayModeImagesEdits:
var requestBody bytes.Buffer
writer := multipart.NewWriter(&requestBody)
writer.WriteField("model", request.Model)
// 获取所有表单字段
formData := c.Request.PostForm
// 遍历表单字段并打印输出
for key, values := range formData {
if key == "model" {
continue
}
for _, value := range values {
writer.WriteField(key, value)
}
}
// Parse the multipart form to handle both single image and multiple images
if err := c.Request.ParseMultipartForm(32 << 20); err != nil { // 32MB max memory
return nil, errors.New("failed to parse multipart form")
}
if c.Request.MultipartForm != nil && c.Request.MultipartForm.File != nil {
// Check if "image" field exists in any form, including array notation
var imageFiles []*multipart.FileHeader
var exists bool
// First check for standard "image" field
if imageFiles, exists = c.Request.MultipartForm.File["image"]; !exists || len(imageFiles) == 0 {
// If not found, check for "image[]" field
if imageFiles, exists = c.Request.MultipartForm.File["image[]"]; !exists || len(imageFiles) == 0 {
// If still not found, iterate through all fields to find any that start with "image["
foundArrayImages := false
for fieldName, files := range c.Request.MultipartForm.File {
if strings.HasPrefix(fieldName, "image[") && len(files) > 0 {
foundArrayImages = true
for _, file := range files {
imageFiles = append(imageFiles, file)
}
}
}
// If no image fields found at all
if !foundArrayImages && (len(imageFiles) == 0) {
return nil, errors.New("image is required")
}
}
}
// Process all image files
for i, fileHeader := range imageFiles {
file, err := fileHeader.Open()
if err != nil {
return nil, fmt.Errorf("failed to open image file %d: %w", i, err)
}
defer file.Close()
// If multiple images, use image[] as the field name
fieldName := "image"
if len(imageFiles) > 1 {
fieldName = "image[]"
}
// Determine MIME type based on file extension
mimeType := detectImageMimeType(fileHeader.Filename)
// Create a form file with the appropriate content type
h := make(textproto.MIMEHeader)
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="%s"; filename="%s"`, fieldName, fileHeader.Filename))
h.Set("Content-Type", mimeType)
part, err := writer.CreatePart(h)
if err != nil {
return nil, fmt.Errorf("create form part failed for image %d: %w", i, err)
}
if _, err := io.Copy(part, file); err != nil {
return nil, fmt.Errorf("copy file failed for image %d: %w", i, err)
}
}
// Handle mask file if present
if maskFiles, exists := c.Request.MultipartForm.File["mask"]; exists && len(maskFiles) > 0 {
maskFile, err := maskFiles[0].Open()
if err != nil {
return nil, errors.New("failed to open mask file")
}
defer maskFile.Close()
// Determine MIME type for mask file
mimeType := detectImageMimeType(maskFiles[0].Filename)
// Create a form file with the appropriate content type
h := make(textproto.MIMEHeader)
h.Set("Content-Disposition", fmt.Sprintf(`form-data; name="mask"; filename="%s"`, maskFiles[0].Filename))
h.Set("Content-Type", mimeType)
maskPart, err := writer.CreatePart(h)
if err != nil {
return nil, errors.New("create form file failed for mask")
}
if _, err := io.Copy(maskPart, maskFile); err != nil {
return nil, errors.New("copy mask file failed")
}
}
} else {
return nil, errors.New("no multipart form data found")
}
// 关闭 multipart 编写器以设置分界线
writer.Close()
c.Request.Header.Set("Content-Type", writer.FormDataContentType())
return bytes.NewReader(requestBody.Bytes()), nil
default:
return request, nil
}
}
// detectImageMimeType determines the MIME type based on the file extension
func detectImageMimeType(filename string) string {
ext := strings.ToLower(filepath.Ext(filename))
switch ext {
case ".jpg", ".jpeg":
return "image/jpeg"
case ".png":
return "image/png"
case ".webp":
return "image/webp"
default:
// Try to detect from extension if possible
if strings.HasPrefix(ext, ".jp") {
return "image/jpeg"
}
// Default to png as a fallback
return "image/png"
}
}
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
switch info.RelayMode {
case constant.RelayModeChatCompletions:
if strings.HasPrefix(info.UpstreamModelName, "bot") {
return fmt.Sprintf("%s/api/v3/bots/chat/completions", info.BaseUrl), nil
}
return fmt.Sprintf("%s/api/v3/chat/completions", info.BaseUrl), nil
case constant.RelayModeEmbeddings:
return fmt.Sprintf("%s/api/v3/embeddings", info.BaseUrl), nil
case constant.RelayModeImagesGenerations:
return fmt.Sprintf("%s/api/v3/images/generations", info.BaseUrl), nil
default:
}
return "", fmt.Errorf("unsupported relay mode: %d", info.RelayMode)
}
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
channel.SetupApiRequestHeader(info, c, req)
req.Set("Authorization", "Bearer "+info.ApiKey)
return nil
}
func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) {
if request == nil {
return nil, errors.New("request is nil")
}
return request, nil
}
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
return nil, nil
}
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
return request, nil
}
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
// TODO implement me
return nil, errors.New("not implemented")
}
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
return channel.DoApiRequest(a, c, info, requestBody)
}
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) {
switch info.RelayMode {
case constant.RelayModeChatCompletions:
if info.IsStream {
usage, err = openai.OaiStreamHandler(c, info, resp)
} else {
usage, err = openai.OpenaiHandler(c, info, resp)
}
case constant.RelayModeEmbeddings:
usage, err = openai.OpenaiHandler(c, info, resp)
case constant.RelayModeImagesGenerations, constant.RelayModeImagesEdits:
usage, err = openai.OpenaiHandlerWithUsage(c, info, resp)
}
return
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}
func (a *Adaptor) GetChannelName() string {
return ChannelName
}

Some files were not shown because too many files have changed in this diff Show More