From aa34c3035a354fa109913fbf4d42cd6bf6218d54 Mon Sep 17 00:00:00 2001 From: CaIon <1808837298@qq.com> Date: Mon, 7 Apr 2025 22:20:47 +0800 Subject: [PATCH] feat: Initialize model settings and improve concurrency control in operation settings --- main.go | 4 ++ setting/operation_setting/cache_ratio.go | 15 ++-- setting/operation_setting/model-ratio.go | 87 +++++++++++++++--------- 3 files changed, 65 insertions(+), 41 deletions(-) diff --git a/main.go b/main.go index 495057cf..eefd808f 100644 --- a/main.go +++ b/main.go @@ -12,6 +12,7 @@ import ( "one-api/model" "one-api/router" "one-api/service" + "one-api/setting/operation_setting" "os" "strconv" @@ -73,6 +74,9 @@ func main() { constant.InitEnv() // Initialize options model.InitOptionMap() + // Initialize model settings + operation_setting.InitModelSettings() + if common.RedisEnabled { // for compatibility with old versions common.MemoryCacheEnabled = true diff --git a/setting/operation_setting/cache_ratio.go b/setting/operation_setting/cache_ratio.go index 89196523..dd29eac2 100644 --- a/setting/operation_setting/cache_ratio.go +++ b/setting/operation_setting/cache_ratio.go @@ -56,17 +56,15 @@ var cacheRatioMapMutex sync.RWMutex // GetCacheRatioMap returns the cache ratio map func GetCacheRatioMap() map[string]float64 { - cacheRatioMapMutex.Lock() - defer cacheRatioMapMutex.Unlock() - if cacheRatioMap == nil { - cacheRatioMap = defaultCacheRatio - } + cacheRatioMapMutex.RLock() + defer cacheRatioMapMutex.RUnlock() return cacheRatioMap } // CacheRatio2JSONString converts the cache ratio map to a JSON string func CacheRatio2JSONString() string { - GetCacheRatioMap() + cacheRatioMapMutex.RLock() + defer cacheRatioMapMutex.RUnlock() jsonBytes, err := json.Marshal(cacheRatioMap) if err != nil { common.SysError("error marshalling cache ratio: " + err.Error()) @@ -84,10 +82,11 @@ func UpdateCacheRatioByJSONString(jsonStr string) error { // GetCacheRatio returns the cache ratio for a model func GetCacheRatio(name string) (float64, bool) { - GetCacheRatioMap() + cacheRatioMapMutex.RLock() + defer cacheRatioMapMutex.RUnlock() ratio, ok := cacheRatioMap[name] if !ok { - return 1, false // Default to 0.5 if not found + return 1, false // Default to 1 if not found } return ratio, true } diff --git a/setting/operation_setting/model-ratio.go b/setting/operation_setting/model-ratio.go index 1412f614..bbaa6d9c 100644 --- a/setting/operation_setting/model-ratio.go +++ b/setting/operation_setting/model-ratio.go @@ -245,17 +245,41 @@ var defaultCompletionRatio = map[string]float64{ "gpt-4-all": 2, } -func GetModelPriceMap() map[string]float64 { +// InitModelSettings initializes all model related settings maps +func InitModelSettings() { + // Initialize modelPriceMap modelPriceMapMutex.Lock() - defer modelPriceMapMutex.Unlock() - if modelPriceMap == nil { - modelPriceMap = defaultModelPrice - } + modelPriceMap = defaultModelPrice + modelPriceMapMutex.Unlock() + + // Initialize modelRatioMap + modelRatioMapMutex.Lock() + modelRatioMap = defaultModelRatio + modelRatioMapMutex.Unlock() + + // Initialize CompletionRatio + CompletionRatioMutex.Lock() + CompletionRatio = defaultCompletionRatio + CompletionRatioMutex.Unlock() + + // Initialize cacheRatioMap + cacheRatioMapMutex.Lock() + cacheRatioMap = defaultCacheRatio + cacheRatioMapMutex.Unlock() + + common.SysLog("model settings initialized") +} + +func GetModelPriceMap() map[string]float64 { + modelPriceMapMutex.RLock() + defer modelPriceMapMutex.RUnlock() return modelPriceMap } func ModelPrice2JSONString() string { - GetModelPriceMap() + modelPriceMapMutex.RLock() + defer modelPriceMapMutex.RUnlock() + jsonBytes, err := json.Marshal(modelPriceMap) if err != nil { common.SysError("error marshalling model price: " + err.Error()) @@ -272,7 +296,9 @@ func UpdateModelPriceByJSONString(jsonStr string) error { // GetModelPrice 返回模型的价格,如果模型不存在则返回-1,false func GetModelPrice(name string, printErr bool) (float64, bool) { - GetModelPriceMap() + modelPriceMapMutex.RLock() + defer modelPriceMapMutex.RUnlock() + if strings.HasPrefix(name, "gpt-4-gizmo") { name = "gpt-4-gizmo-*" } @@ -289,24 +315,6 @@ func GetModelPrice(name string, printErr bool) (float64, bool) { return price, true } -func GetModelRatioMap() map[string]float64 { - modelRatioMapMutex.Lock() - defer modelRatioMapMutex.Unlock() - if modelRatioMap == nil { - modelRatioMap = defaultModelRatio - } - return modelRatioMap -} - -func ModelRatio2JSONString() string { - GetModelRatioMap() - jsonBytes, err := json.Marshal(modelRatioMap) - if err != nil { - common.SysError("error marshalling model ratio: " + err.Error()) - } - return string(jsonBytes) -} - func UpdateModelRatioByJSONString(jsonStr string) error { modelRatioMapMutex.Lock() defer modelRatioMapMutex.Unlock() @@ -315,7 +323,9 @@ func UpdateModelRatioByJSONString(jsonStr string) error { } func GetModelRatio(name string) (float64, bool) { - GetModelRatioMap() + modelRatioMapMutex.RLock() + defer modelRatioMapMutex.RUnlock() + if strings.HasPrefix(name, "gpt-4-gizmo") { name = "gpt-4-gizmo-*" } @@ -339,16 +349,15 @@ func GetDefaultModelRatioMap() map[string]float64 { } func GetCompletionRatioMap() map[string]float64 { - CompletionRatioMutex.Lock() - defer CompletionRatioMutex.Unlock() - if CompletionRatio == nil { - CompletionRatio = defaultCompletionRatio - } + CompletionRatioMutex.RLock() + defer CompletionRatioMutex.RUnlock() return CompletionRatio } func CompletionRatio2JSONString() string { - GetCompletionRatioMap() + CompletionRatioMutex.RLock() + defer CompletionRatioMutex.RUnlock() + jsonBytes, err := json.Marshal(CompletionRatio) if err != nil { common.SysError("error marshalling completion ratio: " + err.Error()) @@ -364,7 +373,8 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error { } func GetCompletionRatio(name string) float64 { - GetCompletionRatioMap() + CompletionRatioMutex.RLock() + defer CompletionRatioMutex.RUnlock() if strings.Contains(name, "/") { if ratio, ok := CompletionRatio[name]; ok { @@ -511,3 +521,14 @@ func GetAudioCompletionRatio(name string) float64 { } return 2 } + +func ModelRatio2JSONString() string { + modelRatioMapMutex.RLock() + defer modelRatioMapMutex.RUnlock() + + jsonBytes, err := json.Marshal(modelRatioMap) + if err != nil { + common.SysError("error marshalling model ratio: " + err.Error()) + } + return string(jsonBytes) +}