@@ -40,6 +40,8 @@ func relayHandler(c *gin.Context, relayMode int) *dto.OpenAIErrorWithStatusCode
|
|||||||
err = relay.EmbeddingHelper(c)
|
err = relay.EmbeddingHelper(c)
|
||||||
case relayconstant.RelayModeResponses:
|
case relayconstant.RelayModeResponses:
|
||||||
err = relay.ResponsesHelper(c)
|
err = relay.ResponsesHelper(c)
|
||||||
|
case relayconstant.RelayModeGemini:
|
||||||
|
err = relay.GeminiHelper(c)
|
||||||
default:
|
default:
|
||||||
err = relay.TextHelper(c)
|
err = relay.TextHelper(c)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/gin-contrib/sessions"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/model"
|
"one-api/model"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-contrib/sessions"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func validUserInfo(username string, role int) bool {
|
func validUserInfo(username string, role int) bool {
|
||||||
@@ -182,6 +183,13 @@ func TokenAuth() func(c *gin.Context) {
|
|||||||
c.Request.Header.Set("Authorization", "Bearer "+key)
|
c.Request.Header.Set("Authorization", "Bearer "+key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// gemini api 从query中获取key
|
||||||
|
if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
|
||||||
|
skKey := c.Query("key")
|
||||||
|
if skKey != "" {
|
||||||
|
c.Request.Header.Set("Authorization", "Bearer "+skKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
key := c.Request.Header.Get("Authorization")
|
key := c.Request.Header.Get("Authorization")
|
||||||
parts := make([]string, 0)
|
parts := make([]string, 0)
|
||||||
key = strings.TrimPrefix(key, "Bearer ")
|
key = strings.TrimPrefix(key, "Bearer ")
|
||||||
|
|||||||
@@ -162,6 +162,14 @@ func getModelRequest(c *gin.Context) (*ModelRequest, bool, error) {
|
|||||||
}
|
}
|
||||||
c.Set("platform", string(constant.TaskPlatformSuno))
|
c.Set("platform", string(constant.TaskPlatformSuno))
|
||||||
c.Set("relay_mode", relayMode)
|
c.Set("relay_mode", relayMode)
|
||||||
|
} else if strings.HasPrefix(c.Request.URL.Path, "/v1beta/models/") {
|
||||||
|
// Gemini API 路径处理: /v1beta/models/gemini-2.0-flash:generateContent
|
||||||
|
relayMode := relayconstant.RelayModeGemini
|
||||||
|
modelName := extractModelNameFromGeminiPath(c.Request.URL.Path)
|
||||||
|
if modelName != "" {
|
||||||
|
modelRequest.Model = modelName
|
||||||
|
}
|
||||||
|
c.Set("relay_mode", relayMode)
|
||||||
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
|
} else if !strings.HasPrefix(c.Request.URL.Path, "/v1/audio/transcriptions") && !strings.HasPrefix(c.Request.URL.Path, "/v1/images/edits") {
|
||||||
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
err = common.UnmarshalBodyReusable(c, &modelRequest)
|
||||||
}
|
}
|
||||||
@@ -244,3 +252,31 @@ func SetupContextForSelectedChannel(c *gin.Context, channel *model.Channel, mode
|
|||||||
c.Set("bot_id", channel.Other)
|
c.Set("bot_id", channel.Other)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractModelNameFromGeminiPath 从 Gemini API URL 路径中提取模型名
|
||||||
|
// 输入格式: /v1beta/models/gemini-2.0-flash:generateContent
|
||||||
|
// 输出: gemini-2.0-flash
|
||||||
|
func extractModelNameFromGeminiPath(path string) string {
|
||||||
|
// 查找 "/models/" 的位置
|
||||||
|
modelsPrefix := "/models/"
|
||||||
|
modelsIndex := strings.Index(path, modelsPrefix)
|
||||||
|
if modelsIndex == -1 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从 "/models/" 之后开始提取
|
||||||
|
startIndex := modelsIndex + len(modelsPrefix)
|
||||||
|
if startIndex >= len(path) {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查找 ":" 的位置,模型名在 ":" 之前
|
||||||
|
colonIndex := strings.Index(path[startIndex:], ":")
|
||||||
|
if colonIndex == -1 {
|
||||||
|
// 如果没有找到 ":",返回从 "/models/" 到路径结尾的部分
|
||||||
|
return path[startIndex:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// 返回模型名部分
|
||||||
|
return path[startIndex : startIndex+colonIndex]
|
||||||
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
"one-api/relay/channel"
|
"one-api/relay/channel"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/relay/constant"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"one-api/setting/model_setting"
|
"one-api/setting/model_setting"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -165,6 +166,14 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
if info.RelayMode == constant.RelayModeGemini {
|
||||||
|
if info.IsStream {
|
||||||
|
return GeminiTextGenerationStreamHandler(c, resp, info)
|
||||||
|
} else {
|
||||||
|
return GeminiTextGenerationHandler(c, resp, info)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
|
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
|
||||||
return GeminiImageHandler(c, resp, info)
|
return GeminiImageHandler(c, resp, info)
|
||||||
}
|
}
|
||||||
|
|||||||
128
relay/channel/gemini/relay-gemini-native.go
Normal file
128
relay/channel/gemini/relay-gemini-native.go
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
package gemini
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/relay/helper"
|
||||||
|
"one-api/service"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
// 读取响应体
|
||||||
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
err = resp.Body.Close()
|
||||||
|
if err != nil {
|
||||||
|
return nil, service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
if common.DebugEnabled {
|
||||||
|
println(string(responseBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析为 Gemini 原生响应格式
|
||||||
|
var geminiResponse GeminiChatResponse
|
||||||
|
err = common.DecodeJson(responseBody, &geminiResponse)
|
||||||
|
if err != nil {
|
||||||
|
return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否有候选响应
|
||||||
|
if len(geminiResponse.Candidates) == 0 {
|
||||||
|
return nil, &dto.OpenAIErrorWithStatusCode{
|
||||||
|
Error: dto.OpenAIError{
|
||||||
|
Message: "No candidates returned",
|
||||||
|
Type: "server_error",
|
||||||
|
Param: "",
|
||||||
|
Code: 500,
|
||||||
|
},
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算使用量(基于 UsageMetadata)
|
||||||
|
usage := dto.Usage{
|
||||||
|
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
||||||
|
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
|
||||||
|
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 直接返回 Gemini 原生格式的 JSON 响应
|
||||||
|
jsonResponse, err := json.Marshal(geminiResponse)
|
||||||
|
if err != nil {
|
||||||
|
return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置响应头并写入响应
|
||||||
|
c.Writer.Header().Set("Content-Type", "application/json")
|
||||||
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
|
if err != nil {
|
||||||
|
return nil, service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &usage, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
var usage = &dto.Usage{}
|
||||||
|
var imageCount int
|
||||||
|
|
||||||
|
helper.SetEventStreamHeaders(c)
|
||||||
|
|
||||||
|
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||||
|
var geminiResponse GeminiChatResponse
|
||||||
|
err := common.DecodeJsonStr(data, &geminiResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 统计图片数量
|
||||||
|
for _, candidate := range geminiResponse.Candidates {
|
||||||
|
for _, part := range candidate.Content.Parts {
|
||||||
|
if part.InlineData != nil && part.InlineData.MimeType != "" {
|
||||||
|
imageCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新使用量统计
|
||||||
|
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
|
||||||
|
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
|
||||||
|
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
|
||||||
|
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
|
||||||
|
}
|
||||||
|
|
||||||
|
// 直接发送 GeminiChatResponse 响应
|
||||||
|
err = helper.ObjectData(c, geminiResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(c, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
if imageCount != 0 {
|
||||||
|
if usage.CompletionTokens == 0 {
|
||||||
|
usage.CompletionTokens = imageCount * 258
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算最终使用量
|
||||||
|
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
|
||||||
|
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
||||||
|
|
||||||
|
// 结束流式响应
|
||||||
|
helper.Done(c)
|
||||||
|
|
||||||
|
return usage, nil
|
||||||
|
}
|
||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"one-api/relay/channel/gemini"
|
"one-api/relay/channel/gemini"
|
||||||
"one-api/relay/channel/openai"
|
"one-api/relay/channel/openai"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/relay/constant"
|
||||||
"one-api/setting/model_setting"
|
"one-api/setting/model_setting"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -201,7 +202,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
case RequestModeClaude:
|
case RequestModeClaude:
|
||||||
err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
err, usage = claude.ClaudeStreamHandler(c, resp, info, claude.RequestModeMessage)
|
||||||
case RequestModeGemini:
|
case RequestModeGemini:
|
||||||
err, usage = gemini.GeminiChatStreamHandler(c, resp, info)
|
if info.RelayMode == constant.RelayModeGemini {
|
||||||
|
usage, err = gemini.GeminiTextGenerationStreamHandler(c, resp, info)
|
||||||
|
} else {
|
||||||
|
err, usage = gemini.GeminiChatStreamHandler(c, resp, info)
|
||||||
|
}
|
||||||
case RequestModeLlama:
|
case RequestModeLlama:
|
||||||
err, usage = openai.OaiStreamHandler(c, resp, info)
|
err, usage = openai.OaiStreamHandler(c, resp, info)
|
||||||
}
|
}
|
||||||
@@ -210,7 +215,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
case RequestModeClaude:
|
case RequestModeClaude:
|
||||||
err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info)
|
err, usage = claude.ClaudeHandler(c, resp, claude.RequestModeMessage, info)
|
||||||
case RequestModeGemini:
|
case RequestModeGemini:
|
||||||
err, usage = gemini.GeminiChatHandler(c, resp, info)
|
if info.RelayMode == constant.RelayModeGemini {
|
||||||
|
usage, err = gemini.GeminiTextGenerationHandler(c, resp, info)
|
||||||
|
} else {
|
||||||
|
err, usage = gemini.GeminiChatHandler(c, resp, info)
|
||||||
|
}
|
||||||
case RequestModeLlama:
|
case RequestModeLlama:
|
||||||
err, usage = openai.OpenaiHandler(c, resp, info)
|
err, usage = openai.OpenaiHandler(c, resp, info)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -43,6 +43,8 @@ const (
|
|||||||
RelayModeResponses
|
RelayModeResponses
|
||||||
|
|
||||||
RelayModeRealtime
|
RelayModeRealtime
|
||||||
|
|
||||||
|
RelayModeGemini
|
||||||
)
|
)
|
||||||
|
|
||||||
func Path2RelayMode(path string) int {
|
func Path2RelayMode(path string) int {
|
||||||
@@ -75,6 +77,8 @@ func Path2RelayMode(path string) int {
|
|||||||
relayMode = RelayModeRerank
|
relayMode = RelayModeRerank
|
||||||
} else if strings.HasPrefix(path, "/v1/realtime") {
|
} else if strings.HasPrefix(path, "/v1/realtime") {
|
||||||
relayMode = RelayModeRealtime
|
relayMode = RelayModeRealtime
|
||||||
|
} else if strings.HasPrefix(path, "/v1beta/models") {
|
||||||
|
relayMode = RelayModeGemini
|
||||||
}
|
}
|
||||||
return relayMode
|
return relayMode
|
||||||
}
|
}
|
||||||
|
|||||||
157
relay/relay-gemini.go
Normal file
157
relay/relay-gemini.go
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
package relay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"one-api/common"
|
||||||
|
"one-api/dto"
|
||||||
|
"one-api/relay/channel/gemini"
|
||||||
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/relay/helper"
|
||||||
|
"one-api/service"
|
||||||
|
"one-api/setting"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getAndValidateGeminiRequest(c *gin.Context) (*gemini.GeminiChatRequest, error) {
|
||||||
|
request := &gemini.GeminiChatRequest{}
|
||||||
|
err := common.UnmarshalBodyReusable(c, request)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(request.Contents) == 0 {
|
||||||
|
return nil, errors.New("contents is required")
|
||||||
|
}
|
||||||
|
return request, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 流模式
|
||||||
|
// /v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse&key=xxx
|
||||||
|
func checkGeminiStreamMode(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
|
||||||
|
if c.Query("alt") == "sse" {
|
||||||
|
relayInfo.IsStream = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// if strings.Contains(c.Request.URL.Path, "streamGenerateContent") {
|
||||||
|
// relayInfo.IsStream = true
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string, error) {
|
||||||
|
var inputTexts []string
|
||||||
|
for _, content := range textRequest.Contents {
|
||||||
|
for _, part := range content.Parts {
|
||||||
|
if part.Text != "" {
|
||||||
|
inputTexts = append(inputTexts, part.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(inputTexts) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
sensitiveWords, err := service.CheckSensitiveInput(inputTexts)
|
||||||
|
return sensitiveWords, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) (int, error) {
|
||||||
|
// 计算输入 token 数量
|
||||||
|
var inputTexts []string
|
||||||
|
for _, content := range req.Contents {
|
||||||
|
for _, part := range content.Parts {
|
||||||
|
if part.Text != "" {
|
||||||
|
inputTexts = append(inputTexts, part.Text)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inputText := strings.Join(inputTexts, "\n")
|
||||||
|
inputTokens, err := service.CountTokenInput(inputText, info.UpstreamModelName)
|
||||||
|
info.PromptTokens = inputTokens
|
||||||
|
return inputTokens, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
req, err := getAndValidateGeminiRequest(c)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(c, fmt.Sprintf("getAndValidateGeminiRequest error: %s", err.Error()))
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "invalid_gemini_request", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
relayInfo := relaycommon.GenRelayInfo(c)
|
||||||
|
|
||||||
|
// 检查 Gemini 流式模式
|
||||||
|
checkGeminiStreamMode(c, relayInfo)
|
||||||
|
|
||||||
|
if setting.ShouldCheckPromptSensitive() {
|
||||||
|
sensitiveWords, err := checkGeminiInputSensitive(req)
|
||||||
|
if err != nil {
|
||||||
|
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "check_request_sensitive_error", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// model mapped 模型映射
|
||||||
|
err = helper.ModelMappedHelper(c, relayInfo)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "model_mapped_error", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
if value, exists := c.Get("prompt_tokens"); exists {
|
||||||
|
promptTokens := value.(int)
|
||||||
|
relayInfo.SetPromptTokens(promptTokens)
|
||||||
|
} else {
|
||||||
|
promptTokens, err := getGeminiInputTokens(req, relayInfo)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
c.Set("prompt_tokens", promptTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens))
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
// pre consume quota
|
||||||
|
preConsumedQuota, userQuota, openaiErr := preConsumeQuota(c, priceData.ShouldPreConsumedQuota, relayInfo)
|
||||||
|
if openaiErr != nil {
|
||||||
|
return openaiErr
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if openaiErr != nil {
|
||||||
|
returnPreConsumedQuota(c, relayInfo, userQuota, preConsumedQuota)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
adaptor := GetAdaptor(relayInfo.ApiType)
|
||||||
|
if adaptor == nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(fmt.Errorf("invalid api type: %d", relayInfo.ApiType), "invalid_api_type", http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
adaptor.Init(relayInfo)
|
||||||
|
|
||||||
|
requestBody, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "marshal_text_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := adaptor.DoRequest(c, relayInfo, bytes.NewReader(requestBody))
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(c, "Do gemini request failed: "+err.Error())
|
||||||
|
return service.OpenAIErrorWrapperLocal(err, "do_request_failed", http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
|
||||||
|
usage, openaiErr := adaptor.DoResponse(c, resp.(*http.Response), relayInfo)
|
||||||
|
if openaiErr != nil {
|
||||||
|
return openaiErr
|
||||||
|
}
|
||||||
|
|
||||||
|
postConsumeQuota(c, relayInfo, usage.(*dto.Usage), preConsumedQuota, userQuota, priceData, "")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -79,6 +79,14 @@ func SetRelayRouter(router *gin.Engine) {
|
|||||||
relaySunoRouter.GET("/fetch/:id", controller.RelayTask)
|
relaySunoRouter.GET("/fetch/:id", controller.RelayTask)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
relayGeminiRouter := router.Group("/v1beta")
|
||||||
|
relayGeminiRouter.Use(middleware.TokenAuth())
|
||||||
|
relayGeminiRouter.Use(middleware.ModelRequestRateLimit())
|
||||||
|
relayGeminiRouter.Use(middleware.Distribute())
|
||||||
|
{
|
||||||
|
// Gemini API 路径格式: /v1beta/models/{model_name}:{action}
|
||||||
|
relayGeminiRouter.POST("/models/*path", controller.Relay)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) {
|
func registerMjRouterGroup(relayMjRouter *gin.RouterGroup) {
|
||||||
|
|||||||
Reference in New Issue
Block a user