Merge pull request #1272 from QuantumNous/gemini-stream-completion-count-fix
fix: gemini 原生格式流模式中断请求未计费
This commit is contained in:
@@ -66,10 +66,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
promptTokens := 0
|
promptTokens := 0
|
||||||
preConsumedTokens := common.PreConsumedQuota
|
preConsumedTokens := common.PreConsumedQuota
|
||||||
if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
|
if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
|
||||||
promptTokens, err = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
|
promptTokens = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
preConsumedTokens = promptTokens
|
preConsumedTokens = promptTokens
|
||||||
relayInfo.PromptTokens = promptTokens
|
relayInfo.PromptTokens = promptTokens
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -549,7 +549,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
|||||||
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
|
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
|
||||||
|
|
||||||
if requestMode == RequestModeCompletion {
|
if requestMode == RequestModeCompletion {
|
||||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||||
} else {
|
} else {
|
||||||
if claudeInfo.Usage.PromptTokens == 0 {
|
if claudeInfo.Usage.PromptTokens == 0 {
|
||||||
//上游出错
|
//上游出错
|
||||||
@@ -558,7 +558,7 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
|
|||||||
if common.DebugEnabled {
|
if common.DebugEnabled {
|
||||||
common.SysError("claude response usage is not complete, maybe upstream error")
|
common.SysError("claude response usage is not complete, maybe upstream error")
|
||||||
}
|
}
|
||||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -618,10 +618,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if requestMode == RequestModeCompletion {
|
if requestMode == RequestModeCompletion {
|
||||||
completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
|
completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError)
|
|
||||||
}
|
|
||||||
claudeInfo.Usage.PromptTokens = info.PromptTokens
|
claudeInfo.Usage.PromptTokens = info.PromptTokens
|
||||||
claudeInfo.Usage.CompletionTokens = completionTokens
|
claudeInfo.Usage.CompletionTokens = completionTokens
|
||||||
claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens
|
claudeInfo.Usage.TotalTokens = info.PromptTokens + completionTokens
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
|||||||
if err := scanner.Err(); err != nil {
|
if err := scanner.Err(); err != nil {
|
||||||
common.LogError(c, "error_scanning_stream_response: "+err.Error())
|
common.LogError(c, "error_scanning_stream_response: "+err.Error())
|
||||||
}
|
}
|
||||||
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
if info.ShouldIncludeUsage {
|
if info.ShouldIncludeUsage {
|
||||||
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
|
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
|
||||||
err := helper.ObjectData(c, response)
|
err := helper.ObjectData(c, response)
|
||||||
@@ -108,7 +108,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
|
|||||||
for _, choice := range response.Choices {
|
for _, choice := range response.Choices {
|
||||||
responseText += choice.Message.StringContent()
|
responseText += choice.Message.StringContent()
|
||||||
}
|
}
|
||||||
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
response.Usage = *usage
|
response.Usage = *usage
|
||||||
response.Id = helper.GetResponseID(c)
|
response.Id = helper.GetResponseID(c)
|
||||||
jsonResponse, err := json.Marshal(response)
|
jsonResponse, err := json.Marshal(response)
|
||||||
@@ -150,7 +150,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
|
|||||||
|
|
||||||
usage := &dto.Usage{}
|
usage := &dto.Usage{}
|
||||||
usage.PromptTokens = info.PromptTokens
|
usage.PromptTokens = info.PromptTokens
|
||||||
usage.CompletionTokens, _ = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
|
usage.CompletionTokens = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
|
||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
|
|
||||||
return nil, usage
|
return nil, usage
|
||||||
|
|||||||
@@ -162,7 +162,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
if usage.PromptTokens == 0 {
|
if usage.PromptTokens == 0 {
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
}
|
}
|
||||||
return nil, usage
|
return nil, usage
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -106,7 +106,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
|
|||||||
|
|
||||||
var currentEvent string
|
var currentEvent string
|
||||||
var currentData string
|
var currentData string
|
||||||
var usage dto.Usage
|
var usage = &dto.Usage{}
|
||||||
|
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
line := scanner.Text()
|
||||||
@@ -114,7 +114,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
|
|||||||
if line == "" {
|
if line == "" {
|
||||||
if currentEvent != "" && currentData != "" {
|
if currentEvent != "" && currentData != "" {
|
||||||
// handle last event
|
// handle last event
|
||||||
handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
|
handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
|
||||||
currentEvent = ""
|
currentEvent = ""
|
||||||
currentData = ""
|
currentData = ""
|
||||||
}
|
}
|
||||||
@@ -134,7 +134,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
|
|||||||
|
|
||||||
// Last event
|
// Last event
|
||||||
if currentEvent != "" && currentData != "" {
|
if currentEvent != "" && currentData != "" {
|
||||||
handleCozeEvent(c, currentEvent, currentData, &responseText, &usage, id, info)
|
handleCozeEvent(c, currentEvent, currentData, &responseText, usage, id, info)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := scanner.Err(); err != nil {
|
if err := scanner.Err(); err != nil {
|
||||||
@@ -143,12 +143,10 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
|
|||||||
helper.Done(c)
|
helper.Done(c)
|
||||||
|
|
||||||
if usage.TotalTokens == 0 {
|
if usage.TotalTokens == 0 {
|
||||||
usage.PromptTokens = info.PromptTokens
|
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, c.GetInt("coze_input_count"))
|
||||||
usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
|
|
||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, &usage
|
return nil, usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {
|
func handleCozeEvent(c *gin.Context, event string, data string, responseText *string, usage *dto.Usage, id string, info *relaycommon.RelayInfo) {
|
||||||
|
|||||||
@@ -243,15 +243,8 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
|||||||
return true
|
return true
|
||||||
})
|
})
|
||||||
helper.Done(c)
|
helper.Done(c)
|
||||||
err := resp.Body.Close()
|
|
||||||
if err != nil {
|
|
||||||
// return service.OpenAIErrorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil
|
|
||||||
common.SysError("close_response_body_failed: " + err.Error())
|
|
||||||
}
|
|
||||||
if usage.TotalTokens == 0 {
|
if usage.TotalTokens == 0 {
|
||||||
usage.PromptTokens = info.PromptTokens
|
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
|
|
||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
|
||||||
}
|
}
|
||||||
usage.CompletionTokens += nodeToken
|
usage.CompletionTokens += nodeToken
|
||||||
return nil, usage
|
return nil, usage
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -75,6 +76,8 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
|
|||||||
|
|
||||||
helper.SetEventStreamHeaders(c)
|
helper.SetEventStreamHeaders(c)
|
||||||
|
|
||||||
|
responseText := strings.Builder{}
|
||||||
|
|
||||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||||
var geminiResponse GeminiChatResponse
|
var geminiResponse GeminiChatResponse
|
||||||
err := common.DecodeJsonStr(data, &geminiResponse)
|
err := common.DecodeJsonStr(data, &geminiResponse)
|
||||||
@@ -89,6 +92,9 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
|
|||||||
if part.InlineData != nil && part.InlineData.MimeType != "" {
|
if part.InlineData != nil && part.InlineData.MimeType != "" {
|
||||||
imageCount++
|
imageCount++
|
||||||
}
|
}
|
||||||
|
if part.Text != "" {
|
||||||
|
responseText.WriteString(part.Text)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -122,8 +128,16 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 计算最终使用量
|
// 如果usage.CompletionTokens为0,则使用本地统计的completion tokens
|
||||||
// usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
if usage.CompletionTokens == 0 {
|
||||||
|
str := responseText.String()
|
||||||
|
if len(str) > 0 {
|
||||||
|
usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||||
|
} else {
|
||||||
|
// 空补全,不需要使用量
|
||||||
|
usage = &dto.Usage{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为
|
// 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为
|
||||||
//helper.Done(c)
|
//helper.Done(c)
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
"path/filepath"
|
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
@@ -16,6 +15,7 @@ import (
|
|||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/bytedance/gopkg/util/gopool"
|
"github.com/bytedance/gopkg/util/gopool"
|
||||||
@@ -181,7 +181,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !containStreamUsage {
|
if !containStreamUsage {
|
||||||
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||||
usage.CompletionTokens += toolCount * 7
|
usage.CompletionTokens += toolCount * 7
|
||||||
} else {
|
} else {
|
||||||
if info.ChannelType == common.ChannelTypeDeepSeek {
|
if info.ChannelType == common.ChannelTypeDeepSeek {
|
||||||
@@ -225,7 +225,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
|||||||
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
|
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
|
||||||
completionTokens := 0
|
completionTokens := 0
|
||||||
for _, choice := range simpleResponse.Choices {
|
for _, choice := range simpleResponse.Choices {
|
||||||
ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
|
ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
|
||||||
completionTokens += ctkm
|
completionTokens += ctkm
|
||||||
}
|
}
|
||||||
simpleResponse.Usage = dto.Usage{
|
simpleResponse.Usage = dto.Usage{
|
||||||
@@ -346,12 +346,12 @@ func countAudioTokens(c *gin.Context) (int, error) {
|
|||||||
if err = c.ShouldBind(&reqBody); err != nil {
|
if err = c.ShouldBind(&reqBody); err != nil {
|
||||||
return 0, errors.WithStack(err)
|
return 0, errors.WithStack(err)
|
||||||
}
|
}
|
||||||
ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
|
ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
|
||||||
reqFp, err := reqBody.File.Open()
|
reqFp, err := reqBody.File.Open()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.WithStack(err)
|
return 0, errors.WithStack(err)
|
||||||
}
|
}
|
||||||
defer reqFp.Close()
|
defer reqFp.Close()
|
||||||
|
|
||||||
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
|
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -110,7 +110,7 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc
|
|||||||
tempStr := responseTextBuilder.String()
|
tempStr := responseTextBuilder.String()
|
||||||
if len(tempStr) > 0 {
|
if len(tempStr) > 0 {
|
||||||
// 非正常结束,使用输出文本的 token 数量
|
// 非正常结束,使用输出文本的 token 数量
|
||||||
completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName)
|
completionTokens := service.CountTextToken(tempStr, info.UpstreamModelName)
|
||||||
usage.CompletionTokens = completionTokens
|
usage.CompletionTokens = completionTokens
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
err, responseText = palmStreamHandler(c, resp)
|
err, responseText = palmStreamHandler(c, resp)
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
} else {
|
} else {
|
||||||
err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -155,7 +155,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
||||||
completionTokens, _ := service.CountTextToken(palmResponse.Candidates[0].Content, model)
|
completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, model)
|
||||||
usage := dto.Usage{
|
usage := dto.Usage{
|
||||||
PromptTokens: promptTokens,
|
PromptTokens: promptTokens,
|
||||||
CompletionTokens: completionTokens,
|
CompletionTokens: completionTokens,
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
err, responseText = tencentStreamHandler(c, resp)
|
err, responseText = tencentStreamHandler(c, resp)
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
} else {
|
} else {
|
||||||
err, usage = tencentHandler(c, resp)
|
err, usage = tencentHandler(c, resp)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
})
|
})
|
||||||
|
|
||||||
if !containStreamUsage {
|
if !containStreamUsage {
|
||||||
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||||
usage.CompletionTokens += toolCount * 7
|
usage.CompletionTokens += toolCount * 7
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
|
func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
|
||||||
token, _ := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model)
|
token := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model)
|
||||||
return token
|
return token
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string,
|
|||||||
return sensitiveWords, err
|
return sensitiveWords, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) (int, error) {
|
func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) int {
|
||||||
// 计算输入 token 数量
|
// 计算输入 token 数量
|
||||||
var inputTexts []string
|
var inputTexts []string
|
||||||
for _, content := range req.Contents {
|
for _, content := range req.Contents {
|
||||||
@@ -71,9 +71,9 @@ func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.Relay
|
|||||||
}
|
}
|
||||||
|
|
||||||
inputText := strings.Join(inputTexts, "\n")
|
inputText := strings.Join(inputTexts, "\n")
|
||||||
inputTokens, err := service.CountTokenInput(inputText, info.UpstreamModelName)
|
inputTokens := service.CountTokenInput(inputText, info.UpstreamModelName)
|
||||||
info.PromptTokens = inputTokens
|
info.PromptTokens = inputTokens
|
||||||
return inputTokens, err
|
return inputTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||||
@@ -106,7 +106,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
promptTokens := value.(int)
|
promptTokens := value.(int)
|
||||||
relayInfo.SetPromptTokens(promptTokens)
|
relayInfo.SetPromptTokens(promptTokens)
|
||||||
} else {
|
} else {
|
||||||
promptTokens, err := getGeminiInputTokens(req, relayInfo)
|
promptTokens := getGeminiInputTokens(req, relayInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest)
|
return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -251,11 +251,11 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
|
|||||||
case relayconstant.RelayModeChatCompletions:
|
case relayconstant.RelayModeChatCompletions:
|
||||||
promptTokens, err = service.CountTokenChatRequest(info, *textRequest)
|
promptTokens, err = service.CountTokenChatRequest(info, *textRequest)
|
||||||
case relayconstant.RelayModeCompletions:
|
case relayconstant.RelayModeCompletions:
|
||||||
promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
|
promptTokens = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
|
||||||
case relayconstant.RelayModeModerations:
|
case relayconstant.RelayModeModerations:
|
||||||
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
|
promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model)
|
||||||
case relayconstant.RelayModeEmbeddings:
|
case relayconstant.RelayModeEmbeddings:
|
||||||
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
|
promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model)
|
||||||
default:
|
default:
|
||||||
err = errors.New("unknown relay mode")
|
err = errors.New("unknown relay mode")
|
||||||
promptTokens = 0
|
promptTokens = 0
|
||||||
|
|||||||
@@ -14,12 +14,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
|
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
|
||||||
token, _ := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
|
token := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
|
||||||
for _, document := range rerankRequest.Documents {
|
for _, document := range rerankRequest.Documents {
|
||||||
tkm, err := service.CountTokenInput(document, rerankRequest.Model)
|
tkm := service.CountTokenInput(document, rerankRequest.Model)
|
||||||
if err == nil {
|
token += tkm
|
||||||
token += tkm
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return token
|
return token
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -40,10 +40,10 @@ func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycom
|
|||||||
return sensitiveWords, err
|
return sensitiveWords, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) (int, error) {
|
func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) int {
|
||||||
inputTokens, err := service.CountTokenInput(req.Input, req.Model)
|
inputTokens := service.CountTokenInput(req.Input, req.Model)
|
||||||
info.PromptTokens = inputTokens
|
info.PromptTokens = inputTokens
|
||||||
return inputTokens, err
|
return inputTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||||
@@ -72,10 +72,7 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
|
|||||||
promptTokens := value.(int)
|
promptTokens := value.(int)
|
||||||
relayInfo.SetPromptTokens(promptTokens)
|
relayInfo.SetPromptTokens(promptTokens)
|
||||||
} else {
|
} else {
|
||||||
promptTokens, err := getInputTokens(req, relayInfo)
|
promptTokens := getInputTokens(req, relayInfo)
|
||||||
if err != nil {
|
|
||||||
return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest)
|
|
||||||
}
|
|
||||||
c.Set("prompt_tokens", promptTokens)
|
c.Set("prompt_tokens", promptTokens)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -171,7 +171,7 @@ func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenA
|
|||||||
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
|
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
toolTokens, err := CountTokenInput(countStr, request.Model)
|
toolTokens := CountTokenInput(countStr, request.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@@ -194,7 +194,7 @@ func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, erro
|
|||||||
|
|
||||||
// Count tokens in system message
|
// Count tokens in system message
|
||||||
if request.System != "" {
|
if request.System != "" {
|
||||||
systemTokens, err := CountTokenInput(request.System, model)
|
systemTokens := CountTokenInput(request.System, model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@@ -296,10 +296,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
|||||||
switch request.Type {
|
switch request.Type {
|
||||||
case dto.RealtimeEventTypeSessionUpdate:
|
case dto.RealtimeEventTypeSessionUpdate:
|
||||||
if request.Session != nil {
|
if request.Session != nil {
|
||||||
msgTokens, err := CountTextToken(request.Session.Instructions, model)
|
msgTokens := CountTextToken(request.Session.Instructions, model)
|
||||||
if err != nil {
|
|
||||||
return 0, 0, err
|
|
||||||
}
|
|
||||||
textToken += msgTokens
|
textToken += msgTokens
|
||||||
}
|
}
|
||||||
case dto.RealtimeEventResponseAudioDelta:
|
case dto.RealtimeEventResponseAudioDelta:
|
||||||
@@ -311,10 +308,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
|||||||
audioToken += atk
|
audioToken += atk
|
||||||
case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
|
case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
|
||||||
// count text token
|
// count text token
|
||||||
tkm, err := CountTextToken(request.Delta, model)
|
tkm := CountTextToken(request.Delta, model)
|
||||||
if err != nil {
|
|
||||||
return 0, 0, fmt.Errorf("error counting text token: %v", err)
|
|
||||||
}
|
|
||||||
textToken += tkm
|
textToken += tkm
|
||||||
case dto.RealtimeEventInputAudioBufferAppend:
|
case dto.RealtimeEventInputAudioBufferAppend:
|
||||||
// count audio token
|
// count audio token
|
||||||
@@ -329,10 +323,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
|||||||
case "message":
|
case "message":
|
||||||
for _, content := range request.Item.Content {
|
for _, content := range request.Item.Content {
|
||||||
if content.Type == "input_text" {
|
if content.Type == "input_text" {
|
||||||
tokens, err := CountTextToken(content.Text, model)
|
tokens := CountTextToken(content.Text, model)
|
||||||
if err != nil {
|
|
||||||
return 0, 0, err
|
|
||||||
}
|
|
||||||
textToken += tokens
|
textToken += tokens
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -343,10 +334,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
|||||||
if !info.IsFirstRequest {
|
if !info.IsFirstRequest {
|
||||||
if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
|
if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
|
||||||
for _, tool := range info.RealtimeTools {
|
for _, tool := range info.RealtimeTools {
|
||||||
toolTokens, err := CountTokenInput(tool, model)
|
toolTokens := CountTokenInput(tool, model)
|
||||||
if err != nil {
|
|
||||||
return 0, 0, err
|
|
||||||
}
|
|
||||||
textToken += 8
|
textToken += 8
|
||||||
textToken += toolTokens
|
textToken += toolTokens
|
||||||
}
|
}
|
||||||
@@ -409,7 +397,7 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod
|
|||||||
return tokenNum, nil
|
return tokenNum, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CountTokenInput(input any, model string) (int, error) {
|
func CountTokenInput(input any, model string) int {
|
||||||
switch v := input.(type) {
|
switch v := input.(type) {
|
||||||
case string:
|
case string:
|
||||||
return CountTextToken(v, model)
|
return CountTextToken(v, model)
|
||||||
@@ -432,13 +420,13 @@ func CountTokenInput(input any, model string) (int, error) {
|
|||||||
func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
|
func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
|
||||||
tokens := 0
|
tokens := 0
|
||||||
for _, message := range messages {
|
for _, message := range messages {
|
||||||
tkm, _ := CountTokenInput(message.Delta.GetContentString(), model)
|
tkm := CountTokenInput(message.Delta.GetContentString(), model)
|
||||||
tokens += tkm
|
tokens += tkm
|
||||||
if message.Delta.ToolCalls != nil {
|
if message.Delta.ToolCalls != nil {
|
||||||
for _, tool := range message.Delta.ToolCalls {
|
for _, tool := range message.Delta.ToolCalls {
|
||||||
tkm, _ := CountTokenInput(tool.Function.Name, model)
|
tkm := CountTokenInput(tool.Function.Name, model)
|
||||||
tokens += tkm
|
tokens += tkm
|
||||||
tkm, _ = CountTokenInput(tool.Function.Arguments, model)
|
tkm = CountTokenInput(tool.Function.Arguments, model)
|
||||||
tokens += tkm
|
tokens += tkm
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -446,9 +434,9 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice,
|
|||||||
return tokens
|
return tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
func CountTTSToken(text string, model string) (int, error) {
|
func CountTTSToken(text string, model string) int {
|
||||||
if strings.HasPrefix(model, "tts") {
|
if strings.HasPrefix(model, "tts") {
|
||||||
return utf8.RuneCountInString(text), nil
|
return utf8.RuneCountInString(text)
|
||||||
} else {
|
} else {
|
||||||
return CountTextToken(text, model)
|
return CountTextToken(text, model)
|
||||||
}
|
}
|
||||||
@@ -483,8 +471,10 @@ func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error)
|
|||||||
//}
|
//}
|
||||||
|
|
||||||
// CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
|
// CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
|
||||||
func CountTextToken(text string, model string) (int, error) {
|
func CountTextToken(text string, model string) int {
|
||||||
var err error
|
if text == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
tokenEncoder := getTokenEncoder(model)
|
tokenEncoder := getTokenEncoder(model)
|
||||||
return getTokenNum(tokenEncoder, text), err
|
return getTokenNum(tokenEncoder, text)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,13 +16,13 @@ import (
|
|||||||
// return 0, errors.New("unknown relay mode")
|
// return 0, errors.New("unknown relay mode")
|
||||||
//}
|
//}
|
||||||
|
|
||||||
func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
|
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *dto.Usage {
|
||||||
usage := &dto.Usage{}
|
usage := &dto.Usage{}
|
||||||
usage.PromptTokens = promptTokens
|
usage.PromptTokens = promptTokens
|
||||||
ctkm, err := CountTextToken(responseText, modeName)
|
ctkm := CountTextToken(responseText, modeName)
|
||||||
usage.CompletionTokens = ctkm
|
usage.CompletionTokens = ctkm
|
||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
return usage, err
|
return usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func ValidUsage(usage *dto.Usage) bool {
|
func ValidUsage(usage *dto.Usage) bool {
|
||||||
|
|||||||
Reference in New Issue
Block a user