feat: Initialize model settings and improve concurrency control in operation settings
This commit is contained in:
4
main.go
4
main.go
@@ -12,6 +12,7 @@ import (
|
|||||||
"one-api/model"
|
"one-api/model"
|
||||||
"one-api/router"
|
"one-api/router"
|
||||||
"one-api/service"
|
"one-api/service"
|
||||||
|
"one-api/setting/operation_setting"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
@@ -73,6 +74,9 @@ func main() {
|
|||||||
constant.InitEnv()
|
constant.InitEnv()
|
||||||
// Initialize options
|
// Initialize options
|
||||||
model.InitOptionMap()
|
model.InitOptionMap()
|
||||||
|
// Initialize model settings
|
||||||
|
operation_setting.InitModelSettings()
|
||||||
|
|
||||||
if common.RedisEnabled {
|
if common.RedisEnabled {
|
||||||
// for compatibility with old versions
|
// for compatibility with old versions
|
||||||
common.MemoryCacheEnabled = true
|
common.MemoryCacheEnabled = true
|
||||||
|
|||||||
@@ -56,17 +56,15 @@ var cacheRatioMapMutex sync.RWMutex
|
|||||||
|
|
||||||
// GetCacheRatioMap returns the cache ratio map
|
// GetCacheRatioMap returns the cache ratio map
|
||||||
func GetCacheRatioMap() map[string]float64 {
|
func GetCacheRatioMap() map[string]float64 {
|
||||||
cacheRatioMapMutex.Lock()
|
cacheRatioMapMutex.RLock()
|
||||||
defer cacheRatioMapMutex.Unlock()
|
defer cacheRatioMapMutex.RUnlock()
|
||||||
if cacheRatioMap == nil {
|
|
||||||
cacheRatioMap = defaultCacheRatio
|
|
||||||
}
|
|
||||||
return cacheRatioMap
|
return cacheRatioMap
|
||||||
}
|
}
|
||||||
|
|
||||||
// CacheRatio2JSONString converts the cache ratio map to a JSON string
|
// CacheRatio2JSONString converts the cache ratio map to a JSON string
|
||||||
func CacheRatio2JSONString() string {
|
func CacheRatio2JSONString() string {
|
||||||
GetCacheRatioMap()
|
cacheRatioMapMutex.RLock()
|
||||||
|
defer cacheRatioMapMutex.RUnlock()
|
||||||
jsonBytes, err := json.Marshal(cacheRatioMap)
|
jsonBytes, err := json.Marshal(cacheRatioMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error marshalling cache ratio: " + err.Error())
|
common.SysError("error marshalling cache ratio: " + err.Error())
|
||||||
@@ -84,10 +82,11 @@ func UpdateCacheRatioByJSONString(jsonStr string) error {
|
|||||||
|
|
||||||
// GetCacheRatio returns the cache ratio for a model
|
// GetCacheRatio returns the cache ratio for a model
|
||||||
func GetCacheRatio(name string) (float64, bool) {
|
func GetCacheRatio(name string) (float64, bool) {
|
||||||
GetCacheRatioMap()
|
cacheRatioMapMutex.RLock()
|
||||||
|
defer cacheRatioMapMutex.RUnlock()
|
||||||
ratio, ok := cacheRatioMap[name]
|
ratio, ok := cacheRatioMap[name]
|
||||||
if !ok {
|
if !ok {
|
||||||
return 1, false // Default to 0.5 if not found
|
return 1, false // Default to 1 if not found
|
||||||
}
|
}
|
||||||
return ratio, true
|
return ratio, true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -245,17 +245,41 @@ var defaultCompletionRatio = map[string]float64{
|
|||||||
"gpt-4-all": 2,
|
"gpt-4-all": 2,
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetModelPriceMap() map[string]float64 {
|
// InitModelSettings initializes all model related settings maps
|
||||||
|
func InitModelSettings() {
|
||||||
|
// Initialize modelPriceMap
|
||||||
modelPriceMapMutex.Lock()
|
modelPriceMapMutex.Lock()
|
||||||
defer modelPriceMapMutex.Unlock()
|
modelPriceMap = defaultModelPrice
|
||||||
if modelPriceMap == nil {
|
modelPriceMapMutex.Unlock()
|
||||||
modelPriceMap = defaultModelPrice
|
|
||||||
}
|
// Initialize modelRatioMap
|
||||||
|
modelRatioMapMutex.Lock()
|
||||||
|
modelRatioMap = defaultModelRatio
|
||||||
|
modelRatioMapMutex.Unlock()
|
||||||
|
|
||||||
|
// Initialize CompletionRatio
|
||||||
|
CompletionRatioMutex.Lock()
|
||||||
|
CompletionRatio = defaultCompletionRatio
|
||||||
|
CompletionRatioMutex.Unlock()
|
||||||
|
|
||||||
|
// Initialize cacheRatioMap
|
||||||
|
cacheRatioMapMutex.Lock()
|
||||||
|
cacheRatioMap = defaultCacheRatio
|
||||||
|
cacheRatioMapMutex.Unlock()
|
||||||
|
|
||||||
|
common.SysLog("model settings initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetModelPriceMap() map[string]float64 {
|
||||||
|
modelPriceMapMutex.RLock()
|
||||||
|
defer modelPriceMapMutex.RUnlock()
|
||||||
return modelPriceMap
|
return modelPriceMap
|
||||||
}
|
}
|
||||||
|
|
||||||
func ModelPrice2JSONString() string {
|
func ModelPrice2JSONString() string {
|
||||||
GetModelPriceMap()
|
modelPriceMapMutex.RLock()
|
||||||
|
defer modelPriceMapMutex.RUnlock()
|
||||||
|
|
||||||
jsonBytes, err := json.Marshal(modelPriceMap)
|
jsonBytes, err := json.Marshal(modelPriceMap)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error marshalling model price: " + err.Error())
|
common.SysError("error marshalling model price: " + err.Error())
|
||||||
@@ -272,7 +296,9 @@ func UpdateModelPriceByJSONString(jsonStr string) error {
|
|||||||
|
|
||||||
// GetModelPrice 返回模型的价格,如果模型不存在则返回-1,false
|
// GetModelPrice 返回模型的价格,如果模型不存在则返回-1,false
|
||||||
func GetModelPrice(name string, printErr bool) (float64, bool) {
|
func GetModelPrice(name string, printErr bool) (float64, bool) {
|
||||||
GetModelPriceMap()
|
modelPriceMapMutex.RLock()
|
||||||
|
defer modelPriceMapMutex.RUnlock()
|
||||||
|
|
||||||
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
||||||
name = "gpt-4-gizmo-*"
|
name = "gpt-4-gizmo-*"
|
||||||
}
|
}
|
||||||
@@ -289,24 +315,6 @@ func GetModelPrice(name string, printErr bool) (float64, bool) {
|
|||||||
return price, true
|
return price, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetModelRatioMap() map[string]float64 {
|
|
||||||
modelRatioMapMutex.Lock()
|
|
||||||
defer modelRatioMapMutex.Unlock()
|
|
||||||
if modelRatioMap == nil {
|
|
||||||
modelRatioMap = defaultModelRatio
|
|
||||||
}
|
|
||||||
return modelRatioMap
|
|
||||||
}
|
|
||||||
|
|
||||||
func ModelRatio2JSONString() string {
|
|
||||||
GetModelRatioMap()
|
|
||||||
jsonBytes, err := json.Marshal(modelRatioMap)
|
|
||||||
if err != nil {
|
|
||||||
common.SysError("error marshalling model ratio: " + err.Error())
|
|
||||||
}
|
|
||||||
return string(jsonBytes)
|
|
||||||
}
|
|
||||||
|
|
||||||
func UpdateModelRatioByJSONString(jsonStr string) error {
|
func UpdateModelRatioByJSONString(jsonStr string) error {
|
||||||
modelRatioMapMutex.Lock()
|
modelRatioMapMutex.Lock()
|
||||||
defer modelRatioMapMutex.Unlock()
|
defer modelRatioMapMutex.Unlock()
|
||||||
@@ -315,7 +323,9 @@ func UpdateModelRatioByJSONString(jsonStr string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetModelRatio(name string) (float64, bool) {
|
func GetModelRatio(name string) (float64, bool) {
|
||||||
GetModelRatioMap()
|
modelRatioMapMutex.RLock()
|
||||||
|
defer modelRatioMapMutex.RUnlock()
|
||||||
|
|
||||||
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
if strings.HasPrefix(name, "gpt-4-gizmo") {
|
||||||
name = "gpt-4-gizmo-*"
|
name = "gpt-4-gizmo-*"
|
||||||
}
|
}
|
||||||
@@ -339,16 +349,15 @@ func GetDefaultModelRatioMap() map[string]float64 {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetCompletionRatioMap() map[string]float64 {
|
func GetCompletionRatioMap() map[string]float64 {
|
||||||
CompletionRatioMutex.Lock()
|
CompletionRatioMutex.RLock()
|
||||||
defer CompletionRatioMutex.Unlock()
|
defer CompletionRatioMutex.RUnlock()
|
||||||
if CompletionRatio == nil {
|
|
||||||
CompletionRatio = defaultCompletionRatio
|
|
||||||
}
|
|
||||||
return CompletionRatio
|
return CompletionRatio
|
||||||
}
|
}
|
||||||
|
|
||||||
func CompletionRatio2JSONString() string {
|
func CompletionRatio2JSONString() string {
|
||||||
GetCompletionRatioMap()
|
CompletionRatioMutex.RLock()
|
||||||
|
defer CompletionRatioMutex.RUnlock()
|
||||||
|
|
||||||
jsonBytes, err := json.Marshal(CompletionRatio)
|
jsonBytes, err := json.Marshal(CompletionRatio)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
common.SysError("error marshalling completion ratio: " + err.Error())
|
common.SysError("error marshalling completion ratio: " + err.Error())
|
||||||
@@ -364,7 +373,8 @@ func UpdateCompletionRatioByJSONString(jsonStr string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func GetCompletionRatio(name string) float64 {
|
func GetCompletionRatio(name string) float64 {
|
||||||
GetCompletionRatioMap()
|
CompletionRatioMutex.RLock()
|
||||||
|
defer CompletionRatioMutex.RUnlock()
|
||||||
|
|
||||||
if strings.Contains(name, "/") {
|
if strings.Contains(name, "/") {
|
||||||
if ratio, ok := CompletionRatio[name]; ok {
|
if ratio, ok := CompletionRatio[name]; ok {
|
||||||
@@ -511,3 +521,14 @@ func GetAudioCompletionRatio(name string) float64 {
|
|||||||
}
|
}
|
||||||
return 2
|
return 2
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func ModelRatio2JSONString() string {
|
||||||
|
modelRatioMapMutex.RLock()
|
||||||
|
defer modelRatioMapMutex.RUnlock()
|
||||||
|
|
||||||
|
jsonBytes, err := json.Marshal(modelRatioMap)
|
||||||
|
if err != nil {
|
||||||
|
common.SysError("error marshalling model ratio: " + err.Error())
|
||||||
|
}
|
||||||
|
return string(jsonBytes)
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user