fix: auto ban

This commit is contained in:
CaIon
2025-07-30 18:39:19 +08:00
parent 95d46d1dfc
commit 0cd93d67ff
15 changed files with 99 additions and 94 deletions

View File

@@ -132,12 +132,12 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
var aliTaskResponse AliResponse var aliTaskResponse AliResponse
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
} }
common.CloseResponseBodyGracefully(resp) common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &aliTaskResponse) err = json.Unmarshal(responseBody, &aliTaskResponse)
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
} }
if aliTaskResponse.Message != "" { if aliTaskResponse.Message != "" {

View File

@@ -34,14 +34,14 @@ func ConvertRerankRequest(request dto.RerankRequest) *AliRerankRequest {
func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) { func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
} }
common.CloseResponseBodyGracefully(resp) common.CloseResponseBodyGracefully(resp)
var aliResponse AliRerankResponse var aliResponse AliRerankResponse
err = json.Unmarshal(responseBody, &aliResponse) err = json.Unmarshal(responseBody, &aliResponse)
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
} }
if aliResponse.Code != "" { if aliResponse.Code != "" {

View File

@@ -43,7 +43,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*types.NewAPIErro
var fullTextResponse dto.FlexibleEmbeddingResponse var fullTextResponse dto.FlexibleEmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&fullTextResponse) err := json.NewDecoder(resp.Body).Decode(&fullTextResponse)
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
} }
common.CloseResponseBodyGracefully(resp) common.CloseResponseBodyGracefully(resp)
@@ -179,12 +179,12 @@ func aliHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.U
var aliResponse AliResponse var aliResponse AliResponse
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
} }
common.CloseResponseBodyGracefully(resp) common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &aliResponse) err = json.Unmarshal(responseBody, &aliResponse)
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
} }
if aliResponse.Code != "" { if aliResponse.Code != "" {
return types.WithOpenAIError(types.OpenAIError{ return types.WithOpenAIError(types.OpenAIError{

View File

@@ -1,12 +1,10 @@
package gemini package gemini
import ( import (
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
"one-api/common"
"one-api/dto" "one-api/dto"
"one-api/relay/channel" "one-api/relay/channel"
"one-api/relay/channel/openai" "one-api/relay/channel/openai"
@@ -212,60 +210,6 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
return nil, types.NewError(errors.New("not implemented"), types.ErrorCodeBadResponseBody) return nil, types.NewError(errors.New("not implemented"), types.ErrorCodeBadResponseBody)
} }
func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody)
}
_ = resp.Body.Close()
var geminiResponse GeminiImageResponse
if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
}
if len(geminiResponse.Predictions) == 0 {
return nil, types.NewError(errors.New("no images generated"), types.ErrorCodeBadResponseBody)
}
// convert to openai format response
openAIResponse := dto.ImageResponse{
Created: common.GetTimestamp(),
Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
}
for _, prediction := range geminiResponse.Predictions {
if prediction.RaiFilteredReason != "" {
continue // skip filtered image
}
openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
B64Json: prediction.BytesBase64Encoded,
})
}
jsonResponse, jsonErr := json.Marshal(openAIResponse)
if jsonErr != nil {
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
// https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
// each image has fixed 258 tokens
const imageTokens = 258
generatedImages := len(openAIResponse.Data)
usage := &dto.Usage{
PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
CompletionTokens: 0, // image generation does not calculate completion tokens
TotalTokens: imageTokens * generatedImages,
}
return usage, nil
}
func (a *Adaptor) GetModelList() []string { func (a *Adaptor) GetModelList() []string {
return ModelList return ModelList
} }

View File

@@ -907,7 +907,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody) return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
} }
common.CloseResponseBodyGracefully(resp) common.CloseResponseBodyGracefully(resp)
if common.DebugEnabled { if common.DebugEnabled {
@@ -916,10 +916,10 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
var geminiResponse GeminiChatResponse var geminiResponse GeminiChatResponse
err = common.Unmarshal(responseBody, &geminiResponse) err = common.Unmarshal(responseBody, &geminiResponse)
if err != nil { if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody) return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
} }
if len(geminiResponse.Candidates) == 0 { if len(geminiResponse.Candidates) == 0 {
return nil, types.NewError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody) return nil, types.NewOpenAIError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
} }
fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse) fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
fullTextResponse.Model = info.UpstreamModelName fullTextResponse.Model = info.UpstreamModelName
@@ -956,12 +956,12 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
responseBody, readErr := io.ReadAll(resp.Body) responseBody, readErr := io.ReadAll(resp.Body)
if readErr != nil { if readErr != nil {
return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody) return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
} }
var geminiResponse GeminiEmbeddingResponse var geminiResponse GeminiEmbeddingResponse
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil { if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody) return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
} }
// convert to openai format response // convert to openai format response
@@ -991,9 +991,63 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
jsonResponse, jsonErr := common.Marshal(openAIResponse) jsonResponse, jsonErr := common.Marshal(openAIResponse)
if jsonErr != nil { if jsonErr != nil {
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody) return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
} }
common.IOCopyBytesGracefully(c, resp, jsonResponse) common.IOCopyBytesGracefully(c, resp, jsonResponse)
return usage, nil return usage, nil
} }
func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
_ = resp.Body.Close()
var geminiResponse GeminiImageResponse
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if len(geminiResponse.Predictions) == 0 {
return nil, types.NewOpenAIError(errors.New("no images generated"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
// convert to openai format response
openAIResponse := dto.ImageResponse{
Created: common.GetTimestamp(),
Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
}
for _, prediction := range geminiResponse.Predictions {
if prediction.RaiFilteredReason != "" {
continue // skip filtered image
}
openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
B64Json: prediction.BytesBase64Encoded,
})
}
jsonResponse, jsonErr := json.Marshal(openAIResponse)
if jsonErr != nil {
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
// https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
// each image has fixed 258 tokens
const imageTokens = 258
generatedImages := len(openAIResponse.Data)
usage := &dto.Usage{
PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
CompletionTokens: 0, // image generation does not calculate completion tokens
TotalTokens: imageTokens * generatedImages,
}
return usage, nil
}

View File

@@ -52,13 +52,13 @@ func jimengImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.R
var jimengResponse ImageResponse var jimengResponse ImageResponse
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
} }
common.CloseResponseBodyGracefully(resp) common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &jimengResponse) err = json.Unmarshal(responseBody, &jimengResponse)
if err != nil { if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody) return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
} }
// Check if the response indicates an error // Check if the response indicates an error

View File

@@ -109,7 +109,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
if resp == nil || resp.Body == nil { if resp == nil || resp.Body == nil {
common.LogError(c, "invalid response or response body") common.LogError(c, "invalid response or response body")
return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse) return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
} }
defer common.CloseResponseBodyGracefully(resp) defer common.CloseResponseBodyGracefully(resp)
@@ -178,11 +178,11 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
var simpleResponse dto.OpenAITextResponse var simpleResponse dto.OpenAITextResponse
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
} }
err = common.Unmarshal(responseBody, &simpleResponse) err = common.Unmarshal(responseBody, &simpleResponse)
if err != nil { if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody) return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
} }
if simpleResponse.Error != nil && simpleResponse.Error.Type != "" { if simpleResponse.Error != nil && simpleResponse.Error.Type != "" {
return nil, types.WithOpenAIError(*simpleResponse.Error, resp.StatusCode) return nil, types.WithOpenAIError(*simpleResponse.Error, resp.StatusCode)
@@ -263,7 +263,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
} }
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
} }
// 写入新的 response body // 写入新的 response body
common.IOCopyBytesGracefully(c, resp, responseBody) common.IOCopyBytesGracefully(c, resp, responseBody)
@@ -547,13 +547,13 @@ func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *h
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
} }
var usageResp dto.SimpleResponse var usageResp dto.SimpleResponse
err = common.Unmarshal(responseBody, &usageResp) err = common.Unmarshal(responseBody, &usageResp)
if err != nil { if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody) return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
} }
// 写入新的 response body // 写入新的 response body

View File

@@ -22,11 +22,11 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
var responsesResponse dto.OpenAIResponsesResponse var responsesResponse dto.OpenAIResponsesResponse
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
} }
err = common.Unmarshal(responseBody, &responsesResponse) err = common.Unmarshal(responseBody, &responsesResponse)
if err != nil { if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody) return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
} }
if responsesResponse.Error != nil { if responsesResponse.Error != nil {
return nil, types.WithOpenAIError(*responsesResponse.Error, resp.StatusCode) return nil, types.WithOpenAIError(*responsesResponse.Error, resp.StatusCode)

View File

@@ -127,13 +127,13 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
} }
common.CloseResponseBodyGracefully(resp) common.CloseResponseBodyGracefully(resp)
var palmResponse PaLMChatResponse var palmResponse PaLMChatResponse
err = json.Unmarshal(responseBody, &palmResponse) err = json.Unmarshal(responseBody, &palmResponse)
if err != nil { if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody) return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
} }
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 { if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
return nil, types.WithOpenAIError(types.OpenAIError{ return nil, types.WithOpenAIError(types.OpenAIError{

View File

@@ -15,13 +15,13 @@ import (
func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
} }
common.CloseResponseBodyGracefully(resp) common.CloseResponseBodyGracefully(resp)
var siliconflowResp SFRerankResponse var siliconflowResp SFRerankResponse
err = json.Unmarshal(responseBody, &siliconflowResp) err = json.Unmarshal(responseBody, &siliconflowResp)
if err != nil { if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody) return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
} }
usage := &dto.Usage{ usage := &dto.Usage{
PromptTokens: siliconflowResp.Meta.Tokens.InputTokens, PromptTokens: siliconflowResp.Meta.Tokens.InputTokens,

View File

@@ -136,12 +136,12 @@ func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Resp
var tencentSb TencentChatResponseSB var tencentSb TencentChatResponseSB
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
} }
common.CloseResponseBodyGracefully(resp) common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &tencentSb) err = json.Unmarshal(responseBody, &tencentSb)
if err != nil { if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody) return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
} }
if tencentSb.Response.Error.Code != 0 { if tencentSb.Response.Error.Code != 0 {
return nil, types.WithOpenAIError(types.OpenAIError{ return nil, types.WithOpenAIError(types.OpenAIError{

View File

@@ -67,11 +67,10 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
func (a *Adaptor) Init(info *relaycommon.RelayInfo) { func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
if strings.HasPrefix(info.UpstreamModelName, "claude") { if strings.HasPrefix(info.UpstreamModelName, "claude") {
a.RequestMode = RequestModeClaude a.RequestMode = RequestModeClaude
} else if strings.HasPrefix(info.UpstreamModelName, "gemini") {
a.RequestMode = RequestModeGemini
} else if strings.Contains(info.UpstreamModelName, "llama") { } else if strings.Contains(info.UpstreamModelName, "llama") {
a.RequestMode = RequestModeLlama a.RequestMode = RequestModeLlama
} }
a.RequestMode = RequestModeGemini
} }
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -83,6 +82,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
a.AccountCredentials = *adc a.AccountCredentials = *adc
suffix := "" suffix := ""
if a.RequestMode == RequestModeGemini { if a.RequestMode == RequestModeGemini {
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled { if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
// 新增逻辑:处理 -thinking-<budget> 格式 // 新增逻辑:处理 -thinking-<budget> 格式
if strings.Contains(info.UpstreamModelName, "-thinking-") { if strings.Contains(info.UpstreamModelName, "-thinking-") {
@@ -100,6 +100,11 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
} else { } else {
suffix = "generateContent" suffix = "generateContent"
} }
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
suffix = "predict"
}
if region == "global" { if region == "global" {
return fmt.Sprintf( return fmt.Sprintf(
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s", "https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
@@ -231,6 +236,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.RelayMode == constant.RelayModeGemini { if info.RelayMode == constant.RelayModeGemini {
usage, err = gemini.GeminiTextGenerationHandler(c, info, resp) usage, err = gemini.GeminiTextGenerationHandler(c, info, resp)
} else { } else {
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
return gemini.GeminiImageHandler(c, info, resp)
}
usage, err = gemini.GeminiChatHandler(c, info, resp) usage, err = gemini.GeminiChatHandler(c, info, resp)
} }
case RequestModeLlama: case RequestModeLlama:

View File

@@ -220,12 +220,12 @@ func zhipuHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon
var zhipuResponse ZhipuResponse var zhipuResponse ZhipuResponse
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
} }
common.CloseResponseBodyGracefully(resp) common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &zhipuResponse) err = json.Unmarshal(responseBody, &zhipuResponse)
if err != nil { if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody) return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
} }
if !zhipuResponse.Success { if !zhipuResponse.Success {
return nil, types.WithOpenAIError(types.OpenAIError{ return nil, types.WithOpenAIError(types.OpenAIError{

View File

@@ -16,7 +16,7 @@ import (
func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) { func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body) responseBody, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed) return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
} }
common.CloseResponseBodyGracefully(resp) common.CloseResponseBodyGracefully(resp)
if common.DebugEnabled { if common.DebugEnabled {
@@ -27,7 +27,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
var xinRerankResponse xinference.XinRerankResponse var xinRerankResponse xinference.XinRerankResponse
err = common.Unmarshal(responseBody, &xinRerankResponse) err = common.Unmarshal(responseBody, &xinRerankResponse)
if err != nil { if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody) return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
} }
jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results)) jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results))
for i, result := range xinRerankResponse.Results { for i, result := range xinRerankResponse.Results {
@@ -62,7 +62,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
} else { } else {
err = common.Unmarshal(responseBody, &jinaResp) err = common.Unmarshal(responseBody, &jinaResp)
if err != nil { if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody) return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
} }
jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens
} }

View File

@@ -2,7 +2,6 @@ package relay
import ( import (
"bytes" "bytes"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -203,7 +202,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
} }
requestBody = bytes.NewReader(body) requestBody = bytes.NewReader(body)
} else { } else {
jsonData, err := json.Marshal(req) jsonData, err := common.Marshal(req)
if err != nil { if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed) return types.NewError(err, types.ErrorCodeConvertRequestFailed)
} }