diff --git a/common/constants.go b/common/constants.go index 1bca1418..7454f57a 100644 --- a/common/constants.go +++ b/common/constants.go @@ -103,14 +103,14 @@ var IsMasterNode = os.Getenv("NODE_TYPE") != "slave" var requestInterval, _ = strconv.Atoi(os.Getenv("POLLING_INTERVAL")) var RequestInterval = time.Duration(requestInterval) * time.Second -var SyncFrequency = GetOrDefault("SYNC_FREQUENCY", 60) // unit is second +var SyncFrequency = GetEnvOrDefault("SYNC_FREQUENCY", 60) // unit is second var BatchUpdateEnabled = false -var BatchUpdateInterval = GetOrDefault("BATCH_UPDATE_INTERVAL", 5) +var BatchUpdateInterval = GetEnvOrDefault("BATCH_UPDATE_INTERVAL", 5) -var RelayTimeout = GetOrDefault("RELAY_TIMEOUT", 0) // unit is second +var RelayTimeout = GetEnvOrDefault("RELAY_TIMEOUT", 0) // unit is second -var GeminiSafetySetting = GetOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE") +var GeminiSafetySetting = GetEnvOrDefaultString("GEMINI_SAFETY_SETTING", "BLOCK_NONE") const ( RequestIdKey = "X-Oneapi-Request-Id" @@ -133,10 +133,10 @@ var ( // All duration's unit is seconds // Shouldn't larger then RateLimitKeyExpirationDuration var ( - GlobalApiRateLimitNum = GetOrDefault("GLOBAL_API_RATE_LIMIT", 180) + GlobalApiRateLimitNum = GetEnvOrDefault("GLOBAL_API_RATE_LIMIT", 180) GlobalApiRateLimitDuration int64 = 3 * 60 - GlobalWebRateLimitNum = GetOrDefault("GLOBAL_WEB_RATE_LIMIT", 60) + GlobalWebRateLimitNum = GetEnvOrDefault("GLOBAL_WEB_RATE_LIMIT", 60) GlobalWebRateLimitDuration int64 = 3 * 60 UploadRateLimitNum = 10 diff --git a/common/env.go b/common/env.go new file mode 100644 index 00000000..856fa61e --- /dev/null +++ b/common/env.go @@ -0,0 +1,26 @@ +package common + +import ( + "fmt" + "os" + "strconv" +) + +func GetEnvOrDefault(env string, defaultValue int) int { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + num, err := strconv.Atoi(os.Getenv(env)) + if err != nil { + SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue)) + return defaultValue + } + return num +} + +func GetEnvOrDefaultString(env string, defaultValue string) string { + if env == "" || os.Getenv(env) == "" { + return defaultValue + } + return os.Getenv(env) +} diff --git a/common/group-ratio.go b/common/group-ratio.go index 1ec73c78..416ba037 100644 --- a/common/group-ratio.go +++ b/common/group-ratio.go @@ -1,6 +1,8 @@ package common -import "encoding/json" +import ( + "encoding/json" +) var GroupRatio = map[string]float64{ "default": 1, diff --git a/common/topup-ratio.go b/common/topup-ratio.go index e045d9b7..8f03395d 100644 --- a/common/topup-ratio.go +++ b/common/topup-ratio.go @@ -1,6 +1,8 @@ package common -import "encoding/json" +import ( + "encoding/json" +) var TopupGroupRatio = map[string]float64{ "default": 1, diff --git a/common/utils.go b/common/utils.go index 6c89d410..3e047c44 100644 --- a/common/utils.go +++ b/common/utils.go @@ -8,7 +8,6 @@ import ( "log" "math/rand" "net" - "os" "os/exec" "runtime" "strconv" @@ -191,25 +190,6 @@ func Max(a int, b int) int { } } -func GetOrDefault(env string, defaultValue int) int { - if env == "" || os.Getenv(env) == "" { - return defaultValue - } - num, err := strconv.Atoi(os.Getenv(env)) - if err != nil { - SysError(fmt.Sprintf("failed to parse %s: %s, using default value: %d", env, err.Error(), defaultValue)) - return defaultValue - } - return num -} - -func GetOrDefaultString(env string, defaultValue string) string { - if env == "" || os.Getenv(env) == "" { - return defaultValue - } - return os.Getenv(env) -} - func MessageWithRequestId(message string, id string) string { return fmt.Sprintf("%s (request id: %s)", message, id) } diff --git a/constant/env.go b/constant/env.go new file mode 100644 index 00000000..6355a6b6 --- /dev/null +++ b/constant/env.go @@ -0,0 +1,7 @@ +package constant + +import ( + "one-api/common" +) + +var StreamingTimeout = common.GetEnvOrDefault("STREAMING_TIMEOUT", 30) diff --git a/constant/system-setting.go b/constant/system-setting.go index de20fef8..b2976e49 100644 --- a/constant/system-setting.go +++ b/constant/system-setting.go @@ -1,13 +1,9 @@ package constant -import "one-api/common" - var ServerAddress = "http://localhost:3000" var WorkerUrl = "" var WorkerValidKey = "" -var StreamingTimeout = common.GetOrDefault("STREAMING_TIMEOUT", 30) - func EnableWorker() bool { return WorkerUrl != "" } diff --git a/controller/topup.go b/controller/topup.go index 90c3f779..87c68c32 100644 --- a/controller/topup.go +++ b/controller/topup.go @@ -5,11 +5,10 @@ import ( "github.com/Calcium-Ion/go-epay/epay" "github.com/gin-gonic/gin" "github.com/samber/lo" - "one-api/constant" - "log" "net/url" "one-api/common" + "one-api/constant" "one-api/model" "one-api/service" "strconv" diff --git a/dto/pricing.go b/dto/pricing.go index b0497491..ee77c098 100644 --- a/dto/pricing.go +++ b/dto/pricing.go @@ -24,14 +24,3 @@ type OpenAIModels struct { Root string `json:"root"` Parent *string `json:"parent"` } - -type ModelPricing struct { - Available bool `json:"available"` - ModelName string `json:"model_name"` - QuotaType int `json:"quota_type"` - ModelRatio float64 `json:"model_ratio"` - ModelPrice float64 `json:"model_price"` - OwnerBy string `json:"owner_by"` - CompletionRatio float64 `json:"completion_ratio"` - EnableGroup []string `json:"enable_group,omitempty"` -} diff --git a/model/main.go b/model/main.go index 710ea059..a70f21bd 100644 --- a/model/main.go +++ b/model/main.go @@ -86,9 +86,9 @@ func InitDB() (err error) { if err != nil { return err } - sqlDB.SetMaxIdleConns(common.GetOrDefault("SQL_MAX_IDLE_CONNS", 100)) - sqlDB.SetMaxOpenConns(common.GetOrDefault("SQL_MAX_OPEN_CONNS", 1000)) - sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetOrDefault("SQL_MAX_LIFETIME", 60))) + sqlDB.SetMaxIdleConns(common.GetEnvOrDefault("SQL_MAX_IDLE_CONNS", 100)) + sqlDB.SetMaxOpenConns(common.GetEnvOrDefault("SQL_MAX_OPEN_CONNS", 1000)) + sqlDB.SetConnMaxLifetime(time.Second * time.Duration(common.GetEnvOrDefault("SQL_MAX_LIFETIME", 60))) if !common.IsMasterNode { return nil diff --git a/model/pricing.go b/model/pricing.go index 90d8bc7e..7384a2fa 100644 --- a/model/pricing.go +++ b/model/pricing.go @@ -2,18 +2,28 @@ package model import ( "one-api/common" - "one-api/dto" "sync" "time" ) +type Pricing struct { + Available bool `json:"available"` + ModelName string `json:"model_name"` + QuotaType int `json:"quota_type"` + ModelRatio float64 `json:"model_ratio"` + ModelPrice float64 `json:"model_price"` + OwnerBy string `json:"owner_by"` + CompletionRatio float64 `json:"completion_ratio"` + EnableGroup []string `json:"enable_group,omitempty"` +} + var ( - pricingMap []dto.ModelPricing + pricingMap []Pricing lastGetPricingTime time.Time updatePricingLock sync.Mutex ) -func GetPricing(group string) []dto.ModelPricing { +func GetPricing(group string) []Pricing { updatePricingLock.Lock() defer updatePricingLock.Unlock() @@ -21,7 +31,7 @@ func GetPricing(group string) []dto.ModelPricing { updatePricing() } if group != "" { - userPricingMap := make([]dto.ModelPricing, 0) + userPricingMap := make([]Pricing, 0) models := GetGroupModels(group) for _, pricing := range pricingMap { if !common.StringsContains(models, pricing.ModelName) { @@ -42,9 +52,9 @@ func updatePricing() { allModels[model] = i } - pricingMap = make([]dto.ModelPricing, 0) + pricingMap = make([]Pricing, 0) for model, _ := range allModels { - pricing := dto.ModelPricing{ + pricing := Pricing{ Available: true, ModelName: model, } diff --git a/service/sensitive.go b/service/sensitive.go index a9b51983..17da69aa 100644 --- a/service/sensitive.go +++ b/service/sensitive.go @@ -3,7 +3,6 @@ package service import ( "errors" "fmt" - "one-api/common" "one-api/constant" "one-api/dto" "strings" @@ -62,7 +61,7 @@ func SensitiveWordContains(text string) (bool, []string) { } checkText := strings.ToLower(text) // 构建一个AC自动机 - m := common.InitAc() + m := InitAc() hits := m.MultiPatternSearch([]rune(checkText), false) if len(hits) > 0 { words := make([]string, 0) @@ -80,7 +79,7 @@ func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string, return false, nil, text } checkText := strings.ToLower(text) - m := common.InitAc() + m := InitAc() hits := m.MultiPatternSearch([]rune(checkText), returnImmediately) if len(hits) > 0 { words := make([]string, 0) diff --git a/common/str.go b/service/str.go similarity index 99% rename from common/str.go rename to service/str.go index bab252c6..a2152bfe 100644 --- a/common/str.go +++ b/service/str.go @@ -1,4 +1,4 @@ -package common +package service import ( "bytes"