fix: auto ban

This commit is contained in:
CaIon
2025-07-30 18:39:19 +08:00
parent 95d46d1dfc
commit 0cd93d67ff
15 changed files with 99 additions and 94 deletions

View File

@@ -132,12 +132,12 @@ func aliImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
var aliTaskResponse AliResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &aliTaskResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
}
if aliTaskResponse.Message != "" {

View File

@@ -34,14 +34,14 @@ func ConvertRerankRequest(request dto.RerankRequest) *AliRerankRequest {
func RerankHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*types.NewAPIError, *dto.Usage) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
var aliResponse AliRerankResponse
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {

View File

@@ -43,7 +43,7 @@ func aliEmbeddingHandler(c *gin.Context, resp *http.Response) (*types.NewAPIErro
var fullTextResponse dto.FlexibleEmbeddingResponse
err := json.NewDecoder(resp.Body).Decode(&fullTextResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
@@ -179,12 +179,12 @@ func aliHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError, *dto.U
var aliResponse AliResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &aliResponse)
if err != nil {
return types.NewError(err, types.ErrorCodeBadResponseBody), nil
return types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError), nil
}
if aliResponse.Code != "" {
return types.WithOpenAIError(types.OpenAIError{

View File

@@ -1,12 +1,10 @@
package gemini
import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/dto"
"one-api/relay/channel"
"one-api/relay/channel/openai"
@@ -212,60 +210,6 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
return nil, types.NewError(errors.New("not implemented"), types.ErrorCodeBadResponseBody)
}
func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody)
}
_ = resp.Body.Close()
var geminiResponse GeminiImageResponse
if jsonErr := json.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
}
if len(geminiResponse.Predictions) == 0 {
return nil, types.NewError(errors.New("no images generated"), types.ErrorCodeBadResponseBody)
}
// convert to openai format response
openAIResponse := dto.ImageResponse{
Created: common.GetTimestamp(),
Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
}
for _, prediction := range geminiResponse.Predictions {
if prediction.RaiFilteredReason != "" {
continue // skip filtered image
}
openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
B64Json: prediction.BytesBase64Encoded,
})
}
jsonResponse, jsonErr := json.Marshal(openAIResponse)
if jsonErr != nil {
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
// https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
// each image has fixed 258 tokens
const imageTokens = 258
generatedImages := len(openAIResponse.Data)
usage := &dto.Usage{
PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
CompletionTokens: 0, // image generation does not calculate completion tokens
TotalTokens: imageTokens * generatedImages,
}
return usage, nil
}
func (a *Adaptor) GetModelList() []string {
return ModelList
}

View File

@@ -907,7 +907,7 @@ func GeminiChatStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *
func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
common.CloseResponseBodyGracefully(resp)
if common.DebugEnabled {
@@ -916,10 +916,10 @@ func GeminiChatHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.R
var geminiResponse GeminiChatResponse
err = common.Unmarshal(responseBody, &geminiResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if len(geminiResponse.Candidates) == 0 {
return nil, types.NewError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(errors.New("no candidates returned"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
fullTextResponse := responseGeminiChat2OpenAI(c, &geminiResponse)
fullTextResponse.Model = info.UpstreamModelName
@@ -956,12 +956,12 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
responseBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, types.NewError(readErr, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
var geminiResponse GeminiEmbeddingResponse
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
// convert to openai format response
@@ -991,9 +991,63 @@ func GeminiEmbeddingHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *h
jsonResponse, jsonErr := common.Marshal(openAIResponse)
if jsonErr != nil {
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
common.IOCopyBytesGracefully(c, resp, jsonResponse)
return usage, nil
}
func GeminiImageHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
return nil, types.NewOpenAIError(readErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
_ = resp.Body.Close()
var geminiResponse GeminiImageResponse
if jsonErr := common.Unmarshal(responseBody, &geminiResponse); jsonErr != nil {
return nil, types.NewOpenAIError(jsonErr, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if len(geminiResponse.Predictions) == 0 {
return nil, types.NewOpenAIError(errors.New("no images generated"), types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
// convert to openai format response
openAIResponse := dto.ImageResponse{
Created: common.GetTimestamp(),
Data: make([]dto.ImageData, 0, len(geminiResponse.Predictions)),
}
for _, prediction := range geminiResponse.Predictions {
if prediction.RaiFilteredReason != "" {
continue // skip filtered image
}
openAIResponse.Data = append(openAIResponse.Data, dto.ImageData{
B64Json: prediction.BytesBase64Encoded,
})
}
jsonResponse, jsonErr := json.Marshal(openAIResponse)
if jsonErr != nil {
return nil, types.NewError(jsonErr, types.ErrorCodeBadResponseBody)
}
c.Writer.Header().Set("Content-Type", "application/json")
c.Writer.WriteHeader(resp.StatusCode)
_, _ = c.Writer.Write(jsonResponse)
// https://github.com/google-gemini/cookbook/blob/719a27d752aac33f39de18a8d3cb42a70874917e/quickstarts/Counting_Tokens.ipynb
// each image has fixed 258 tokens
const imageTokens = 258
generatedImages := len(openAIResponse.Data)
usage := &dto.Usage{
PromptTokens: imageTokens * generatedImages, // each generated image has fixed 258 tokens
CompletionTokens: 0, // image generation does not calculate completion tokens
TotalTokens: imageTokens * generatedImages,
}
return usage, nil
}

View File

@@ -52,13 +52,13 @@ func jimengImageHandler(c *gin.Context, resp *http.Response, info *relaycommon.R
var jimengResponse ImageResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &jimengResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
// Check if the response indicates an error

View File

@@ -109,7 +109,7 @@ func sendStreamData(c *gin.Context, info *relaycommon.RelayInfo, data string, fo
func OaiStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
if resp == nil || resp.Body == nil {
common.LogError(c, "invalid response or response body")
return nil, types.NewError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse)
return nil, types.NewOpenAIError(fmt.Errorf("invalid response"), types.ErrorCodeBadResponse, http.StatusInternalServerError)
}
defer common.CloseResponseBodyGracefully(resp)
@@ -178,11 +178,11 @@ func OpenaiHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
var simpleResponse dto.OpenAITextResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
err = common.Unmarshal(responseBody, &simpleResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if simpleResponse.Error != nil && simpleResponse.Error.Type != "" {
return nil, types.WithOpenAIError(*simpleResponse.Error, resp.StatusCode)
@@ -263,7 +263,7 @@ func OpenaiSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return types.NewError(err, types.ErrorCodeReadResponseBodyFailed), nil
return types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError), nil
}
// 写入新的 response body
common.IOCopyBytesGracefully(c, resp, responseBody)
@@ -547,13 +547,13 @@ func OpenaiHandlerWithUsage(c *gin.Context, info *relaycommon.RelayInfo, resp *h
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
var usageResp dto.SimpleResponse
err = common.Unmarshal(responseBody, &usageResp)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
// 写入新的 response body

View File

@@ -22,11 +22,11 @@ func OaiResponsesHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http
var responsesResponse dto.OpenAIResponsesResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
err = common.Unmarshal(responseBody, &responsesResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if responsesResponse.Error != nil {
return nil, types.WithOpenAIError(*responsesResponse.Error, resp.StatusCode)

View File

@@ -127,13 +127,13 @@ func palmStreamHandler(c *gin.Context, resp *http.Response) (*types.NewAPIError,
func palmHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
common.CloseResponseBodyGracefully(resp)
var palmResponse PaLMChatResponse
err = json.Unmarshal(responseBody, &palmResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if palmResponse.Error.Code != 0 || len(palmResponse.Candidates) == 0 {
return nil, types.WithOpenAIError(types.OpenAIError{

View File

@@ -15,13 +15,13 @@ import (
func siliconflowRerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
common.CloseResponseBodyGracefully(resp)
var siliconflowResp SFRerankResponse
err = json.Unmarshal(responseBody, &siliconflowResp)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
usage := &dto.Usage{
PromptTokens: siliconflowResp.Meta.Tokens.InputTokens,

View File

@@ -136,12 +136,12 @@ func tencentHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Resp
var tencentSb TencentChatResponseSB
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &tencentSb)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if tencentSb.Response.Error.Code != 0 {
return nil, types.WithOpenAIError(types.OpenAIError{

View File

@@ -67,11 +67,10 @@ func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInf
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
if strings.HasPrefix(info.UpstreamModelName, "claude") {
a.RequestMode = RequestModeClaude
} else if strings.HasPrefix(info.UpstreamModelName, "gemini") {
a.RequestMode = RequestModeGemini
} else if strings.Contains(info.UpstreamModelName, "llama") {
a.RequestMode = RequestModeLlama
}
a.RequestMode = RequestModeGemini
}
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
@@ -83,6 +82,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
a.AccountCredentials = *adc
suffix := ""
if a.RequestMode == RequestModeGemini {
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
// 新增逻辑:处理 -thinking-<budget> 格式
if strings.Contains(info.UpstreamModelName, "-thinking-") {
@@ -100,6 +100,11 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
} else {
suffix = "generateContent"
}
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
suffix = "predict"
}
if region == "global" {
return fmt.Sprintf(
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
@@ -231,6 +236,9 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
if info.RelayMode == constant.RelayModeGemini {
usage, err = gemini.GeminiTextGenerationHandler(c, info, resp)
} else {
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
return gemini.GeminiImageHandler(c, info, resp)
}
usage, err = gemini.GeminiChatHandler(c, info, resp)
}
case RequestModeLlama:

View File

@@ -220,12 +220,12 @@ func zhipuHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respon
var zhipuResponse ZhipuResponse
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
common.CloseResponseBodyGracefully(resp)
err = json.Unmarshal(responseBody, &zhipuResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
if !zhipuResponse.Success {
return nil, types.WithOpenAIError(types.OpenAIError{

View File

@@ -16,7 +16,7 @@ import (
func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeReadResponseBodyFailed)
return nil, types.NewOpenAIError(err, types.ErrorCodeReadResponseBodyFailed, http.StatusInternalServerError)
}
common.CloseResponseBodyGracefully(resp)
if common.DebugEnabled {
@@ -27,7 +27,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
var xinRerankResponse xinference.XinRerankResponse
err = common.Unmarshal(responseBody, &xinRerankResponse)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
jinaRespResults := make([]dto.RerankResponseResult, len(xinRerankResponse.Results))
for i, result := range xinRerankResponse.Results {
@@ -62,7 +62,7 @@ func RerankHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Respo
} else {
err = common.Unmarshal(responseBody, &jinaResp)
if err != nil {
return nil, types.NewError(err, types.ErrorCodeBadResponseBody)
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
}
jinaResp.Usage.PromptTokens = jinaResp.Usage.TotalTokens
}

View File

@@ -2,7 +2,6 @@ package relay
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
@@ -203,7 +202,7 @@ func GeminiHelper(c *gin.Context) (newAPIError *types.NewAPIError) {
}
requestBody = bytes.NewReader(body)
} else {
jsonData, err := json.Marshal(req)
jsonData, err := common.Marshal(req)
if err != nil {
return types.NewError(err, types.ErrorCodeConvertRequestFailed)
}