Merge remote-tracking branch 'origin/alpha' into fix/openrouter-custom-ratio-billing

This commit is contained in:
yyhhyyyyyy
2025-08-15 14:58:42 +08:00
140 changed files with 3384 additions and 2647 deletions

View File

@@ -85,7 +85,7 @@ func RelayErrorHandler(resp *http.Response, showBodyWhenFail bool) (newApiErr *t
if err != nil {
return
}
common.CloseResponseBodyGracefully(resp)
CloseResponseBodyGracefully(resp)
var errResponse dto.GeneralErrorResponse
err = common.Unmarshal(responseBody, &errResponse)

59
service/http.go Normal file
View File

@@ -0,0 +1,59 @@
package service
import (
"bytes"
"fmt"
"io"
"net/http"
"one-api/common"
"one-api/logger"
"github.com/gin-gonic/gin"
)
func CloseResponseBodyGracefully(httpResponse *http.Response) {
if httpResponse == nil || httpResponse.Body == nil {
return
}
err := httpResponse.Body.Close()
if err != nil {
common.SysError("failed to close response body: " + err.Error())
}
}
func IOCopyBytesGracefully(c *gin.Context, src *http.Response, data []byte) {
if c.Writer == nil {
return
}
body := io.NopCloser(bytes.NewBuffer(data))
// We shouldn't set the header before we parse the response body, because the parse part may fail.
// And then we will have to send an error response, but in this case, the header has already been set.
// So the httpClient will be confused by the response.
// For example, Postman will report error, and we cannot check the response at all.
if src != nil {
for k, v := range src.Header {
// avoid setting Content-Length
if k == "Content-Length" {
continue
}
c.Writer.Header().Set(k, v[0])
}
}
// set Content-Length header manually BEFORE calling WriteHeader
c.Writer.Header().Set("Content-Length", fmt.Sprintf("%d", len(data)))
// Write header with status code (this sends the headers)
if src != nil {
c.Writer.WriteHeader(src.StatusCode)
} else {
c.Writer.WriteHeader(http.StatusOK)
}
_, err := io.Copy(c.Writer, body)
if err != nil {
logger.LogError(c, fmt.Sprintf("failed to copy response body: %s", err.Error()))
}
}

View File

@@ -5,7 +5,7 @@ import (
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/types"
"github.com/gin-gonic/gin"
)
@@ -78,7 +78,7 @@ func GenerateClaudeOtherInfo(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
return info
}
func GenerateMjOtherInfo(priceData helper.PerCallPriceData) map[string]interface{} {
func GenerateMjOtherInfo(priceData types.PerCallPriceData) map[string]interface{} {
other := make(map[string]interface{})
other["model_price"] = priceData.ModelPrice
other["group_ratio"] = priceData.GroupRatioInfo.GroupRatio

View File

@@ -212,7 +212,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
defer cancel()
resp, err := GetHttpClient().Do(req)
if err != nil {
common.SysError("do request failed: " + err.Error())
common.SysLog("do request failed: " + err.Error())
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err
}
statusCode := resp.StatusCode
@@ -233,7 +233,7 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
}
common.CloseResponseBodyGracefully(resp)
CloseResponseBodyGracefully(resp)
respStr := string(responseBody)
log.Printf("respStr: %s", respStr)
if respStr == "" {

View File

@@ -0,0 +1,79 @@
package service
import (
"errors"
"fmt"
"net/http"
"one-api/common"
"one-api/logger"
"one-api/model"
relaycommon "one-api/relay/common"
"one-api/types"
"github.com/bytedance/gopkg/util/gopool"
"github.com/gin-gonic/gin"
)
func ReturnPreConsumedQuota(c *gin.Context, relayInfo *relaycommon.RelayInfo, preConsumedQuota int) {
if preConsumedQuota != 0 {
logger.LogInfo(c, fmt.Sprintf("用户 %d 请求失败, 返还预扣费额度 %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota)))
gopool.Go(func() {
relayInfoCopy := *relayInfo
err := PostConsumeQuota(&relayInfoCopy, -preConsumedQuota, 0, false)
if err != nil {
common.SysLog("error return pre-consumed quota: " + err.Error())
}
})
}
}
// PreConsumeQuota checks if the user has enough quota to pre-consume.
// It returns the pre-consumed quota if successful, or an error if not.
func PreConsumeQuota(c *gin.Context, preConsumedQuota int, relayInfo *relaycommon.RelayInfo) (int, *types.NewAPIError) {
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
if err != nil {
return 0, types.NewError(err, types.ErrorCodeQueryDataError, types.ErrOptionWithSkipRetry())
}
if userQuota <= 0 {
return 0, types.NewErrorWithStatusCode(errors.New("user quota is not enough"), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
}
if userQuota-preConsumedQuota < 0 {
return 0, types.NewErrorWithStatusCode(fmt.Errorf("预扣费额度失败, 用户剩余额度: %s, 需要预扣费额度: %s", logger.FormatQuota(userQuota), logger.FormatQuota(preConsumedQuota)), types.ErrorCodeInsufficientUserQuota, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
}
trustQuota := common.GetTrustQuota()
relayInfo.UserQuota = userQuota
if userQuota > trustQuota {
// 用户额度充足,判断令牌额度是否充足
if !relayInfo.TokenUnlimited {
// 非无限令牌,判断令牌额度是否充足
tokenQuota := c.GetInt("token_quota")
if tokenQuota > trustQuota {
// 令牌额度充足,信任令牌
preConsumedQuota = 0
logger.LogInfo(c, fmt.Sprintf("用户 %d 剩余额度 %s 且令牌 %d 额度 %d 充足, 信任且不需要预扣费", relayInfo.UserId, logger.FormatQuota(userQuota), relayInfo.TokenId, tokenQuota))
}
} else {
// in this case, we do not pre-consume quota
// because the user has enough quota
preConsumedQuota = 0
logger.LogInfo(c, fmt.Sprintf("用户 %d 额度充足且为无限额度令牌, 信任且不需要预扣费", relayInfo.UserId))
}
}
if preConsumedQuota > 0 {
err := PreConsumeTokenQuota(relayInfo, preConsumedQuota)
if err != nil {
return 0, types.NewErrorWithStatusCode(err, types.ErrorCodePreConsumeTokenQuotaFailed, http.StatusForbidden, types.ErrOptionWithSkipRetry(), types.ErrOptionWithNoRecordErrorLog())
}
err = model.DecreaseUserQuota(relayInfo.UserId, preConsumedQuota)
if err != nil {
return 0, types.NewError(err, types.ErrorCodeUpdateDataError, types.ErrOptionWithSkipRetry())
}
logger.LogInfo(c, fmt.Sprintf("用户 %d 预扣费 %s, 预扣费后剩余额度: %s", relayInfo.UserId, logger.FormatQuota(preConsumedQuota), logger.FormatQuota(userQuota-preConsumedQuota)))
}
relayInfo.FinalPreConsumedQuota = preConsumedQuota
return preConsumedQuota, nil
}

View File

@@ -8,11 +8,12 @@ import (
"one-api/common"
"one-api/constant"
"one-api/dto"
"one-api/logger"
"one-api/model"
relaycommon "one-api/relay/common"
"one-api/relay/helper"
"one-api/setting"
"one-api/setting/ratio_setting"
"one-api/types"
"strings"
"time"
@@ -137,23 +138,23 @@ func PreWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usag
quota := calculateAudioQuota(quotaInfo)
if userQuota < quota {
return fmt.Errorf("user quota is not enough, user quota: %s, need quota: %s", common.FormatQuota(userQuota), common.FormatQuota(quota))
return fmt.Errorf("user quota is not enough, user quota: %s, need quota: %s", logger.FormatQuota(userQuota), logger.FormatQuota(quota))
}
if !token.UnlimitedQuota && token.RemainQuota < quota {
return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota))
return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", logger.FormatQuota(token.RemainQuota), logger.FormatQuota(quota))
}
err = PostConsumeQuota(relayInfo, quota, 0, false)
if err != nil {
return err
}
common.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota))
logger.LogInfo(ctx, "realtime streaming consume quota success, quota: "+fmt.Sprintf("%d", quota))
return nil
}
func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, modelName string,
usage *dto.RealtimeUsage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
usage *dto.RealtimeUsage, extraContent string) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
textInputTokens := usage.InputTokenDetails.TextTokens
@@ -167,10 +168,10 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName))
audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(modelName))
modelRatio := priceData.ModelRatio
groupRatio := priceData.GroupRatioInfo.GroupRatio
modelPrice := priceData.ModelPrice
usePrice := priceData.UsePrice
modelRatio := relayInfo.PriceData.ModelRatio
groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
modelPrice := relayInfo.PriceData.ModelPrice
usePrice := relayInfo.PriceData.UsePrice
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
@@ -204,8 +205,8 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
logContent += fmt.Sprintf("(可能是上游超时)")
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
} else {
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
@@ -216,7 +217,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
logContent += ", " + extraContent
}
other := GenerateWssOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: relayInfo.ChannelId,
PromptTokens: usage.InputTokens,
@@ -226,7 +227,6 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
Quota: quota,
Content: logContent,
TokenId: relayInfo.TokenId,
UserQuota: userQuota,
UseTimeSeconds: int(useTimeSeconds),
IsStream: relayInfo.IsStream,
Group: relayInfo.UsingGroup,
@@ -234,8 +234,7 @@ func PostWssConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, mod
})
}
func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
promptTokens := usage.PromptTokens
@@ -243,21 +242,21 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
modelName := relayInfo.OriginModelName
tokenName := ctx.GetString("token_name")
completionRatio := priceData.CompletionRatio
modelRatio := priceData.ModelRatio
groupRatio := priceData.GroupRatioInfo.GroupRatio
modelPrice := priceData.ModelPrice
cacheRatio := priceData.CacheRatio
completionRatio := relayInfo.PriceData.CompletionRatio
modelRatio := relayInfo.PriceData.ModelRatio
groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
modelPrice := relayInfo.PriceData.ModelPrice
cacheRatio := relayInfo.PriceData.CacheRatio
cacheTokens := usage.PromptTokensDetails.CachedTokens
cacheCreationRatio := priceData.CacheCreationRatio
cacheCreationRatio := relayInfo.PriceData.CacheCreationRatio
cacheCreationTokens := usage.PromptTokensDetails.CachedCreationTokens
if relayInfo.ChannelType == constant.ChannelTypeOpenRouter {
promptTokens -= cacheTokens
isUsingCustomSettings := priceData.UsePrice || hasCustomModelRatio(modelName, priceData.ModelRatio)
if cacheCreationTokens == 0 && priceData.CacheCreationRatio != 1 && usage.Cost != 0 && !isUsingCustomSettings {
maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, priceData)
isUsingCustomSettings := relayInfo.PriceData.UsePrice || hasCustomModelRatio(modelName, relayInfo.PriceData.ModelRatio)
if cacheCreationTokens == 0 && relayInfo.PriceData.CacheCreationRatio != 1 && usage.Cost != 0 && !isUsingCustomSettings {
maybeCacheCreationTokens := CalcOpenRouterCacheCreateTokens(*usage, relayInfo.PriceData)
if maybeCacheCreationTokens >= 0 && promptTokens >= maybeCacheCreationTokens {
cacheCreationTokens = maybeCacheCreationTokens
}
@@ -266,7 +265,7 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
}
calculateQuota := 0.0
if !priceData.UsePrice {
if !relayInfo.PriceData.UsePrice {
calculateQuota = float64(promptTokens)
calculateQuota += float64(cacheTokens) * cacheRatio
calculateQuota += float64(cacheCreationTokens) * cacheCreationRatio
@@ -291,23 +290,38 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
logContent += fmt.Sprintf("(可能是上游出错)")
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, preConsumedQuota))
logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, modelName, relayInfo.FinalPreConsumedQuota))
} else {
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
quotaDelta := quota - preConsumedQuota
quotaDelta := quota - relayInfo.FinalPreConsumedQuota
if quotaDelta > 0 {
logger.LogInfo(ctx, fmt.Sprintf("预扣费后补扣费:%s实际消耗%s预扣费%s",
logger.FormatQuota(quotaDelta),
logger.FormatQuota(quota),
logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
))
} else if quotaDelta < 0 {
logger.LogInfo(ctx, fmt.Sprintf("预扣费后返还扣费:%s实际消耗%s预扣费%s",
logger.FormatQuota(-quotaDelta),
logger.FormatQuota(quota),
logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
))
}
if quotaDelta != 0 {
err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
err := PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
logger.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
}
other := GenerateClaudeOtherInfo(ctx, relayInfo, modelRatio, groupRatio, completionRatio,
cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
cacheTokens, cacheRatio, cacheCreationTokens, cacheCreationRatio, modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: relayInfo.ChannelId,
PromptTokens: promptTokens,
@@ -317,7 +331,6 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
Quota: quota,
Content: logContent,
TokenId: relayInfo.TokenId,
UserQuota: userQuota,
UseTimeSeconds: int(useTimeSeconds),
IsStream: relayInfo.IsStream,
Group: relayInfo.UsingGroup,
@@ -326,7 +339,7 @@ func PostClaudeConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
}
func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData helper.PriceData) int {
func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData types.PriceData) int {
if priceData.CacheCreationRatio == 1 {
return 0
}
@@ -347,8 +360,7 @@ func CalcOpenRouterCacheCreateTokens(usage dto.Usage, priceData helper.PriceData
(promptCacheCreatePrice - quotaPrice)))
}
func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
usage *dto.Usage, preConsumedQuota int, userQuota int, priceData helper.PriceData, extraContent string) {
func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo, usage *dto.Usage, extraContent string) {
useTimeSeconds := time.Now().Unix() - relayInfo.StartTime.Unix()
textInputTokens := usage.PromptTokensDetails.TextTokens
@@ -362,10 +374,10 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
audioRatio := decimal.NewFromFloat(ratio_setting.GetAudioRatio(relayInfo.OriginModelName))
audioCompletionRatio := decimal.NewFromFloat(ratio_setting.GetAudioCompletionRatio(relayInfo.OriginModelName))
modelRatio := priceData.ModelRatio
groupRatio := priceData.GroupRatioInfo.GroupRatio
modelPrice := priceData.ModelPrice
usePrice := priceData.UsePrice
modelRatio := relayInfo.PriceData.ModelRatio
groupRatio := relayInfo.PriceData.GroupRatioInfo.GroupRatio
modelPrice := relayInfo.PriceData.ModelPrice
usePrice := relayInfo.PriceData.UsePrice
quotaInfo := QuotaInfo{
InputDetails: TokenDetails{
@@ -399,18 +411,33 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
logContent += fmt.Sprintf("(可能是上游超时)")
common.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, preConsumedQuota))
logger.LogError(ctx, fmt.Sprintf("total tokens is 0, cannot consume quota, userId %d, channelId %d, "+
"tokenId %d, model %s pre-consumed quota %d", relayInfo.UserId, relayInfo.ChannelId, relayInfo.TokenId, relayInfo.OriginModelName, relayInfo.FinalPreConsumedQuota))
} else {
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
}
quotaDelta := quota - preConsumedQuota
quotaDelta := quota - relayInfo.FinalPreConsumedQuota
if quotaDelta > 0 {
logger.LogInfo(ctx, fmt.Sprintf("预扣费后补扣费:%s实际消耗%s预扣费%s",
logger.FormatQuota(quotaDelta),
logger.FormatQuota(quota),
logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
))
} else if quotaDelta < 0 {
logger.LogInfo(ctx, fmt.Sprintf("预扣费后返还扣费:%s实际消耗%s预扣费%s",
logger.FormatQuota(-quotaDelta),
logger.FormatQuota(quota),
logger.FormatQuota(relayInfo.FinalPreConsumedQuota),
))
}
if quotaDelta != 0 {
err := PostConsumeQuota(relayInfo, quotaDelta, preConsumedQuota, true)
err := PostConsumeQuota(relayInfo, quotaDelta, relayInfo.FinalPreConsumedQuota, true)
if err != nil {
common.LogError(ctx, "error consuming token remain quota: "+err.Error())
logger.LogError(ctx, "error consuming token remain quota: "+err.Error())
}
}
@@ -419,7 +446,7 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
logContent += ", " + extraContent
}
other := GenerateAudioOtherInfo(ctx, relayInfo, usage, modelRatio, groupRatio,
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, priceData.GroupRatioInfo.GroupSpecialRatio)
completionRatio.InexactFloat64(), audioRatio.InexactFloat64(), audioCompletionRatio.InexactFloat64(), modelPrice, relayInfo.PriceData.GroupRatioInfo.GroupSpecialRatio)
model.RecordConsumeLog(ctx, relayInfo.UserId, model.RecordConsumeLogParams{
ChannelId: relayInfo.ChannelId,
PromptTokens: usage.PromptTokens,
@@ -429,7 +456,6 @@ func PostAudioConsumeQuota(ctx *gin.Context, relayInfo *relaycommon.RelayInfo,
Quota: quota,
Content: logContent,
TokenId: relayInfo.TokenId,
UserQuota: userQuota,
UseTimeSeconds: int(useTimeSeconds),
IsStream: relayInfo.IsStream,
Group: relayInfo.UsingGroup,
@@ -452,7 +478,7 @@ func PreConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, quota int) error {
return err
}
if !relayInfo.TokenUnlimited && token.RemainQuota < quota {
return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", common.FormatQuota(token.RemainQuota), common.FormatQuota(quota))
return fmt.Errorf("token quota is not enough, token remain quota: %s, need quota: %s", logger.FormatQuota(token.RemainQuota), logger.FormatQuota(quota))
}
err = model.DecreaseTokenQuota(relayInfo.TokenId, relayInfo.TokenKey, quota)
if err != nil {
@@ -510,7 +536,7 @@ func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preCon
prompt := "您的额度即将用尽"
topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
content := "{{value}},当前剩余额度为 {{value}},为了不影响您的使用,请及时充值。<br/>充值链接:<a href='{{value}}'>{{value}}</a>"
err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, common.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink}))
err := NotifyUser(relayInfo.UserId, relayInfo.UserEmail, relayInfo.UserSetting, dto.NewNotify(dto.NotifyTypeQuotaExceed, prompt, content, []interface{}{prompt, logger.FormatQuota(relayInfo.UserQuota), topUpLink, topUpLink}))
if err != nil {
common.SysError(fmt.Sprintf("failed to send quota notify to user %d: %s", relayInfo.UserId, err.Error()))
}

View File

@@ -2,7 +2,6 @@ package service
import (
"errors"
"fmt"
"one-api/dto"
"one-api/setting"
"strings"
@@ -32,25 +31,8 @@ func CheckSensitiveMessages(messages []dto.Message) ([]string, error) {
return nil, nil
}
func CheckSensitiveText(text string) ([]string, error) {
if ok, words := SensitiveWordContains(text); ok {
return words, errors.New("sensitive words detected")
}
return nil, nil
}
func CheckSensitiveInput(input any) ([]string, error) {
switch v := input.(type) {
case string:
return CheckSensitiveText(v)
case []string:
var builder strings.Builder
for _, s := range v {
builder.WriteString(s)
}
return CheckSensitiveText(builder.String())
}
return CheckSensitiveText(fmt.Sprintf("%v", input))
func CheckSensitiveText(text string) (bool, []string) {
return SensitiveWordContains(text)
}
// SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表

View File

@@ -4,8 +4,6 @@ import (
"encoding/json"
"errors"
"fmt"
"github.com/tiktoken-go/tokenizer"
"github.com/tiktoken-go/tokenizer/codec"
"image"
"log"
"math"
@@ -13,9 +11,14 @@ import (
"one-api/constant"
"one-api/dto"
relaycommon "one-api/relay/common"
"one-api/types"
"strings"
"sync"
"unicode/utf8"
"github.com/gin-gonic/gin"
"github.com/tiktoken-go/tokenizer"
"github.com/tiktoken-go/tokenizer/codec"
)
// tokenEncoderMap won't grow after initialization
@@ -72,52 +75,95 @@ func getTokenNum(tokenEncoder tokenizer.Codec, text string) int {
return tkm
}
func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
if imageUrl == nil {
func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, error) {
if fileMeta == nil {
return 0, fmt.Errorf("image_url_is_nil")
}
// Defaults for 4o/4.1/4.5 family unless overridden below
baseTokens := 85
if model == "glm-4v" {
tileTokens := 170
// Model classification
lowerModel := strings.ToLower(model)
// Special cases from existing behavior
if strings.HasPrefix(lowerModel, "glm-4") {
return 1047, nil
}
if imageUrl.Detail == "low" {
// Patch-based models (32x32 patches, capped at 1536, with multiplier)
isPatchBased := false
multiplier := 1.0
switch {
case strings.Contains(lowerModel, "gpt-4.1-mini"):
isPatchBased = true
multiplier = 1.62
case strings.Contains(lowerModel, "gpt-4.1-nano"):
isPatchBased = true
multiplier = 2.46
case strings.HasPrefix(lowerModel, "o4-mini"):
isPatchBased = true
multiplier = 1.72
case strings.HasPrefix(lowerModel, "gpt-5-mini"):
isPatchBased = true
multiplier = 1.62
case strings.HasPrefix(lowerModel, "gpt-5-nano"):
isPatchBased = true
multiplier = 2.46
}
// Tile-based model tokens and bases per doc
if !isPatchBased {
if strings.HasPrefix(lowerModel, "gpt-4o-mini") {
baseTokens = 2833
tileTokens = 5667
} else if strings.HasPrefix(lowerModel, "gpt-5-chat-latest") || (strings.HasPrefix(lowerModel, "gpt-5") && !strings.Contains(lowerModel, "mini") && !strings.Contains(lowerModel, "nano")) {
baseTokens = 70
tileTokens = 140
} else if strings.HasPrefix(lowerModel, "o1") || strings.HasPrefix(lowerModel, "o3") || strings.HasPrefix(lowerModel, "o1-pro") {
baseTokens = 75
tileTokens = 150
} else if strings.Contains(lowerModel, "computer-use-preview") {
baseTokens = 65
tileTokens = 129
} else if strings.Contains(lowerModel, "4.1") || strings.Contains(lowerModel, "4o") || strings.Contains(lowerModel, "4.5") {
baseTokens = 85
tileTokens = 170
}
}
// Respect existing feature flags/short-circuits
if fileMeta.Detail == "low" && !isPatchBased {
return baseTokens, nil
}
if !constant.GetMediaTokenNotStream && !stream {
return 3 * baseTokens, nil
}
// 同步One API的图片计费逻辑
if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
imageUrl.Detail = "high"
// Normalize detail
if fileMeta.Detail == "auto" || fileMeta.Detail == "" {
fileMeta.Detail = "high"
}
tileTokens := 170
if strings.HasPrefix(model, "gpt-4o-mini") {
tileTokens = 5667
baseTokens = 2833
}
// 是否统计图片token
// Whether to count image tokens at all
if !constant.GetMediaToken {
return 3 * baseTokens, nil
}
if info.ChannelType == constant.ChannelTypeGemini || info.ChannelType == constant.ChannelTypeVertexAi || info.ChannelType == constant.ChannelTypeAnthropic {
return 3 * baseTokens, nil
}
// Decode image to get dimensions
var config image.Config
var err error
var format string
var b64str string
if strings.HasPrefix(imageUrl.Url, "http") {
config, format, err = DecodeUrlImageData(imageUrl.Url)
if strings.HasPrefix(fileMeta.Data, "http") {
config, format, err = DecodeUrlImageData(fileMeta.Data)
} else {
common.SysLog(fmt.Sprintf("decoding image"))
config, format, b64str, err = DecodeBase64ImageData(imageUrl.Url)
config, format, b64str, err = DecodeBase64ImageData(fileMeta.Data)
}
if err != nil {
return 0, err
}
imageUrl.MimeType = format
fileMeta.MimeType = format
if config.Width == 0 || config.Height == 0 {
// not an image
@@ -125,60 +171,155 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m
// file type
return 3 * baseTokens, nil
}
return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", imageUrl.Url))
return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", fileMeta.Data))
}
shortSide := config.Width
otherSide := config.Height
log.Printf("format: %s, width: %d, height: %d", format, config.Width, config.Height)
// 缩放倍数
scale := 1.0
if config.Height < shortSide {
shortSide = config.Height
otherSide = config.Width
width := config.Width
height := config.Height
log.Printf("format: %s, width: %d, height: %d", format, width, height)
if isPatchBased {
// 32x32 patch-based calculation with 1536 cap and model multiplier
ceilDiv := func(a, b int) int { return (a + b - 1) / b }
rawPatchesW := ceilDiv(width, 32)
rawPatchesH := ceilDiv(height, 32)
rawPatches := rawPatchesW * rawPatchesH
if rawPatches > 1536 {
// scale down
area := float64(width * height)
r := math.Sqrt(float64(32*32*1536) / area)
wScaled := float64(width) * r
hScaled := float64(height) * r
// adjust to fit whole number of patches after scaling
adjW := math.Floor(wScaled/32.0) / (wScaled / 32.0)
adjH := math.Floor(hScaled/32.0) / (hScaled / 32.0)
adj := math.Min(adjW, adjH)
if !math.IsNaN(adj) && adj > 0 {
r = r * adj
}
wScaled = float64(width) * r
hScaled = float64(height) * r
patchesW := math.Ceil(wScaled / 32.0)
patchesH := math.Ceil(hScaled / 32.0)
imageTokens := int(patchesW * patchesH)
if imageTokens > 1536 {
imageTokens = 1536
}
return int(math.Round(float64(imageTokens) * multiplier)), nil
}
// below cap
imageTokens := rawPatches
return int(math.Round(float64(imageTokens) * multiplier)), nil
}
// 将最小变的尺寸缩小到768以下如果大于768则缩放到768
if shortSide > 768 {
scale = float64(shortSide) / 768
shortSide = 768
// Tile-based calculation for 4o/4.1/4.5/o1/o3/etc.
// Step 1: fit within 2048x2048 square
maxSide := math.Max(float64(width), float64(height))
fitScale := 1.0
if maxSide > 2048 {
fitScale = maxSide / 2048.0
}
// 将另一边按照相同的比例缩小,向上取整
otherSide = int(math.Ceil(float64(otherSide) / scale))
log.Printf("shortSide: %d, otherSide: %d, scale: %f", shortSide, otherSide, scale)
// 计算图片的token数量(边的长度除以512向上取整)
tiles := (shortSide + 511) / 512 * ((otherSide + 511) / 512)
log.Printf("tiles: %d", tiles)
fitW := int(math.Round(float64(width) / fitScale))
fitH := int(math.Round(float64(height) / fitScale))
// Step 2: scale so that shortest side is exactly 768
minSide := math.Min(float64(fitW), float64(fitH))
if minSide == 0 {
return baseTokens, nil
}
shortScale := 768.0 / minSide
finalW := int(math.Round(float64(fitW) * shortScale))
finalH := int(math.Round(float64(fitH) * shortScale))
// Count 512px tiles
tilesW := (finalW + 512 - 1) / 512
tilesH := (finalH + 512 - 1) / 512
tiles := tilesW * tilesH
if common.DebugEnabled {
log.Printf("scaled to: %dx%d, tiles: %d", finalW, finalH, tiles)
}
return tiles*tileTokens + baseTokens, nil
}
func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) (int, error) {
tkm := 0
msgTokens, err := CountTokenMessages(info, request.Messages, request.Model, request.Stream)
if err != nil {
return 0, err
}
tkm += msgTokens
if request.Tools != nil {
openaiTools := request.Tools
countStr := ""
for _, tool := range openaiTools {
countStr = tool.Function.Name
if tool.Function.Description != "" {
countStr += tool.Function.Description
}
if tool.Function.Parameters != nil {
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
}
}
toolTokens := CountTokenInput(countStr, request.Model)
tkm += 8
tkm += toolTokens
func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) {
if meta == nil {
return 0, errors.New("token count meta is nil")
}
if info.RelayFormat == types.RelayFormatOpenAIRealtime {
return 0, nil
}
model := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
tkm := 0
if meta.TokenType == types.TokenTypeTextNumber {
tkm += utf8.RuneCountInString(meta.CombineText)
} else {
tkm += CountTextToken(meta.CombineText, model)
}
if info.RelayFormat == types.RelayFormatOpenAI {
tkm += meta.ToolsCount * 8
tkm += meta.MessagesCount * 3 // 每条消息的格式化token数量
tkm += meta.NameCount * 3
tkm += 3
}
for _, file := range meta.Files {
switch file.FileType {
case types.FileTypeImage:
if info.RelayFormat == types.RelayFormatGemini {
tkm += 240
} else {
token, err := getImageToken(file, model, info.IsStream)
if err != nil {
return 0, fmt.Errorf("error counting image token: %v", err)
}
tkm += token
}
case types.FileTypeAudio:
tkm += 100
case types.FileTypeVideo:
tkm += 5000
case types.FileTypeFile:
tkm += 5000
}
}
common.SetContextKey(c, constant.ContextKeyPromptTokens, tkm)
return tkm, nil
}
//func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) (int, error) {
// tkm := 0
// msgTokens, err := CountTokenMessages(info, request.Messages, request.Model, request.Stream)
// if err != nil {
// return 0, err
// }
// tkm += msgTokens
// if request.Tools != nil {
// openaiTools := request.Tools
// countStr := ""
// for _, tool := range openaiTools {
// countStr = tool.Function.Name
// if tool.Function.Description != "" {
// countStr += tool.Function.Description
// }
// if tool.Function.Parameters != nil {
// countStr += fmt.Sprintf("%v", tool.Function.Parameters)
// }
// }
// toolTokens := CountTokenInput(countStr, request.Model)
// tkm += 8
// tkm += toolTokens
// }
//
// return tkm, nil
//}
func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, error) {
tkm := 0
@@ -338,58 +479,55 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
return textToken, audioToken, nil
}
func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, model string, stream bool) (int, error) {
//recover when panic
tokenEncoder := getTokenEncoder(model)
// Reference:
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
// https://github.com/pkoukk/tiktoken-go/issues/6
//
// Every message follows <|start|>{role/name}\n{content}<|end|>\n
var tokensPerMessage int
var tokensPerName int
if model == "gpt-3.5-turbo-0301" {
tokensPerMessage = 4
tokensPerName = -1 // If there's a name, the role is omitted
} else {
tokensPerMessage = 3
tokensPerName = 1
}
tokenNum := 0
for _, message := range messages {
tokenNum += tokensPerMessage
tokenNum += getTokenNum(tokenEncoder, message.Role)
if message.Content != nil {
if message.Name != nil {
tokenNum += tokensPerName
tokenNum += getTokenNum(tokenEncoder, *message.Name)
}
arrayContent := message.ParseContent()
for _, m := range arrayContent {
if m.Type == dto.ContentTypeImageURL {
imageUrl := m.GetImageMedia()
imageTokenNum, err := getImageToken(info, imageUrl, model, stream)
if err != nil {
return 0, err
}
tokenNum += imageTokenNum
log.Printf("image token num: %d", imageTokenNum)
} else if m.Type == dto.ContentTypeInputAudio {
// TODO: 音频token数量计算
tokenNum += 100
} else if m.Type == dto.ContentTypeFile {
tokenNum += 5000
} else if m.Type == dto.ContentTypeVideoUrl {
tokenNum += 5000
} else {
tokenNum += getTokenNum(tokenEncoder, m.Text)
}
}
}
}
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
return tokenNum, nil
}
//func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, model string, stream bool) (int, error) {
// //recover when panic
// tokenEncoder := getTokenEncoder(model)
// // Reference:
// // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
// // https://github.com/pkoukk/tiktoken-go/issues/6
// //
// // Every message follows <|start|>{role/name}\n{content}<|end|>\n
// var tokensPerMessage int
// var tokensPerName int
//
// tokensPerMessage = 3
// tokensPerName = 1
//
// tokenNum := 0
// for _, message := range messages {
// tokenNum += tokensPerMessage
// tokenNum += getTokenNum(tokenEncoder, message.Role)
// if message.Content != nil {
// if message.Name != nil {
// tokenNum += tokensPerName
// tokenNum += getTokenNum(tokenEncoder, *message.Name)
// }
// arrayContent := message.ParseContent()
// for _, m := range arrayContent {
// if m.Type == dto.ContentTypeImageURL {
// imageUrl := m.GetImageMedia()
// imageTokenNum, err := getImageToken(info, imageUrl, model, stream)
// if err != nil {
// return 0, err
// }
// tokenNum += imageTokenNum
// log.Printf("image token num: %d", imageTokenNum)
// } else if m.Type == dto.ContentTypeInputAudio {
// // TODO: 音频token数量计算
// tokenNum += 100
// } else if m.Type == dto.ContentTypeFile {
// tokenNum += 5000
// } else if m.Type == dto.ContentTypeVideoUrl {
// tokenNum += 5000
// } else {
// tokenNum += getTokenNum(tokenEncoder, m.Text)
// }
// }
// }
// }
// tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
// return tokenNum, nil
//}
func CountTokenInput(input any, model string) int {
switch v := input.(type) {

View File

@@ -12,7 +12,7 @@ func NotifyRootUser(t string, subject string, content string) {
user := model.GetRootUser().ToBaseUser()
err := NotifyUser(user.Id, user.Email, user.GetSetting(), dto.NewNotify(t, subject, content, nil))
if err != nil {
common.SysError(fmt.Sprintf("failed to notify root user: %s", err.Error()))
common.SysLog(fmt.Sprintf("failed to notify root user: %s", err.Error()))
}
}
@@ -25,7 +25,7 @@ func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data
// Check notification limit
canSend, err := CheckNotificationLimit(userId, data.Type)
if err != nil {
common.SysError(fmt.Sprintf("failed to check notification limit: %s", err.Error()))
common.SysLog(fmt.Sprintf("failed to check notification limit: %s", err.Error()))
return err
}
if !canSend {
@@ -44,7 +44,7 @@ func NotifyUser(userId int, userEmail string, userSetting dto.UserSetting, data
case dto.NotifyTypeWebhook:
webhookURLStr := userSetting.WebhookUrl
if webhookURLStr == "" {
common.SysError(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId))
common.SysLog(fmt.Sprintf("user %d has no webhook url, skip sending webhook", userId))
return nil
}