gemini stream
This commit is contained in:
@@ -1,69 +0,0 @@
|
|||||||
package dto
|
|
||||||
|
|
||||||
import "encoding/json"
|
|
||||||
|
|
||||||
type GeminiPart struct {
|
|
||||||
Text string `json:"text"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GeminiContent struct {
|
|
||||||
Parts []GeminiPart `json:"parts"`
|
|
||||||
Role string `json:"role"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GeminiCandidate struct {
|
|
||||||
Content GeminiContent `json:"content"`
|
|
||||||
FinishReason string `json:"finishReason"`
|
|
||||||
AvgLogprobs float64 `json:"avgLogprobs"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GeminiTokenDetails struct {
|
|
||||||
Modality string `json:"modality"`
|
|
||||||
TokenCount int `json:"tokenCount"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GeminiUsageMetadata struct {
|
|
||||||
PromptTokenCount int `json:"promptTokenCount"`
|
|
||||||
CandidatesTokenCount int `json:"candidatesTokenCount"`
|
|
||||||
TotalTokenCount int `json:"totalTokenCount"`
|
|
||||||
PromptTokensDetails []GeminiTokenDetails `json:"promptTokensDetails"`
|
|
||||||
CandidatesTokensDetails []GeminiTokenDetails `json:"candidatesTokensDetails"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GeminiTextGenerationResponse struct {
|
|
||||||
Candidates []GeminiCandidate `json:"candidates"`
|
|
||||||
UsageMetadata GeminiUsageMetadata `json:"usageMetadata"`
|
|
||||||
ModelVersion string `json:"modelVersion"`
|
|
||||||
ResponseID string `json:"responseId"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GeminiGenerationConfig struct {
|
|
||||||
StopSequences []string `json:"stopSequences,omitempty"`
|
|
||||||
ResponseMimeType string `json:"responseMimeType,omitempty"`
|
|
||||||
ResponseSchema *json.RawMessage `json:"responseSchema,omitempty"`
|
|
||||||
ResponseModalities *json.RawMessage `json:"responseModalities,omitempty"`
|
|
||||||
CandidateCount int `json:"candidateCount,omitempty"`
|
|
||||||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
|
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
|
||||||
TopP float64 `json:"topP,omitempty"`
|
|
||||||
TopK int `json:"topK,omitempty"`
|
|
||||||
Seed int `json:"seed,omitempty"`
|
|
||||||
PresencePenalty float64 `json:"presencePenalty,omitempty"`
|
|
||||||
FrequencyPenalty float64 `json:"frequencyPenalty,omitempty"`
|
|
||||||
ResponseLogprobs bool `json:"responseLogprobs,omitempty"`
|
|
||||||
LogProbs int `json:"logProbs,omitempty"`
|
|
||||||
EnableEnhancedCivicAnswers bool `json:"enableEnhancedCivicAnswers,omitempty"`
|
|
||||||
SpeechConfig *json.RawMessage `json:"speechConfig,omitempty"`
|
|
||||||
ThinkingConfig *json.RawMessage `json:"thinkingConfig,omitempty"`
|
|
||||||
MediaResolution *json.RawMessage `json:"mediaResolution,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type GeminiTextGenerationRequest struct {
|
|
||||||
Contents []GeminiContent `json:"contents"`
|
|
||||||
Tools *json.RawMessage `json:"tools,omitempty"`
|
|
||||||
ToolConfig *json.RawMessage `json:"toolConfig,omitempty"`
|
|
||||||
SafetySettings *json.RawMessage `json:"safetySettings,omitempty"`
|
|
||||||
SystemInstruction *json.RawMessage `json:"systemInstruction,omitempty"`
|
|
||||||
GenerationConfig GeminiGenerationConfig `json:"generationConfig,omitempty"`
|
|
||||||
CachedContent *json.RawMessage `json:"cachedContent,omitempty"`
|
|
||||||
}
|
|
||||||
@@ -167,8 +167,11 @@ func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, request
|
|||||||
|
|
||||||
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *dto.OpenAIErrorWithStatusCode) {
|
||||||
if info.RelayMode == constant.RelayModeGemini {
|
if info.RelayMode == constant.RelayModeGemini {
|
||||||
err, usage = GeminiTextGenerationHandler(c, resp, info)
|
if info.IsStream {
|
||||||
return usage, err
|
return GeminiTextGenerationStreamHandler(c, resp, info)
|
||||||
|
} else {
|
||||||
|
return GeminiTextGenerationHandler(c, resp, info)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
|
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
|
||||||
|
|||||||
@@ -7,20 +7,21 @@ import (
|
|||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
|
||||||
// 读取响应体
|
// 读取响应体
|
||||||
responseBody, err := io.ReadAll(resp.Body)
|
responseBody, err := io.ReadAll(resp.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil
|
return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
err = resp.Body.Close()
|
err = resp.Body.Close()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
return nil, service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
if common.DebugEnabled {
|
if common.DebugEnabled {
|
||||||
@@ -28,15 +29,15 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 解析为 Gemini 原生响应格式
|
// 解析为 Gemini 原生响应格式
|
||||||
var geminiResponse dto.GeminiTextGenerationResponse
|
var geminiResponse GeminiChatResponse
|
||||||
err = common.DecodeJson(responseBody, &geminiResponse)
|
err = common.DecodeJson(responseBody, &geminiResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil
|
return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查是否有候选响应
|
// 检查是否有候选响应
|
||||||
if len(geminiResponse.Candidates) == 0 {
|
if len(geminiResponse.Candidates) == 0 {
|
||||||
return &dto.OpenAIErrorWithStatusCode{
|
return nil, &dto.OpenAIErrorWithStatusCode{
|
||||||
Error: dto.OpenAIError{
|
Error: dto.OpenAIError{
|
||||||
Message: "No candidates returned",
|
Message: "No candidates returned",
|
||||||
Type: "server_error",
|
Type: "server_error",
|
||||||
@@ -44,7 +45,7 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
|
|||||||
Code: 500,
|
Code: 500,
|
||||||
},
|
},
|
||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
}, nil
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 计算使用量(基于 UsageMetadata)
|
// 计算使用量(基于 UsageMetadata)
|
||||||
@@ -54,15 +55,10 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
|
|||||||
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
|
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置模型版本
|
|
||||||
if geminiResponse.ModelVersion == "" {
|
|
||||||
geminiResponse.ModelVersion = info.UpstreamModelName
|
|
||||||
}
|
|
||||||
|
|
||||||
// 直接返回 Gemini 原生格式的 JSON 响应
|
// 直接返回 Gemini 原生格式的 JSON 响应
|
||||||
jsonResponse, err := json.Marshal(geminiResponse)
|
jsonResponse, err := json.Marshal(geminiResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil
|
return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 设置响应头并写入响应
|
// 设置响应头并写入响应
|
||||||
@@ -70,8 +66,63 @@ func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *rela
|
|||||||
c.Writer.WriteHeader(resp.StatusCode)
|
c.Writer.WriteHeader(resp.StatusCode)
|
||||||
_, err = c.Writer.Write(jsonResponse)
|
_, err = c.Writer.Write(jsonResponse)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError), nil
|
return nil, service.OpenAIErrorWrapper(err, "write_response_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, &usage
|
return &usage, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
|
||||||
|
var usage = &dto.Usage{}
|
||||||
|
var imageCount int
|
||||||
|
|
||||||
|
helper.SetEventStreamHeaders(c)
|
||||||
|
|
||||||
|
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||||
|
var geminiResponse GeminiChatResponse
|
||||||
|
err := common.DecodeJsonStr(data, &geminiResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 统计图片数量
|
||||||
|
for _, candidate := range geminiResponse.Candidates {
|
||||||
|
for _, part := range candidate.Content.Parts {
|
||||||
|
if part.InlineData != nil && part.InlineData.MimeType != "" {
|
||||||
|
imageCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新使用量统计
|
||||||
|
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
|
||||||
|
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
|
||||||
|
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount
|
||||||
|
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
|
||||||
|
}
|
||||||
|
|
||||||
|
// 直接发送 GeminiChatResponse 响应
|
||||||
|
err = helper.ObjectData(c, geminiResponse)
|
||||||
|
if err != nil {
|
||||||
|
common.LogError(c, err.Error())
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
if imageCount != 0 {
|
||||||
|
if usage.CompletionTokens == 0 {
|
||||||
|
usage.CompletionTokens = imageCount * 258
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 计算最终使用量
|
||||||
|
usage.PromptTokensDetails.TextTokens = usage.PromptTokens
|
||||||
|
usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
||||||
|
|
||||||
|
// 结束流式响应
|
||||||
|
helper.Done(c)
|
||||||
|
|
||||||
|
return usage, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
|
"one-api/relay/channel/gemini"
|
||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
@@ -17,8 +18,8 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiTextGenerationRequest, error) {
|
func getAndValidateGeminiRequest(c *gin.Context) (*gemini.GeminiChatRequest, error) {
|
||||||
request := &dto.GeminiTextGenerationRequest{}
|
request := &gemini.GeminiChatRequest{}
|
||||||
err := common.UnmarshalBodyReusable(c, request)
|
err := common.UnmarshalBodyReusable(c, request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -29,7 +30,19 @@ func getAndValidateGeminiRequest(c *gin.Context) (*dto.GeminiTextGenerationReque
|
|||||||
return request, nil
|
return request, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkGeminiInputSensitive(textRequest *dto.GeminiTextGenerationRequest, info *relaycommon.RelayInfo) ([]string, error) {
|
// 流模式
|
||||||
|
// /v1beta/models/gemini-2.0-flash:streamGenerateContent?alt=sse&key=xxx
|
||||||
|
func checkGeminiStreamMode(c *gin.Context, relayInfo *relaycommon.RelayInfo) {
|
||||||
|
if c.Query("alt") == "sse" {
|
||||||
|
relayInfo.IsStream = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// if strings.Contains(c.Request.URL.Path, "streamGenerateContent") {
|
||||||
|
// relayInfo.IsStream = true
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string, error) {
|
||||||
var inputTexts []string
|
var inputTexts []string
|
||||||
for _, content := range textRequest.Contents {
|
for _, content := range textRequest.Contents {
|
||||||
for _, part := range content.Parts {
|
for _, part := range content.Parts {
|
||||||
@@ -46,7 +59,7 @@ func checkGeminiInputSensitive(textRequest *dto.GeminiTextGenerationRequest, inf
|
|||||||
return sensitiveWords, err
|
return sensitiveWords, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func getGeminiInputTokens(req *dto.GeminiTextGenerationRequest, info *relaycommon.RelayInfo) (int, error) {
|
func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) (int, error) {
|
||||||
// 计算输入 token 数量
|
// 计算输入 token 数量
|
||||||
var inputTexts []string
|
var inputTexts []string
|
||||||
for _, content := range req.Contents {
|
for _, content := range req.Contents {
|
||||||
@@ -72,8 +85,11 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
|
|
||||||
relayInfo := relaycommon.GenRelayInfo(c)
|
relayInfo := relaycommon.GenRelayInfo(c)
|
||||||
|
|
||||||
|
// 检查 Gemini 流式模式
|
||||||
|
checkGeminiStreamMode(c, relayInfo)
|
||||||
|
|
||||||
if setting.ShouldCheckPromptSensitive() {
|
if setting.ShouldCheckPromptSensitive() {
|
||||||
sensitiveWords, err := checkGeminiInputSensitive(req, relayInfo)
|
sensitiveWords, err := checkGeminiInputSensitive(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
|
common.LogWarn(c, fmt.Sprintf("user sensitive words detected: %s", strings.Join(sensitiveWords, ", ")))
|
||||||
return service.OpenAIErrorWrapperLocal(err, "check_request_sensitive_error", http.StatusBadRequest)
|
return service.OpenAIErrorWrapperLocal(err, "check_request_sensitive_error", http.StatusBadRequest)
|
||||||
@@ -97,7 +113,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
c.Set("prompt_tokens", promptTokens)
|
c.Set("prompt_tokens", promptTokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, req.GenerationConfig.MaxOutputTokens)
|
priceData, err := helper.ModelPriceHelper(c, relayInfo, relayInfo.PromptTokens, int(req.GenerationConfig.MaxOutputTokens))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapperLocal(err, "model_price_error", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user