refactor: Introduce pre-consume quota and unify relay handlers
This commit introduces a major architectural refactoring to improve quota management, centralize logging, and streamline the relay handling logic. Key changes: - **Pre-consume Quota:** Implements a new mechanism to check and reserve user quota *before* making the request to the upstream provider. This ensures more accurate quota deduction and prevents users from exceeding their limits due to concurrent requests. - **Unified Relay Handlers:** Refactors the relay logic to use generic handlers (e.g., `ChatHandler`, `ImageHandler`) instead of provider-specific implementations. This significantly reduces code duplication and simplifies adding new channels. - **Centralized Logger:** A new dedicated `logger` package is introduced, and all system logging calls are migrated to use it, moving this responsibility out of the `common` package. - **Code Reorganization:** DTOs are generalized (e.g., `dalle.go` -> `openai_image.go`) and utility code is moved to more appropriate packages (e.g., `common/http.go` -> `service/http.go`) for better code structure.
This commit is contained in:
@@ -4,18 +4,22 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/tiktoken-go/tokenizer"
|
||||
"github.com/tiktoken-go/tokenizer/codec"
|
||||
"image"
|
||||
"log"
|
||||
"math"
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
"sync"
|
||||
"unicode/utf8"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tiktoken-go/tokenizer"
|
||||
"github.com/tiktoken-go/tokenizer/codec"
|
||||
)
|
||||
|
||||
// tokenEncoderMap won't grow after initialization
|
||||
@@ -28,9 +32,9 @@ var tokenEncoderMap = make(map[string]tokenizer.Codec)
|
||||
var tokenEncoderMutex sync.RWMutex
|
||||
|
||||
func InitTokenEncoders() {
|
||||
common.SysLog("initializing token encoders")
|
||||
logger.SysLog("initializing token encoders")
|
||||
defaultTokenEncoder = codec.NewCl100kBase()
|
||||
common.SysLog("token encoders initialized")
|
||||
logger.SysLog("token encoders initialized")
|
||||
}
|
||||
|
||||
func getTokenEncoder(model string) tokenizer.Codec {
|
||||
@@ -72,52 +76,95 @@ func getTokenNum(tokenEncoder tokenizer.Codec, text string) int {
|
||||
return tkm
|
||||
}
|
||||
|
||||
func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {
|
||||
if imageUrl == nil {
|
||||
func getImageToken(fileMeta *types.FileMeta, model string, stream bool) (int, error) {
|
||||
if fileMeta == nil {
|
||||
return 0, fmt.Errorf("image_url_is_nil")
|
||||
}
|
||||
|
||||
// Defaults for 4o/4.1/4.5 family unless overridden below
|
||||
baseTokens := 85
|
||||
if model == "glm-4v" {
|
||||
tileTokens := 170
|
||||
|
||||
// Model classification
|
||||
lowerModel := strings.ToLower(model)
|
||||
|
||||
// Special cases from existing behavior
|
||||
if strings.HasPrefix(lowerModel, "glm-4") {
|
||||
return 1047, nil
|
||||
}
|
||||
if imageUrl.Detail == "low" {
|
||||
|
||||
// Patch-based models (32x32 patches, capped at 1536, with multiplier)
|
||||
isPatchBased := false
|
||||
multiplier := 1.0
|
||||
switch {
|
||||
case strings.Contains(lowerModel, "gpt-4.1-mini"):
|
||||
isPatchBased = true
|
||||
multiplier = 1.62
|
||||
case strings.Contains(lowerModel, "gpt-4.1-nano"):
|
||||
isPatchBased = true
|
||||
multiplier = 2.46
|
||||
case strings.HasPrefix(lowerModel, "o4-mini"):
|
||||
isPatchBased = true
|
||||
multiplier = 1.72
|
||||
case strings.HasPrefix(lowerModel, "gpt-5-mini"):
|
||||
isPatchBased = true
|
||||
multiplier = 1.62
|
||||
case strings.HasPrefix(lowerModel, "gpt-5-nano"):
|
||||
isPatchBased = true
|
||||
multiplier = 2.46
|
||||
}
|
||||
|
||||
// Tile-based model tokens and bases per doc
|
||||
if !isPatchBased {
|
||||
if strings.HasPrefix(lowerModel, "gpt-4o-mini") {
|
||||
baseTokens = 2833
|
||||
tileTokens = 5667
|
||||
} else if strings.HasPrefix(lowerModel, "gpt-5-chat-latest") || (strings.HasPrefix(lowerModel, "gpt-5") && !strings.Contains(lowerModel, "mini") && !strings.Contains(lowerModel, "nano")) {
|
||||
baseTokens = 70
|
||||
tileTokens = 140
|
||||
} else if strings.HasPrefix(lowerModel, "o1") || strings.HasPrefix(lowerModel, "o3") || strings.HasPrefix(lowerModel, "o1-pro") {
|
||||
baseTokens = 75
|
||||
tileTokens = 150
|
||||
} else if strings.Contains(lowerModel, "computer-use-preview") {
|
||||
baseTokens = 65
|
||||
tileTokens = 129
|
||||
} else if strings.Contains(lowerModel, "4.1") || strings.Contains(lowerModel, "4o") || strings.Contains(lowerModel, "4.5") {
|
||||
baseTokens = 85
|
||||
tileTokens = 170
|
||||
}
|
||||
}
|
||||
|
||||
// Respect existing feature flags/short-circuits
|
||||
if fileMeta.Detail == "low" && !isPatchBased {
|
||||
return baseTokens, nil
|
||||
}
|
||||
if !constant.GetMediaTokenNotStream && !stream {
|
||||
return 3 * baseTokens, nil
|
||||
}
|
||||
|
||||
// 同步One API的图片计费逻辑
|
||||
if imageUrl.Detail == "auto" || imageUrl.Detail == "" {
|
||||
imageUrl.Detail = "high"
|
||||
// Normalize detail
|
||||
if fileMeta.Detail == "auto" || fileMeta.Detail == "" {
|
||||
fileMeta.Detail = "high"
|
||||
}
|
||||
|
||||
tileTokens := 170
|
||||
if strings.HasPrefix(model, "gpt-4o-mini") {
|
||||
tileTokens = 5667
|
||||
baseTokens = 2833
|
||||
}
|
||||
// 是否统计图片token
|
||||
// Whether to count image tokens at all
|
||||
if !constant.GetMediaToken {
|
||||
return 3 * baseTokens, nil
|
||||
}
|
||||
if info.ChannelType == constant.ChannelTypeGemini || info.ChannelType == constant.ChannelTypeVertexAi || info.ChannelType == constant.ChannelTypeAnthropic {
|
||||
return 3 * baseTokens, nil
|
||||
}
|
||||
|
||||
// Decode image to get dimensions
|
||||
var config image.Config
|
||||
var err error
|
||||
var format string
|
||||
var b64str string
|
||||
if strings.HasPrefix(imageUrl.Url, "http") {
|
||||
config, format, err = DecodeUrlImageData(imageUrl.Url)
|
||||
if strings.HasPrefix(fileMeta.Data, "http") {
|
||||
config, format, err = DecodeUrlImageData(fileMeta.Data)
|
||||
} else {
|
||||
common.SysLog(fmt.Sprintf("decoding image"))
|
||||
config, format, b64str, err = DecodeBase64ImageData(imageUrl.Url)
|
||||
logger.SysLog(fmt.Sprintf("decoding image"))
|
||||
config, format, b64str, err = DecodeBase64ImageData(fileMeta.Data)
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
imageUrl.MimeType = format
|
||||
fileMeta.MimeType = format
|
||||
|
||||
if config.Width == 0 || config.Height == 0 {
|
||||
// not an image
|
||||
@@ -125,60 +172,144 @@ func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, m
|
||||
// file type
|
||||
return 3 * baseTokens, nil
|
||||
}
|
||||
return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", imageUrl.Url))
|
||||
return 0, errors.New(fmt.Sprintf("fail to decode base64 config: %s", fileMeta.Data))
|
||||
}
|
||||
|
||||
shortSide := config.Width
|
||||
otherSide := config.Height
|
||||
log.Printf("format: %s, width: %d, height: %d", format, config.Width, config.Height)
|
||||
// 缩放倍数
|
||||
scale := 1.0
|
||||
if config.Height < shortSide {
|
||||
shortSide = config.Height
|
||||
otherSide = config.Width
|
||||
width := config.Width
|
||||
height := config.Height
|
||||
log.Printf("format: %s, width: %d, height: %d", format, width, height)
|
||||
|
||||
if isPatchBased {
|
||||
// 32x32 patch-based calculation with 1536 cap and model multiplier
|
||||
ceilDiv := func(a, b int) int { return (a + b - 1) / b }
|
||||
rawPatchesW := ceilDiv(width, 32)
|
||||
rawPatchesH := ceilDiv(height, 32)
|
||||
rawPatches := rawPatchesW * rawPatchesH
|
||||
if rawPatches > 1536 {
|
||||
// scale down
|
||||
area := float64(width * height)
|
||||
r := math.Sqrt(float64(32*32*1536) / area)
|
||||
wScaled := float64(width) * r
|
||||
hScaled := float64(height) * r
|
||||
// adjust to fit whole number of patches after scaling
|
||||
adjW := math.Floor(wScaled/32.0) / (wScaled / 32.0)
|
||||
adjH := math.Floor(hScaled/32.0) / (hScaled / 32.0)
|
||||
adj := math.Min(adjW, adjH)
|
||||
if !math.IsNaN(adj) && adj > 0 {
|
||||
r = r * adj
|
||||
}
|
||||
wScaled = float64(width) * r
|
||||
hScaled = float64(height) * r
|
||||
patchesW := math.Ceil(wScaled / 32.0)
|
||||
patchesH := math.Ceil(hScaled / 32.0)
|
||||
imageTokens := int(patchesW * patchesH)
|
||||
if imageTokens > 1536 {
|
||||
imageTokens = 1536
|
||||
}
|
||||
return int(math.Round(float64(imageTokens) * multiplier)), nil
|
||||
}
|
||||
// below cap
|
||||
imageTokens := rawPatches
|
||||
return int(math.Round(float64(imageTokens) * multiplier)), nil
|
||||
}
|
||||
|
||||
// 将最小变的尺寸缩小到768以下,如果大于768,则缩放到768
|
||||
if shortSide > 768 {
|
||||
scale = float64(shortSide) / 768
|
||||
shortSide = 768
|
||||
// Tile-based calculation for 4o/4.1/4.5/o1/o3/etc.
|
||||
// Step 1: fit within 2048x2048 square
|
||||
maxSide := math.Max(float64(width), float64(height))
|
||||
fitScale := 1.0
|
||||
if maxSide > 2048 {
|
||||
fitScale = maxSide / 2048.0
|
||||
}
|
||||
// 将另一边按照相同的比例缩小,向上取整
|
||||
otherSide = int(math.Ceil(float64(otherSide) / scale))
|
||||
log.Printf("shortSide: %d, otherSide: %d, scale: %f", shortSide, otherSide, scale)
|
||||
// 计算图片的token数量(边的长度除以512,向上取整)
|
||||
tiles := (shortSide + 511) / 512 * ((otherSide + 511) / 512)
|
||||
log.Printf("tiles: %d", tiles)
|
||||
fitW := int(math.Round(float64(width) / fitScale))
|
||||
fitH := int(math.Round(float64(height) / fitScale))
|
||||
|
||||
// Step 2: scale so that shortest side is exactly 768
|
||||
minSide := math.Min(float64(fitW), float64(fitH))
|
||||
if minSide == 0 {
|
||||
return baseTokens, nil
|
||||
}
|
||||
shortScale := 768.0 / minSide
|
||||
finalW := int(math.Round(float64(fitW) * shortScale))
|
||||
finalH := int(math.Round(float64(fitH) * shortScale))
|
||||
|
||||
// Count 512px tiles
|
||||
tilesW := (finalW + 512 - 1) / 512
|
||||
tilesH := (finalH + 512 - 1) / 512
|
||||
tiles := tilesW * tilesH
|
||||
|
||||
if common.DebugEnabled {
|
||||
log.Printf("scaled to: %dx%d, tiles: %d", finalW, finalH, tiles)
|
||||
}
|
||||
|
||||
return tiles*tileTokens + baseTokens, nil
|
||||
}
|
||||
|
||||
func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) (int, error) {
|
||||
tkm := 0
|
||||
msgTokens, err := CountTokenMessages(info, request.Messages, request.Model, request.Stream)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
func CountRequestToken(c *gin.Context, meta *types.TokenCountMeta, info *relaycommon.RelayInfo) (int, error) {
|
||||
if meta == nil {
|
||||
return 0, errors.New("token count meta is nil")
|
||||
}
|
||||
tkm += msgTokens
|
||||
if request.Tools != nil {
|
||||
openaiTools := request.Tools
|
||||
countStr := ""
|
||||
for _, tool := range openaiTools {
|
||||
countStr = tool.Function.Name
|
||||
if tool.Function.Description != "" {
|
||||
countStr += tool.Function.Description
|
||||
}
|
||||
if tool.Function.Parameters != nil {
|
||||
countStr += fmt.Sprintf("%v", tool.Function.Parameters)
|
||||
}
|
||||
}
|
||||
toolTokens := CountTokenInput(countStr, request.Model)
|
||||
tkm += 8
|
||||
tkm += toolTokens
|
||||
model := common.GetContextKeyString(c, constant.ContextKeyOriginalModel)
|
||||
tkm := CountTextToken(meta.CombineText, model)
|
||||
|
||||
if info.RelayFormat == types.RelayFormatOpenAI {
|
||||
tkm += meta.ToolsCount * 8
|
||||
tkm += meta.MessagesCount * 3 // 每条消息的格式化token数量
|
||||
tkm += meta.NameCount * 3
|
||||
tkm += 3
|
||||
}
|
||||
|
||||
for _, file := range meta.Files {
|
||||
switch file.FileType {
|
||||
case types.FileTypeImage:
|
||||
if info.RelayFormat == types.RelayFormatGemini {
|
||||
tkm += 240
|
||||
} else {
|
||||
token, err := getImageToken(file, model, info.IsStream)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error counting image token: %v", err)
|
||||
}
|
||||
tkm += token
|
||||
}
|
||||
case types.FileTypeAudio:
|
||||
tkm += 100
|
||||
case types.FileTypeVideo:
|
||||
tkm += 5000
|
||||
case types.FileTypeFile:
|
||||
tkm += 5000
|
||||
}
|
||||
}
|
||||
|
||||
common.SetContextKey(c, constant.ContextKeyPromptTokens, tkm)
|
||||
return tkm, nil
|
||||
}
|
||||
|
||||
//func CountTokenChatRequest(info *relaycommon.RelayInfo, request dto.GeneralOpenAIRequest) (int, error) {
|
||||
// tkm := 0
|
||||
// msgTokens, err := CountTokenMessages(info, request.Messages, request.Model, request.Stream)
|
||||
// if err != nil {
|
||||
// return 0, err
|
||||
// }
|
||||
// tkm += msgTokens
|
||||
// if request.Tools != nil {
|
||||
// openaiTools := request.Tools
|
||||
// countStr := ""
|
||||
// for _, tool := range openaiTools {
|
||||
// countStr = tool.Function.Name
|
||||
// if tool.Function.Description != "" {
|
||||
// countStr += tool.Function.Description
|
||||
// }
|
||||
// if tool.Function.Parameters != nil {
|
||||
// countStr += fmt.Sprintf("%v", tool.Function.Parameters)
|
||||
// }
|
||||
// }
|
||||
// toolTokens := CountTokenInput(countStr, request.Model)
|
||||
// tkm += 8
|
||||
// tkm += toolTokens
|
||||
// }
|
||||
//
|
||||
// return tkm, nil
|
||||
//}
|
||||
|
||||
func CountTokenClaudeRequest(request dto.ClaudeRequest, model string) (int, error) {
|
||||
tkm := 0
|
||||
|
||||
@@ -338,58 +469,55 @@ func CountTokenRealtime(info *relaycommon.RelayInfo, request dto.RealtimeEvent,
|
||||
return textToken, audioToken, nil
|
||||
}
|
||||
|
||||
func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, model string, stream bool) (int, error) {
|
||||
//recover when panic
|
||||
tokenEncoder := getTokenEncoder(model)
|
||||
// Reference:
|
||||
// https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
// https://github.com/pkoukk/tiktoken-go/issues/6
|
||||
//
|
||||
// Every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
var tokensPerMessage int
|
||||
var tokensPerName int
|
||||
if model == "gpt-3.5-turbo-0301" {
|
||||
tokensPerMessage = 4
|
||||
tokensPerName = -1 // If there's a name, the role is omitted
|
||||
} else {
|
||||
tokensPerMessage = 3
|
||||
tokensPerName = 1
|
||||
}
|
||||
tokenNum := 0
|
||||
for _, message := range messages {
|
||||
tokenNum += tokensPerMessage
|
||||
tokenNum += getTokenNum(tokenEncoder, message.Role)
|
||||
if message.Content != nil {
|
||||
if message.Name != nil {
|
||||
tokenNum += tokensPerName
|
||||
tokenNum += getTokenNum(tokenEncoder, *message.Name)
|
||||
}
|
||||
arrayContent := message.ParseContent()
|
||||
for _, m := range arrayContent {
|
||||
if m.Type == dto.ContentTypeImageURL {
|
||||
imageUrl := m.GetImageMedia()
|
||||
imageTokenNum, err := getImageToken(info, imageUrl, model, stream)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
tokenNum += imageTokenNum
|
||||
log.Printf("image token num: %d", imageTokenNum)
|
||||
} else if m.Type == dto.ContentTypeInputAudio {
|
||||
// TODO: 音频token数量计算
|
||||
tokenNum += 100
|
||||
} else if m.Type == dto.ContentTypeFile {
|
||||
tokenNum += 5000
|
||||
} else if m.Type == dto.ContentTypeVideoUrl {
|
||||
tokenNum += 5000
|
||||
} else {
|
||||
tokenNum += getTokenNum(tokenEncoder, m.Text)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
|
||||
return tokenNum, nil
|
||||
}
|
||||
//func CountTokenMessages(info *relaycommon.RelayInfo, messages []dto.Message, model string, stream bool) (int, error) {
|
||||
// //recover when panic
|
||||
// tokenEncoder := getTokenEncoder(model)
|
||||
// // Reference:
|
||||
// // https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
|
||||
// // https://github.com/pkoukk/tiktoken-go/issues/6
|
||||
// //
|
||||
// // Every message follows <|start|>{role/name}\n{content}<|end|>\n
|
||||
// var tokensPerMessage int
|
||||
// var tokensPerName int
|
||||
//
|
||||
// tokensPerMessage = 3
|
||||
// tokensPerName = 1
|
||||
//
|
||||
// tokenNum := 0
|
||||
// for _, message := range messages {
|
||||
// tokenNum += tokensPerMessage
|
||||
// tokenNum += getTokenNum(tokenEncoder, message.Role)
|
||||
// if message.Content != nil {
|
||||
// if message.Name != nil {
|
||||
// tokenNum += tokensPerName
|
||||
// tokenNum += getTokenNum(tokenEncoder, *message.Name)
|
||||
// }
|
||||
// arrayContent := message.ParseContent()
|
||||
// for _, m := range arrayContent {
|
||||
// if m.Type == dto.ContentTypeImageURL {
|
||||
// imageUrl := m.GetImageMedia()
|
||||
// imageTokenNum, err := getImageToken(info, imageUrl, model, stream)
|
||||
// if err != nil {
|
||||
// return 0, err
|
||||
// }
|
||||
// tokenNum += imageTokenNum
|
||||
// log.Printf("image token num: %d", imageTokenNum)
|
||||
// } else if m.Type == dto.ContentTypeInputAudio {
|
||||
// // TODO: 音频token数量计算
|
||||
// tokenNum += 100
|
||||
// } else if m.Type == dto.ContentTypeFile {
|
||||
// tokenNum += 5000
|
||||
// } else if m.Type == dto.ContentTypeVideoUrl {
|
||||
// tokenNum += 5000
|
||||
// } else {
|
||||
// tokenNum += getTokenNum(tokenEncoder, m.Text)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// tokenNum += 3 // Every reply is primed with <|start|>assistant<|message|>
|
||||
// return tokenNum, nil
|
||||
//}
|
||||
|
||||
func CountTokenInput(input any, model string) int {
|
||||
switch v := input.(type) {
|
||||
|
||||
Reference in New Issue
Block a user