This commit refactors the logging mechanism across the application by replacing direct logger calls with a centralized logging approach using the `common` package. Key changes include: - Replaced instances of `logger.SysLog` and `logger.FatalLog` with `common.SysLog` and `common.FatalLog` for consistent logging practices. - Updated resource initialization error handling to utilize the new logging structure, enhancing maintainability and readability. - Minor adjustments to improve code clarity and organization throughout various modules. This change aims to streamline logging and improve the overall architecture of the codebase.
259 lines
8.7 KiB
Go
259 lines
8.7 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"io"
|
|
"log"
|
|
"net/http"
|
|
"one-api/common"
|
|
"one-api/constant"
|
|
"one-api/dto"
|
|
relayconstant "one-api/relay/constant"
|
|
"one-api/setting"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
func CoverActionToModelName(mjAction string) string {
|
|
modelName := "mj_" + strings.ToLower(mjAction)
|
|
if mjAction == constant.MjActionSwapFace {
|
|
modelName = "swap_face"
|
|
}
|
|
return modelName
|
|
}
|
|
|
|
func GetMjRequestModel(relayMode int, midjRequest *dto.MidjourneyRequest) (string, *dto.MidjourneyResponse, bool) {
|
|
action := ""
|
|
if relayMode == relayconstant.RelayModeMidjourneyAction {
|
|
// plus request
|
|
err := CoverPlusActionToNormalAction(midjRequest)
|
|
if err != nil {
|
|
return "", err, false
|
|
}
|
|
action = midjRequest.Action
|
|
} else {
|
|
switch relayMode {
|
|
case relayconstant.RelayModeMidjourneyImagine:
|
|
action = constant.MjActionImagine
|
|
case relayconstant.RelayModeMidjourneyVideo:
|
|
action = constant.MjActionVideo
|
|
case relayconstant.RelayModeMidjourneyEdits:
|
|
action = constant.MjActionEdits
|
|
case relayconstant.RelayModeMidjourneyDescribe:
|
|
action = constant.MjActionDescribe
|
|
case relayconstant.RelayModeMidjourneyBlend:
|
|
action = constant.MjActionBlend
|
|
case relayconstant.RelayModeMidjourneyShorten:
|
|
action = constant.MjActionShorten
|
|
case relayconstant.RelayModeMidjourneyChange:
|
|
action = midjRequest.Action
|
|
case relayconstant.RelayModeMidjourneyModal:
|
|
action = constant.MjActionModal
|
|
case relayconstant.RelayModeSwapFace:
|
|
action = constant.MjActionSwapFace
|
|
case relayconstant.RelayModeMidjourneyUpload:
|
|
action = constant.MjActionUpload
|
|
case relayconstant.RelayModeMidjourneySimpleChange:
|
|
params := ConvertSimpleChangeParams(midjRequest.Content)
|
|
if params == nil {
|
|
return "", MidjourneyErrorWrapper(constant.MjRequestError, "invalid_request"), false
|
|
}
|
|
action = params.Action
|
|
case relayconstant.RelayModeMidjourneyTaskFetch, relayconstant.RelayModeMidjourneyTaskFetchByCondition, relayconstant.RelayModeMidjourneyNotify:
|
|
return "", nil, true
|
|
default:
|
|
return "", MidjourneyErrorWrapper(constant.MjRequestError, "unknown_relay_action"), false
|
|
}
|
|
}
|
|
modelName := CoverActionToModelName(action)
|
|
return modelName, nil, true
|
|
}
|
|
|
|
func CoverPlusActionToNormalAction(midjRequest *dto.MidjourneyRequest) *dto.MidjourneyResponse {
|
|
// "customId": "MJ::JOB::upsample::2::3dbbd469-36af-4a0f-8f02-df6c579e7011"
|
|
customId := midjRequest.CustomId
|
|
if customId == "" {
|
|
return MidjourneyErrorWrapper(constant.MjRequestError, "custom_id_is_required")
|
|
}
|
|
splits := strings.Split(customId, "::")
|
|
var action string
|
|
if splits[1] == "JOB" {
|
|
action = splits[2]
|
|
} else {
|
|
action = splits[1]
|
|
}
|
|
|
|
if action == "" {
|
|
return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action")
|
|
}
|
|
if strings.Contains(action, "upsample") {
|
|
index, err := strconv.Atoi(splits[3])
|
|
if err != nil {
|
|
return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
|
|
}
|
|
midjRequest.Index = index
|
|
midjRequest.Action = constant.MjActionUpscale
|
|
} else if strings.Contains(action, "variation") {
|
|
midjRequest.Index = 1
|
|
if action == "variation" {
|
|
index, err := strconv.Atoi(splits[3])
|
|
if err != nil {
|
|
return MidjourneyErrorWrapper(constant.MjRequestError, "index_parse_failed")
|
|
}
|
|
midjRequest.Index = index
|
|
midjRequest.Action = constant.MjActionVariation
|
|
} else if action == "low_variation" {
|
|
midjRequest.Action = constant.MjActionLowVariation
|
|
} else if action == "high_variation" {
|
|
midjRequest.Action = constant.MjActionHighVariation
|
|
}
|
|
} else if strings.Contains(action, "pan") {
|
|
midjRequest.Action = constant.MjActionPan
|
|
midjRequest.Index = 1
|
|
} else if strings.Contains(action, "reroll") {
|
|
midjRequest.Action = constant.MjActionReRoll
|
|
midjRequest.Index = 1
|
|
} else if action == "Outpaint" {
|
|
midjRequest.Action = constant.MjActionZoom
|
|
midjRequest.Index = 1
|
|
} else if action == "CustomZoom" {
|
|
midjRequest.Action = constant.MjActionCustomZoom
|
|
midjRequest.Index = 1
|
|
} else if action == "Inpaint" {
|
|
midjRequest.Action = constant.MjActionInPaint
|
|
midjRequest.Index = 1
|
|
} else {
|
|
return MidjourneyErrorWrapper(constant.MjRequestError, "unknown_action:"+customId)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func ConvertSimpleChangeParams(content string) *dto.MidjourneyRequest {
|
|
split := strings.Split(content, " ")
|
|
if len(split) != 2 {
|
|
return nil
|
|
}
|
|
|
|
action := strings.ToLower(split[1])
|
|
changeParams := &dto.MidjourneyRequest{}
|
|
changeParams.TaskId = split[0]
|
|
|
|
if action[0] == 'u' {
|
|
changeParams.Action = "UPSCALE"
|
|
} else if action[0] == 'v' {
|
|
changeParams.Action = "VARIATION"
|
|
} else if action == "r" {
|
|
changeParams.Action = "REROLL"
|
|
return changeParams
|
|
} else {
|
|
return nil
|
|
}
|
|
|
|
index, err := strconv.Atoi(action[1:2])
|
|
if err != nil || index < 1 || index > 4 {
|
|
return nil
|
|
}
|
|
changeParams.Index = index
|
|
return changeParams
|
|
}
|
|
|
|
func DoMidjourneyHttpRequest(c *gin.Context, timeout time.Duration, fullRequestURL string) (*dto.MidjourneyResponseWithStatusCode, []byte, error) {
|
|
var nullBytes []byte
|
|
//var requestBody io.Reader
|
|
//requestBody = c.Request.Body
|
|
// read request body to json, delete accountFilter and notifyHook
|
|
var mapResult map[string]interface{}
|
|
// if get request, no need to read request body
|
|
if c.Request.Method != "GET" {
|
|
err := json.NewDecoder(c.Request.Body).Decode(&mapResult)
|
|
if err != nil {
|
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_request_body_failed", http.StatusInternalServerError), nullBytes, err
|
|
}
|
|
if !setting.MjAccountFilterEnabled {
|
|
delete(mapResult, "accountFilter")
|
|
}
|
|
if !setting.MjNotifyEnabled {
|
|
delete(mapResult, "notifyHook")
|
|
}
|
|
//req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
|
// make new request with mapResult
|
|
}
|
|
if setting.MjModeClearEnabled {
|
|
if prompt, ok := mapResult["prompt"].(string); ok {
|
|
prompt = strings.Replace(prompt, "--fast", "", -1)
|
|
prompt = strings.Replace(prompt, "--relax", "", -1)
|
|
prompt = strings.Replace(prompt, "--turbo", "", -1)
|
|
|
|
mapResult["prompt"] = prompt
|
|
}
|
|
}
|
|
reqBody, err := json.Marshal(mapResult)
|
|
if err != nil {
|
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "marshal_request_body_failed", http.StatusInternalServerError), nullBytes, err
|
|
}
|
|
req, err := http.NewRequest(c.Request.Method, fullRequestURL, strings.NewReader(string(reqBody)))
|
|
if err != nil {
|
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "create_request_failed", http.StatusInternalServerError), nullBytes, err
|
|
}
|
|
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
// 使用带有超时的 context 创建新的请求
|
|
req = req.WithContext(ctx)
|
|
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
|
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
|
auth := common.GetContextKeyString(c, constant.ContextKeyChannelKey)
|
|
if auth != "" {
|
|
auth = strings.TrimPrefix(auth, "Bearer ")
|
|
req.Header.Set("mj-api-secret", auth)
|
|
}
|
|
defer cancel()
|
|
resp, err := GetHttpClient().Do(req)
|
|
if err != nil {
|
|
common.SysLog("do request failed: " + err.Error())
|
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "do_request_failed", http.StatusInternalServerError), nullBytes, err
|
|
}
|
|
statusCode := resp.StatusCode
|
|
//if statusCode != 200 {
|
|
// return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "bad_response_status_code", statusCode), nullBytes, nil
|
|
//}
|
|
err = req.Body.Close()
|
|
if err != nil {
|
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
|
|
}
|
|
err = c.Request.Body.Close()
|
|
if err != nil {
|
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "close_request_body_failed", statusCode), nullBytes, err
|
|
}
|
|
var midjResponse dto.MidjourneyResponse
|
|
var midjourneyUploadsResponse dto.MidjourneyUploadResponse
|
|
responseBody, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "read_response_body_failed", statusCode), nullBytes, err
|
|
}
|
|
CloseResponseBodyGracefully(resp)
|
|
respStr := string(responseBody)
|
|
log.Printf("respStr: %s", respStr)
|
|
if respStr == "" {
|
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "empty_response_body", statusCode), responseBody, nil
|
|
} else {
|
|
err = json.Unmarshal(responseBody, &midjResponse)
|
|
if err != nil {
|
|
err2 := json.Unmarshal(responseBody, &midjourneyUploadsResponse)
|
|
if err2 != nil {
|
|
return MidjourneyErrorWithStatusCodeWrapper(constant.MjErrorUnknown, "unmarshal_response_body_failed", statusCode), responseBody, err
|
|
}
|
|
}
|
|
}
|
|
//log.Printf("midjResponse: %v", midjResponse)
|
|
//for k, v := range resp.Header {
|
|
// c.Writer.Header().Set(k, v[0])
|
|
//}
|
|
return &dto.MidjourneyResponseWithStatusCode{
|
|
StatusCode: statusCode,
|
|
Response: midjResponse,
|
|
}, responseBody, nil
|
|
}
|