feat: update o1 default token encoder

This commit is contained in:
CalciumIon
2024-12-27 15:03:10 +08:00
parent 62ae46b552
commit d2297d2723

View File

@@ -19,42 +19,40 @@ import (
// tokenEncoderMap won't grow after initialization // tokenEncoderMap won't grow after initialization
var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} var tokenEncoderMap = map[string]*tiktoken.Tiktoken{}
var defaultTokenEncoder *tiktoken.Tiktoken var defaultTokenEncoder *tiktoken.Tiktoken
var cl200kTokenEncoder *tiktoken.Tiktoken var o200kTokenEncoder *tiktoken.Tiktoken
func InitTokenEncoders() { func InitTokenEncoders() {
common.SysLog("initializing token encoders") common.SysLog("initializing token encoders")
gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo") cl100TokenEncoder, err := tiktoken.GetEncoding(tiktoken.MODEL_CL100K_BASE)
if err != nil { if err != nil {
common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error())) common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error()))
} }
defaultTokenEncoder = gpt35TokenEncoder defaultTokenEncoder = cl100TokenEncoder
gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4") o200kTokenEncoder, err = tiktoken.GetEncoding(tiktoken.MODEL_O200K_BASE)
if err != nil {
common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error()))
}
cl200kTokenEncoder, err = tiktoken.EncodingForModel("gpt-4o")
if err != nil { if err != nil {
common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error())) common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error()))
} }
for model, _ := range common.GetDefaultModelRatioMap() { for model, _ := range common.GetDefaultModelRatioMap() {
if strings.HasPrefix(model, "gpt-3.5") { if strings.HasPrefix(model, "gpt-3.5") {
tokenEncoderMap[model] = gpt35TokenEncoder tokenEncoderMap[model] = cl100TokenEncoder
} else if strings.HasPrefix(model, "gpt-4") { } else if strings.HasPrefix(model, "gpt-4") {
if strings.HasPrefix(model, "gpt-4o") { if strings.HasPrefix(model, "gpt-4o") {
tokenEncoderMap[model] = cl200kTokenEncoder tokenEncoderMap[model] = o200kTokenEncoder
} else { } else {
tokenEncoderMap[model] = gpt4TokenEncoder tokenEncoderMap[model] = defaultTokenEncoder
} }
} else if strings.HasPrefix(model, "o1") {
tokenEncoderMap[model] = o200kTokenEncoder
} else { } else {
tokenEncoderMap[model] = nil tokenEncoderMap[model] = defaultTokenEncoder
} }
} }
common.SysLog("token encoders initialized") common.SysLog("token encoders initialized")
} }
func getModelDefaultTokenEncoder(model string) *tiktoken.Tiktoken { func getModelDefaultTokenEncoder(model string) *tiktoken.Tiktoken {
if strings.HasPrefix(model, "gpt-4o") || strings.HasPrefix(model, "chatgpt-4o") { if strings.HasPrefix(model, "gpt-4o") || strings.HasPrefix(model, "chatgpt-4o") || strings.HasPrefix(model, "o1") {
return cl200kTokenEncoder return o200kTokenEncoder
} }
return defaultTokenEncoder return defaultTokenEncoder
} }