|
|
|
|
@@ -5,6 +5,7 @@ import (
|
|
|
|
|
"encoding/json"
|
|
|
|
|
"fmt"
|
|
|
|
|
"io"
|
|
|
|
|
"log"
|
|
|
|
|
"net/http"
|
|
|
|
|
"one-api/common"
|
|
|
|
|
"one-api/constant"
|
|
|
|
|
@@ -162,8 +163,12 @@ func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *dto.Ch
|
|
|
|
|
return &response
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, string) {
|
|
|
|
|
func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
|
|
|
responseText := ""
|
|
|
|
|
responseJson := ""
|
|
|
|
|
id := fmt.Sprintf("chatcmpl-%s", common.GetUUID())
|
|
|
|
|
createAt := common.GetTimestamp()
|
|
|
|
|
var usage = &dto.Usage{}
|
|
|
|
|
dataChan := make(chan string, 5)
|
|
|
|
|
stopChan := make(chan bool, 2)
|
|
|
|
|
scanner := bufio.NewScanner(resp.Body)
|
|
|
|
|
@@ -182,6 +187,7 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
|
|
|
|
go func() {
|
|
|
|
|
for scanner.Scan() {
|
|
|
|
|
data := scanner.Text()
|
|
|
|
|
responseJson += data
|
|
|
|
|
data = strings.TrimSpace(data)
|
|
|
|
|
if !strings.HasPrefix(data, "\"text\": \"") {
|
|
|
|
|
continue
|
|
|
|
|
@@ -216,10 +222,10 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
|
|
|
|
var choice dto.ChatCompletionsStreamResponseChoice
|
|
|
|
|
choice.Delta.SetContentString(dummy.Content)
|
|
|
|
|
response := dto.ChatCompletionsStreamResponse{
|
|
|
|
|
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()),
|
|
|
|
|
Id: id,
|
|
|
|
|
Object: "chat.completion.chunk",
|
|
|
|
|
Created: common.GetTimestamp(),
|
|
|
|
|
Model: "gemini-pro",
|
|
|
|
|
Created: createAt,
|
|
|
|
|
Model: info.UpstreamModelName,
|
|
|
|
|
Choices: []dto.ChatCompletionsStreamResponseChoice{choice},
|
|
|
|
|
}
|
|
|
|
|
jsonResponse, err := json.Marshal(response)
|
|
|
|
|
@@ -230,15 +236,34 @@ func geminiChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycom
|
|
|
|
|
c.Render(-1, common.CustomEvent{Data: "data: " + string(jsonResponse)})
|
|
|
|
|
return true
|
|
|
|
|
case <-stopChan:
|
|
|
|
|
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"})
|
|
|
|
|
return false
|
|
|
|
|
}
|
|
|
|
|
})
|
|
|
|
|
err := resp.Body.Close()
|
|
|
|
|
var geminiChatResponses []GeminiChatResponse
|
|
|
|
|
err := json.Unmarshal([]byte(responseJson), &geminiChatResponses)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), ""
|
|
|
|
|
log.Printf("cannot get gemini usage: %s", err.Error())
|
|
|
|
|
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
|
|
|
|
} else {
|
|
|
|
|
for _, response := range geminiChatResponses {
|
|
|
|
|
usage.PromptTokens = response.UsageMetadata.PromptTokenCount
|
|
|
|
|
usage.CompletionTokens = response.UsageMetadata.CandidatesTokenCount
|
|
|
|
|
}
|
|
|
|
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
|
|
|
|
}
|
|
|
|
|
return nil, responseText
|
|
|
|
|
if info.ShouldIncludeUsage {
|
|
|
|
|
response := service.GenerateFinalUsageResponse(id, createAt, info.UpstreamModelName, *usage)
|
|
|
|
|
err := service.ObjectData(c, response)
|
|
|
|
|
if err != nil {
|
|
|
|
|
common.SysError("send final response failed: " + err.Error())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
service.Done(c)
|
|
|
|
|
err = resp.Body.Close()
|
|
|
|
|
if err != nil {
|
|
|
|
|
return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), usage
|
|
|
|
|
}
|
|
|
|
|
return nil, usage
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
|
|
|
|
@@ -267,11 +292,10 @@ func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, mo
|
|
|
|
|
}, nil
|
|
|
|
|
}
|
|
|
|
|
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse)
|
|
|
|
|
completionTokens, _ := service.CountTokenText(geminiResponse.GetResponseText(), model)
|
|
|
|
|
usage := dto.Usage{
|
|
|
|
|
PromptTokens: promptTokens,
|
|
|
|
|
CompletionTokens: completionTokens,
|
|
|
|
|
TotalTokens: promptTokens + completionTokens,
|
|
|
|
|
PromptTokens: geminiResponse.UsageMetadata.PromptTokenCount,
|
|
|
|
|
CompletionTokens: geminiResponse.UsageMetadata.CandidatesTokenCount,
|
|
|
|
|
TotalTokens: geminiResponse.UsageMetadata.TotalTokenCount,
|
|
|
|
|
}
|
|
|
|
|
fullTextResponse.Usage = usage
|
|
|
|
|
jsonResponse, err := json.Marshal(fullTextResponse)
|
|
|
|
|
|