diff --git a/controller/channel-test.go b/controller/channel-test.go index 970c1768..f9c7bf7b 100644 --- a/controller/channel-test.go +++ b/controller/channel-test.go @@ -271,6 +271,13 @@ func testAllChannels(notify bool) error { disableThreshold = 10000000 // a impossible value } gopool.Go(func() { + // 使用 defer 确保无论如何都会重置运行状态,防止死锁 + defer func() { + testAllChannelsLock.Lock() + testAllChannelsRunning = false + testAllChannelsLock.Unlock() + }() + for _, channel := range channels { isChannelEnabled := channel.Status == common.ChannelStatusEnabled tik := time.Now() @@ -305,9 +312,7 @@ func testAllChannels(notify bool) error { channel.UpdateResponseTime(milliseconds) time.Sleep(common.RequestInterval) } - testAllChannelsLock.Lock() - testAllChannelsRunning = false - testAllChannelsLock.Unlock() + if notify { service.NotifyRootUser(dto.NotifyTypeChannelTest, "通道测试完成", "所有通道测试已完成") } diff --git a/controller/misc.go b/controller/misc.go index 69398b11..8fa8e8f6 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -6,6 +6,7 @@ import ( "net/http" "one-api/common" "one-api/constant" + "one-api/middleware" "one-api/model" "one-api/setting" "one-api/setting/operation_setting" @@ -24,14 +25,18 @@ func TestStatus(c *gin.Context) { }) return } + // 获取HTTP统计信息 + httpStats := middleware.GetStats() c.JSON(http.StatusOK, gin.H{ - "success": true, - "message": "Server is running", + "success": true, + "message": "Server is running", + "http_stats": httpStats, }) return } func GetStatus(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ "success": true, "message": "", diff --git a/go.mod b/go.mod index ce768bf3..9479ba55 100644 --- a/go.mod +++ b/go.mod @@ -11,7 +11,6 @@ require ( github.com/aws/aws-sdk-go-v2/credentials v1.17.11 github.com/aws/aws-sdk-go-v2/service/bedrockruntime v1.7.4 github.com/bytedance/gopkg v0.0.0-20220118071334-3db87571198b - github.com/bytedance/sonic v1.11.6 github.com/gin-contrib/cors v1.7.2 github.com/gin-contrib/gzip v0.0.6 github.com/gin-contrib/sessions v0.0.5 @@ -25,10 +24,10 @@ require ( github.com/gorilla/websocket v1.5.0 github.com/joho/godotenv v1.5.1 github.com/pkg/errors v0.9.1 - github.com/pkoukk/tiktoken-go v0.1.7 github.com/samber/lo v1.39.0 github.com/shirou/gopsutil v3.21.11+incompatible github.com/shopspring/decimal v1.4.0 + github.com/tiktoken-go/tokenizer v0.6.2 golang.org/x/crypto v0.35.0 golang.org/x/image v0.23.0 golang.org/x/net v0.35.0 @@ -43,12 +42,13 @@ require ( github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.5 // indirect github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.5 // indirect github.com/aws/smithy-go v1.20.2 // indirect + github.com/bytedance/sonic v1.11.6 // indirect github.com/bytedance/sonic/loader v0.1.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect - github.com/dlclark/regexp2 v1.11.0 // indirect + github.com/dlclark/regexp2 v1.11.5 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect github.com/gin-contrib/sse v0.1.0 // indirect diff --git a/go.sum b/go.sum index 2bd81fa3..71dd83c2 100644 --- a/go.sum +++ b/go.sum @@ -38,8 +38,8 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= -github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= -github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ= +github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= @@ -167,8 +167,6 @@ github.com/pelletier/go-toml/v2 v2.2.1/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= -github.com/pkoukk/tiktoken-go v0.1.7 h1:qOBHXX4PHtvIvmOtyg1EeKlwFRiMKAcoMp4Q+bLQDmw= -github.com/pkoukk/tiktoken-go v0.1.7/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= @@ -197,6 +195,8 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tiktoken-go/tokenizer v0.6.2 h1:t0GN2DvcUZSFWT/62YOgoqb10y7gSXBGs0A+4VCQK+g= +github.com/tiktoken-go/tokenizer v0.6.2/go.mod h1:6UCYI/DtOallbmL7sSy30p6YQv60qNyU/4aVigPOx6w= github.com/tklauser/go-sysconf v0.3.12 h1:0QaGUFOdQaIVdPgfITYzaTegZvdCjmYO52cSFAEVmqU= github.com/tklauser/go-sysconf v0.3.12/go.mod h1:Ho14jnntGE1fpdOqQEEaiKRpvIavV0hSfmBq8nJbHYI= github.com/tklauser/numcpus v0.6.1 h1:ng9scYS7az0Bk4OZLvrNXNSAO2Pxr1XXRAPyjhIx+Fk= diff --git a/middleware/stats.go b/middleware/stats.go new file mode 100644 index 00000000..1c97983f --- /dev/null +++ b/middleware/stats.go @@ -0,0 +1,41 @@ +package middleware + +import ( + "sync/atomic" + + "github.com/gin-gonic/gin" +) + +// HTTPStats 存储HTTP统计信息 +type HTTPStats struct { + activeConnections int64 +} + +var globalStats = &HTTPStats{} + +// StatsMiddleware 统计中间件 +func StatsMiddleware() gin.HandlerFunc { + return func(c *gin.Context) { + // 增加活跃连接数 + atomic.AddInt64(&globalStats.activeConnections, 1) + + // 确保在请求结束时减少连接数 + defer func() { + atomic.AddInt64(&globalStats.activeConnections, -1) + }() + + c.Next() + } +} + +// StatsInfo 统计信息结构 +type StatsInfo struct { + ActiveConnections int64 `json:"active_connections"` +} + +// GetStats 获取统计信息 +func GetStats() StatsInfo { + return StatsInfo{ + ActiveConnections: atomic.LoadInt64(&globalStats.activeConnections), + } +} \ No newline at end of file diff --git a/router/relay-router.go b/router/relay-router.go index 1115a491..aa7f27a8 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -11,6 +11,7 @@ import ( func SetRelayRouter(router *gin.Engine) { router.Use(middleware.CORS()) router.Use(middleware.DecompressRequestMiddleware()) + router.Use(middleware.StatsMiddleware()) // https://platform.openai.com/docs/api-reference/introduction modelsRouter := router.Group("/v1/models") modelsRouter.Use(middleware.TokenAuth()) diff --git a/service/token_counter.go b/service/token_counter.go index e1722013..d27bb5ea 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -4,6 +4,8 @@ import ( "encoding/json" "errors" "fmt" + "github.com/tiktoken-go/tokenizer" + "github.com/tiktoken-go/tokenizer/codec" "image" "log" "math" @@ -11,78 +13,63 @@ import ( "one-api/constant" "one-api/dto" relaycommon "one-api/relay/common" - "one-api/setting/operation_setting" "strings" + "sync" "unicode/utf8" - - "github.com/pkoukk/tiktoken-go" ) // tokenEncoderMap won't grow after initialization -var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} -var defaultTokenEncoder *tiktoken.Tiktoken -var o200kTokenEncoder *tiktoken.Tiktoken +var defaultTokenEncoder tokenizer.Codec + +// tokenEncoderMap is used to store token encoders for different models +var tokenEncoderMap = make(map[string]tokenizer.Codec) + +// tokenEncoderMutex protects tokenEncoderMap for concurrent access +var tokenEncoderMutex sync.RWMutex func InitTokenEncoders() { common.SysLog("initializing token encoders") - cl100TokenEncoder, err := tiktoken.GetEncoding(tiktoken.MODEL_CL100K_BASE) - if err != nil { - common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error())) - } - defaultTokenEncoder = cl100TokenEncoder - o200kTokenEncoder, err = tiktoken.GetEncoding(tiktoken.MODEL_O200K_BASE) - if err != nil { - common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error())) - } - for model, _ := range operation_setting.GetDefaultModelRatioMap() { - if strings.HasPrefix(model, "gpt-3.5") { - tokenEncoderMap[model] = cl100TokenEncoder - } else if strings.HasPrefix(model, "gpt-4") { - if strings.HasPrefix(model, "gpt-4o") { - tokenEncoderMap[model] = o200kTokenEncoder - } else { - tokenEncoderMap[model] = defaultTokenEncoder - } - } else if strings.HasPrefix(model, "o") { - tokenEncoderMap[model] = o200kTokenEncoder - } else { - tokenEncoderMap[model] = defaultTokenEncoder - } - } + defaultTokenEncoder = codec.NewCl100kBase() common.SysLog("token encoders initialized") } -func getModelDefaultTokenEncoder(model string) *tiktoken.Tiktoken { - if strings.HasPrefix(model, "gpt-4o") || strings.HasPrefix(model, "chatgpt-4o") || strings.HasPrefix(model, "o1") { - return o200kTokenEncoder +func getTokenEncoder(model string) tokenizer.Codec { + // First, try to get the encoder from cache with read lock + tokenEncoderMutex.RLock() + if encoder, exists := tokenEncoderMap[model]; exists { + tokenEncoderMutex.RUnlock() + return encoder } - return defaultTokenEncoder + tokenEncoderMutex.RUnlock() + + // If not in cache, create new encoder with write lock + tokenEncoderMutex.Lock() + defer tokenEncoderMutex.Unlock() + + // Double-check if another goroutine already created the encoder + if encoder, exists := tokenEncoderMap[model]; exists { + return encoder + } + + // Create new encoder + modelCodec, err := tokenizer.ForModel(tokenizer.Model(model)) + if err != nil { + // Cache the default encoder for this model to avoid repeated failures + tokenEncoderMap[model] = defaultTokenEncoder + return defaultTokenEncoder + } + + // Cache the new encoder + tokenEncoderMap[model] = modelCodec + return modelCodec } -func getTokenEncoder(model string) *tiktoken.Tiktoken { - tokenEncoder, ok := tokenEncoderMap[model] - if ok && tokenEncoder != nil { - return tokenEncoder - } - // 如果ok(即model在tokenEncoderMap中),但是tokenEncoder为nil,说明可能是自定义模型 - if ok { - tokenEncoder, err := tiktoken.EncodingForModel(model) - if err != nil { - common.SysError(fmt.Sprintf("failed to get token encoder for model %s: %s, using encoder for gpt-3.5-turbo", model, err.Error())) - tokenEncoder = getModelDefaultTokenEncoder(model) - } - tokenEncoderMap[model] = tokenEncoder - return tokenEncoder - } - // 如果model不在tokenEncoderMap中,直接返回默认的tokenEncoder - return getModelDefaultTokenEncoder(model) -} - -func getTokenNum(tokenEncoder *tiktoken.Tiktoken, text string) int { +func getTokenNum(tokenEncoder tokenizer.Codec, text string) int { if text == "" { return 0 } - return len(tokenEncoder.Encode(text, nil, nil)) + ids, _, _ := tokenEncoder.Encode(text) + return len(ids) } func getImageToken(info *relaycommon.RelayInfo, imageUrl *dto.MessageImageUrl, model string, stream bool) (int, error) {