diff --git a/common/model-ratio.go b/common/model-ratio.go index ffeda83d..4b64c79f 100644 --- a/common/model-ratio.go +++ b/common/model-ratio.go @@ -233,7 +233,11 @@ var ( modelRatioMapMutex = sync.RWMutex{} ) -var CompletionRatio map[string]float64 = nil +var ( + CompletionRatio map[string]float64 = nil + CompletionRatioMutex = sync.RWMutex{} +) + var defaultCompletionRatio = map[string]float64{ "gpt-4-gizmo-*": 2, "gpt-4o-gizmo-*": 3, @@ -334,10 +338,17 @@ func GetDefaultModelRatioMap() map[string]float64 { return defaultModelRatio } -func CompletionRatio2JSONString() string { +func GetCompletionRatioMap() map[string]float64 { + CompletionRatioMutex.Lock() + defer CompletionRatioMutex.Unlock() if CompletionRatio == nil { CompletionRatio = defaultCompletionRatio } + return CompletionRatio +} + +func CompletionRatio2JSONString() string { + GetCompletionRatioMap() jsonBytes, err := json.Marshal(CompletionRatio) if err != nil { SysError("error marshalling completion ratio: " + err.Error()) @@ -345,12 +356,9 @@ func CompletionRatio2JSONString() string { return string(jsonBytes) } -func UpdateCompletionRatioByJSONString(jsonStr string) error { - CompletionRatio = make(map[string]float64) - return json.Unmarshal([]byte(jsonStr), &CompletionRatio) -} - func GetCompletionRatio(name string) float64 { + GetCompletionRatioMap() + if strings.Contains(name, "/") { if ratio, ok := CompletionRatio[name]; ok { return ratio @@ -476,10 +484,3 @@ func GetAudioCompletionRatio(name string) float64 { } return 2 } - -func GetCompletionRatioMap() map[string]float64 { - if CompletionRatio == nil { - CompletionRatio = defaultCompletionRatio - } - return CompletionRatio -}