gemini stream

This commit is contained in:
creamlike1024
2025-05-26 14:50:50 +08:00
parent 738a9a4558
commit d90e4bef63
4 changed files with 93 additions and 92 deletions

View File

@@ -8,6 +8,7 @@ import (
"net/http"
"one-api/common"
"one-api/dto"
"one-api/relay/channel/gemini"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/service"
@@ -17,8 +18,8 @@ import (
"github.com/gin-gonic/gin"
)
func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiTextGenerationRequest, error) {
request := &dto.GeminiTextGenerationRequest{}
func getAndValidateGeminiRequest(c *gin.Context) (*gemini.GeminiChatRequest, error) {
request := &gemini.GeminiChatRequest{}
err := common.UnmarshalBodyReusable(c, request)
if err != nil {
return nil, err
@@ -29,7 +30,19 @@ func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiTextGenerationReque
return request, nil
}
func checkGeminiInputSensitive(textRequest *dto.GeminiTextGenerationRequest, info *relaycommon.RelayInfo) ([]string, error) {
// 流模式
// /v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse&key=xxx
func checkGeminiStreamMode(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
if c.Query("alt") == "sse" {
relayInfo.IsStream = true
}
// if strings.Contains(c.Request.URL.Path, "streamGenerateContent") {
// relayInfo.IsStream = true
// }
}
func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string, error) {
var inputTexts []string
for _, content := range textRequest.Contents {
for _, part := range content.Parts {
@@ -46,7 +59,7 @@ func checkGeminiInputSensitive(textRequest *dto.GeminiTextGenerationRequest, inf
return sensitiveWords, err
}
func getGeminiInputTokens(req *dto.GeminiTextGenerationRequest, info *relaycommon.RelayInfo) (int, error) {
func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) (int, error) {
// 计算输入 token 数量
var inputTexts []string
for _, content := range req.Contents {
@@ -72,8 +85,11 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
relayInfo := relaycommon.GenRelayInfo(c)
// 检查 Gemini 流式模式
checkGeminiStreamMode(c, relayInfo)
if setting.ShouldCheckPromptSensitive() {
sensitiveWords, err := checkGeminiInputSensitive(req, relayInfo)
sensitiveWords, err := checkGeminiInputSensitive(req)
if err != nil {
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
return service.OpenAIErrorWrapperLocal(err, "check_request_sensitive_error", http.StatusBadRequest)
@@ -97,7 +113,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
c.Set("prompt_tokens", promptTokens)
}
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, req.GenerationConfig.MaxOutputTokens)
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens))
if err != nil {
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
}