This update replaces instances of DecodeJson and DecodeJsonStr with UnmarshalJson and UnmarshalJsonStr in various relay handlers, enhancing code consistency and clarity in JSON processing. The changes improve maintainability and align with recent refactoring efforts in the codebase.
138 lines
4.2 KiB
Go
138 lines
4.2 KiB
Go
package gemini
|
||
|
||
import (
|
||
"io"
|
||
"net/http"
|
||
"one-api/common"
|
||
"one-api/dto"
|
||
relaycommon "one-api/relay/common"
|
||
"one-api/relay/helper"
|
||
"one-api/service"
|
||
"strings"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
func GeminiTextGenerationHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
|
||
defer common.CloseResponseBodyGracefully(resp)
|
||
|
||
// 读取响应体
|
||
responseBody, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return nil, service.OpenAIErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||
}
|
||
|
||
if common.DebugEnabled {
|
||
println(string(responseBody))
|
||
}
|
||
|
||
// 解析为 Gemini 原生响应格式
|
||
var geminiResponse GeminiChatResponse
|
||
err = common.UnmarshalJson(responseBody, &geminiResponse)
|
||
if err != nil {
|
||
return nil, service.OpenAIErrorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError)
|
||
}
|
||
|
||
// 计算使用量(基于 UsageMetadata)
|
||
usage := dto.Usage{
|
||
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
||
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount,
|
||
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
|
||
}
|
||
|
||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||
|
||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
||
if detail.Modality == "AUDIO" {
|
||
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
|
||
} else if detail.Modality == "TEXT" {
|
||
usage.PromptTokensDetails.TextTokens = detail.TokenCount
|
||
}
|
||
}
|
||
|
||
// 直接返回 Gemini 原生格式的 JSON 响应
|
||
jsonResponse, err := common.EncodeJson(geminiResponse)
|
||
if err != nil {
|
||
return nil, service.OpenAIErrorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError)
|
||
}
|
||
|
||
common.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||
|
||
return &usage, nil
|
||
}
|
||
|
||
func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *dto.OpenAIErrorWithStatusCode) {
|
||
var usage = &dto.Usage{}
|
||
var imageCount int
|
||
|
||
helper.SetEventStreamHeaders(c)
|
||
|
||
responseText := strings.Builder{}
|
||
|
||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||
var geminiResponse GeminiChatResponse
|
||
err := common.UnmarshalJsonStr(data, &geminiResponse)
|
||
if err != nil {
|
||
common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||
return false
|
||
}
|
||
|
||
// 统计图片数量
|
||
for _, candidate := range geminiResponse.Candidates {
|
||
for _, part := range candidate.Content.Parts {
|
||
if part.InlineData != nil && part.InlineData.MimeType != "" {
|
||
imageCount++
|
||
}
|
||
if part.Text != "" {
|
||
responseText.WriteString(part.Text)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 更新使用量统计
|
||
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
|
||
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
|
||
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
|
||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
||
if detail.Modality == "AUDIO" {
|
||
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
|
||
} else if detail.Modality == "TEXT" {
|
||
usage.PromptTokensDetails.TextTokens = detail.TokenCount
|
||
}
|
||
}
|
||
}
|
||
|
||
// 直接发送 GeminiChatResponse 响应
|
||
err = helper.StringData(c, data)
|
||
if err != nil {
|
||
common.LogError(c, err.Error())
|
||
}
|
||
|
||
return true
|
||
})
|
||
|
||
if imageCount != 0 {
|
||
if usage.CompletionTokens == 0 {
|
||
usage.CompletionTokens = imageCount * 258
|
||
}
|
||
}
|
||
|
||
// 如果usage.CompletionTokens为0,则使用本地统计的completion tokens
|
||
if usage.CompletionTokens == 0 {
|
||
str := responseText.String()
|
||
if len(str) > 0 {
|
||
usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||
} else {
|
||
// 空补全,不需要使用量
|
||
usage = &dto.Usage{}
|
||
}
|
||
}
|
||
|
||
// 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为
|
||
//helper.Done(c)
|
||
|
||
return usage, nil
|
||
}
|