From c834694992f2fd68c9117d72d47b240bd2bc2e57 Mon Sep 17 00:00:00 2001 From: CaIon Date: Sat, 16 Aug 2025 19:11:15 +0800 Subject: [PATCH] fix: update token usage calculation --- model/log.go | 17 +++++++---------- model/usedata.go | 6 ------ relay/channel/openai/relay_responses.go | 10 ++++++---- 3 files changed, 13 insertions(+), 20 deletions(-) diff --git a/model/log.go b/model/log.go index e443516d..979cbe7b 100644 --- a/model/log.go +++ b/model/log.go @@ -5,6 +5,7 @@ import ( "fmt" "one-api/common" "one-api/logger" + "one-api/types" "os" "strings" "time" @@ -150,10 +151,10 @@ type RecordConsumeLogParams struct { } func RecordConsumeLog(c *gin.Context, userId int, params RecordConsumeLogParams) { - logger.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params))) if !common.LogConsumeEnabled { return } + logger.LogInfo(c, fmt.Sprintf("record consume log: userId=%d, params=%s", userId, common.GetJsonString(params))) username := c.GetString("username") otherStr := common.MapToJsonStr(params.Other) // 判断是否需要记录 IP @@ -236,26 +237,22 @@ func GetAllLogs(logType int, startTimestamp int64, endTimestamp int64, modelName return nil, 0, err } - channelIdsMap := make(map[int]struct{}) - channelMap := make(map[int]string) + channelIds := types.NewSet[int]() for _, log := range logs { if log.ChannelId != 0 { - channelIdsMap[log.ChannelId] = struct{}{} + channelIds.Add(log.ChannelId) } } - channelIds := make([]int, 0, len(channelIdsMap)) - for channelId := range channelIdsMap { - channelIds = append(channelIds, channelId) - } - if len(channelIds) > 0 { + if channelIds.Len() > 0 { var channels []struct { Id int `gorm:"column:id"` Name string `gorm:"column:name"` } - if err = DB.Table("channels").Select("id, name").Where("id IN ?", channelIds).Find(&channels).Error; err != nil { + if err = DB.Table("channels").Select("id, name").Where("id IN ?", channelIds.Items()).Find(&channels).Error; err != nil { return logs, total, err } + channelMap := make(map[int]string, len(channels)) for _, channel := range channels { channelMap[channel.Id] = channel.Name } diff --git a/model/usedata.go b/model/usedata.go index 1255b0be..7e525d2e 100644 --- a/model/usedata.go +++ b/model/usedata.go @@ -21,12 +21,6 @@ type QuotaData struct { } func UpdateQuotaData() { - // recover - defer func() { - if r := recover(); r != nil { - common.SysLog(fmt.Sprintf("UpdateQuotaData panic: %s", r)) - } - }() for { if common.DataExportEnabled { common.SysLog("正在更新数据看板数据...") diff --git a/relay/channel/openai/relay_responses.go b/relay/channel/openai/relay_responses.go index c5ff6d24..ab2aa8a4 100644 --- a/relay/channel/openai/relay_responses.go +++ b/relay/channel/openai/relay_responses.go @@ -103,12 +103,14 @@ func OaiResponsesStreamHandler(c *gin.Context, info *relaycommon.RelayInfo, resp // 非正常结束,使用输出文本的 token 数量 completionTokens := service.CountTextToken(tempStr, info.UpstreamModelName) usage.CompletionTokens = completionTokens - - if usage.PromptTokens == 0 { - usage.PromptTokens = info.PromptTokens - } } } + if usage.PromptTokens == 0 && usage.CompletionTokens != 0 { + usage.PromptTokens = usage.CompletionTokens + } else { + usage.TotalTokens = usage.PromptTokens + usage.CompletionTokens + } + return usage, nil }