This commit refactors the logging mechanism across the application by replacing direct logger calls with a centralized logging approach using the `common` package. Key changes include: - Replaced instances of `logger.SysLog` and `logger.FatalLog` with `common.SysLog` and `common.FatalLog` for consistent logging practices. - Updated resource initialization error handling to utilize the new logging structure, enhancing maintainability and readability. - Minor adjustments to improve code clarity and organization throughout various modules. This change aims to streamline logging and improve the overall architecture of the codebase.
654 lines
22 KiB
Go
654 lines
22 KiB
Go
package relay
|
||
|
||
import (
|
||
"bytes"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"log"
|
||
"net/http"
|
||
"one-api/common"
|
||
"one-api/constant"
|
||
"one-api/dto"
|
||
"one-api/model"
|
||
relaycommon "one-api/relay/common"
|
||
relayconstant "one-api/relay/constant"
|
||
"one-api/relay/helper"
|
||
"one-api/service"
|
||
"one-api/setting"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
)
|
||
|
||
func RelayMidjourneyImage(c *gin.Context) {
|
||
taskId := c.Param("id")
|
||
midjourneyTask := model.GetByOnlyMJId(taskId)
|
||
if midjourneyTask == nil {
|
||
c.JSON(400, gin.H{
|
||
"error": "midjourney_task_not_found",
|
||
})
|
||
return
|
||
}
|
||
var httpClient *http.Client
|
||
if channel, err := model.CacheGetChannel(midjourneyTask.ChannelId); err == nil {
|
||
proxy := channel.GetSetting().Proxy
|
||
if proxy != "" {
|
||
if httpClient, err = service.NewProxyHttpClient(proxy); err != nil {
|
||
c.JSON(400, gin.H{
|
||
"error": "proxy_url_invalid",
|
||
})
|
||
return
|
||
}
|
||
}
|
||
}
|
||
if httpClient == nil {
|
||
httpClient = service.GetHttpClient()
|
||
}
|
||
resp, err := httpClient.Get(midjourneyTask.ImageUrl)
|
||
if err != nil {
|
||
c.JSON(http.StatusInternalServerError, gin.H{
|
||
"error": "http_get_image_failed",
|
||
})
|
||
return
|
||
}
|
||
defer resp.Body.Close()
|
||
if resp.StatusCode != http.StatusOK {
|
||
responseBody, _ := io.ReadAll(resp.Body)
|
||
c.JSON(resp.StatusCode, gin.H{
|
||
"error": string(responseBody),
|
||
})
|
||
return
|
||
}
|
||
// 从Content-Type头获取MIME类型
|
||
contentType := resp.Header.Get("Content-Type")
|
||
if contentType == "" {
|
||
// 如果无法确定内容类型,则默认为jpeg
|
||
contentType = "image/jpeg"
|
||
}
|
||
// 设置响应的内容类型
|
||
c.Writer.Header().Set("Content-Type", contentType)
|
||
// 将图片流式传输到响应体
|
||
_, err = io.Copy(c.Writer, resp.Body)
|
||
if err != nil {
|
||
log.Println("Failed to stream image:", err)
|
||
}
|
||
return
|
||
}
|
||
|
||
func RelayMidjourneyNotify(c *gin.Context) *dto.MidjourneyResponse {
|
||
var midjRequest dto.MidjourneyDto
|
||
err := common.UnmarshalBodyReusable(c, &midjRequest)
|
||
if err != nil {
|
||
return &dto.MidjourneyResponse{
|
||
Code: 4,
|
||
Description: "bind_request_body_failed",
|
||
Properties: nil,
|
||
Result: "",
|
||
}
|
||
}
|
||
midjourneyTask := model.GetByOnlyMJId(midjRequest.MjId)
|
||
if midjourneyTask == nil {
|
||
return &dto.MidjourneyResponse{
|
||
Code: 4,
|
||
Description: "midjourney_task_not_found",
|
||
Properties: nil,
|
||
Result: "",
|
||
}
|
||
}
|
||
midjourneyTask.Progress = midjRequest.Progress
|
||
midjourneyTask.PromptEn = midjRequest.PromptEn
|
||
midjourneyTask.State = midjRequest.State
|
||
midjourneyTask.SubmitTime = midjRequest.SubmitTime
|
||
midjourneyTask.StartTime = midjRequest.StartTime
|
||
midjourneyTask.FinishTime = midjRequest.FinishTime
|
||
midjourneyTask.ImageUrl = midjRequest.ImageUrl
|
||
midjourneyTask.VideoUrl = midjRequest.VideoUrl
|
||
videoUrlsStr, _ := json.Marshal(midjRequest.VideoUrls)
|
||
midjourneyTask.VideoUrls = string(videoUrlsStr)
|
||
midjourneyTask.Status = midjRequest.Status
|
||
midjourneyTask.FailReason = midjRequest.FailReason
|
||
err = midjourneyTask.Update()
|
||
if err != nil {
|
||
return &dto.MidjourneyResponse{
|
||
Code: 4,
|
||
Description: "update_midjourney_task_failed",
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
func coverMidjourneyTaskDto(c *gin.Context, originTask *model.Midjourney) (midjourneyTask dto.MidjourneyDto) {
|
||
midjourneyTask.MjId = originTask.MjId
|
||
midjourneyTask.Progress = originTask.Progress
|
||
midjourneyTask.PromptEn = originTask.PromptEn
|
||
midjourneyTask.State = originTask.State
|
||
midjourneyTask.SubmitTime = originTask.SubmitTime
|
||
midjourneyTask.StartTime = originTask.StartTime
|
||
midjourneyTask.FinishTime = originTask.FinishTime
|
||
midjourneyTask.ImageUrl = ""
|
||
if originTask.ImageUrl != "" && setting.MjForwardUrlEnabled {
|
||
midjourneyTask.ImageUrl = setting.ServerAddress + "/mj/image/" + originTask.MjId
|
||
if originTask.Status != "SUCCESS" {
|
||
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
|
||
}
|
||
} else {
|
||
midjourneyTask.ImageUrl = originTask.ImageUrl
|
||
}
|
||
if originTask.VideoUrl != "" {
|
||
midjourneyTask.VideoUrl = originTask.VideoUrl
|
||
}
|
||
midjourneyTask.Status = originTask.Status
|
||
midjourneyTask.FailReason = originTask.FailReason
|
||
midjourneyTask.Action = originTask.Action
|
||
midjourneyTask.Description = originTask.Description
|
||
midjourneyTask.Prompt = originTask.Prompt
|
||
if originTask.Buttons != "" {
|
||
var buttons []dto.ActionButton
|
||
err := json.Unmarshal([]byte(originTask.Buttons), &buttons)
|
||
if err == nil {
|
||
midjourneyTask.Buttons = buttons
|
||
}
|
||
}
|
||
if originTask.VideoUrls != "" {
|
||
var videoUrls []dto.ImgUrls
|
||
err := json.Unmarshal([]byte(originTask.VideoUrls), &videoUrls)
|
||
if err == nil {
|
||
midjourneyTask.VideoUrls = videoUrls
|
||
}
|
||
}
|
||
if originTask.Properties != "" {
|
||
var properties dto.Properties
|
||
err := json.Unmarshal([]byte(originTask.Properties), &properties)
|
||
if err == nil {
|
||
midjourneyTask.Properties = &properties
|
||
}
|
||
}
|
||
return
|
||
}
|
||
|
||
func RelaySwapFace(c *gin.Context, info *relaycommon.RelayInfo) *dto.MidjourneyResponse {
|
||
var swapFaceRequest dto.SwapFaceRequest
|
||
err := common.UnmarshalBodyReusable(c, &swapFaceRequest)
|
||
if err != nil {
|
||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
|
||
}
|
||
if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" {
|
||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "sour_base64_and_target_base64_is_required")
|
||
}
|
||
modelName := service.CoverActionToModelName(constant.MjActionSwapFace)
|
||
|
||
priceData := helper.ModelPriceHelperPerCall(c, info)
|
||
|
||
userQuota, err := model.GetUserQuota(info.UserId, false)
|
||
if err != nil {
|
||
return &dto.MidjourneyResponse{
|
||
Code: 4,
|
||
Description: err.Error(),
|
||
}
|
||
}
|
||
|
||
if userQuota-priceData.Quota < 0 {
|
||
return &dto.MidjourneyResponse{
|
||
Code: 4,
|
||
Description: "quota_not_enough",
|
||
}
|
||
}
|
||
requestURL := getMjRequestPath(c.Request.URL.String())
|
||
baseURL := c.GetString("base_url")
|
||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||
mjResp, _, err := service.DoMidjourneyHttpRequest(c, time.Second*60, fullRequestURL)
|
||
if err != nil {
|
||
return &mjResp.Response
|
||
}
|
||
defer func() {
|
||
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
|
||
err := service.PostConsumeQuota(info, priceData.Quota, 0, true)
|
||
if err != nil {
|
||
common.SysLog("error consuming token remain quota: " + err.Error())
|
||
}
|
||
|
||
tokenName := c.GetString("token_name")
|
||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, constant.MjActionSwapFace)
|
||
other := service.GenerateMjOtherInfo(priceData)
|
||
model.RecordConsumeLog(c, info.UserId, model.RecordConsumeLogParams{
|
||
ChannelId: info.ChannelId,
|
||
ModelName: modelName,
|
||
TokenName: tokenName,
|
||
Quota: priceData.Quota,
|
||
Content: logContent,
|
||
TokenId: info.TokenId,
|
||
Group: info.UsingGroup,
|
||
Other: other,
|
||
})
|
||
model.UpdateUserUsedQuotaAndRequestCount(info.UserId, priceData.Quota)
|
||
model.UpdateChannelUsedQuota(info.ChannelId, priceData.Quota)
|
||
}
|
||
}()
|
||
midjResponse := &mjResp.Response
|
||
midjourneyTask := &model.Midjourney{
|
||
UserId: info.UserId,
|
||
Code: midjResponse.Code,
|
||
Action: constant.MjActionSwapFace,
|
||
MjId: midjResponse.Result,
|
||
Prompt: "InsightFace",
|
||
PromptEn: "",
|
||
Description: midjResponse.Description,
|
||
State: "",
|
||
SubmitTime: info.StartTime.UnixNano() / int64(time.Millisecond),
|
||
StartTime: time.Now().UnixNano() / int64(time.Millisecond),
|
||
FinishTime: 0,
|
||
ImageUrl: "",
|
||
Status: "",
|
||
Progress: "0%",
|
||
FailReason: "",
|
||
ChannelId: c.GetInt("channel_id"),
|
||
Quota: priceData.Quota,
|
||
}
|
||
err = midjourneyTask.Insert()
|
||
if err != nil {
|
||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "insert_midjourney_task_failed")
|
||
}
|
||
c.Writer.WriteHeader(mjResp.StatusCode)
|
||
respBody, err := json.Marshal(midjResponse)
|
||
if err != nil {
|
||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
|
||
}
|
||
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
|
||
if err != nil {
|
||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "copy_response_body_failed")
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func RelayMidjourneyTaskImageSeed(c *gin.Context) *dto.MidjourneyResponse {
|
||
taskId := c.Param("id")
|
||
userId := c.GetInt("id")
|
||
originTask := model.GetByMJId(userId, taskId)
|
||
if originTask == nil {
|
||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_no_found")
|
||
}
|
||
channel, err := model.GetChannelById(originTask.ChannelId, true)
|
||
if err != nil {
|
||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
|
||
}
|
||
if channel.Status != common.ChannelStatusEnabled {
|
||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
|
||
}
|
||
c.Set("channel_id", originTask.ChannelId)
|
||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||
|
||
requestURL := getMjRequestPath(c.Request.URL.String())
|
||
fullRequestURL := fmt.Sprintf("%s%s", channel.GetBaseURL(), requestURL)
|
||
midjResponseWithStatus, _, err := service.DoMidjourneyHttpRequest(c, time.Second*30, fullRequestURL)
|
||
if err != nil {
|
||
return &midjResponseWithStatus.Response
|
||
}
|
||
midjResponse := &midjResponseWithStatus.Response
|
||
c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
|
||
respBody, err := json.Marshal(midjResponse)
|
||
if err != nil {
|
||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "unmarshal_response_body_failed")
|
||
}
|
||
service.IOCopyBytesGracefully(c, nil, respBody)
|
||
return nil
|
||
}
|
||
|
||
func RelayMidjourneyTask(c *gin.Context, relayMode int) *dto.MidjourneyResponse {
|
||
userId := c.GetInt("id")
|
||
var err error
|
||
var respBody []byte
|
||
switch relayMode {
|
||
case relayconstant.RelayModeMidjourneyTaskFetch:
|
||
taskId := c.Param("id")
|
||
originTask := model.GetByMJId(userId, taskId)
|
||
if originTask == nil {
|
||
return &dto.MidjourneyResponse{
|
||
Code: 4,
|
||
Description: "task_no_found",
|
||
}
|
||
}
|
||
midjourneyTask := coverMidjourneyTaskDto(c, originTask)
|
||
respBody, err = json.Marshal(midjourneyTask)
|
||
if err != nil {
|
||
return &dto.MidjourneyResponse{
|
||
Code: 4,
|
||
Description: "unmarshal_response_body_failed",
|
||
}
|
||
}
|
||
case relayconstant.RelayModeMidjourneyTaskFetchByCondition:
|
||
var condition = struct {
|
||
IDs []string `json:"ids"`
|
||
}{}
|
||
err = c.BindJSON(&condition)
|
||
if err != nil {
|
||
return &dto.MidjourneyResponse{
|
||
Code: 4,
|
||
Description: "do_request_failed",
|
||
}
|
||
}
|
||
var tasks []dto.MidjourneyDto
|
||
if len(condition.IDs) != 0 {
|
||
originTasks := model.GetByMJIds(userId, condition.IDs)
|
||
for _, originTask := range originTasks {
|
||
midjourneyTask := coverMidjourneyTaskDto(c, originTask)
|
||
tasks = append(tasks, midjourneyTask)
|
||
}
|
||
}
|
||
if tasks == nil {
|
||
tasks = make([]dto.MidjourneyDto, 0)
|
||
}
|
||
respBody, err = json.Marshal(tasks)
|
||
if err != nil {
|
||
return &dto.MidjourneyResponse{
|
||
Code: 4,
|
||
Description: "unmarshal_response_body_failed",
|
||
}
|
||
}
|
||
}
|
||
|
||
c.Writer.Header().Set("Content-Type", "application/json")
|
||
|
||
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
|
||
if err != nil {
|
||
return &dto.MidjourneyResponse{
|
||
Code: 4,
|
||
Description: "copy_response_body_failed",
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func RelayMidjourneySubmit(c *gin.Context, relayInfo *relaycommon.RelayInfo) *dto.MidjourneyResponse {
|
||
consumeQuota := true
|
||
var midjRequest dto.MidjourneyRequest
|
||
err := common.UnmarshalBodyReusable(c, &midjRequest)
|
||
if err != nil {
|
||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "bind_request_body_failed")
|
||
}
|
||
|
||
if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyAction { // midjourney plus,需要从customId中获取任务信息
|
||
mjErr := service.CoverPlusActionToNormalAction(&midjRequest)
|
||
if mjErr != nil {
|
||
return mjErr
|
||
}
|
||
relayInfo.RelayMode = relayconstant.RelayModeMidjourneyChange
|
||
}
|
||
if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyVideo {
|
||
midjRequest.Action = constant.MjActionVideo
|
||
}
|
||
|
||
if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
|
||
if midjRequest.Prompt == "" {
|
||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "prompt_is_required")
|
||
}
|
||
midjRequest.Action = constant.MjActionImagine
|
||
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
|
||
midjRequest.Action = constant.MjActionDescribe
|
||
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyEdits { //编辑任务,此类任务可重复
|
||
midjRequest.Action = constant.MjActionEdits
|
||
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyShorten { //缩短任务,此类任务可重复,plus only
|
||
midjRequest.Action = constant.MjActionShorten
|
||
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
|
||
midjRequest.Action = constant.MjActionBlend
|
||
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyUpload { //绘画任务,此类任务可重复
|
||
midjRequest.Action = constant.MjActionUpload
|
||
} else if midjRequest.TaskId != "" { //放大、变换任务,此类任务,如果重复且已有结果,远端api会直接返回最终结果
|
||
mjId := ""
|
||
if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyChange {
|
||
if midjRequest.TaskId == "" {
|
||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required")
|
||
} else if midjRequest.Action == "" {
|
||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "action_is_required")
|
||
} else if midjRequest.Index == 0 {
|
||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "index_is_required")
|
||
}
|
||
//action = midjRequest.Action
|
||
mjId = midjRequest.TaskId
|
||
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneySimpleChange {
|
||
if midjRequest.Content == "" {
|
||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_is_required")
|
||
}
|
||
params := service.ConvertSimpleChangeParams(midjRequest.Content)
|
||
if params == nil {
|
||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "content_parse_failed")
|
||
}
|
||
mjId = params.TaskId
|
||
midjRequest.Action = params.Action
|
||
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyModal {
|
||
//if midjRequest.MaskBase64 == "" {
|
||
// return service.MidjourneyErrorWrapper(constant.MjRequestError, "mask_base64_is_required")
|
||
//}
|
||
mjId = midjRequest.TaskId
|
||
midjRequest.Action = constant.MjActionModal
|
||
} else if relayInfo.RelayMode == relayconstant.RelayModeMidjourneyVideo {
|
||
midjRequest.Action = constant.MjActionVideo
|
||
if midjRequest.TaskId == "" {
|
||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_id_is_required")
|
||
} else if midjRequest.Action == "" {
|
||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "action_is_required")
|
||
}
|
||
mjId = midjRequest.TaskId
|
||
}
|
||
|
||
originTask := model.GetByMJId(relayInfo.UserId, mjId)
|
||
if originTask == nil {
|
||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_not_found")
|
||
} else { //原任务的Status=SUCCESS,则可以做放大UPSCALE、变换VARIATION等动作,此时必须使用原来的请求地址才能正确处理
|
||
if setting.MjActionCheckSuccessEnabled {
|
||
if originTask.Status != "SUCCESS" && relayInfo.RelayMode != relayconstant.RelayModeMidjourneyModal {
|
||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "task_status_not_success")
|
||
}
|
||
}
|
||
channel, err := model.GetChannelById(originTask.ChannelId, true)
|
||
if err != nil {
|
||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "get_channel_info_failed")
|
||
}
|
||
if channel.Status != common.ChannelStatusEnabled {
|
||
return service.MidjourneyErrorWrapper(constant.MjRequestError, "该任务所属渠道已被禁用")
|
||
}
|
||
c.Set("base_url", channel.GetBaseURL())
|
||
c.Set("channel_id", originTask.ChannelId)
|
||
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
||
log.Printf("检测到此操作为放大、变换、重绘,获取原channel信息: %s,%s", strconv.Itoa(originTask.ChannelId), channel.GetBaseURL())
|
||
}
|
||
midjRequest.Prompt = originTask.Prompt
|
||
|
||
//if channelType == common.ChannelTypeMidjourneyPlus {
|
||
// // plus
|
||
//} else {
|
||
// // 普通版渠道
|
||
//
|
||
//}
|
||
}
|
||
|
||
if midjRequest.Action == constant.MjActionInPaint || midjRequest.Action == constant.MjActionCustomZoom {
|
||
consumeQuota = false
|
||
}
|
||
|
||
//baseURL := common.ChannelBaseURLs[channelType]
|
||
requestURL := getMjRequestPath(c.Request.URL.String())
|
||
|
||
baseURL := c.GetString("base_url")
|
||
|
||
//midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
|
||
|
||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||
|
||
modelName := service.CoverActionToModelName(midjRequest.Action)
|
||
|
||
priceData := helper.ModelPriceHelperPerCall(c, relayInfo)
|
||
|
||
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
||
if err != nil {
|
||
return &dto.MidjourneyResponse{
|
||
Code: 4,
|
||
Description: err.Error(),
|
||
}
|
||
}
|
||
|
||
if consumeQuota && userQuota-priceData.Quota < 0 {
|
||
return &dto.MidjourneyResponse{
|
||
Code: 4,
|
||
Description: "quota_not_enough",
|
||
}
|
||
}
|
||
|
||
midjResponseWithStatus, responseBody, err := service.DoMidjourneyHttpRequest(c, time.Second*60, fullRequestURL)
|
||
if err != nil {
|
||
return &midjResponseWithStatus.Response
|
||
}
|
||
midjResponse := &midjResponseWithStatus.Response
|
||
|
||
defer func() {
|
||
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
|
||
err := service.PostConsumeQuota(relayInfo, priceData.Quota, 0, true)
|
||
if err != nil {
|
||
common.SysLog("error consuming token remain quota: " + err.Error())
|
||
}
|
||
tokenName := c.GetString("token_name")
|
||
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s,ID %s", priceData.ModelPrice, priceData.GroupRatioInfo.GroupRatio, midjRequest.Action, midjResponse.Result)
|
||
other := service.GenerateMjOtherInfo(priceData)
|
||
model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
|
||
ChannelId: relayInfo.ChannelId,
|
||
ModelName: modelName,
|
||
TokenName: tokenName,
|
||
Quota: priceData.Quota,
|
||
Content: logContent,
|
||
TokenId: relayInfo.TokenId,
|
||
Group: relayInfo.UsingGroup,
|
||
Other: other,
|
||
})
|
||
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, priceData.Quota)
|
||
model.UpdateChannelUsedQuota(relayInfo.ChannelId, priceData.Quota)
|
||
}
|
||
}()
|
||
|
||
// 文档:https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
|
||
//1-提交成功
|
||
// 21-任务已存在(处理中或者有结果了) {"code":21,"description":"任务已存在","result":"0741798445574458","properties":{"status":"SUCCESS","imageUrl":"https://xxxx"}}
|
||
// 22-排队中 {"code":22,"description":"排队中,前面还有1个任务","result":"0741798445574458","properties":{"numberOfQueues":1,"discordInstanceId":"1118138338562560102"}}
|
||
// 23-队列已满,请稍后再试 {"code":23,"description":"队列已满,请稍后尝试","result":"14001929738841620","properties":{"discordInstanceId":"1118138338562560102"}}
|
||
// 24-prompt包含敏感词 {"code":24,"description":"可能包含敏感词","properties":{"promptEn":"nude body","bannedWord":"nude"}}
|
||
// other: 提交错误,description为错误描述
|
||
midjourneyTask := &model.Midjourney{
|
||
UserId: relayInfo.UserId,
|
||
Code: midjResponse.Code,
|
||
Action: midjRequest.Action,
|
||
MjId: midjResponse.Result,
|
||
Prompt: midjRequest.Prompt,
|
||
PromptEn: "",
|
||
Description: midjResponse.Description,
|
||
State: "",
|
||
SubmitTime: time.Now().UnixNano() / int64(time.Millisecond),
|
||
StartTime: 0,
|
||
FinishTime: 0,
|
||
ImageUrl: "",
|
||
Status: "",
|
||
Progress: "0%",
|
||
FailReason: "",
|
||
ChannelId: c.GetInt("channel_id"),
|
||
Quota: priceData.Quota,
|
||
}
|
||
if midjResponse.Code == 3 {
|
||
//无实例账号自动禁用渠道(No available account instance)
|
||
channel, err := model.GetChannelById(midjourneyTask.ChannelId, true)
|
||
if err != nil {
|
||
common.SysLog("get_channel_null: " + err.Error())
|
||
}
|
||
if channel.GetAutoBan() && common.AutomaticDisableChannelEnabled {
|
||
model.UpdateChannelStatus(midjourneyTask.ChannelId, "", 2, "No available account instance")
|
||
}
|
||
}
|
||
if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 {
|
||
//非1-提交成功,21-任务已存在和22-排队中,则记录错误原因
|
||
midjourneyTask.FailReason = midjResponse.Description
|
||
consumeQuota = false
|
||
}
|
||
|
||
if midjResponse.Code == 21 { //21-任务已存在(处理中或者有结果了)
|
||
// 将 properties 转换为一个 map
|
||
properties, ok := midjResponse.Properties.(map[string]interface{})
|
||
if ok {
|
||
imageUrl, ok1 := properties["imageUrl"].(string)
|
||
status, ok2 := properties["status"].(string)
|
||
if ok1 && ok2 {
|
||
midjourneyTask.ImageUrl = imageUrl
|
||
midjourneyTask.Status = status
|
||
if status == "SUCCESS" {
|
||
midjourneyTask.Progress = "100%"
|
||
midjourneyTask.StartTime = time.Now().UnixNano() / int64(time.Millisecond)
|
||
midjourneyTask.FinishTime = time.Now().UnixNano() / int64(time.Millisecond)
|
||
midjResponse.Code = 1
|
||
}
|
||
}
|
||
}
|
||
//修改返回值
|
||
if midjRequest.Action != constant.MjActionInPaint && midjRequest.Action != constant.MjActionCustomZoom {
|
||
newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1)
|
||
responseBody = []byte(newBody)
|
||
}
|
||
}
|
||
if midjResponse.Code == 1 && midjRequest.Action == "UPLOAD" {
|
||
midjourneyTask.Progress = "100%"
|
||
midjourneyTask.Status = "SUCCESS"
|
||
}
|
||
err = midjourneyTask.Insert()
|
||
if err != nil {
|
||
return &dto.MidjourneyResponse{
|
||
Code: 4,
|
||
Description: "insert_midjourney_task_failed",
|
||
}
|
||
}
|
||
|
||
if midjResponse.Code == 22 { //22-排队中,说明任务已存在
|
||
//修改返回值
|
||
newBody := strings.Replace(string(responseBody), `"code":22`, `"code":1`, -1)
|
||
responseBody = []byte(newBody)
|
||
}
|
||
//resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||
bodyReader := io.NopCloser(bytes.NewBuffer(responseBody))
|
||
|
||
//for k, v := range resp.Header {
|
||
// c.Writer.Header().Set(k, v[0])
|
||
//}
|
||
c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
|
||
|
||
_, err = io.Copy(c.Writer, bodyReader)
|
||
if err != nil {
|
||
return &dto.MidjourneyResponse{
|
||
Code: 4,
|
||
Description: "copy_response_body_failed",
|
||
}
|
||
}
|
||
err = bodyReader.Close()
|
||
if err != nil {
|
||
return &dto.MidjourneyResponse{
|
||
Code: 4,
|
||
Description: "close_response_body_failed",
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
type taskChangeParams struct {
|
||
ID string
|
||
Action string
|
||
Index int
|
||
}
|
||
|
||
func getMjRequestPath(path string) string {
|
||
requestURL := path
|
||
if strings.Contains(requestURL, "/mj-") {
|
||
urls := strings.Split(requestURL, "/mj/")
|
||
if len(urls) < 2 {
|
||
return requestURL
|
||
}
|
||
requestURL = "/mj/" + urls[1]
|
||
}
|
||
return requestURL
|
||
}
|