diff --git a/controller/midjourney.go b/controller/midjourney.go index a67d39c2..3a730441 100644 --- a/controller/midjourney.go +++ b/controller/midjourney.go @@ -13,6 +13,7 @@ import ( "one-api/model" "one-api/service" "one-api/setting" + "one-api/setting/system_setting" "time" "github.com/gin-gonic/gin" @@ -259,7 +260,7 @@ func GetAllMidjourney(c *gin.Context) { if setting.MjForwardUrlEnabled { for i, midjourney := range items { - midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId + midjourney.ImageUrl = system_setting.ServerAddress + "/mj/image/" + midjourney.MjId items[i] = midjourney } } @@ -284,7 +285,7 @@ func GetUserMidjourney(c *gin.Context) { if setting.MjForwardUrlEnabled { for i, midjourney := range items { - midjourney.ImageUrl = setting.ServerAddress + "/mj/image/" + midjourney.MjId + midjourney.ImageUrl = system_setting.ServerAddress + "/mj/image/" + midjourney.MjId items[i] = midjourney } } diff --git a/controller/misc.go b/controller/misc.go index 08582930..875142ff 100644 --- a/controller/misc.go +++ b/controller/misc.go @@ -58,7 +58,7 @@ func GetStatus(c *gin.Context) { "footer_html": common.Footer, "wechat_qrcode": common.WeChatAccountQRCodeImageURL, "wechat_login": common.WeChatAuthEnabled, - "server_address": setting.ServerAddress, + "server_address": system_setting.ServerAddress, "turnstile_check": common.TurnstileCheckEnabled, "turnstile_site_key": common.TurnstileSiteKey, "top_up_link": common.TopUpLink, @@ -249,7 +249,7 @@ func SendPasswordResetEmail(c *gin.Context) { } code := common.GenerateVerificationCode(0) common.RegisterVerificationCodeWithKey(email, code, common.PasswordResetPurpose) - link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", setting.ServerAddress, email, code) + link := fmt.Sprintf("%s/user/reset?email=%s&token=%s", system_setting.ServerAddress, email, code) subject := fmt.Sprintf("%s密码重置", common.SystemName) content := fmt.Sprintf("
您好,你正在进行%s密码重置。
"+ "点击 此处 进行密码重置。
"+ diff --git a/controller/oidc.go b/controller/oidc.go index f3def0e3..8e254d38 100644 --- a/controller/oidc.go +++ b/controller/oidc.go @@ -8,7 +8,6 @@ import ( "net/url" "one-api/common" "one-api/model" - "one-api/setting" "one-api/setting/system_setting" "strconv" "strings" @@ -45,7 +44,7 @@ func getOidcUserInfoByCode(code string) (*OidcUser, error) { values.Set("client_secret", system_setting.GetOIDCSettings().ClientSecret) values.Set("code", code) values.Set("grant_type", "authorization_code") - values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", setting.ServerAddress)) + values.Set("redirect_uri", fmt.Sprintf("%s/oauth/oidc", system_setting.ServerAddress)) formData := values.Encode() req, err := http.NewRequest("POST", system_setting.GetOIDCSettings().TokenEndpoint, strings.NewReader(formData)) if err != nil { diff --git a/controller/topup.go b/controller/topup.go index 93f3e58e..243e6794 100644 --- a/controller/topup.go +++ b/controller/topup.go @@ -10,6 +10,7 @@ import ( "one-api/service" "one-api/setting" "one-api/setting/operation_setting" + "one-api/setting/system_setting" "strconv" "sync" "time" @@ -152,7 +153,7 @@ func RequestEpay(c *gin.Context) { } callBackAddress := service.GetCallbackAddress() - returnUrl, _ := url.Parse(setting.ServerAddress + "/console/log") + returnUrl, _ := url.Parse(system_setting.ServerAddress + "/console/log") notifyUrl, _ := url.Parse(callBackAddress + "/api/user/epay/notify") tradeNo := fmt.Sprintf("%s%d", common.GetRandomString(6), time.Now().Unix()) tradeNo = fmt.Sprintf("USR%dNO%s", id, tradeNo) diff --git a/controller/topup_stripe.go b/controller/topup_stripe.go index bf0d7bf3..d462acb4 100644 --- a/controller/topup_stripe.go +++ b/controller/topup_stripe.go @@ -9,6 +9,7 @@ import ( "one-api/model" "one-api/setting" "one-api/setting/operation_setting" + "one-api/setting/system_setting" "strconv" "strings" "time" @@ -216,8 +217,8 @@ func genStripeLink(referenceId string, customerId string, email string, amount i params := &stripe.CheckoutSessionParams{ ClientReferenceID: stripe.String(referenceId), - SuccessURL: stripe.String(setting.ServerAddress + "/log"), - CancelURL: stripe.String(setting.ServerAddress + "/topup"), + SuccessURL: stripe.String(system_setting.ServerAddress + "/log"), + CancelURL: stripe.String(system_setting.ServerAddress + "/topup"), LineItems: []*stripe.CheckoutSessionLineItemParams{ { Price: stripe.String(setting.StripePriceId), diff --git a/model/option.go b/model/option.go index 73fe92ad..fefee4e7 100644 --- a/model/option.go +++ b/model/option.go @@ -6,6 +6,7 @@ import ( "one-api/setting/config" "one-api/setting/operation_setting" "one-api/setting/ratio_setting" + "one-api/setting/system_setting" "strconv" "strings" "time" @@ -66,9 +67,9 @@ func InitOptionMap() { common.OptionMap["SystemName"] = common.SystemName common.OptionMap["Logo"] = common.Logo common.OptionMap["ServerAddress"] = "" - common.OptionMap["WorkerUrl"] = setting.WorkerUrl - common.OptionMap["WorkerValidKey"] = setting.WorkerValidKey - common.OptionMap["WorkerAllowHttpImageRequestEnabled"] = strconv.FormatBool(setting.WorkerAllowHttpImageRequestEnabled) + common.OptionMap["WorkerUrl"] = system_setting.WorkerUrl + common.OptionMap["WorkerValidKey"] = system_setting.WorkerValidKey + common.OptionMap["WorkerAllowHttpImageRequestEnabled"] = strconv.FormatBool(system_setting.WorkerAllowHttpImageRequestEnabled) common.OptionMap["PayAddress"] = "" common.OptionMap["CustomCallbackAddress"] = "" common.OptionMap["EpayId"] = "" @@ -271,7 +272,7 @@ func updateOptionMap(key string, value string) (err error) { case "SMTPSSLEnabled": common.SMTPSSLEnabled = boolValue case "WorkerAllowHttpImageRequestEnabled": - setting.WorkerAllowHttpImageRequestEnabled = boolValue + system_setting.WorkerAllowHttpImageRequestEnabled = boolValue case "DefaultUseAutoGroup": setting.DefaultUseAutoGroup = boolValue case "ExposeRatioEnabled": @@ -293,11 +294,11 @@ func updateOptionMap(key string, value string) (err error) { case "SMTPToken": common.SMTPToken = value case "ServerAddress": - setting.ServerAddress = value + system_setting.ServerAddress = value case "WorkerUrl": - setting.WorkerUrl = value + system_setting.WorkerUrl = value case "WorkerValidKey": - setting.WorkerValidKey = value + system_setting.WorkerValidKey = value case "PayAddress": operation_setting.PayAddress = value case "Chats": diff --git a/relay/channel/task/vertex/adaptor.go b/relay/channel/task/vertex/adaptor.go index d2ab826d..4a236b2f 100644 --- a/relay/channel/task/vertex/adaptor.go +++ b/relay/channel/task/vertex/adaptor.go @@ -7,12 +7,12 @@ import ( "fmt" "io" "net/http" + "one-api/model" "regexp" "strings" "github.com/gin-gonic/gin" - "one-api/common" "one-api/constant" "one-api/dto" "one-api/relay/channel" @@ -21,6 +21,10 @@ import ( "one-api/service" ) +// ============================ +// Request / Response structures +// ============================ + type requestPayload struct { Instances []map[string]any `json:"instances"` Parameters map[string]any `json:"parameters,omitempty"` @@ -52,33 +56,35 @@ type operationResponse struct { } `json:"error"` } -type TaskAdaptor struct{} +// ============================ +// Adaptor implementation +// ============================ -func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {} - -func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) { - info.Action = constant.TaskActionTextGenerate - - req := relaycommon.TaskSubmitReq{} - if err := common.UnmarshalBodyReusable(c, &req); err != nil { - return service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest) - } - if strings.TrimSpace(req.Prompt) == "" { - return service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest) - } - c.Set("task_request", req) - return nil +type TaskAdaptor struct { + ChannelType int + apiKey string + baseURL string } -func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) { +func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { + a.ChannelType = info.ChannelType + a.baseURL = info.ChannelBaseUrl + a.apiKey = info.ApiKey +} + +// ValidateRequestAndSetAction parses body, validates fields and sets default action. +func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { + // Use the standard validation method for TaskSubmitReq + return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionTextGenerate) +} + +// BuildRequestURL constructs the upstream URL. +func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { adc := &vertexcore.Credentials{} - if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil { + if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil { return "", fmt.Errorf("failed to decode credentials: %w", err) } modelName := info.OriginModelName - if v, ok := getRequestModelFromContext(info); ok { - modelName = v - } if modelName == "" { modelName = "veo-3.0-generate-001" } @@ -103,16 +109,17 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, ), nil } -func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error { +// BuildRequestHeader sets required headers. +func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") adc := &vertexcore.Credentials{} - if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil { + if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil { return fmt.Errorf("failed to decode credentials: %w", err) } - token, err := vertexcore.AcquireAccessToken(*adc, info.ChannelSetting.Proxy) + token, err := vertexcore.AcquireAccessToken(*adc, "") if err != nil { return fmt.Errorf("failed to acquire access token: %w", err) } @@ -121,7 +128,8 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info return nil } -func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.TaskRelayInfo) (io.Reader, error) { +// BuildRequestBody converts request into Vertex specific format. +func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { v, ok := c.Get("task_request") if !ok { return nil, fmt.Errorf("request not found in context") @@ -151,11 +159,13 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.TaskRelayI return bytes.NewReader(data), nil } -func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) { +// DoRequest delegates to common helper. +func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { return channel.DoTaskApiRequest(a, c, info, requestBody) } -func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { +// DoResponse handles upstream response, returns taskID etc. +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { return "", nil, service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) @@ -177,6 +187,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relayco func (a *TaskAdaptor) GetModelList() []string { return []string{"veo-3.0-generate-001"} } func (a *TaskAdaptor) GetChannelName() string { return "vertex" } +// FetchTask fetch task status func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { taskID, ok := body["task_id"].(string) if !ok { @@ -191,15 +202,15 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http region = "us-central1" } project := extractProjectFromOperationName(upstreamName) - model := extractModelFromOperationName(upstreamName) - if project == "" || model == "" { + modelName := extractModelFromOperationName(upstreamName) + if project == "" || modelName == "" { return nil, fmt.Errorf("cannot extract project/model from operation name") } var url string if region == "global" { - url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:fetchPredictOperation", project, model) + url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:fetchPredictOperation", project, modelName) } else { - url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, model) + url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, modelName) } payload := map[string]string{"operationName": upstreamName} data, err := json.Marshal(payload) @@ -232,17 +243,17 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e } ti := &relaycommon.TaskInfo{} if op.Error.Message != "" { - ti.Status = "FAILURE" + ti.Status = model.TaskStatusFailure ti.Reason = op.Error.Message ti.Progress = "100%" return ti, nil } if !op.Done { - ti.Status = "IN_PROGRESS" + ti.Status = model.TaskStatusInProgress ti.Progress = "50%" return ti, nil } - ti.Status = "SUCCESS" + ti.Status = model.TaskStatusSuccess ti.Progress = "100%" if len(op.Response.Videos) > 0 { v0 := op.Response.Videos[0] @@ -290,9 +301,9 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e return ti, nil } -func getRequestModelFromContext(info *relaycommon.TaskRelayInfo) (string, bool) { - return info.OriginModelName, info.OriginModelName != "" -} +// ============================ +// helpers +// ============================ func encodeLocalTaskID(name string) string { return base64.RawURLEncoding.EncodeToString([]byte(name)) diff --git a/relay/mjproxy_handler.go b/relay/mjproxy_handler.go index 7c52cb6b..ec8dfc6b 100644 --- a/relay/mjproxy_handler.go +++ b/relay/mjproxy_handler.go @@ -16,6 +16,7 @@ import ( "one-api/relay/helper" "one-api/service" "one-api/setting" + "one-api/setting/system_setting" "strconv" "strings" "time" @@ -131,7 +132,7 @@ func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjo midjourneyTask.FinishTime = originTask.FinishTime midjourneyTask.ImageUrl = "" if originTask.ImageUrl != "" && setting.MjForwardUrlEnabled { - midjourneyTask.ImageUrl = setting.ServerAddress + "/mj/image/" + originTask.MjId + midjourneyTask.ImageUrl = system_setting.ServerAddress + "/mj/image/" + originTask.MjId if originTask.Status != "SUCCESS" { midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10) } diff --git a/service/epay.go b/service/epay.go index a1ff484e..48b84dd5 100644 --- a/service/epay.go +++ b/service/epay.go @@ -1,13 +1,13 @@ package service import ( - "one-api/setting" "one-api/setting/operation_setting" + "one-api/setting/system_setting" ) func GetCallbackAddress() string { if operation_setting.CustomCallbackAddress == "" { - return setting.ServerAddress + return system_setting.ServerAddress } return operation_setting.CustomCallbackAddress } diff --git a/service/quota.go b/service/quota.go index e078a1ad..12017e11 100644 --- a/service/quota.go +++ b/service/quota.go @@ -11,8 +11,8 @@ import ( "one-api/logger" "one-api/model" relaycommon "one-api/relay/common" - "one-api/setting" "one-api/setting/ratio_setting" + "one-api/setting/system_setting" "one-api/types" "strings" "time" @@ -534,7 +534,7 @@ func checkAndSendQuotaNotify(relayInfo *relaycommon.RelayInfo, quota int, preCon } if quotaTooLow { prompt := "您的额度即将用尽" - topUpLink := fmt.Sprintf("%s/topup", setting.ServerAddress) + topUpLink := fmt.Sprintf("%s/topup", system_setting.ServerAddress) // 根据通知方式生成不同的内容格式 var content string diff --git a/setting/system_setting.go b/setting/system_setting/system_setting_old.go similarity index 89% rename from setting/system_setting.go rename to setting/system_setting/system_setting_old.go index c37a6123..4e0f1a50 100644 --- a/setting/system_setting.go +++ b/setting/system_setting/system_setting_old.go @@ -1,4 +1,4 @@ -package setting +package system_setting var ServerAddress = "http://localhost:3000" var WorkerUrl = ""