diff --git a/service/token_counter.go b/service/token_counter.go index c93d2776..e82da5cc 100644 --- a/service/token_counter.go +++ b/service/token_counter.go @@ -19,42 +19,40 @@ import ( // tokenEncoderMap won't grow after initialization var tokenEncoderMap = map[string]*tiktoken.Tiktoken{} var defaultTokenEncoder *tiktoken.Tiktoken -var cl200kTokenEncoder *tiktoken.Tiktoken +var o200kTokenEncoder *tiktoken.Tiktoken func InitTokenEncoders() { common.SysLog("initializing token encoders") - gpt35TokenEncoder, err := tiktoken.EncodingForModel("gpt-3.5-turbo") + cl100TokenEncoder, err := tiktoken.GetEncoding(tiktoken.MODEL_CL100K_BASE) if err != nil { common.FatalLog(fmt.Sprintf("failed to get gpt-3.5-turbo token encoder: %s", err.Error())) } - defaultTokenEncoder = gpt35TokenEncoder - gpt4TokenEncoder, err := tiktoken.EncodingForModel("gpt-4") - if err != nil { - common.FatalLog(fmt.Sprintf("failed to get gpt-4 token encoder: %s", err.Error())) - } - cl200kTokenEncoder, err = tiktoken.EncodingForModel("gpt-4o") + defaultTokenEncoder = cl100TokenEncoder + o200kTokenEncoder, err = tiktoken.GetEncoding(tiktoken.MODEL_O200K_BASE) if err != nil { common.FatalLog(fmt.Sprintf("failed to get gpt-4o token encoder: %s", err.Error())) } for model, _ := range common.GetDefaultModelRatioMap() { if strings.HasPrefix(model, "gpt-3.5") { - tokenEncoderMap[model] = gpt35TokenEncoder + tokenEncoderMap[model] = cl100TokenEncoder } else if strings.HasPrefix(model, "gpt-4") { if strings.HasPrefix(model, "gpt-4o") { - tokenEncoderMap[model] = cl200kTokenEncoder + tokenEncoderMap[model] = o200kTokenEncoder } else { - tokenEncoderMap[model] = gpt4TokenEncoder + tokenEncoderMap[model] = defaultTokenEncoder } + } else if strings.HasPrefix(model, "o1") { + tokenEncoderMap[model] = o200kTokenEncoder } else { - tokenEncoderMap[model] = nil + tokenEncoderMap[model] = defaultTokenEncoder } } common.SysLog("token encoders initialized") } func getModelDefaultTokenEncoder(model string) *tiktoken.Tiktoken { - if strings.HasPrefix(model, "gpt-4o") || strings.HasPrefix(model, "chatgpt-4o") { - return cl200kTokenEncoder + if strings.HasPrefix(model, "gpt-4o") || strings.HasPrefix(model, "chatgpt-4o") || strings.HasPrefix(model, "o1") { + return o200kTokenEncoder } return defaultTokenEncoder }