refactor: token counter logic
This commit is contained in:
@@ -66,7 +66,7 @@ func AudioHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
promptTokens := 0
|
promptTokens := 0
|
||||||
preConsumedTokens := common.PreConsumedQuota
|
preConsumedTokens := common.PreConsumedQuota
|
||||||
if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
|
if relayInfo.RelayMode == relayconstant.RelayModeAudioSpeech {
|
||||||
promptTokens, err = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
|
promptTokens = service.CountTTSToken(audioRequest.Input, audioRequest.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "count_audio_token_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -549,7 +549,7 @@ func HandleStreamResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
|||||||
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
|
func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, claudeInfo *ClaudeResponseInfo, requestMode int) {
|
||||||
|
|
||||||
if requestMode == RequestModeCompletion {
|
if requestMode == RequestModeCompletion {
|
||||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||||
} else {
|
} else {
|
||||||
if claudeInfo.Usage.PromptTokens == 0 {
|
if claudeInfo.Usage.PromptTokens == 0 {
|
||||||
//上游出错
|
//上游出错
|
||||||
@@ -558,7 +558,7 @@ func HandleStreamFinalResponse(c *gin.Context, info *relaycommon.RelayInfo, clau
|
|||||||
if common.DebugEnabled {
|
if common.DebugEnabled {
|
||||||
common.SysError("claude response usage is not complete, maybe upstream error")
|
common.SysError("claude response usage is not complete, maybe upstream error")
|
||||||
}
|
}
|
||||||
claudeInfo.Usage, _ = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
claudeInfo.Usage = service.ResponseText2Usage(claudeInfo.ResponseText.String(), info.UpstreamModelName, claudeInfo.Usage.PromptTokens)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -618,7 +618,7 @@ func HandleClaudeResponseData(c *gin.Context, info *relaycommon.RelayInfo, claud
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if requestMode == RequestModeCompletion {
|
if requestMode == RequestModeCompletion {
|
||||||
completionTokens, err := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
|
completionTokens := service.CountTextToken(claudeResponse.Completion, info.OriginModelName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError)
|
return service.OpenAIErrorWrapper(err, "count_token_text_failed", http.StatusInternalServerError)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ func cfStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rela
|
|||||||
if err := scanner.Err(); err != nil {
|
if err := scanner.Err(); err != nil {
|
||||||
common.LogError(c, "error_scanning_stream_response: "+err.Error())
|
common.LogError(c, "error_scanning_stream_response: "+err.Error())
|
||||||
}
|
}
|
||||||
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
if info.ShouldIncludeUsage {
|
if info.ShouldIncludeUsage {
|
||||||
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
|
response := helper.GenerateFinalUsageResponse(id, info.StartTime.Unix(), info.UpstreamModelName, *usage)
|
||||||
err := helper.ObjectData(c, response)
|
err := helper.ObjectData(c, response)
|
||||||
@@ -108,7 +108,7 @@ func cfHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo)
|
|||||||
for _, choice := range response.Choices {
|
for _, choice := range response.Choices {
|
||||||
responseText += choice.Message.StringContent()
|
responseText += choice.Message.StringContent()
|
||||||
}
|
}
|
||||||
usage, _ := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage := service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
response.Usage = *usage
|
response.Usage = *usage
|
||||||
response.Id = helper.GetResponseID(c)
|
response.Id = helper.GetResponseID(c)
|
||||||
jsonResponse, err := json.Marshal(response)
|
jsonResponse, err := json.Marshal(response)
|
||||||
@@ -150,7 +150,7 @@ func cfSTTHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayIn
|
|||||||
|
|
||||||
usage := &dto.Usage{}
|
usage := &dto.Usage{}
|
||||||
usage.PromptTokens = info.PromptTokens
|
usage.PromptTokens = info.PromptTokens
|
||||||
usage.CompletionTokens, _ = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
|
usage.CompletionTokens = service.CountTextToken(cfResp.Result.Text, info.UpstreamModelName)
|
||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
|
|
||||||
return nil, usage
|
return nil, usage
|
||||||
|
|||||||
@@ -162,7 +162,7 @@ func cohereStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
if usage.PromptTokens == 0 {
|
if usage.PromptTokens == 0 {
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
}
|
}
|
||||||
return nil, usage
|
return nil, usage
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -144,7 +144,7 @@ func cozeChatStreamHandler(c *gin.Context, resp *http.Response, info *relaycommo
|
|||||||
|
|
||||||
if usage.TotalTokens == 0 {
|
if usage.TotalTokens == 0 {
|
||||||
usage.PromptTokens = info.PromptTokens
|
usage.PromptTokens = info.PromptTokens
|
||||||
usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
|
usage.CompletionTokens = service.CountTextToken("gpt-3.5-turbo", responseText)
|
||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -250,7 +250,7 @@ func difyStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Re
|
|||||||
}
|
}
|
||||||
if usage.TotalTokens == 0 {
|
if usage.TotalTokens == 0 {
|
||||||
usage.PromptTokens = info.PromptTokens
|
usage.PromptTokens = info.PromptTokens
|
||||||
usage.CompletionTokens, _ = service.CountTextToken("gpt-3.5-turbo", responseText)
|
usage.CompletionTokens = service.CountTextToken("gpt-3.5-turbo", responseText)
|
||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
}
|
}
|
||||||
usage.CompletionTokens += nodeToken
|
usage.CompletionTokens += nodeToken
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
relaycommon "one-api/relay/common"
|
relaycommon "one-api/relay/common"
|
||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
@@ -75,8 +76,7 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
|
|||||||
|
|
||||||
helper.SetEventStreamHeaders(c)
|
helper.SetEventStreamHeaders(c)
|
||||||
|
|
||||||
// 本地统计的completion tokens
|
responseText := strings.Builder{}
|
||||||
localCompletionTokens := 0
|
|
||||||
|
|
||||||
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
helper.StreamScannerHandler(c, resp, info, func(data string) bool {
|
||||||
var geminiResponse GeminiChatResponse
|
var geminiResponse GeminiChatResponse
|
||||||
@@ -92,12 +92,9 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
|
|||||||
if part.InlineData != nil && part.InlineData.MimeType != "" {
|
if part.InlineData != nil && part.InlineData.MimeType != "" {
|
||||||
imageCount++
|
imageCount++
|
||||||
}
|
}
|
||||||
// 本地统计completion tokens
|
if part.Text != "" {
|
||||||
textTokens, err := service.CountTextToken(part.Text, info.UpstreamModelName)
|
responseText.WriteString(part.Text)
|
||||||
if err != nil {
|
|
||||||
common.LogError(c, "error counting text token: "+err.Error())
|
|
||||||
}
|
}
|
||||||
localCompletionTokens += textTokens
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -133,13 +130,9 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, resp *http.Response, info
|
|||||||
|
|
||||||
// 如果usage.CompletionTokens为0,则使用本地统计的completion tokens
|
// 如果usage.CompletionTokens为0,则使用本地统计的completion tokens
|
||||||
if usage.CompletionTokens == 0 {
|
if usage.CompletionTokens == 0 {
|
||||||
usage.CompletionTokens = localCompletionTokens
|
usage = service.ResponseText2Usage(responseText.String(), info.UpstreamModelName, info.PromptTokens)
|
||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 计算最终使用量
|
|
||||||
// usage.CompletionTokens = usage.TotalTokens - usage.PromptTokens
|
|
||||||
|
|
||||||
// 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为
|
// 移除流式响应结尾的[Done],因为Gemini API没有发送Done的行为
|
||||||
//helper.Done(c)
|
//helper.Done(c)
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"math"
|
"math"
|
||||||
"mime/multipart"
|
"mime/multipart"
|
||||||
"net/http"
|
"net/http"
|
||||||
"path/filepath"
|
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
"one-api/constant"
|
"one-api/constant"
|
||||||
"one-api/dto"
|
"one-api/dto"
|
||||||
@@ -16,6 +15,7 @@ import (
|
|||||||
"one-api/relay/helper"
|
"one-api/relay/helper"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
"os"
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/bytedance/gopkg/util/gopool"
|
"github.com/bytedance/gopkg/util/gopool"
|
||||||
@@ -181,7 +181,7 @@ func OaiStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !containStreamUsage {
|
if !containStreamUsage {
|
||||||
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||||
usage.CompletionTokens += toolCount * 7
|
usage.CompletionTokens += toolCount * 7
|
||||||
} else {
|
} else {
|
||||||
if info.ChannelType == common.ChannelTypeDeepSeek {
|
if info.ChannelType == common.ChannelTypeDeepSeek {
|
||||||
@@ -216,7 +216,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
|||||||
StatusCode: resp.StatusCode,
|
StatusCode: resp.StatusCode,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
forceFormat := false
|
forceFormat := false
|
||||||
if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
|
if forceFmt, ok := info.ChannelSetting[constant.ForceFormat].(bool); ok {
|
||||||
forceFormat = forceFmt
|
forceFormat = forceFmt
|
||||||
@@ -225,7 +225,7 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
|||||||
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
|
if simpleResponse.Usage.TotalTokens == 0 || (simpleResponse.Usage.PromptTokens == 0 && simpleResponse.Usage.CompletionTokens == 0) {
|
||||||
completionTokens := 0
|
completionTokens := 0
|
||||||
for _, choice := range simpleResponse.Choices {
|
for _, choice := range simpleResponse.Choices {
|
||||||
ctkm, _ := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
|
ctkm := service.CountTextToken(choice.Message.StringContent()+choice.Message.ReasoningContent+choice.Message.Reasoning, info.UpstreamModelName)
|
||||||
completionTokens += ctkm
|
completionTokens += ctkm
|
||||||
}
|
}
|
||||||
simpleResponse.Usage = dto.Usage{
|
simpleResponse.Usage = dto.Usage{
|
||||||
@@ -276,9 +276,9 @@ func OpenaiHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayI
|
|||||||
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
func OpenaiTTSHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (*dto.OpenAIErrorWithStatusCode, *dto.Usage) {
|
||||||
// the status code has been judged before, if there is a body reading failure,
|
// the status code has been judged before, if there is a body reading failure,
|
||||||
// it should be regarded as a non-recoverable error, so it should not return err for external retry.
|
// it should be regarded as a non-recoverable error, so it should not return err for external retry.
|
||||||
// Analogous to nginx's load balancing, it will only retry if it can't be requested or
|
// Analogous to nginx's load balancing, it will only retry if it can't be requested or
|
||||||
// if the upstream returns a specific status code, once the upstream has already written the header,
|
// if the upstream returns a specific status code, once the upstream has already written the header,
|
||||||
// the subsequent failure of the response body should be regarded as a non-recoverable error,
|
// the subsequent failure of the response body should be regarded as a non-recoverable error,
|
||||||
// and can be terminated directly.
|
// and can be terminated directly.
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
usage := &dto.Usage{}
|
usage := &dto.Usage{}
|
||||||
@@ -346,12 +346,12 @@ func countAudioTokens(c *gin.Context) (int, error) {
|
|||||||
if err = c.ShouldBind(&reqBody); err != nil {
|
if err = c.ShouldBind(&reqBody); err != nil {
|
||||||
return 0, errors.WithStack(err)
|
return 0, errors.WithStack(err)
|
||||||
}
|
}
|
||||||
ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
|
ext := filepath.Ext(reqBody.File.Filename) // 获取文件扩展名
|
||||||
reqFp, err := reqBody.File.Open()
|
reqFp, err := reqBody.File.Open()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.WithStack(err)
|
return 0, errors.WithStack(err)
|
||||||
}
|
}
|
||||||
defer reqFp.Close()
|
defer reqFp.Close()
|
||||||
|
|
||||||
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
|
tmpFp, err := os.CreateTemp("", "audio-*"+ext)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -110,7 +110,7 @@ func OaiResponsesStreamHandler(c *gin.Context, resp *http.Response, info *relayc
|
|||||||
tempStr := responseTextBuilder.String()
|
tempStr := responseTextBuilder.String()
|
||||||
if len(tempStr) > 0 {
|
if len(tempStr) > 0 {
|
||||||
// 非正常结束,使用输出文本的 token 数量
|
// 非正常结束,使用输出文本的 token 数量
|
||||||
completionTokens, _ := service.CountTextToken(tempStr, info.UpstreamModelName)
|
completionTokens := service.CountTextToken(tempStr, info.UpstreamModelName)
|
||||||
usage.CompletionTokens = completionTokens
|
usage.CompletionTokens = completionTokens
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
err, responseText = palmStreamHandler(c, resp)
|
err, responseText = palmStreamHandler(c, resp)
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
} else {
|
} else {
|
||||||
err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
err, usage = palmHandler(c, resp, info.PromptTokens, info.UpstreamModelName)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -155,7 +155,7 @@ func palmHandler(c *gin.Context, resp *http.Response, promptTokens int, model st
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
fullTextResponse := responsePaLM2OpenAI(&palmResponse)
|
||||||
completionTokens, _ := service.CountTextToken(palmResponse.Candidates[0].Content, model)
|
completionTokens := service.CountTextToken(palmResponse.Candidates[0].Content, model)
|
||||||
usage := dto.Usage{
|
usage := dto.Usage{
|
||||||
PromptTokens: promptTokens,
|
PromptTokens: promptTokens,
|
||||||
CompletionTokens: completionTokens,
|
CompletionTokens: completionTokens,
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycom
|
|||||||
if info.IsStream {
|
if info.IsStream {
|
||||||
var responseText string
|
var responseText string
|
||||||
err, responseText = tencentStreamHandler(c, resp)
|
err, responseText = tencentStreamHandler(c, resp)
|
||||||
usage, _ = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
usage = service.ResponseText2Usage(responseText, info.UpstreamModelName, info.PromptTokens)
|
||||||
} else {
|
} else {
|
||||||
err, usage = tencentHandler(c, resp)
|
err, usage = tencentHandler(c, resp)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ func xAIStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
})
|
})
|
||||||
|
|
||||||
if !containStreamUsage {
|
if !containStreamUsage {
|
||||||
usage, _ = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
usage = service.ResponseText2Usage(responseTextBuilder.String(), info.UpstreamModelName, info.PromptTokens)
|
||||||
usage.CompletionTokens += toolCount * 7
|
usage.CompletionTokens += toolCount * 7
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
|
func getEmbeddingPromptToken(embeddingRequest dto.EmbeddingRequest) int {
|
||||||
token, _ := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model)
|
token := service.CountTokenInput(embeddingRequest.Input, embeddingRequest.Model)
|
||||||
return token
|
return token
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ func checkGeminiInputSensitive(textRequest *gemini.GeminiChatRequest) ([]string,
|
|||||||
return sensitiveWords, err
|
return sensitiveWords, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) (int, error) {
|
func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.RelayInfo) int {
|
||||||
// 计算输入 token 数量
|
// 计算输入 token 数量
|
||||||
var inputTexts []string
|
var inputTexts []string
|
||||||
for _, content := range req.Contents {
|
for _, content := range req.Contents {
|
||||||
@@ -71,9 +71,9 @@ func getGeminiInputTokens(req *gemini.GeminiChatRequest, info *relaycommon.Relay
|
|||||||
}
|
}
|
||||||
|
|
||||||
inputText := strings.Join(inputTexts, "\n")
|
inputText := strings.Join(inputTexts, "\n")
|
||||||
inputTokens, err := service.CountTokenInput(inputText, info.UpstreamModelName)
|
inputTokens := service.CountTokenInput(inputText, info.UpstreamModelName)
|
||||||
info.PromptTokens = inputTokens
|
info.PromptTokens = inputTokens
|
||||||
return inputTokens, err
|
return inputTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||||
@@ -106,7 +106,7 @@ func GeminiHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
|||||||
promptTokens := value.(int)
|
promptTokens := value.(int)
|
||||||
relayInfo.SetPromptTokens(promptTokens)
|
relayInfo.SetPromptTokens(promptTokens)
|
||||||
} else {
|
} else {
|
||||||
promptTokens, err := getGeminiInputTokens(req, relayInfo)
|
promptTokens := getGeminiInputTokens(req, relayInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest)
|
return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -251,11 +251,11 @@ func getPromptTokens(textRequest *dto.GeneralOpenAIRequest, info *relaycommon.Re
|
|||||||
case relayconstant.RelayModeChatCompletions:
|
case relayconstant.RelayModeChatCompletions:
|
||||||
promptTokens, err = service.CountTokenChatRequest(info, *textRequest)
|
promptTokens, err = service.CountTokenChatRequest(info, *textRequest)
|
||||||
case relayconstant.RelayModeCompletions:
|
case relayconstant.RelayModeCompletions:
|
||||||
promptTokens, err = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
|
promptTokens = service.CountTokenInput(textRequest.Prompt, textRequest.Model)
|
||||||
case relayconstant.RelayModeModerations:
|
case relayconstant.RelayModeModerations:
|
||||||
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
|
promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model)
|
||||||
case relayconstant.RelayModeEmbeddings:
|
case relayconstant.RelayModeEmbeddings:
|
||||||
promptTokens, err = service.CountTokenInput(textRequest.Input, textRequest.Model)
|
promptTokens = service.CountTokenInput(textRequest.Input, textRequest.Model)
|
||||||
default:
|
default:
|
||||||
err = errors.New("unknown relay mode")
|
err = errors.New("unknown relay mode")
|
||||||
promptTokens = 0
|
promptTokens = 0
|
||||||
|
|||||||
@@ -14,12 +14,10 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
|
func getRerankPromptToken(rerankRequest dto.RerankRequest) int {
|
||||||
token, _ := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
|
token := service.CountTokenInput(rerankRequest.Query, rerankRequest.Model)
|
||||||
for _, document := range rerankRequest.Documents {
|
for _, document := range rerankRequest.Documents {
|
||||||
tkm, err := service.CountTokenInput(document, rerankRequest.Model)
|
tkm := service.CountTokenInput(document, rerankRequest.Model)
|
||||||
if err == nil {
|
token += tkm
|
||||||
token += tkm
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return token
|
return token
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -40,10 +40,10 @@ func checkInputSensitive(textRequest *dto.OpenAIResponsesRequest, info *relaycom
|
|||||||
return sensitiveWords, err
|
return sensitiveWords, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) (int, error) {
|
func getInputTokens(req *dto.OpenAIResponsesRequest, info *relaycommon.RelayInfo) int {
|
||||||
inputTokens, err := service.CountTokenInput(req.Input, req.Model)
|
inputTokens := service.CountTokenInput(req.Input, req.Model)
|
||||||
info.PromptTokens = inputTokens
|
info.PromptTokens = inputTokens
|
||||||
return inputTokens, err
|
return inputTokens
|
||||||
}
|
}
|
||||||
|
|
||||||
func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
|
||||||
@@ -72,7 +72,7 @@ func ResponsesHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode)
|
|||||||
promptTokens := value.(int)
|
promptTokens := value.(int)
|
||||||
relayInfo.SetPromptTokens(promptTokens)
|
relayInfo.SetPromptTokens(promptTokens)
|
||||||
} else {
|
} else {
|
||||||
promptTokens, err := getInputTokens(req, relayInfo)
|
promptTokens := getInputTokens(req, relayInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest)
|
return service.OpenAIErrorWrapperLocal(err, "count_input_tokens_error", http.StatusBadRequest)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -171,7 +171,7 @@ func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenA
|
|||||||
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
|
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
toolTokens, err := CountTokenInput(countStr, request.Model)
|
toolTokens := CountTokenInput(countStr, request.Model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@@ -194,7 +194,7 @@ func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, erro
|
|||||||
|
|
||||||
// Count tokens in system message
|
// Count tokens in system message
|
||||||
if request.System != "" {
|
if request.System != "" {
|
||||||
systemTokens, err := CountTokenInput(request.System, model)
|
systemTokens := CountTokenInput(request.System, model)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@@ -296,10 +296,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
|||||||
switch request.Type {
|
switch request.Type {
|
||||||
case dto.RealtimeEventTypeSessionUpdate:
|
case dto.RealtimeEventTypeSessionUpdate:
|
||||||
if request.Session != nil {
|
if request.Session != nil {
|
||||||
msgTokens, err := CountTextToken(request.Session.Instructions, model)
|
msgTokens := CountTextToken(request.Session.Instructions, model)
|
||||||
if err != nil {
|
|
||||||
return 0, 0, err
|
|
||||||
}
|
|
||||||
textToken += msgTokens
|
textToken += msgTokens
|
||||||
}
|
}
|
||||||
case dto.RealtimeEventResponseAudioDelta:
|
case dto.RealtimeEventResponseAudioDelta:
|
||||||
@@ -311,10 +308,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
|||||||
audioToken += atk
|
audioToken += atk
|
||||||
case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
|
case dto.RealtimeEventResponseAudioTranscriptionDelta, dto.RealtimeEventResponseFunctionCallArgumentsDelta:
|
||||||
// count text token
|
// count text token
|
||||||
tkm, err := CountTextToken(request.Delta, model)
|
tkm := CountTextToken(request.Delta, model)
|
||||||
if err != nil {
|
|
||||||
return 0, 0, fmt.Errorf("error counting text token: %v", err)
|
|
||||||
}
|
|
||||||
textToken += tkm
|
textToken += tkm
|
||||||
case dto.RealtimeEventInputAudioBufferAppend:
|
case dto.RealtimeEventInputAudioBufferAppend:
|
||||||
// count audio token
|
// count audio token
|
||||||
@@ -329,10 +323,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
|||||||
case "message":
|
case "message":
|
||||||
for _, content := range request.Item.Content {
|
for _, content := range request.Item.Content {
|
||||||
if content.Type == "input_text" {
|
if content.Type == "input_text" {
|
||||||
tokens, err := CountTextToken(content.Text, model)
|
tokens := CountTextToken(content.Text, model)
|
||||||
if err != nil {
|
|
||||||
return 0, 0, err
|
|
||||||
}
|
|
||||||
textToken += tokens
|
textToken += tokens
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -343,10 +334,7 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
|||||||
if !info.IsFirstRequest {
|
if !info.IsFirstRequest {
|
||||||
if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
|
if info.RealtimeTools != nil && len(info.RealtimeTools) > 0 {
|
||||||
for _, tool := range info.RealtimeTools {
|
for _, tool := range info.RealtimeTools {
|
||||||
toolTokens, err := CountTokenInput(tool, model)
|
toolTokens := CountTokenInput(tool, model)
|
||||||
if err != nil {
|
|
||||||
return 0, 0, err
|
|
||||||
}
|
|
||||||
textToken += 8
|
textToken += 8
|
||||||
textToken += toolTokens
|
textToken += toolTokens
|
||||||
}
|
}
|
||||||
@@ -409,7 +397,7 @@ func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, mod
|
|||||||
return tokenNum, nil
|
return tokenNum, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CountTokenInput(input any, model string) (int, error) {
|
func CountTokenInput(input any, model string) int {
|
||||||
switch v := input.(type) {
|
switch v := input.(type) {
|
||||||
case string:
|
case string:
|
||||||
return CountTextToken(v, model)
|
return CountTextToken(v, model)
|
||||||
@@ -432,13 +420,13 @@ func CountTokenInput(input any, model string) (int, error) {
|
|||||||
func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
|
func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice, model string) int {
|
||||||
tokens := 0
|
tokens := 0
|
||||||
for _, message := range messages {
|
for _, message := range messages {
|
||||||
tkm, _ := CountTokenInput(message.Delta.GetContentString(), model)
|
tkm := CountTokenInput(message.Delta.GetContentString(), model)
|
||||||
tokens += tkm
|
tokens += tkm
|
||||||
if message.Delta.ToolCalls != nil {
|
if message.Delta.ToolCalls != nil {
|
||||||
for _, tool := range message.Delta.ToolCalls {
|
for _, tool := range message.Delta.ToolCalls {
|
||||||
tkm, _ := CountTokenInput(tool.Function.Name, model)
|
tkm := CountTokenInput(tool.Function.Name, model)
|
||||||
tokens += tkm
|
tokens += tkm
|
||||||
tkm, _ = CountTokenInput(tool.Function.Arguments, model)
|
tkm = CountTokenInput(tool.Function.Arguments, model)
|
||||||
tokens += tkm
|
tokens += tkm
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -446,9 +434,9 @@ func CountTokenStreamChoices(messages []dto.ChatCompletionsStreamResponseChoice,
|
|||||||
return tokens
|
return tokens
|
||||||
}
|
}
|
||||||
|
|
||||||
func CountTTSToken(text string, model string) (int, error) {
|
func CountTTSToken(text string, model string) int {
|
||||||
if strings.HasPrefix(model, "tts") {
|
if strings.HasPrefix(model, "tts") {
|
||||||
return utf8.RuneCountInString(text), nil
|
return utf8.RuneCountInString(text)
|
||||||
} else {
|
} else {
|
||||||
return CountTextToken(text, model)
|
return CountTextToken(text, model)
|
||||||
}
|
}
|
||||||
@@ -483,8 +471,10 @@ func CountAudioTokenOutput(audioBase64 string, audioFormat string) (int, error)
|
|||||||
//}
|
//}
|
||||||
|
|
||||||
// CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
|
// CountTextToken 统计文本的token数量,仅当文本包含敏感词,返回错误,同时返回token数量
|
||||||
func CountTextToken(text string, model string) (int, error) {
|
func CountTextToken(text string, model string) int {
|
||||||
var err error
|
if text == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
tokenEncoder := getTokenEncoder(model)
|
tokenEncoder := getTokenEncoder(model)
|
||||||
return getTokenNum(tokenEncoder, text), err
|
return getTokenNum(tokenEncoder, text)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -16,13 +16,13 @@ import (
|
|||||||
// return 0, errors.New("unknown relay mode")
|
// return 0, errors.New("unknown relay mode")
|
||||||
//}
|
//}
|
||||||
|
|
||||||
func ResponseText2Usage(responseText string, modeName string, promptTokens int) (*dto.Usage, error) {
|
func ResponseText2Usage(responseText string, modeName string, promptTokens int) *dto.Usage {
|
||||||
usage := &dto.Usage{}
|
usage := &dto.Usage{}
|
||||||
usage.PromptTokens = promptTokens
|
usage.PromptTokens = promptTokens
|
||||||
ctkm, err := CountTextToken(responseText, modeName)
|
ctkm := CountTextToken(responseText, modeName)
|
||||||
usage.CompletionTokens = ctkm
|
usage.CompletionTokens = ctkm
|
||||||
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens
|
||||||
return usage, err
|
return usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func ValidUsage(usage *dto.Usage) bool {
|
func ValidUsage(usage *dto.Usage) bool {
|
||||||
|
|||||||
Reference in New Issue
Block a user