gemini stream

This commit is contained in:
creamlike1024
2025-05-26 14:50:50 +08:00
parent 738a9a4558
commit d90e4bef63
4 changed files with 93 additions and 92 deletions

View File

@@ -167,8 +167,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
if info.RelayMode == constant.RelayModeGemini {
err, usage = GeminiTextGenerationHandler(c, resp, info)
return usage, err
if info.IsStream {
return GeminiTextGenerationStreamHandler(c, resp, info)
} else {
return GeminiTextGenerationHandler(c, resp, info)
}
}
if strings.HasPrefix(info.UpstreamModelName, "imagen") {

View File

@@ -7,20 +7,21 @@ import (
"one-api/common"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
"github.com/gin-gonic/gin"
)
func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
// 读取响应体
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
}
err = resp.Body.Close()
if err != nil {
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
return nil, service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
}
if common.DebugEnabled {
@@ -28,15 +29,15 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
}
// 解析为 Gemini 原生响应格式
var geminiResponse dto.GeminiTextGenerationResponse
var geminiResponse GeminiChatResponse
err = common.DecodeJson(responseBody, &geminiResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
}
// 检查是否有候选响应
if len(geminiResponse.Candidates) == 0 {
return &dto.OpenAIErrorWithStatusCode{
return nil, &dto.OpenAIErrorWithStatusCode{
Error: dto.OpenAIError{
Message: "No candidates returned",
Type: "server_error",
@@ -44,7 +45,7 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
Code: 500,
},
StatusCode: resp.StatusCode,
}, nil
}
}
// 计算使用量(基于 UsageMetadata
@@ -54,15 +55,10 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
}
// 设置模型版本
if geminiResponse.ModelVersion == "" {
geminiResponse.ModelVersion = info.UpstreamModelName
}
// 直接返回 Gemini 原生格式的 JSON 响应
jsonResponse, err := json.Marshal(geminiResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
}
// 设置响应头并写入响应
@@ -70,8 +66,63 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
c.Writer.WriteHeader(resp.StatusCode)
_, err = c.Writer.Write(jsonResponse)
if err != nil {
return service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError), nil
return nil, service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError)
}
return nil, &usage
return &usage, nil
}
func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
var usage = &dto.Usage{}
var imageCount int
helper.SetEventStreamHeaders(c)
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
var geminiResponse GeminiChatResponse
err := common.DecodeJsonStr(data, &geminiResponse)
if err != nil {
common.LogError(c, "error unmarshalling stream response: "+err.Error())
return false
}
// 统计图片数量
for _, candidate := range geminiResponse.Candidates {
for _, part := range candidate.Content.Parts {
if part.InlineData != nil && part.InlineData.MimeType != "" {
imageCount++
}
}
}
// 更新使用量统计
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
}
// 直接发送 GeminiChatResponse 响应
err = helper.ObjectData(c, geminiResponse)
if err != nil {
common.LogError(c, err.Error())
}
return true
})
if imageCount != 0 {
if usage.CompletionTokens == 0 {
usage.CompletionTokens = imageCount * 258
}
}
// 计算最终使用量
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
// 结束流式响应
helper.Done(c)
return usage, nil
}

View File

@@ -8,6 +8,7 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
"one-api/relay/channel/gemini"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -17,8 +18,8 @@ import (
"github.com/gin-gonic/gin"
)
func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiTextGenerationRequest, error) {
request := &dto.GeminiTextGenerationRequest{}
func getAndValidateGeminiRequest(c *gin.Context) (*gemini.GeminiChatRequest, error) {
request := &gemini.GeminiChatRequest{}
err := common.UnmarshalBodyReusable(c, request)
if err != nil {
return nil, err
@@ -29,7 +30,19 @@ func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiTextGenerationReque
return request, nil
}
func checkGeminiInputSensitive(textRequest *dto.GeminiTextGenerationRequest, info *relaycommon.RelayInfo) ([]string, error) {
// 流模式
// /v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse&key=xxx
func checkGeminiStreamMode(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
if c.Query("alt") == "sse" {
relayInfo.IsStream = true
}
// if strings.Contains(c.Request.URL.Path, "streamGenerateContent") {
// relayInfo.IsStream = true
// }
}
func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string, error) {
var inputTexts []string
for _, content := range textRequest.Contents {
for _, part := range content.Parts {
@@ -46,7 +59,7 @@ func checkGeminiInputSensitive(textRequest *dto.GeminiTextGenerationRequest, inf
return sensitiveWords, err
}
func getGeminiInputTokens(req *dto.GeminiTextGenerationRequest, info *relaycommon.RelayInfo) (int, error) {
func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) (int, error) {
// 计算输入 token 数量
var inputTexts []string
for _, content := range req.Contents {
@@ -72,8 +85,11 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo := relaycommon.GenRelayInfo(c)
// 检查 Gemini 流式模式
checkGeminiStreamMode(c, relayInfo)
if setting.ShouldCheckPromptSensitive() {
sensitiveWords, err := checkGeminiInputSensitive(req, relayInfo)
sensitiveWords, err := checkGeminiInputSensitive(req)
if err != nil {
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
return service.OpenAIErrorWrapperLocal(err, "check_request_sensitive_error", http.StatusBadRequest)
@@ -97,7 +113,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
c.Set("prompt_tokens", promptTokens)
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, req.GenerationConfig.MaxOutputTokens)
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens))
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
}