Merge remote-tracking branch 'origin/alpha' into fix/openrouter-custom-ratio-billing
This commit is contained in:
@@ -85,7 +85,7 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *t
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
CloseResponseBodyGracefully(resp)
|
||||
var errResponse dto.GeneralErrorResponse
|
||||
|
||||
err = common.Unmarshal(responseBody, &errResponse)
|
||||
|
||||
59
service/http.go
Normal file
59
service/http.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/logger"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func CloseResponseBodyGracefully(httpResponse *http.Response) {
|
||||
if httpResponse == nil || httpResponse.Body == nil {
|
||||
return
|
||||
}
|
||||
err := httpResponse.Body.Close()
|
||||
if err != nil {
|
||||
common.SysError("failed to close response body: " + err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func IOCopyBytesGracefully(c *gin.Context, src *http.Response, data []byte) {
|
||||
if c.Writer == nil {
|
||||
return
|
||||
}
|
||||
|
||||
body := io.NopCloser(bytes.NewBuffer(data))
|
||||
|
||||
// We shouldn't set the header before we parse the response body, because the parse part may fail.
|
||||
// And then we will have to send an error response, but in this case, the header has already been set.
|
||||
// So the httpClient will be confused by the response.
|
||||
// For example, Postman will report error, and we cannot check the response at all.
|
||||
if src != nil {
|
||||
for k, v := range src.Header {
|
||||
// avoid setting Content-Length
|
||||
if k == "Content-Length" {
|
||||
continue
|
||||
}
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
}
|
||||
|
||||
// set Content-Length header manually BEFORE calling WriteHeader
|
||||
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
|
||||
|
||||
// Write header with status code (this sends the headers)
|
||||
if src != nil {
|
||||
c.Writer.WriteHeader(src.StatusCode)
|
||||
} else {
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
_, err := io.Copy(c.Writer, body)
|
||||
if err != nil {
|
||||
logger.LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error()))
|
||||
}
|
||||
}
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -78,7 +78,7 @@ func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
return info
|
||||
}
|
||||
|
||||
func GenerateMjOtherInfo(priceData helper.PerCallPriceData) map[string]interface{} {
|
||||
func GenerateMjOtherInfo(priceData types.PerCallPriceData) map[string]interface{} {
|
||||
other := make(map[string]interface{})
|
||||
other["model_price"] = priceData.ModelPrice
|
||||
other["group_ratio"] = priceData.GroupRatioInfo.GroupRatio
|
||||
|
||||
@@ -212,7 +212,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
|
||||
defer cancel()
|
||||
resp, err := GetHttpClient().Do(req)
|
||||
if err != nil {
|
||||
common.SysError("do request failed: " + err.Error())
|
||||
common.SysLog("do request failed: " + err.Error())
|
||||
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err
|
||||
}
|
||||
statusCode := resp.StatusCode
|
||||
@@ -233,7 +233,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
|
||||
if err != nil {
|
||||
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
|
||||
}
|
||||
common.CloseResponseBodyGracefully(resp)
|
||||
CloseResponseBodyGracefully(resp)
|
||||
respStr := string(responseBody)
|
||||
log.Printf("respStr: %s", respStr)
|
||||
if respStr == "" {
|
||||
|
||||
79
service/pre_consume_quota.go
Normal file
79
service/pre_consume_quota.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/logger"
|
||||
"one-api/model"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/types"
|
||||
|
||||
"github.com/bytedance/gopkg/util/gopool"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, preConsumedQuota int) {
|
||||
if preConsumedQuota != 0 {
|
||||
logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费额度 %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota)))
|
||||
gopool.Go(func() {
|
||||
relayInfoCopy := *relayInfo
|
||||
|
||||
err := PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false)
|
||||
if err != nil {
|
||||
common.SysLog("error return pre-consumed quota: " + err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// PreConsumeQuota checks if the user has enough quota to pre-consume.
|
||||
// It returns the pre-consumed quota if successful, or an error if not.
|
||||
func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, *types.NewAPIError) {
|
||||
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
||||
if err != nil {
|
||||
return 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
if userQuota <= 0 {
|
||||
return 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
}
|
||||
if userQuota-preConsumedQuota < 0 {
|
||||
return 0, types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
}
|
||||
|
||||
trustQuota := common.GetTrustQuota()
|
||||
|
||||
relayInfo.UserQuota = userQuota
|
||||
if userQuota > trustQuota {
|
||||
// 用户额度充足,判断令牌额度是否充足
|
||||
if !relayInfo.TokenUnlimited {
|
||||
// 非无限令牌,判断令牌额度是否充足
|
||||
tokenQuota := c.GetInt("token_quota")
|
||||
if tokenQuota > trustQuota {
|
||||
// 令牌额度充足,信任令牌
|
||||
preConsumedQuota = 0
|
||||
logger.LogInfo(c, fmt.Sprintf("用户 %d 剩余额度 %s 且令牌 %d 额度 %d 充足, 信任且不需要预扣费", relayInfo.UserId, logger.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota))
|
||||
}
|
||||
} else {
|
||||
// in this case, we do not pre-consume quota
|
||||
// because the user has enough quota
|
||||
preConsumedQuota = 0
|
||||
logger.LogInfo(c, fmt.Sprintf("用户 %d 额度充足且为无限额度令牌, 信任且不需要预扣费", relayInfo.UserId))
|
||||
}
|
||||
}
|
||||
|
||||
if preConsumedQuota > 0 {
|
||||
err := PreConsumeTokenQuota(relayInfo, preConsumedQuota)
|
||||
if err != nil {
|
||||
return 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
|
||||
}
|
||||
err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
|
||||
if err != nil {
|
||||
return 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
|
||||
}
|
||||
logger.LogInfo(c, fmt.Sprintf("用户 %d 预扣费 %s, 预扣费后剩余额度: %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota), logger.FormatQuota(userQuota-preConsumedQuota)))
|
||||
}
|
||||
relayInfo.FinalPreConsumedQuota = preConsumedQuota
|
||||
return preConsumedQuota, nil
|
||||
}
|
||||
122
service/quota.go
122
service/quota.go
@@ -8,11 +8,12 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
"one-api/model"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/relay/helper"
|
||||
"one-api/setting"
|
||||
"one-api/setting/ratio_setting"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -137,23 +138,23 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
|
||||
quota := calculateAudioQuota(quotaInfo)
|
||||
|
||||
if userQuota < quota {
|
||||
return fmt.Errorf("user quota is not enough, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota))
|
||||
return fmt.Errorf("user quota is not enough, user quota: %s, need quota: %s", logger.FormatQuota(userQuota), logger.FormatQuota(quota))
|
||||
}
|
||||
|
||||
if !token.UnlimitedQuota && token.RemainQuota < quota {
|
||||
return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota))
|
||||
return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", logger.FormatQuota(token.RemainQuota), logger.FormatQuota(quota))
|
||||
}
|
||||
|
||||
err = PostConsumeQuota(relayInfo, quota, 0, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
common.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota))
|
||||
logger.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota))
|
||||
return nil
|
||||
}
|
||||
|
||||
func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
|
||||
usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
|
||||
usage *dto.RealtimeUsage, extraContent string) {
|
||||
|
||||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||
textInputTokens := usage.InputTokenDetails.TextTokens
|
||||
@@ -167,10 +168,10 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
|
||||
audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName))
|
||||
audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(modelName))
|
||||
|
||||
modelRatio := priceData.ModelRatio
|
||||
groupRatio := priceData.GroupRatioInfo.GroupRatio
|
||||
modelPrice := priceData.ModelPrice
|
||||
usePrice := priceData.UsePrice
|
||||
modelRatio := relayInfo.PriceData.ModelRatio
|
||||
groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
|
||||
modelPrice := relayInfo.PriceData.ModelPrice
|
||||
usePrice := relayInfo.PriceData.UsePrice
|
||||
|
||||
quotaInfo := QuotaInfo{
|
||||
InputDetails: TokenDetails{
|
||||
@@ -204,8 +205,8 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
|
||||
// we cannot just return, because we may have to return the pre-consumed quota
|
||||
quota = 0
|
||||
logContent += fmt.Sprintf("(可能是上游超时)")
|
||||
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
|
||||
"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
|
||||
logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
|
||||
"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
|
||||
} else {
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
@@ -216,7 +217,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
|
||||
logContent += ", " + extraContent
|
||||
}
|
||||
other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
|
||||
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: relayInfo.ChannelId,
|
||||
PromptTokens: usage.InputTokens,
|
||||
@@ -226,7 +227,6 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
|
||||
Quota: quota,
|
||||
Content: logContent,
|
||||
TokenId: relayInfo.TokenId,
|
||||
UserQuota: userQuota,
|
||||
UseTimeSeconds: int(useTimeSeconds),
|
||||
IsStream: relayInfo.IsStream,
|
||||
Group: relayInfo.UsingGroup,
|
||||
@@ -234,8 +234,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
|
||||
})
|
||||
}
|
||||
|
||||
func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
|
||||
func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage) {
|
||||
|
||||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||
promptTokens := usage.PromptTokens
|
||||
@@ -243,21 +242,21 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
modelName := relayInfo.OriginModelName
|
||||
|
||||
tokenName := ctx.GetString("token_name")
|
||||
completionRatio := priceData.CompletionRatio
|
||||
modelRatio := priceData.ModelRatio
|
||||
groupRatio := priceData.GroupRatioInfo.GroupRatio
|
||||
modelPrice := priceData.ModelPrice
|
||||
cacheRatio := priceData.CacheRatio
|
||||
completionRatio := relayInfo.PriceData.CompletionRatio
|
||||
modelRatio := relayInfo.PriceData.ModelRatio
|
||||
groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
|
||||
modelPrice := relayInfo.PriceData.ModelPrice
|
||||
cacheRatio := relayInfo.PriceData.CacheRatio
|
||||
cacheTokens := usage.PromptTokensDetails.CachedTokens
|
||||
|
||||
cacheCreationRatio := priceData.CacheCreationRatio
|
||||
cacheCreationRatio := relayInfo.PriceData.CacheCreationRatio
|
||||
cacheCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
|
||||
|
||||
if relayInfo.ChannelType == constant.ChannelTypeOpenRouter {
|
||||
promptTokens -= cacheTokens
|
||||
isUsingCustomSettings := priceData.UsePrice || hasCustomModelRatio(modelName, priceData.ModelRatio)
|
||||
if cacheCreationTokens == 0 && priceData.CacheCreationRatio != 1 && usage.Cost != 0 && !isUsingCustomSettings {
|
||||
maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, priceData)
|
||||
isUsingCustomSettings := relayInfo.PriceData.UsePrice || hasCustomModelRatio(modelName, relayInfo.PriceData.ModelRatio)
|
||||
if cacheCreationTokens == 0 && relayInfo.PriceData.CacheCreationRatio != 1 && usage.Cost != 0 && !isUsingCustomSettings {
|
||||
maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, relayInfo.PriceData)
|
||||
if maybeCacheCreationTokens >= 0 && promptTokens >= maybeCacheCreationTokens {
|
||||
cacheCreationTokens = maybeCacheCreationTokens
|
||||
}
|
||||
@@ -266,7 +265,7 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
}
|
||||
|
||||
calculateQuota := 0.0
|
||||
if !priceData.UsePrice {
|
||||
if !relayInfo.PriceData.UsePrice {
|
||||
calculateQuota = float64(promptTokens)
|
||||
calculateQuota += float64(cacheTokens) * cacheRatio
|
||||
calculateQuota += float64(cacheCreationTokens) * cacheCreationRatio
|
||||
@@ -291,23 +290,38 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
// we cannot just return, because we may have to return the pre-consumed quota
|
||||
quota = 0
|
||||
logContent += fmt.Sprintf("(可能是上游出错)")
|
||||
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
|
||||
"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
|
||||
logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
|
||||
"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
|
||||
} else {
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
}
|
||||
|
||||
quotaDelta := quota - preConsumedQuota
|
||||
quotaDelta := quota - relayInfo.FinalPreConsumedQuota
|
||||
|
||||
if quotaDelta > 0 {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("预扣费后补扣费:%s(实际消耗:%s,预扣费:%s)",
|
||||
logger.FormatQuota(quotaDelta),
|
||||
logger.FormatQuota(quota),
|
||||
logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
|
||||
))
|
||||
} else if quotaDelta < 0 {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("预扣费后返还扣费:%s(实际消耗:%s,预扣费:%s)",
|
||||
logger.FormatQuota(-quotaDelta),
|
||||
logger.FormatQuota(quota),
|
||||
logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
|
||||
))
|
||||
}
|
||||
|
||||
if quotaDelta != 0 {
|
||||
err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
|
||||
err := PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||
logger.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio,
|
||||
cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: relayInfo.ChannelId,
|
||||
PromptTokens: promptTokens,
|
||||
@@ -317,7 +331,6 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
Quota: quota,
|
||||
Content: logContent,
|
||||
TokenId: relayInfo.TokenId,
|
||||
UserQuota: userQuota,
|
||||
UseTimeSeconds: int(useTimeSeconds),
|
||||
IsStream: relayInfo.IsStream,
|
||||
Group: relayInfo.UsingGroup,
|
||||
@@ -326,7 +339,7 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
|
||||
}
|
||||
|
||||
func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData helper.PriceData) int {
|
||||
func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData types.PriceData) int {
|
||||
if priceData.CacheCreationRatio == 1 {
|
||||
return 0
|
||||
}
|
||||
@@ -347,8 +360,7 @@ func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData helper.PriceData
|
||||
(promptCacheCreatePrice - quotaPrice)))
|
||||
}
|
||||
|
||||
func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
|
||||
func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) {
|
||||
|
||||
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
|
||||
textInputTokens := usage.PromptTokensDetails.TextTokens
|
||||
@@ -362,10 +374,10 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName))
|
||||
audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(relayInfo.OriginModelName))
|
||||
|
||||
modelRatio := priceData.ModelRatio
|
||||
groupRatio := priceData.GroupRatioInfo.GroupRatio
|
||||
modelPrice := priceData.ModelPrice
|
||||
usePrice := priceData.UsePrice
|
||||
modelRatio := relayInfo.PriceData.ModelRatio
|
||||
groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
|
||||
modelPrice := relayInfo.PriceData.ModelPrice
|
||||
usePrice := relayInfo.PriceData.UsePrice
|
||||
|
||||
quotaInfo := QuotaInfo{
|
||||
InputDetails: TokenDetails{
|
||||
@@ -399,18 +411,33 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
// we cannot just return, because we may have to return the pre-consumed quota
|
||||
quota = 0
|
||||
logContent += fmt.Sprintf("(可能是上游超时)")
|
||||
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
|
||||
"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, preConsumedQuota))
|
||||
logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
|
||||
"tokenId %d, model %s, pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, relayInfo.FinalPreConsumedQuota))
|
||||
} else {
|
||||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
||||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
||||
}
|
||||
|
||||
quotaDelta := quota - preConsumedQuota
|
||||
quotaDelta := quota - relayInfo.FinalPreConsumedQuota
|
||||
|
||||
if quotaDelta > 0 {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("预扣费后补扣费:%s(实际消耗:%s,预扣费:%s)",
|
||||
logger.FormatQuota(quotaDelta),
|
||||
logger.FormatQuota(quota),
|
||||
logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
|
||||
))
|
||||
} else if quotaDelta < 0 {
|
||||
logger.LogInfo(ctx, fmt.Sprintf("预扣费后返还扣费:%s(实际消耗:%s,预扣费:%s)",
|
||||
logger.FormatQuota(-quotaDelta),
|
||||
logger.FormatQuota(quota),
|
||||
logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
|
||||
))
|
||||
}
|
||||
|
||||
if quotaDelta != 0 {
|
||||
err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
|
||||
err := PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||
logger.LogError(ctx, "error consuming token remain quota: "+err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -419,7 +446,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
logContent += ", " + extraContent
|
||||
}
|
||||
other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
|
||||
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
|
||||
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
|
||||
ChannelId: relayInfo.ChannelId,
|
||||
PromptTokens: usage.PromptTokens,
|
||||
@@ -429,7 +456,6 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
|
||||
Quota: quota,
|
||||
Content: logContent,
|
||||
TokenId: relayInfo.TokenId,
|
||||
UserQuota: userQuota,
|
||||
UseTimeSeconds: int(useTimeSeconds),
|
||||
IsStream: relayInfo.IsStream,
|
||||
Group: relayInfo.UsingGroup,
|
||||
@@ -452,7 +478,7 @@ func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
|
||||
return err
|
||||
}
|
||||
if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
|
||||
return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota))
|
||||
return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", logger.FormatQuota(token.RemainQuota), logger.FormatQuota(quota))
|
||||
}
|
||||
err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
|
||||
if err != nil {
|
||||
@@ -510,7 +536,7 @@ func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preCon
|
||||
prompt := "您的额度即将用尽"
|
||||
topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
|
||||
content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。<br/>充值链接:<a href='{{value}}'>{{value}}</a>"
|
||||
err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink}))
|
||||
err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, logger.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink}))
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", relayInfo.UserId, err.Error()))
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/dto"
|
||||
"one-api/setting"
|
||||
"strings"
|
||||
@@ -32,25 +31,8 @@ func CheckSensitiveMessages(messages []dto.Message) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func CheckSensitiveText(text string) ([]string, error) {
|
||||
if ok, words := SensitiveWordContains(text); ok {
|
||||
return words, errors.New("sensitive words detected")
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func CheckSensitiveInput(input any) ([]string, error) {
|
||||
switch v := input.(type) {
|
||||
case string:
|
||||
return CheckSensitiveText(v)
|
||||
case []string:
|
||||
var builder strings.Builder
|
||||
for _, s := range v {
|
||||
builder.WriteString(s)
|
||||
}
|
||||
return CheckSensitiveText(builder.String())
|
||||
}
|
||||
return CheckSensitiveText(fmt.Sprintf("%v", input))
|
||||
func CheckSensitiveText(text string) (bool, []string) {
|
||||
return SensitiveWordContains(text)
|
||||
}
|
||||
|
||||
// SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表
|
||||
|
||||
@@ -4,8 +4,6 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/tiktoken-go/tokenizer"
|
||||
"github.com/tiktoken-go/tokenizer/codec"
|
||||
"image"
|
||||
"log"
|
||||
"math"
|
||||
@@ -13,9 +11,14 @@ import (
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"sync"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tiktoken-go/tokenizer"
|
||||
"github.com/tiktoken-go/tokenizer/codec"
|
||||
)
|
||||
|
||||
// tokenEncoderMap won't grow after initialization
|
||||
@@ -72,52 +75,95 @@ func getTokenNum(tokenEncoder tokenizer.Codec, text string) int {
|
||||
return tkm
|
||||
}
|
||||
|
||||
func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
|
||||
if imageUrl == nil {
|
||||
func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, error) {
|
||||
if fileMeta == nil {
|
||||
return 0, fmt.Errorf("image_url_is_nil")
|
||||
}
|
||||
|
||||
// Defaults for 4o/4.1/4.5 family unless overridden below
|
||||
baseTokens := 85
|
||||
if model == "glm-4v" {
|
||||
tileTokens := 170
|
||||
|
||||
// Model classification
|
||||
lowerModel := strings.ToLower(model)
|
||||
|
||||
// Special cases from existing behavior
|
||||
if strings.HasPrefix(lowerModel, "glm-4") {
|
||||
return 1047, nil
|
||||
}
|
||||
if imageUrl.Detail == "low" {
|
||||
|
||||
// Patch-based models (32x32 patches, capped at 1536, with multiplier)
|
||||
isPatchBased := false
|
||||
multiplier := 1.0
|
||||
switch {
|
||||
case strings.Contains(lowerModel, "gpt-4.1-mini"):
|
||||
isPatchBased = true
|
||||
multiplier = 1.62
|
||||
case strings.Contains(lowerModel, "gpt-4.1-nano"):
|
||||
isPatchBased = true
|
||||
multiplier = 2.46
|
||||
case strings.HasPrefix(lowerModel, "o4-mini"):
|
||||
isPatchBased = true
|
||||
multiplier = 1.72
|
||||
case strings.HasPrefix(lowerModel, "gpt-5-mini"):
|
||||
isPatchBased = true
|
||||
multiplier = 1.62
|
||||
case strings.HasPrefix(lowerModel, "gpt-5-nano"):
|
||||
isPatchBased = true
|
||||
multiplier = 2.46
|
||||
}
|
||||
|
||||
// Tile-based model tokens and bases per doc
|
||||
if !isPatchBased {
|
||||
if strings.HasPrefix(lowerModel, "gpt-4o-mini") {
|
||||
baseTokens = 2833
|
||||
tileTokens = 5667
|
||||
} else if strings.HasPrefix(lowerModel, "gpt-5-chat-latest") || (strings.HasPrefix(lowerModel, "gpt-5") && !strings.Contains(lowerModel, "mini") && !strings.Contains(lowerModel, "nano")) {
|
||||
baseTokens = 70
|
||||
tileTokens = 140
|
||||
} else if strings.HasPrefix(lowerModel, "o1") || strings.HasPrefix(lowerModel, "o3") || strings.HasPrefix(lowerModel, "o1-pro") {
|
||||
baseTokens = 75
|
||||
tileTokens = 150
|
||||
} else if strings.Contains(lowerModel, "computer-use-preview") {
|
||||
baseTokens = 65
|
||||
tileTokens = 129
|
||||
} else if strings.Contains(lowerModel, "4.1") || strings.Contains(lowerModel, "4o") || strings.Contains(lowerModel, "4.5") {
|
||||
baseTokens = 85
|
||||
tileTokens = 170
|
||||
}
|
||||
}
|
||||
|
||||
// Respect existing feature flags/short-circuits
|
||||
if fileMeta.Detail == "low" && !isPatchBased {
|
||||
return baseTokens, nil
|
||||
}
|
||||
if !constant.GetMediaTokenNotStream && !stream {
|
||||
return 3 * baseTokens, nil
|
||||
}
|
||||
|
||||
// 同步One API的图片计费逻辑
|
||||
if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
|
||||
imageUrl.Detail = "high"
|
||||
// Normalize detail
|
||||
if fileMeta.Detail == "auto" || fileMeta.Detail == "" {
|
||||
fileMeta.Detail = "high"
|
||||
}
|
||||
|
||||
tileTokens := 170
|
||||
if strings.HasPrefix(model, "gpt-4o-mini") {
|
||||
tileTokens = 5667
|
||||
baseTokens = 2833
|
||||
}
|
||||
// 是否统计图片token
|
||||
// Whether to count image tokens at all
|
||||
if !constant.GetMediaToken {
|
||||
return 3 * baseTokens, nil
|
||||
}
|
||||
if info.ChannelType == constant.ChannelTypeGemini || info.ChannelType == constant.ChannelTypeVertexAi || info.ChannelType == constant.ChannelTypeAnthropic {
|
||||
return 3 * baseTokens, nil
|
||||
}
|
||||
|
||||
// Decode image to get dimensions
|
||||
var config image.Config
|
||||
var err error
|
||||
var format string
|
||||
var b64str string
|
||||
if strings.HasPrefix(imageUrl.Url, "http") {
|
||||
config, format, err = DecodeUrlImageData(imageUrl.Url)
|
||||
if strings.HasPrefix(fileMeta.Data, "http") {
|
||||
config, format, err = DecodeUrlImageData(fileMeta.Data)
|
||||
} else {
|
||||
common.SysLog(fmt.Sprintf("decoding image"))
|
||||
config, format, b64str, err = DecodeBase64ImageData(imageUrl.Url)
|
||||
config, format, b64str, err = DecodeBase64ImageData(fileMeta.Data)
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
imageUrl.MimeType = format
|
||||
fileMeta.MimeType = format
|
||||
|
||||
if config.Width == 0 || config.Height == 0 {
|
||||
// not an image
|
||||
@@ -125,60 +171,155 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m
|
||||
// file type
|
||||
return 3 * baseTokens, nil
|
||||
}
|
||||
return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", imageUrl.Url))
|
||||
return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", fileMeta.Data))
|
||||
}
|
||||
|
||||
shortSide := config.Width
|
||||
otherSide := config.Height
|
||||
log.Printf("format: %s, width: %d, height: %d", format, config.Width, config.Height)
|
||||
// 缩放倍数
|
||||
scale := 1.0
|
||||
if config.Height < shortSide {
|
||||
shortSide = config.Height
|
||||
otherSide = config.Width
|
||||
width := config.Width
|
||||
height := config.Height
|
||||
log.Printf("format: %s, width: %d, height: %d", format, width, height)
|
||||
|
||||
if isPatchBased {
|
||||
// 32x32 patch-based calculation with 1536 cap and model multiplier
|
||||
ceilDiv := func(a, b int) int { return (a + b - 1) / b }
|
||||
rawPatchesW := ceilDiv(width, 32)
|
||||
rawPatchesH := ceilDiv(height, 32)
|
||||
rawPatches := rawPatchesW * rawPatchesH
|
||||
if rawPatches > 1536 {
|
||||
// scale down
|
||||
area := float64(width * height)
|
||||
r := math.Sqrt(float64(32*32*1536) / area)
|
||||
wScaled := float64(width) * r
|
||||
hScaled := float64(height) * r
|
||||
// adjust to fit whole number of patches after scaling
|
||||
adjW := math.Floor(wScaled/32.0) / (wScaled / 32.0)
|
||||
adjH := math.Floor(hScaled/32.0) / (hScaled / 32.0)
|
||||
adj := math.Min(adjW, adjH)
|
||||
if !math.IsNaN(adj) && adj > 0 {
|
||||
r = r * adj
|
||||
}
|
||||
wScaled = float64(width) * r
|
||||
hScaled = float64(height) * r
|
||||
patchesW := math.Ceil(wScaled / 32.0)
|
||||
patchesH := math.Ceil(hScaled / 32.0)
|
||||
imageTokens := int(patchesW * patchesH)
|
||||
if imageTokens > 1536 {
|
||||
imageTokens = 1536
|
||||
}
|
||||
return int(math.Round(float64(imageTokens) * multiplier)), nil
|
||||
}
|
||||
// below cap
|
||||
imageTokens := rawPatches
|
||||
return int(math.Round(float64(imageTokens) * multiplier)), nil
|
||||
}
|
||||
|
||||
// 将最小变的尺寸缩小到768以下,如果大于768,则缩放到768
|
||||
if shortSide > 768 {
|
||||
scale = float64(shortSide) / 768
|
||||
shortSide = 768
|
||||
// Tile-based calculation for 4o/4.1/4.5/o1/o3/etc.
|
||||
// Step 1: fit within 2048x2048 square
|
||||
maxSide := math.Max(float64(width), float64(height))
|
||||
fitScale := 1.0
|
||||
if maxSide > 2048 {
|
||||
fitScale = maxSide / 2048.0
|
||||
}
|
||||
// 将另一边按照相同的比例缩小,向上取整
|
||||
otherSide = int(math.Ceil(float64(otherSide) / scale))
|
||||
log.Printf("shortSide: %d, otherSide: %d, scale: %f", shortSide, otherSide, scale)
|
||||
// 计算图片的token数量(边的长度除以512,向上取整)
|
||||
tiles := (shortSide + 511) / 512 * ((otherSide + 511) / 512)
|
||||
log.Printf("tiles: %d", tiles)
|
||||
fitW := int(math.Round(float64(width) / fitScale))
|
||||
fitH := int(math.Round(float64(height) / fitScale))
|
||||
|
||||
// Step 2: scale so that shortest side is exactly 768
|
||||
minSide := math.Min(float64(fitW), float64(fitH))
|
||||
if minSide == 0 {
|
||||
return baseTokens, nil
|
||||
}
|
||||
shortScale := 768.0 / minSide
|
||||
finalW := int(math.Round(float64(fitW) * shortScale))
|
||||
finalH := int(math.Round(float64(fitH) * shortScale))
|
||||
|
||||
// Count 512px tiles
|
||||
tilesW := (finalW + 512 - 1) / 512
|
||||
tilesH := (finalH + 512 - 1) / 512
|
||||
tiles := tilesW * tilesH
|
||||
|
||||
if common.DebugEnabled {
|
||||
log.Printf("scaled to: %dx%d, tiles: %d", finalW, finalH, tiles)
|
||||
}
|
||||
|
||||
return tiles*tileTokens + baseTokens, nil
|
||||
}
|
||||
|
||||
func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) (int, error) {
|
||||
tkm := 0
|
||||
msgTokens, err := CountTokenMessages(info, request.Messages, request.Model, request.Stream)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
tkm += msgTokens
|
||||
if request.Tools != nil {
|
||||
openaiTools := request.Tools
|
||||
countStr := ""
|
||||
for _, tool := range openaiTools {
|
||||
countStr = tool.Function.Name
|
||||
if tool.Function.Description != "" {
|
||||
countStr += tool.Function.Description
|
||||
}
|
||||
if tool.Function.Parameters != nil {
|
||||
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
|
||||
}
|
||||
}
|
||||
toolTokens := CountTokenInput(countStr, request.Model)
|
||||
tkm += 8
|
||||
tkm += toolTokens
|
||||
func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) {
|
||||
if meta == nil {
|
||||
return 0, errors.New("token count meta is nil")
|
||||
}
|
||||
|
||||
if info.RelayFormat == types.RelayFormatOpenAIRealtime {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
model := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
|
||||
tkm := 0
|
||||
|
||||
if meta.TokenType == types.TokenTypeTextNumber {
|
||||
tkm += utf8.RuneCountInString(meta.CombineText)
|
||||
} else {
|
||||
tkm += CountTextToken(meta.CombineText, model)
|
||||
}
|
||||
|
||||
if info.RelayFormat == types.RelayFormatOpenAI {
|
||||
tkm += meta.ToolsCount * 8
|
||||
tkm += meta.MessagesCount * 3 // 每条消息的格式化token数量
|
||||
tkm += meta.NameCount * 3
|
||||
tkm += 3
|
||||
}
|
||||
|
||||
for _, file := range meta.Files {
|
||||
switch file.FileType {
|
||||
case types.FileTypeImage:
|
||||
if info.RelayFormat == types.RelayFormatGemini {
|
||||
tkm += 240
|
||||
} else {
|
||||
token, err := getImageToken(file, model, info.IsStream)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error counting image token: %v", err)
|
||||
}
|
||||
tkm += token
|
||||
}
|
||||
case types.FileTypeAudio:
|
||||
tkm += 100
|
||||
case types.FileTypeVideo:
|
||||
tkm += 5000
|
||||
case types.FileTypeFile:
|
||||
tkm += 5000
|
||||
}
|
||||
}
|
||||
|
||||
common.SetContextKey(c, constant.ContextKeyPromptTokens, tkm)
|
||||
return tkm, nil
|
||||
}
|
||||
|
||||
//func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) (int, error) {
|
||||
// tkm := 0
|
||||
// msgTokens, err := CountTokenMessages(info, request.Messages, request.Model, request.Stream)
|
||||
// if err != nil {
|
||||
// return 0, err
|
||||
// }
|
||||
// tkm += msgTokens
|
||||
// if request.Tools != nil {
|
||||
// openaiTools := request.Tools
|
||||
// countStr := ""
|
||||
// for _, tool := range openaiTools {
|
||||
// countStr = tool.Function.Name
|
||||
// if tool.Function.Description != "" {
|
||||
// countStr += tool.Function.Description
|
||||
// }
|
||||
// if tool.Function.Parameters != nil {
|
||||
// countStr += fmt.Sprintf("%v", tool.Function.Parameters)
|
||||
// }
|
||||
// }
|
||||
// toolTokens := CountTokenInput(countStr, request.Model)
|
||||
// tkm += 8
|
||||
// tkm += toolTokens
|
||||
// }
|
||||
//
|
||||
// return tkm, nil
|
||||
//}
|
||||
|
||||
func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, error) {
|
||||
tkm := 0
|
||||
|
||||
@@ -338,58 +479,55 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
||||
return textToken, audioToken, nil
|
||||
}
|
||||
|
||||
func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, model string, stream bool) (int, error) {
|
||||
//recover when panic
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
// Reference:
|
||||
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
// https://github.com/pkoukk/tiktoken-go/issues/6
|
||||
//
|
||||
// Every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
var tokensPerMessage int
|
||||
var tokensPerName int
|
||||
if model == "gpt-3.5-turbo-0301" {
|
||||
tokensPerMessage = 4
|
||||
tokensPerName = -1 // If there's a name, the role is omitted
|
||||
} else {
|
||||
tokensPerMessage = 3
|
||||
tokensPerName = 1
|
||||
}
|
||||
tokenNum := 0
|
||||
for _, message := range messages {
|
||||
tokenNum += tokensPerMessage
|
||||
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
||||
if message.Content != nil {
|
||||
if message.Name != nil {
|
||||
tokenNum += tokensPerName
|
||||
tokenNum += getTokenNum(tokenEncoder, *message.Name)
|
||||
}
|
||||
arrayContent := message.ParseContent()
|
||||
for _, m := range arrayContent {
|
||||
if m.Type == dto.ContentTypeImageURL {
|
||||
imageUrl := m.GetImageMedia()
|
||||
imageTokenNum, err := getImageToken(info, imageUrl, model, stream)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
tokenNum += imageTokenNum
|
||||
log.Printf("image token num: %d", imageTokenNum)
|
||||
} else if m.Type == dto.ContentTypeInputAudio {
|
||||
// TODO: 音频token数量计算
|
||||
tokenNum += 100
|
||||
} else if m.Type == dto.ContentTypeFile {
|
||||
tokenNum += 5000
|
||||
} else if m.Type == dto.ContentTypeVideoUrl {
|
||||
tokenNum += 5000
|
||||
} else {
|
||||
tokenNum += getTokenNum(tokenEncoder, m.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
|
||||
return tokenNum, nil
|
||||
}
|
||||
//func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, model string, stream bool) (int, error) {
|
||||
// //recover when panic
|
||||
// tokenEncoder := getTokenEncoder(model)
|
||||
// // Reference:
|
||||
// // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
// // https://github.com/pkoukk/tiktoken-go/issues/6
|
||||
// //
|
||||
// // Every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
// var tokensPerMessage int
|
||||
// var tokensPerName int
|
||||
//
|
||||
// tokensPerMessage = 3
|
||||
// tokensPerName = 1
|
||||
//
|
||||
// tokenNum := 0
|
||||
// for _, message := range messages {
|
||||
// tokenNum += tokensPerMessage
|
||||
// tokenNum += getTokenNum(tokenEncoder, message.Role)
|
||||
// if message.Content != nil {
|
||||
// if message.Name != nil {
|
||||
// tokenNum += tokensPerName
|
||||
// tokenNum += getTokenNum(tokenEncoder, *message.Name)
|
||||
// }
|
||||
// arrayContent := message.ParseContent()
|
||||
// for _, m := range arrayContent {
|
||||
// if m.Type == dto.ContentTypeImageURL {
|
||||
// imageUrl := m.GetImageMedia()
|
||||
// imageTokenNum, err := getImageToken(info, imageUrl, model, stream)
|
||||
// if err != nil {
|
||||
// return 0, err
|
||||
// }
|
||||
// tokenNum += imageTokenNum
|
||||
// log.Printf("image token num: %d", imageTokenNum)
|
||||
// } else if m.Type == dto.ContentTypeInputAudio {
|
||||
// // TODO: 音频token数量计算
|
||||
// tokenNum += 100
|
||||
// } else if m.Type == dto.ContentTypeFile {
|
||||
// tokenNum += 5000
|
||||
// } else if m.Type == dto.ContentTypeVideoUrl {
|
||||
// tokenNum += 5000
|
||||
// } else {
|
||||
// tokenNum += getTokenNum(tokenEncoder, m.Text)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
|
||||
// return tokenNum, nil
|
||||
//}
|
||||
|
||||
func CountTokenInput(input any, model string) int {
|
||||
switch v := input.(type) {
|
||||
|
||||
@@ -12,7 +12,7 @@ func NotifyRootUser(t string, subject string, content string) {
|
||||
user := model.GetRootUser().ToBaseUser()
|
||||
err := NotifyUser(user.Id, user.Email, user.GetSetting(), dto.NewNotify(t, subject, content, nil))
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("failed to notify root user: %s", err.Error()))
|
||||
common.SysLog(fmt.Sprintf("failed to notify root user: %s", err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,7 +25,7 @@ func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data
|
||||
// Check notification limit
|
||||
canSend, err := CheckNotificationLimit(userId, data.Type)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error()))
|
||||
common.SysLog(fmt.Sprintf("failed to check notification limit: %s", err.Error()))
|
||||
return err
|
||||
}
|
||||
if !canSend {
|
||||
@@ -44,7 +44,7 @@ func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data
|
||||
case dto.NotifyTypeWebhook:
|
||||
webhookURLStr := userSetting.WebhookUrl
|
||||
if webhookURLStr == "" {
|
||||
common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId))
|
||||
common.SysLog(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId))
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user