feat: Introduce settings package and refactor constants

- Added a new `setting` package to replace the `constant` package for configuration management, improving code organization and clarity.
- Moved various configuration variables such as `ServerAddress`, `PayAddress`, and `SensitiveWords` to the new `setting` package.
- Updated references throughout the codebase to use the new `setting` package, ensuring consistent access to configuration values.
- Introduced new files for managing chat settings and midjourney settings, enhancing modularity and maintainability of the code.
This commit is contained in:
CalciumIon
2024-12-22 17:24:29 +08:00
parent c4e256e69b
commit a7e1d17c3e
21 changed files with 99 additions and 94 deletions

1
constant/context_key.go Normal file
View File

@@ -0,0 +1 @@
package constant

View File

@@ -1,11 +1,5 @@
package constant package constant
var MjNotifyEnabled = false
var MjAccountFilterEnabled = false
var MjModeClearEnabled = false
var MjForwardUrlEnabled = true
var MjActionCheckSuccessEnabled = true
const ( const (
MjErrorUnknown = 5 MjErrorUnknown = 5
MjRequestError = 4 MjRequestError = 4

View File

@@ -10,10 +10,10 @@ import (
"log" "log"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/model" "one-api/model"
"one-api/service" "one-api/service"
"one-api/setting"
"strconv" "strconv"
"time" "time"
) )
@@ -231,9 +231,9 @@ func GetAllMidjourney(c *gin.Context) {
if logs == nil { if logs == nil {
logs = make([]*model.Midjourney, 0) logs = make([]*model.Midjourney, 0)
} }
if constant.MjForwardUrlEnabled { if setting.MjForwardUrlEnabled {
for i, midjourney := range logs { for i, midjourney := range logs {
midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
logs[i] = midjourney logs[i] = midjourney
} }
} }
@@ -263,9 +263,9 @@ func GetUserMidjourney(c *gin.Context) {
if logs == nil { if logs == nil {
logs = make([]*model.Midjourney, 0) logs = make([]*model.Midjourney, 0)
} }
if constant.MjForwardUrlEnabled { if setting.MjForwardUrlEnabled {
for i, midjourney := range logs { for i, midjourney := range logs {
midjourney.ImageUrl = constant.ServerAddress + "/mj/image/" + midjourney.MjId midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId
logs[i] = midjourney logs[i] = midjourney
} }
} }

View File

@@ -5,8 +5,8 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/constant"
"one-api/model" "one-api/model"
"one-api/setting"
"strings" "strings"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -47,9 +47,9 @@ func GetStatus(c *gin.Context) {
"footer_html": common.Footer, "footer_html": common.Footer,
"wechat_qrcode": common.WeChatAccountQRCodeImageURL, "wechat_qrcode": common.WeChatAccountQRCodeImageURL,
"wechat_login": common.WeChatAuthEnabled, "wechat_login": common.WeChatAuthEnabled,
"server_address": constant.ServerAddress, "server_address": setting.ServerAddress,
"price": constant.Price, "price": setting.Price,
"min_topup": constant.MinTopUp, "min_topup": setting.MinTopUp,
"turnstile_check": common.TurnstileCheckEnabled, "turnstile_check": common.TurnstileCheckEnabled,
"turnstile_site_key": common.TurnstileSiteKey, "turnstile_site_key": common.TurnstileSiteKey,
"top_up_link": common.TopUpLink, "top_up_link": common.TopUpLink,
@@ -63,9 +63,9 @@ func GetStatus(c *gin.Context) {
"enable_data_export": common.DataExportEnabled, "enable_data_export": common.DataExportEnabled,
"data_export_default_time": common.DataExportDefaultTime, "data_export_default_time": common.DataExportDefaultTime,
"default_collapse_sidebar": common.DefaultCollapseSidebar, "default_collapse_sidebar": common.DefaultCollapseSidebar,
"enable_online_topup": constant.PayAddress != "" && constant.EpayId != "" && constant.EpayKey != "", "enable_online_topup": setting.PayAddress != "" && setting.EpayId != "" && setting.EpayKey != "",
"mj_notify_enabled": constant.MjNotifyEnabled, "mj_notify_enabled": setting.MjNotifyEnabled,
"chats": constant.Chats, "chats": setting.Chats,
}, },
}) })
return return
@@ -207,7 +207,7 @@ func SendPasswordResetEmail(c *gin.Context) {
} }
code := common.GenerateVerificationCode(0) code := common.GenerateVerificationCode(0)
common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose) common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose)
link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", constant.ServerAddress, email, code) link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", setting.ServerAddress, email, code)
subject := fmt.Sprintf("%s密码重置", common.SystemName) subject := fmt.Sprintf("%s密码重置", common.SystemName)
content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+ content := fmt.Sprintf("<p>您好,你正在进行%s密码重置。</p>"+
"<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+ "<p>点击 <a href='%s'>此处</a> 进行密码重置。</p>"+

View File

@@ -8,9 +8,9 @@ import (
"log" "log"
"net/url" "net/url"
"one-api/common" "one-api/common"
"one-api/constant"
"one-api/model" "one-api/model"
"one-api/service" "one-api/service"
"one-api/setting"
"strconv" "strconv"
"sync" "sync"
"time" "time"
@@ -28,13 +28,13 @@ type AmountRequest struct {
} }
func GetEpayClient() *epay.Client { func GetEpayClient() *epay.Client {
if constant.PayAddress == "" || constant.EpayId == "" || constant.EpayKey == "" { if setting.PayAddress == "" || setting.EpayId == "" || setting.EpayKey == "" {
return nil return nil
} }
withUrl, err := epay.NewClient(&epay.Config{ withUrl, err := epay.NewClient(&epay.Config{
PartnerID: constant.EpayId, PartnerID: setting.EpayId,
Key: constant.EpayKey, Key: setting.EpayKey,
}, constant.PayAddress) }, setting.PayAddress)
if err != nil { if err != nil {
return nil return nil
} }
@@ -50,12 +50,12 @@ func getPayMoney(amount float64, group string) float64 {
if topupGroupRatio == 0 { if topupGroupRatio == 0 {
topupGroupRatio = 1 topupGroupRatio = 1
} }
payMoney := amount * constant.Price * topupGroupRatio payMoney := amount * setting.Price * topupGroupRatio
return payMoney return payMoney
} }
func getMinTopup() int { func getMinTopup() int {
minTopup := constant.MinTopUp minTopup := setting.MinTopUp
if !common.DisplayInCurrencyEnabled { if !common.DisplayInCurrencyEnabled {
minTopup = minTopup * int(common.QuotaPerUnit) minTopup = minTopup * int(common.QuotaPerUnit)
} }
@@ -94,7 +94,7 @@ func RequestEpay(c *gin.Context) {
payType = "wxpay" payType = "wxpay"
} }
callBackAddress := service.GetCallbackAddress() callBackAddress := service.GetCallbackAddress()
returnUrl, _ := url.Parse(constant.ServerAddress + "/log") returnUrl, _ := url.Parse(setting.ServerAddress + "/log")
notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify") notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify")
tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix()) tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix())
tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo) tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo)

View File

@@ -2,7 +2,7 @@ package model
import ( import (
"one-api/common" "one-api/common"
"one-api/constant" "one-api/setting"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@@ -61,16 +61,16 @@ func InitOptionMap() {
common.OptionMap["SystemName"] = common.SystemName common.OptionMap["SystemName"] = common.SystemName
common.OptionMap["Logo"] = common.Logo common.OptionMap["Logo"] = common.Logo
common.OptionMap["ServerAddress"] = "" common.OptionMap["ServerAddress"] = ""
common.OptionMap["WorkerUrl"] = constant.WorkerUrl common.OptionMap["WorkerUrl"] = setting.WorkerUrl
common.OptionMap["WorkerValidKey"] = constant.WorkerValidKey common.OptionMap["WorkerValidKey"] = setting.WorkerValidKey
common.OptionMap["PayAddress"] = "" common.OptionMap["PayAddress"] = ""
common.OptionMap["CustomCallbackAddress"] = "" common.OptionMap["CustomCallbackAddress"] = ""
common.OptionMap["EpayId"] = "" common.OptionMap["EpayId"] = ""
common.OptionMap["EpayKey"] = "" common.OptionMap["EpayKey"] = ""
common.OptionMap["Price"] = strconv.FormatFloat(constant.Price, 'f', -1, 64) common.OptionMap["Price"] = strconv.FormatFloat(setting.Price, 'f', -1, 64)
common.OptionMap["MinTopUp"] = strconv.Itoa(constant.MinTopUp) common.OptionMap["MinTopUp"] = strconv.Itoa(setting.MinTopUp)
common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString() common.OptionMap["TopupGroupRatio"] = common.TopupGroupRatio2JSONString()
common.OptionMap["Chats"] = constant.Chats2JsonString() common.OptionMap["Chats"] = setting.Chats2JsonString()
common.OptionMap["GitHubClientId"] = "" common.OptionMap["GitHubClientId"] = ""
common.OptionMap["GitHubClientSecret"] = "" common.OptionMap["GitHubClientSecret"] = ""
common.OptionMap["TelegramBotToken"] = "" common.OptionMap["TelegramBotToken"] = ""
@@ -98,17 +98,17 @@ func InitOptionMap() {
common.OptionMap["DataExportInterval"] = strconv.Itoa(common.DataExportInterval) common.OptionMap["DataExportInterval"] = strconv.Itoa(common.DataExportInterval)
common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime common.OptionMap["DataExportDefaultTime"] = common.DataExportDefaultTime
common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar) common.OptionMap["DefaultCollapseSidebar"] = strconv.FormatBool(common.DefaultCollapseSidebar)
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(constant.MjNotifyEnabled) common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(setting.MjNotifyEnabled)
common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(constant.MjAccountFilterEnabled) common.OptionMap["MjAccountFilterEnabled"] = strconv.FormatBool(setting.MjAccountFilterEnabled)
common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(constant.MjModeClearEnabled) common.OptionMap["MjModeClearEnabled"] = strconv.FormatBool(setting.MjModeClearEnabled)
common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(constant.MjForwardUrlEnabled) common.OptionMap["MjForwardUrlEnabled"] = strconv.FormatBool(setting.MjForwardUrlEnabled)
common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(constant.MjActionCheckSuccessEnabled) common.OptionMap["MjActionCheckSuccessEnabled"] = strconv.FormatBool(setting.MjActionCheckSuccessEnabled)
common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(constant.CheckSensitiveEnabled) common.OptionMap["CheckSensitiveEnabled"] = strconv.FormatBool(setting.CheckSensitiveEnabled)
common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnPromptEnabled) common.OptionMap["CheckSensitiveOnPromptEnabled"] = strconv.FormatBool(setting.CheckSensitiveOnPromptEnabled)
//common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled) //common.OptionMap["CheckSensitiveOnCompletionEnabled"] = strconv.FormatBool(constant.CheckSensitiveOnCompletionEnabled)
common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(constant.StopOnSensitiveEnabled) common.OptionMap["StopOnSensitiveEnabled"] = strconv.FormatBool(setting.StopOnSensitiveEnabled)
common.OptionMap["SensitiveWords"] = constant.SensitiveWordsToString() common.OptionMap["SensitiveWords"] = setting.SensitiveWordsToString()
common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(constant.StreamCacheQueueLength) common.OptionMap["StreamCacheQueueLength"] = strconv.Itoa(setting.StreamCacheQueueLength)
common.OptionMapRWMutex.Unlock() common.OptionMapRWMutex.Unlock()
loadOptionsFromDatabase() loadOptionsFromDatabase()
@@ -209,23 +209,23 @@ func updateOptionMap(key string, value string) (err error) {
case "DefaultCollapseSidebar": case "DefaultCollapseSidebar":
common.DefaultCollapseSidebar = boolValue common.DefaultCollapseSidebar = boolValue
case "MjNotifyEnabled": case "MjNotifyEnabled":
constant.MjNotifyEnabled = boolValue setting.MjNotifyEnabled = boolValue
case "MjAccountFilterEnabled": case "MjAccountFilterEnabled":
constant.MjAccountFilterEnabled = boolValue setting.MjAccountFilterEnabled = boolValue
case "MjModeClearEnabled": case "MjModeClearEnabled":
constant.MjModeClearEnabled = boolValue setting.MjModeClearEnabled = boolValue
case "MjForwardUrlEnabled": case "MjForwardUrlEnabled":
constant.MjForwardUrlEnabled = boolValue setting.MjForwardUrlEnabled = boolValue
case "MjActionCheckSuccessEnabled": case "MjActionCheckSuccessEnabled":
constant.MjActionCheckSuccessEnabled = boolValue setting.MjActionCheckSuccessEnabled = boolValue
case "CheckSensitiveEnabled": case "CheckSensitiveEnabled":
constant.CheckSensitiveEnabled = boolValue setting.CheckSensitiveEnabled = boolValue
case "CheckSensitiveOnPromptEnabled": case "CheckSensitiveOnPromptEnabled":
constant.CheckSensitiveOnPromptEnabled = boolValue setting.CheckSensitiveOnPromptEnabled = boolValue
//case "CheckSensitiveOnCompletionEnabled": //case "CheckSensitiveOnCompletionEnabled":
// constant.CheckSensitiveOnCompletionEnabled = boolValue // constant.CheckSensitiveOnCompletionEnabled = boolValue
case "StopOnSensitiveEnabled": case "StopOnSensitiveEnabled":
constant.StopOnSensitiveEnabled = boolValue setting.StopOnSensitiveEnabled = boolValue
case "SMTPSSLEnabled": case "SMTPSSLEnabled":
common.SMTPSSLEnabled = boolValue common.SMTPSSLEnabled = boolValue
} }
@@ -245,25 +245,25 @@ func updateOptionMap(key string, value string) (err error) {
case "SMTPToken": case "SMTPToken":
common.SMTPToken = value common.SMTPToken = value
case "ServerAddress": case "ServerAddress":
constant.ServerAddress = value setting.ServerAddress = value
case "WorkerUrl": case "WorkerUrl":
constant.WorkerUrl = value setting.WorkerUrl = value
case "WorkerValidKey": case "WorkerValidKey":
constant.WorkerValidKey = value setting.WorkerValidKey = value
case "PayAddress": case "PayAddress":
constant.PayAddress = value setting.PayAddress = value
case "Chats": case "Chats":
err = constant.UpdateChatsByJsonString(value) err = setting.UpdateChatsByJsonString(value)
case "CustomCallbackAddress": case "CustomCallbackAddress":
constant.CustomCallbackAddress = value setting.CustomCallbackAddress = value
case "EpayId": case "EpayId":
constant.EpayId = value setting.EpayId = value
case "EpayKey": case "EpayKey":
constant.EpayKey = value setting.EpayKey = value
case "Price": case "Price":
constant.Price, _ = strconv.ParseFloat(value, 64) setting.Price, _ = strconv.ParseFloat(value, 64)
case "MinTopUp": case "MinTopUp":
constant.MinTopUp, _ = strconv.Atoi(value) setting.MinTopUp, _ = strconv.Atoi(value)
case "TopupGroupRatio": case "TopupGroupRatio":
err = common.UpdateTopupGroupRatioByJSONString(value) err = common.UpdateTopupGroupRatioByJSONString(value)
case "GitHubClientId": case "GitHubClientId":
@@ -331,9 +331,9 @@ func updateOptionMap(key string, value string) (err error) {
case "QuotaPerUnit": case "QuotaPerUnit":
common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64) common.QuotaPerUnit, _ = strconv.ParseFloat(value, 64)
case "SensitiveWords": case "SensitiveWords":
constant.SensitiveWordsFromString(value) setting.SensitiveWordsFromString(value)
case "StreamCacheQueueLength": case "StreamCacheQueueLength":
constant.StreamCacheQueueLength, _ = strconv.Atoi(value) setting.StreamCacheQueueLength, _ = strconv.Atoi(value)
} }
return err return err
} }

View File

@@ -5,8 +5,8 @@ import (
"fmt" "fmt"
"gorm.io/gorm" "gorm.io/gorm"
"one-api/common" "one-api/common"
"one-api/constant"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/setting"
"strconv" "strconv"
"strings" "strings"
) )
@@ -325,7 +325,7 @@ func PostConsumeTokenQuota(relayInfo *relaycommon.RelayInfo, userQuota int, quot
prompt = "您的额度已用尽" prompt = "您的额度已用尽"
} }
if email != "" { if email != "" {
topUpLink := fmt.Sprintf("%s/topup", constant.ServerAddress) topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress)
err = common.SendEmail(prompt, email, err = common.SendEmail(prompt, email,
fmt.Sprintf("%s当前剩余额度为 %d为了不影响您的使用请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink)) fmt.Sprintf("%s当前剩余额度为 %d为了不影响您的使用请及时充值。<br/>充值链接:<a href='%s'>%s</a>", prompt, userQuota, topUpLink, topUpLink))
if err != nil { if err != nil {

View File

@@ -7,12 +7,12 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/model" "one-api/model"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
"one-api/service" "one-api/service"
"one-api/setting"
) )
func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) { func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.AudioRequest, error) {
@@ -26,7 +26,7 @@ func getAndValidAudioRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
if audioRequest.Model == "" { if audioRequest.Model == "" {
return nil, errors.New("model is required") return nil, errors.New("model is required")
} }
if constant.ShouldCheckPromptSensitive() { if setting.ShouldCheckPromptSensitive() {
err := service.CheckSensitiveInput(audioRequest.Input) err := service.CheckSensitiveInput(audioRequest.Input)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -9,11 +9,11 @@ import (
"io" "io"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/model" "one-api/model"
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
"one-api/service" "one-api/service"
"one-api/setting"
"strings" "strings"
) )
@@ -59,7 +59,7 @@ func getAndValidImageRequest(c *gin.Context, info *relaycommon.RelayInfo) (*dto.
//if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) { //if imageRequest.N != 0 && (imageRequest.N < 1 || imageRequest.N > 10) {
// return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest) // return service.OpenAIErrorWrapper(errors.New("n must be between 1 and 10"), "invalid_field_value", http.StatusBadRequest)
//} //}
if constant.ShouldCheckPromptSensitive() { if setting.ShouldCheckPromptSensitive() {
err := service.CheckSensitiveInput(imageRequest.Prompt) err := service.CheckSensitiveInput(imageRequest.Prompt)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -15,6 +15,7 @@ import (
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
"one-api/service" "one-api/service"
"one-api/setting"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@@ -111,8 +112,8 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo
midjourneyTask.StartTime = originTask.StartTime midjourneyTask.StartTime = originTask.StartTime
midjourneyTask.FinishTime = originTask.FinishTime midjourneyTask.FinishTime = originTask.FinishTime
midjourneyTask.ImageUrl = "" midjourneyTask.ImageUrl = ""
if originTask.ImageUrl != "" && constant.MjForwardUrlEnabled { if originTask.ImageUrl != "" && setting.MjForwardUrlEnabled {
midjourneyTask.ImageUrl = constant.ServerAddress + "/mj/image/" + originTask.MjId midjourneyTask.ImageUrl = setting.ServerAddress + "/mj/image/" + originTask.MjId
if originTask.Status != "SUCCESS" { if originTask.Status != "SUCCESS" {
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10) midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
} }
@@ -421,7 +422,7 @@ func RelayMidjourneySubmit(c *gin.Context, relayMode int) *dto.MidjourneyRespons
if originTask == nil { if originTask == nil {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found") return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found")
} else { //原任务的Status=SUCCESS则可以做放大UPSCALE、变换VARIATION等动作此时必须使用原来的请求地址才能正确处理 } else { //原任务的Status=SUCCESS则可以做放大UPSCALE、变换VARIATION等动作此时必须使用原来的请求地址才能正确处理
if constant.MjActionCheckSuccessEnabled { if setting.MjActionCheckSuccessEnabled {
if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal { if originTask.Status != "SUCCESS" && relayMode != relayconstant.RelayModeMidjourneyModal {
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success") return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
} }

View File

@@ -15,6 +15,7 @@ import (
relaycommon "one-api/relay/common" relaycommon "one-api/relay/common"
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
"one-api/service" "one-api/service"
"one-api/setting"
"strings" "strings"
"time" "time"
@@ -100,7 +101,7 @@ func TextHelper(c *gin.Context) (openaiErr *dto.OpenAIErrorWithStatusCode) {
var modelRatio float64 var modelRatio float64
//err := service.SensitiveWordsCheck(textRequest) //err := service.SensitiveWordsCheck(textRequest)
if constant.ShouldCheckPromptSensitive() { if setting.ShouldCheckPromptSensitive() {
err = checkRequestSensitive(textRequest, relayInfo) err = checkRequestSensitive(textRequest, relayInfo)
if err != nil { if err != nil {
return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest) return service.OpenAIErrorWrapperLocal(err, "sensitive_words_detected", http.StatusBadRequest)

View File

@@ -1,12 +1,12 @@
package service package service
import ( import (
"one-api/constant" "one-api/setting"
) )
func GetCallbackAddress() string { func GetCallbackAddress() string {
if constant.CustomCallbackAddress == "" { if setting.CustomCallbackAddress == "" {
return constant.ServerAddress return setting.ServerAddress
} }
return constant.CustomCallbackAddress return setting.CustomCallbackAddress
} }

View File

@@ -11,6 +11,7 @@ import (
"one-api/constant" "one-api/constant"
"one-api/dto" "one-api/dto"
relayconstant "one-api/relay/constant" relayconstant "one-api/relay/constant"
"one-api/setting"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@@ -167,16 +168,16 @@ func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestU
if err != nil { if err != nil {
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err
} }
if !constant.MjAccountFilterEnabled { if !setting.MjAccountFilterEnabled {
delete(mapResult, "accountFilter") delete(mapResult, "accountFilter")
} }
if !constant.MjNotifyEnabled { if !setting.MjNotifyEnabled {
delete(mapResult, "notifyHook") delete(mapResult, "notifyHook")
} }
//req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody) //req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
// make new request with mapResult // make new request with mapResult
} }
if constant.MjModeClearEnabled { if setting.MjModeClearEnabled {
if prompt, ok := mapResult["prompt"].(string); ok { if prompt, ok := mapResult["prompt"].(string); ok {
prompt = strings.Replace(prompt, "--fast", "", -1) prompt = strings.Replace(prompt, "--fast", "", -1)
prompt = strings.Replace(prompt, "--relax", "", -1) prompt = strings.Replace(prompt, "--relax", "", -1)

View File

@@ -3,8 +3,8 @@ package service
import ( import (
"errors" "errors"
"fmt" "fmt"
"one-api/constant"
"one-api/dto" "one-api/dto"
"one-api/setting"
"strings" "strings"
) )
@@ -56,7 +56,7 @@ func CheckSensitiveInput(input any) error {
// SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表 // SensitiveWordContains 是否包含敏感词,返回是否包含敏感词和敏感词列表
func SensitiveWordContains(text string) (bool, []string) { func SensitiveWordContains(text string) (bool, []string) {
if len(constant.SensitiveWords) == 0 { if len(setting.SensitiveWords) == 0 {
return false, nil return false, nil
} }
checkText := strings.ToLower(text) checkText := strings.ToLower(text)
@@ -75,7 +75,7 @@ func SensitiveWordContains(text string) (bool, []string) {
// SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本 // SensitiveWordReplace 敏感词替换,返回是否包含敏感词和替换后的文本
func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string, string) { func SensitiveWordReplace(text string, returnImmediately bool) (bool, []string, string) {
if len(constant.SensitiveWords) == 0 { if len(setting.SensitiveWords) == 0 {
return false, nil, text return false, nil, text
} }
checkText := strings.ToLower(text) checkText := strings.ToLower(text)

View File

@@ -4,7 +4,7 @@ import (
"bytes" "bytes"
"fmt" "fmt"
goahocorasick "github.com/anknown/ahocorasick" goahocorasick "github.com/anknown/ahocorasick"
"one-api/constant" "one-api/setting"
"strings" "strings"
) )
@@ -70,7 +70,7 @@ func InitAc() *goahocorasick.Machine {
func readRunes() [][]rune { func readRunes() [][]rune {
var dict [][]rune var dict [][]rune
for _, word := range constant.SensitiveWords { for _, word := range setting.SensitiveWords {
word = strings.ToLower(word) word = strings.ToLower(word)
l := bytes.TrimSpace([]byte(word)) l := bytes.TrimSpace([]byte(word))
dict = append(dict, bytes.Runes(l)) dict = append(dict, bytes.Runes(l))

View File

@@ -5,20 +5,20 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"one-api/common" "one-api/common"
"one-api/constant" "one-api/setting"
"strings" "strings"
) )
func DoImageRequest(originUrl string) (resp *http.Response, err error) { func DoImageRequest(originUrl string) (resp *http.Response, err error) {
if constant.EnableWorker() { if setting.EnableWorker() {
common.SysLog(fmt.Sprintf("downloading image from worker: %s", originUrl)) common.SysLog(fmt.Sprintf("downloading image from worker: %s", originUrl))
workerUrl := constant.WorkerUrl workerUrl := setting.WorkerUrl
if !strings.HasSuffix(workerUrl, "/") { if !strings.HasSuffix(workerUrl, "/") {
workerUrl += "/" workerUrl += "/"
} }
// post request to worker // post request to worker
data := []byte(`{"url":"` + originUrl + `","key":"` + constant.WorkerValidKey + `"}`) data := []byte(`{"url":"` + originUrl + `","key":"` + setting.WorkerValidKey + `"}`)
return http.Post(constant.WorkerUrl, "application/json", bytes.NewBuffer(data)) return http.Post(setting.WorkerUrl, "application/json", bytes.NewBuffer(data))
} else { } else {
common.SysLog(fmt.Sprintf("downloading image from origin: %s", originUrl)) common.SysLog(fmt.Sprintf("downloading image from origin: %s", originUrl))
return http.Get(originUrl) return http.Get(originUrl)

View File

@@ -1,4 +1,4 @@
package constant package setting
import ( import (
"encoding/json" "encoding/json"

7
setting/midjourney.go Normal file
View File

@@ -0,0 +1,7 @@
package setting
var MjNotifyEnabled = false
var MjAccountFilterEnabled = false
var MjModeClearEnabled = false
var MjForwardUrlEnabled = true
var MjActionCheckSuccessEnabled = true

View File

@@ -1,4 +1,4 @@
package constant package setting
var PayAddress = "" var PayAddress = ""
var CustomCallbackAddress = "" var CustomCallbackAddress = ""

View File

@@ -1,4 +1,4 @@
package constant package setting
import "strings" import "strings"

View File

@@ -1,4 +1,4 @@
package constant package setting
var ServerAddress = "http://localhost:3000" var ServerAddress = "http://localhost:3000"
var WorkerUrl = "" var WorkerUrl = ""