feat: Initialize model settings and improve concurrency control in operation settings

This commit is contained in:
CaIon
2025-04-07 22:20:47 +08:00
parent fb9f595044
commit aa34c3035a
3 changed files with 65 additions and 41 deletions

View File

@@ -12,6 +12,7 @@ import (
"one-api/model" "one-api/model"
"one-api/router" "one-api/router"
"one-api/service" "one-api/service"
"one-api/setting/operation_setting"
"os" "os"
"strconv" "strconv"
@@ -73,6 +74,9 @@ func main() {
constant.InitEnv() constant.InitEnv()
// Initialize options // Initialize options
model.InitOptionMap() model.InitOptionMap()
// Initialize model settings
operation_setting.InitModelSettings()
if common.RedisEnabled { if common.RedisEnabled {
// for compatibility with old versions // for compatibility with old versions
common.MemoryCacheEnabled = true common.MemoryCacheEnabled = true

View File

@@ -56,17 +56,15 @@ var cacheRatioMapMutex sync.RWMutex
// GetCacheRatioMap returns the cache ratio map // GetCacheRatioMap returns the cache ratio map
func GetCacheRatioMap() map[string]float64 { func GetCacheRatioMap() map[string]float64 {
cacheRatioMapMutex.Lock() cacheRatioMapMutex.RLock()
defer cacheRatioMapMutex.Unlock() defer cacheRatioMapMutex.RUnlock()
if cacheRatioMap == nil {
cacheRatioMap = defaultCacheRatio
}
return cacheRatioMap return cacheRatioMap
} }
// CacheRatio2JSONString converts the cache ratio map to a JSON string // CacheRatio2JSONString converts the cache ratio map to a JSON string
func CacheRatio2JSONString() string { func CacheRatio2JSONString() string {
GetCacheRatioMap() cacheRatioMapMutex.RLock()
defer cacheRatioMapMutex.RUnlock()
jsonBytes, err := json.Marshal(cacheRatioMap) jsonBytes, err := json.Marshal(cacheRatioMap)
if err != nil { if err != nil {
common.SysError("error marshalling cache ratio: " + err.Error()) common.SysError("error marshalling cache ratio: " + err.Error())
@@ -84,10 +82,11 @@ func UpdateCacheRatioByJSONString(jsonStr string) error {
// GetCacheRatio returns the cache ratio for a model // GetCacheRatio returns the cache ratio for a model
func GetCacheRatio(name string) (float64, bool) { func GetCacheRatio(name string) (float64, bool) {
GetCacheRatioMap() cacheRatioMapMutex.RLock()
defer cacheRatioMapMutex.RUnlock()
ratio, ok := cacheRatioMap[name] ratio, ok := cacheRatioMap[name]
if !ok { if !ok {
return 1, false // Default to 0.5 if not found return 1, false // Default to 1 if not found
} }
return ratio, true return ratio, true
} }

View File

@@ -245,17 +245,41 @@ var defaultCompletionRatio = map[string]float64{
"gpt-4-all": 2, "gpt-4-all": 2,
} }
func GetModelPriceMap() map[string]float64 { // InitModelSettings initializes all model related settings maps
func InitModelSettings() {
// Initialize modelPriceMap
modelPriceMapMutex.Lock() modelPriceMapMutex.Lock()
defer modelPriceMapMutex.Unlock() modelPriceMap = defaultModelPrice
if modelPriceMap == nil { modelPriceMapMutex.Unlock()
modelPriceMap = defaultModelPrice
} // Initialize modelRatioMap
modelRatioMapMutex.Lock()
modelRatioMap = defaultModelRatio
modelRatioMapMutex.Unlock()
// Initialize CompletionRatio
CompletionRatioMutex.Lock()
CompletionRatio = defaultCompletionRatio
CompletionRatioMutex.Unlock()
// Initialize cacheRatioMap
cacheRatioMapMutex.Lock()
cacheRatioMap = defaultCacheRatio
cacheRatioMapMutex.Unlock()
common.SysLog("model settings initialized")
}
func GetModelPriceMap() map[string]float64 {
modelPriceMapMutex.RLock()
defer modelPriceMapMutex.RUnlock()
return modelPriceMap return modelPriceMap
} }
func ModelPrice2JSONString() string { func ModelPrice2JSONString() string {
GetModelPriceMap() modelPriceMapMutex.RLock()
defer modelPriceMapMutex.RUnlock()
jsonBytes, err := json.Marshal(modelPriceMap) jsonBytes, err := json.Marshal(modelPriceMap)
if err != nil { if err != nil {
common.SysError("error marshalling model price: " + err.Error()) common.SysError("error marshalling model price: " + err.Error())
@@ -272,7 +296,9 @@ func UpdateModelPriceByJSONString(jsonStr string) error {
// GetModelPrice 返回模型的价格,如果模型不存在则返回-1false // GetModelPrice 返回模型的价格,如果模型不存在则返回-1false
func GetModelPrice(name string, printErr bool) (float64, bool) { func GetModelPrice(name string, printErr bool) (float64, bool) {
GetModelPriceMap() modelPriceMapMutex.RLock()
defer modelPriceMapMutex.RUnlock()
if strings.HasPrefix(name, "gpt-4-gizmo") { if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*" name = "gpt-4-gizmo-*"
} }
@@ -289,24 +315,6 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
return price, true return price, true
} }
func GetModelRatioMap() map[string]float64 {
modelRatioMapMutex.Lock()
defer modelRatioMapMutex.Unlock()
if modelRatioMap == nil {
modelRatioMap = defaultModelRatio
}
return modelRatioMap
}
func ModelRatio2JSONString() string {
GetModelRatioMap()
jsonBytes, err := json.Marshal(modelRatioMap)
if err != nil {
common.SysError("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}
func UpdateModelRatioByJSONString(jsonStr string) error { func UpdateModelRatioByJSONString(jsonStr string) error {
modelRatioMapMutex.Lock() modelRatioMapMutex.Lock()
defer modelRatioMapMutex.Unlock() defer modelRatioMapMutex.Unlock()
@@ -315,7 +323,9 @@ func UpdateModelRatioByJSONString(jsonStr string) error {
} }
func GetModelRatio(name string) (float64, bool) { func GetModelRatio(name string) (float64, bool) {
GetModelRatioMap() modelRatioMapMutex.RLock()
defer modelRatioMapMutex.RUnlock()
if strings.HasPrefix(name, "gpt-4-gizmo") { if strings.HasPrefix(name, "gpt-4-gizmo") {
name = "gpt-4-gizmo-*" name = "gpt-4-gizmo-*"
} }
@@ -339,16 +349,15 @@ func GetDefaultModelRatioMap() map[string]float64 {
} }
func GetCompletionRatioMap() map[string]float64 { func GetCompletionRatioMap() map[string]float64 {
CompletionRatioMutex.Lock() CompletionRatioMutex.RLock()
defer CompletionRatioMutex.Unlock() defer CompletionRatioMutex.RUnlock()
if CompletionRatio == nil {
CompletionRatio = defaultCompletionRatio
}
return CompletionRatio return CompletionRatio
} }
func CompletionRatio2JSONString() string { func CompletionRatio2JSONString() string {
GetCompletionRatioMap() CompletionRatioMutex.RLock()
defer CompletionRatioMutex.RUnlock()
jsonBytes, err := json.Marshal(CompletionRatio) jsonBytes, err := json.Marshal(CompletionRatio)
if err != nil { if err != nil {
common.SysError("error marshalling completion ratio: " + err.Error()) common.SysError("error marshalling completion ratio: " + err.Error())
@@ -364,7 +373,8 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error {
} }
func GetCompletionRatio(name string) float64 { func GetCompletionRatio(name string) float64 {
GetCompletionRatioMap() CompletionRatioMutex.RLock()
defer CompletionRatioMutex.RUnlock()
if strings.Contains(name, "/") { if strings.Contains(name, "/") {
if ratio, ok := CompletionRatio[name]; ok { if ratio, ok := CompletionRatio[name]; ok {
@@ -511,3 +521,14 @@ func GetAudioCompletionRatio(name string) float64 {
} }
return 2 return 2
} }
func ModelRatio2JSONString() string {
modelRatioMapMutex.RLock()
defer modelRatioMapMutex.RUnlock()
jsonBytes, err := json.Marshal(modelRatioMap)
if err != nil {
common.SysError("error marshalling model ratio: " + err.Error())
}
return string(jsonBytes)
}