This commit refactors the logging mechanism across the application by replacing direct logger calls with a centralized logging approach using the `common` package. Key changes include: - Replaced instances of `logger.SysLog` and `logger.FatalLog` with `common.SysLog` and `common.FatalLog` for consistent logging practices. - Updated resource initialization error handling to utilize the new logging structure, enhancing maintainability and readability. - Minor adjustments to improve code clarity and organization throughout various modules. This change aims to streamline logging and improve the overall architecture of the codebase.
299 lines
8.7 KiB
Go
299 lines
8.7 KiB
Go
package relay
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"one-api/common"
|
|
"one-api/constant"
|
|
"one-api/dto"
|
|
"one-api/model"
|
|
relaycommon "one-api/relay/common"
|
|
relayconstant "one-api/relay/constant"
|
|
"one-api/service"
|
|
"one-api/setting/ratio_setting"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
/*
|
|
Task 任务通过平台、Action 区分任务
|
|
*/
|
|
func RelayTaskSubmit(c *gin.Context, relayMode int) (taskErr *dto.TaskError) {
|
|
platform := constant.TaskPlatform(c.GetString("platform"))
|
|
if platform == "" {
|
|
platform = GetTaskPlatform(c)
|
|
}
|
|
|
|
relayInfo, err := relaycommon.GenTaskRelayInfo(c)
|
|
if err != nil {
|
|
return service.TaskErrorWrapper(err, "gen_relay_info_failed", http.StatusInternalServerError)
|
|
}
|
|
|
|
adaptor := GetTaskAdaptor(platform)
|
|
if adaptor == nil {
|
|
return service.TaskErrorWrapperLocal(fmt.Errorf("invalid api platform: %s", platform), "invalid_api_platform", http.StatusBadRequest)
|
|
}
|
|
adaptor.Init(relayInfo)
|
|
// get & validate taskRequest 获取并验证文本请求
|
|
taskErr = adaptor.ValidateRequestAndSetAction(c, relayInfo)
|
|
if taskErr != nil {
|
|
return
|
|
}
|
|
|
|
modelName := relayInfo.OriginModelName
|
|
if modelName == "" {
|
|
modelName = service.CoverTaskActionToModelName(platform, relayInfo.Action)
|
|
}
|
|
modelPrice, success := ratio_setting.GetModelPrice(modelName, true)
|
|
if !success {
|
|
defaultPrice, ok := ratio_setting.GetDefaultModelRatioMap()[modelName]
|
|
if !ok {
|
|
modelPrice = 0.1
|
|
} else {
|
|
modelPrice = defaultPrice
|
|
}
|
|
}
|
|
|
|
// 预扣
|
|
groupRatio := ratio_setting.GetGroupRatio(relayInfo.UsingGroup)
|
|
var ratio float64
|
|
userGroupRatio, hasUserGroupRatio := ratio_setting.GetGroupGroupRatio(relayInfo.UserGroup, relayInfo.UsingGroup)
|
|
if hasUserGroupRatio {
|
|
ratio = modelPrice * userGroupRatio
|
|
} else {
|
|
ratio = modelPrice * groupRatio
|
|
}
|
|
userQuota, err := model.GetUserQuota(relayInfo.UserId, false)
|
|
if err != nil {
|
|
taskErr = service.TaskErrorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
quota := int(ratio * common.QuotaPerUnit)
|
|
if userQuota-quota < 0 {
|
|
taskErr = service.TaskErrorWrapperLocal(errors.New("user quota is not enough"), "quota_not_enough", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
if relayInfo.OriginTaskID != "" {
|
|
originTask, exist, err := model.GetByTaskId(relayInfo.UserId, relayInfo.OriginTaskID)
|
|
if err != nil {
|
|
taskErr = service.TaskErrorWrapper(err, "get_origin_task_failed", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
if !exist {
|
|
taskErr = service.TaskErrorWrapperLocal(errors.New("task_origin_not_exist"), "task_not_exist", http.StatusBadRequest)
|
|
return
|
|
}
|
|
if originTask.ChannelId != relayInfo.ChannelId {
|
|
channel, err := model.GetChannelById(originTask.ChannelId, true)
|
|
if err != nil {
|
|
taskErr = service.TaskErrorWrapperLocal(err, "channel_not_found", http.StatusBadRequest)
|
|
return
|
|
}
|
|
if channel.Status != common.ChannelStatusEnabled {
|
|
return service.TaskErrorWrapperLocal(errors.New("该任务所属渠道已被禁用"), "task_channel_disable", http.StatusBadRequest)
|
|
}
|
|
c.Set("base_url", channel.GetBaseURL())
|
|
c.Set("channel_id", originTask.ChannelId)
|
|
c.Request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", channel.Key))
|
|
|
|
relayInfo.ChannelBaseUrl = channel.GetBaseURL()
|
|
relayInfo.ChannelId = originTask.ChannelId
|
|
}
|
|
}
|
|
|
|
// build body
|
|
requestBody, err := adaptor.BuildRequestBody(c, relayInfo)
|
|
if err != nil {
|
|
taskErr = service.TaskErrorWrapper(err, "build_request_failed", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
// do request
|
|
resp, err := adaptor.DoRequest(c, relayInfo, requestBody)
|
|
if err != nil {
|
|
taskErr = service.TaskErrorWrapper(err, "do_request_failed", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
// handle response
|
|
if resp != nil && resp.StatusCode != http.StatusOK {
|
|
responseBody, _ := io.ReadAll(resp.Body)
|
|
taskErr = service.TaskErrorWrapper(fmt.Errorf(string(responseBody)), "fail_to_fetch_task", resp.StatusCode)
|
|
return
|
|
}
|
|
|
|
defer func() {
|
|
// release quota
|
|
if relayInfo.ConsumeQuota && taskErr == nil {
|
|
|
|
err := service.PostConsumeQuota(relayInfo.RelayInfo, quota, 0, true)
|
|
if err != nil {
|
|
common.SysLog("error consuming token remain quota: " + err.Error())
|
|
}
|
|
if quota != 0 {
|
|
tokenName := c.GetString("token_name")
|
|
gRatio := groupRatio
|
|
if hasUserGroupRatio {
|
|
gRatio = userGroupRatio
|
|
}
|
|
logContent := fmt.Sprintf("模型固定价格 %.2f,分组倍率 %.2f,操作 %s", modelPrice, gRatio, relayInfo.Action)
|
|
other := make(map[string]interface{})
|
|
other["model_price"] = modelPrice
|
|
other["group_ratio"] = groupRatio
|
|
if hasUserGroupRatio {
|
|
other["user_group_ratio"] = userGroupRatio
|
|
}
|
|
model.RecordConsumeLog(c, relayInfo.UserId, model.RecordConsumeLogParams{
|
|
ChannelId: relayInfo.ChannelId,
|
|
ModelName: modelName,
|
|
TokenName: tokenName,
|
|
Quota: quota,
|
|
Content: logContent,
|
|
TokenId: relayInfo.TokenId,
|
|
Group: relayInfo.UsingGroup,
|
|
Other: other,
|
|
})
|
|
model.UpdateUserUsedQuotaAndRequestCount(relayInfo.UserId, quota)
|
|
model.UpdateChannelUsedQuota(relayInfo.ChannelId, quota)
|
|
}
|
|
}
|
|
}()
|
|
|
|
taskID, taskData, taskErr := adaptor.DoResponse(c, resp, relayInfo)
|
|
if taskErr != nil {
|
|
return
|
|
}
|
|
relayInfo.ConsumeQuota = true
|
|
// insert task
|
|
task := model.InitTask(platform, relayInfo)
|
|
task.TaskID = taskID
|
|
task.Quota = quota
|
|
task.Data = taskData
|
|
task.Action = relayInfo.Action
|
|
err = task.Insert()
|
|
if err != nil {
|
|
taskErr = service.TaskErrorWrapper(err, "insert_task_failed", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
return nil
|
|
}
|
|
|
|
var fetchRespBuilders = map[int]func(c *gin.Context) (respBody []byte, taskResp *dto.TaskError){
|
|
relayconstant.RelayModeSunoFetchByID: sunoFetchByIDRespBodyBuilder,
|
|
relayconstant.RelayModeSunoFetch: sunoFetchRespBodyBuilder,
|
|
relayconstant.RelayModeVideoFetchByID: videoFetchByIDRespBodyBuilder,
|
|
}
|
|
|
|
func RelayTaskFetch(c *gin.Context, relayMode int) (taskResp *dto.TaskError) {
|
|
respBuilder, ok := fetchRespBuilders[relayMode]
|
|
if !ok {
|
|
taskResp = service.TaskErrorWrapperLocal(errors.New("invalid_relay_mode"), "invalid_relay_mode", http.StatusBadRequest)
|
|
}
|
|
|
|
respBody, taskErr := respBuilder(c)
|
|
if taskErr != nil {
|
|
return taskErr
|
|
}
|
|
|
|
c.Writer.Header().Set("Content-Type", "application/json")
|
|
_, err := io.Copy(c.Writer, bytes.NewBuffer(respBody))
|
|
if err != nil {
|
|
taskResp = service.TaskErrorWrapper(err, "copy_response_body_failed", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
return
|
|
}
|
|
|
|
func sunoFetchRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
|
|
userId := c.GetInt("id")
|
|
var condition = struct {
|
|
IDs []any `json:"ids"`
|
|
Action string `json:"action"`
|
|
}{}
|
|
err := c.BindJSON(&condition)
|
|
if err != nil {
|
|
taskResp = service.TaskErrorWrapper(err, "invalid_request", http.StatusBadRequest)
|
|
return
|
|
}
|
|
var tasks []any
|
|
if len(condition.IDs) > 0 {
|
|
taskModels, err := model.GetByTaskIds(userId, condition.IDs)
|
|
if err != nil {
|
|
taskResp = service.TaskErrorWrapper(err, "get_tasks_failed", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
for _, task := range taskModels {
|
|
tasks = append(tasks, TaskModel2Dto(task))
|
|
}
|
|
} else {
|
|
tasks = make([]any, 0)
|
|
}
|
|
respBody, err = json.Marshal(dto.TaskResponse[[]any]{
|
|
Code: "success",
|
|
Data: tasks,
|
|
})
|
|
return
|
|
}
|
|
|
|
func sunoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
|
|
taskId := c.Param("id")
|
|
userId := c.GetInt("id")
|
|
|
|
originTask, exist, err := model.GetByTaskId(userId, taskId)
|
|
if err != nil {
|
|
taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
if !exist {
|
|
taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
respBody, err = json.Marshal(dto.TaskResponse[any]{
|
|
Code: "success",
|
|
Data: TaskModel2Dto(originTask),
|
|
})
|
|
return
|
|
}
|
|
|
|
func videoFetchByIDRespBodyBuilder(c *gin.Context) (respBody []byte, taskResp *dto.TaskError) {
|
|
taskId := c.Param("task_id")
|
|
if taskId == "" {
|
|
taskId = c.GetString("task_id")
|
|
}
|
|
userId := c.GetInt("id")
|
|
|
|
originTask, exist, err := model.GetByTaskId(userId, taskId)
|
|
if err != nil {
|
|
taskResp = service.TaskErrorWrapper(err, "get_task_failed", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
if !exist {
|
|
taskResp = service.TaskErrorWrapperLocal(errors.New("task_not_exist"), "task_not_exist", http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
respBody, err = json.Marshal(dto.TaskResponse[any]{
|
|
Code: "success",
|
|
Data: TaskModel2Dto(originTask),
|
|
})
|
|
return
|
|
}
|
|
|
|
func TaskModel2Dto(task *model.Task) *dto.TaskDto {
|
|
return &dto.TaskDto{
|
|
TaskID: task.TaskID,
|
|
Action: task.Action,
|
|
Status: string(task.Status),
|
|
FailReason: task.FailReason,
|
|
SubmitTime: task.SubmitTime,
|
|
StartTime: task.StartTime,
|
|
FinishTime: task.FinishTime,
|
|
Progress: task.Progress,
|
|
Data: task.Data,
|
|
}
|
|
}
|