gemini stream
This commit is contained in:
@@ -167,8 +167,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
||||
|
||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||
if info.RelayMode == constant.RelayModeGemini {
|
||||
err, usage = GeminiTextGenerationHandler(c, resp, info)
|
||||
return usage, err
|
||||
if info.IsStream {
|
||||
return GeminiTextGenerationStreamHandler(c, resp, info)
|
||||
} else {
|
||||
return GeminiTextGenerationHandler(c, resp, info)
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
|
||||
|
||||
@@ -7,20 +7,21 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||
func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
|
||||
// 读取响应体
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
if common.DebugEnabled {
|
||||
@@ -28,15 +29,15 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
|
||||
}
|
||||
|
||||
// 解析为 Gemini 原生响应格式
|
||||
var geminiResponse dto.GeminiTextGenerationResponse
|
||||
var geminiResponse GeminiChatResponse
|
||||
err = common.DecodeJson(responseBody, &geminiResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// 检查是否有候选响应
|
||||
if len(geminiResponse.Candidates) == 0 {
|
||||
return &dto.OpenAIErrorWithStatusCode{
|
||||
return nil, &dto.OpenAIErrorWithStatusCode{
|
||||
Error: dto.OpenAIError{
|
||||
Message: "No candidates returned",
|
||||
Type: "server_error",
|
||||
@@ -44,7 +45,7 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
|
||||
Code: 500,
|
||||
},
|
||||
StatusCode: resp.StatusCode,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 计算使用量(基于 UsageMetadata)
|
||||
@@ -54,15 +55,10 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
|
||||
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
|
||||
}
|
||||
|
||||
// 设置模型版本
|
||||
if geminiResponse.ModelVersion == "" {
|
||||
geminiResponse.ModelVersion = info.UpstreamModelName
|
||||
}
|
||||
|
||||
// 直接返回 Gemini 原生格式的 JSON 响应
|
||||
jsonResponse, err := json.Marshal(geminiResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
||||
return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// 设置响应头并写入响应
|
||||
@@ -70,8 +66,63 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
_, err = c.Writer.Write(jsonResponse)
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError), nil
|
||||
return nil, service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
return nil, &usage
|
||||
return &usage, nil
|
||||
}
|
||||
|
||||
func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
|
||||
var usage = &dto.Usage{}
|
||||
var imageCount int
|
||||
|
||||
helper.SetEventStreamHeaders(c)
|
||||
|
||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||
var geminiResponse GeminiChatResponse
|
||||
err := common.DecodeJsonStr(data, &geminiResponse)
|
||||
if err != nil {
|
||||
common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||||
return false
|
||||
}
|
||||
|
||||
// 统计图片数量
|
||||
for _, candidate := range geminiResponse.Candidates {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.InlineData != nil && part.InlineData.MimeType != "" {
|
||||
imageCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 更新使用量统计
|
||||
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
|
||||
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
|
||||
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
|
||||
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
|
||||
}
|
||||
|
||||
// 直接发送 GeminiChatResponse 响应
|
||||
err = helper.ObjectData(c, geminiResponse)
|
||||
if err != nil {
|
||||
common.LogError(c, err.Error())
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
if imageCount != 0 {
|
||||
if usage.CompletionTokens == 0 {
|
||||
usage.CompletionTokens = imageCount * 258
|
||||
}
|
||||
}
|
||||
|
||||
// 计算最终使用量
|
||||
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
|
||||
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
||||
|
||||
// 结束流式响应
|
||||
helper.Done(c)
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel/gemini"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/service"
|
||||
@@ -17,8 +18,8 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiTextGenerationRequest, error) {
|
||||
request := &dto.GeminiTextGenerationRequest{}
|
||||
func getAndValidateGeminiRequest(c *gin.Context) (*gemini.GeminiChatRequest, error) {
|
||||
request := &gemini.GeminiChatRequest{}
|
||||
err := common.UnmarshalBodyReusable(c, request)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -29,7 +30,19 @@ func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiTextGenerationReque
|
||||
return request, nil
|
||||
}
|
||||
|
||||
func checkGeminiInputSensitive(textRequest *dto.GeminiTextGenerationRequest, info *relaycommon.RelayInfo) ([]string, error) {
|
||||
// 流模式
|
||||
// /v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse&key=xxx
|
||||
func checkGeminiStreamMode(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
|
||||
if c.Query("alt") == "sse" {
|
||||
relayInfo.IsStream = true
|
||||
}
|
||||
|
||||
// if strings.Contains(c.Request.URL.Path, "streamGenerateContent") {
|
||||
// relayInfo.IsStream = true
|
||||
// }
|
||||
}
|
||||
|
||||
func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string, error) {
|
||||
var inputTexts []string
|
||||
for _, content := range textRequest.Contents {
|
||||
for _, part := range content.Parts {
|
||||
@@ -46,7 +59,7 @@ func checkGeminiInputSensitive(textRequest *dto.GeminiTextGenerationRequest, inf
|
||||
return sensitiveWords, err
|
||||
}
|
||||
|
||||
func getGeminiInputTokens(req *dto.GeminiTextGenerationRequest, info *relaycommon.RelayInfo) (int, error) {
|
||||
func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) (int, error) {
|
||||
// 计算输入 token 数量
|
||||
var inputTexts []string
|
||||
for _, content := range req.Contents {
|
||||
@@ -72,8 +85,11 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
|
||||
relayInfo := relaycommon.GenRelayInfo(c)
|
||||
|
||||
// 检查 Gemini 流式模式
|
||||
checkGeminiStreamMode(c, relayInfo)
|
||||
|
||||
if setting.ShouldCheckPromptSensitive() {
|
||||
sensitiveWords, err := checkGeminiInputSensitive(req, relayInfo)
|
||||
sensitiveWords, err := checkGeminiInputSensitive(req)
|
||||
if err != nil {
|
||||
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
|
||||
return service.OpenAIErrorWrapperLocal(err, "check_request_sensitive_error", http.StatusBadRequest)
|
||||
@@ -97,7 +113,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||
c.Set("prompt_tokens", promptTokens)
|
||||
}
|
||||
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, req.GenerationConfig.MaxOutputTokens)
|
||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens))
|
||||
if err != nil {
|
||||
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user