181 lines
5.5 KiB
Go
181 lines
5.5 KiB
Go
package gemini
|
||
|
||
import (
|
||
"io"
|
||
"net/http"
|
||
"one-api/common"
|
||
"one-api/dto"
|
||
relaycommon "one-api/relay/common"
|
||
"one-api/relay/helper"
|
||
"one-api/service"
|
||
"one-api/types"
|
||
"strings"
|
||
|
||
"github.com/pkg/errors"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||
defer common.CloseResponseBodyGracefully(resp)
|
||
|
||
// 读取响应体
|
||
responseBody, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||
}
|
||
|
||
if common.DebugEnabled {
|
||
println(string(responseBody))
|
||
}
|
||
|
||
// 解析为 Gemini 原生响应格式
|
||
var geminiResponse dto.GeminiChatResponse
|
||
err = common.Unmarshal(responseBody, &geminiResponse)
|
||
if err != nil {
|
||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||
}
|
||
|
||
// 计算使用量(基于 UsageMetadata)
|
||
usage := dto.Usage{
|
||
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
||
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount,
|
||
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
|
||
}
|
||
|
||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||
|
||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
||
if detail.Modality == "AUDIO" {
|
||
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
|
||
} else if detail.Modality == "TEXT" {
|
||
usage.PromptTokensDetails.TextTokens = detail.TokenCount
|
||
}
|
||
}
|
||
|
||
// 直接返回 Gemini 原生格式的 JSON 响应
|
||
jsonResponse, err := common.Marshal(geminiResponse)
|
||
if err != nil {
|
||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||
}
|
||
|
||
common.IOCopyBytesGracefully(c, resp, jsonResponse)
|
||
|
||
return &usage, nil
|
||
}
|
||
|
||
func NativeGeminiEmbeddingHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.Usage, *types.NewAPIError) {
|
||
defer common.CloseResponseBodyGracefully(resp)
|
||
|
||
responseBody, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||
}
|
||
|
||
if common.DebugEnabled {
|
||
println(string(responseBody))
|
||
}
|
||
|
||
usage := &dto.Usage{
|
||
PromptTokens: info.PromptTokens,
|
||
TotalTokens: info.PromptTokens,
|
||
}
|
||
|
||
if info.IsGeminiBatchEmbedding {
|
||
var geminiResponse dto.GeminiBatchEmbeddingResponse
|
||
err = common.Unmarshal(responseBody, &geminiResponse)
|
||
if err != nil {
|
||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||
}
|
||
} else {
|
||
var geminiResponse dto.GeminiEmbeddingResponse
|
||
err = common.Unmarshal(responseBody, &geminiResponse)
|
||
if err != nil {
|
||
return nil, types.NewOpenAIError(err, types.ErrorCodeBadResponseBody, http.StatusInternalServerError)
|
||
}
|
||
}
|
||
|
||
common.IOCopyBytesGracefully(c, resp, responseBody)
|
||
|
||
return usage, nil
|
||
}
|
||
|
||
func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp *http.Response) (*dto.Usage, *types.NewAPIError) {
|
||
var usage = &dto.Usage{}
|
||
var imageCount int
|
||
|
||
helper.SetEventStreamHeaders(c)
|
||
|
||
responseText := strings.Builder{}
|
||
|
||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||
var geminiResponse dto.GeminiChatResponse
|
||
err := common.UnmarshalJsonStr(data, &geminiResponse)
|
||
if err != nil {
|
||
common.LogError(c, "error unmarshalling stream response: "+err.Error())
|
||
return false
|
||
}
|
||
|
||
// 统计图片数量
|
||
for _, candidate := range geminiResponse.Candidates {
|
||
for _, part := range candidate.Content.Parts {
|
||
if part.InlineData != nil && part.InlineData.MimeType != "" {
|
||
imageCount++
|
||
}
|
||
if part.Text != "" {
|
||
responseText.WriteString(part.Text)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 更新使用量统计
|
||
if geminiResponse.UsageMetadata.TotalTokenCount != 0 {
|
||
usage.PromptTokens = geminiResponse.UsageMetadata.PromptTokenCount
|
||
usage.CompletionTokens = geminiResponse.UsageMetadata.CandidatesTokenCount + geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||
usage.TotalTokens = geminiResponse.UsageMetadata.TotalTokenCount
|
||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
||
if detail.Modality == "AUDIO" {
|
||
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
|
||
} else if detail.Modality == "TEXT" {
|
||
usage.PromptTokensDetails.TextTokens = detail.TokenCount
|
||
}
|
||
}
|
||
}
|
||
|
||
// 直接发送 GeminiChatResponse 响应
|
||
err = helper.StringData(c, data)
|
||
if err != nil {
|
||
common.LogError(c, err.Error())
|
||
}
|
||
info.SendResponseCount++
|
||
return true
|
||
})
|
||
|
||
if info.SendResponseCount == 0 {
|
||
return nil, types.NewOpenAIError(errors.New("no response received from Gemini API"), types.ErrorCodeEmptyResponse, http.StatusInternalServerError)
|
||
}
|
||
|
||
if imageCount != 0 {
|
||
if usage.CompletionTokens == 0 {
|
||
usage.CompletionTokens = imageCount * 258
|
||
}
|
||
}
|
||
|
||
// 如果usage.CompletionTokens为0,则使用本地统计的completion tokens
|
||
if usage.CompletionTokens == 0 {
|
||
str := responseText.String()
|
||
if len(str) > 0 {
|
||
usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||
} else {
|
||
// 空补全,不需要使用量
|
||
usage = &dto.Usage{}
|
||
}
|
||
}
|
||
|
||
// 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为
|
||
//helper.Done(c)
|
||
|
||
return usage, nil
|
||
}
|