refactor: Introduce pre-consume quota and unify relay handlers
This commit introduces a major architectural refactoring to improve quota management, centralize logging, and streamline the relay handling logic. Key changes: - **Pre-consume Quota:** Implements a new mechanism to check and reserve user quota *before* making the request to the upstream provider. This ensures more accurate quota deduction and prevents users from exceeding their limits due to concurrent requests. - **Unified Relay Handlers:** Refactors the relay logic to use generic handlers (e.g., `ChatHandler`, `ImageHandler`) instead of provider-specific implementations. This significantly reduces code duplication and simplifies adding new channels. - **Centralized Logger:** A new dedicated `logger` package is introduced, and all system logging calls are migrated to use it, moving this responsibility out of the `common` package. - **Code Reorganization:** DTOs are generalized (e.g., `dalle.go` -> `openai_image.go`) and utility code is moved to more appropriate packages (e.g., `common/http.go` -> `service/http.go`) for better code structure.
This commit is contained in:
@@ -10,6 +10,7 @@ import (
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/logger"
|
||||
"one-api/model"
|
||||
"one-api/relay"
|
||||
"sort"
|
||||
@@ -25,7 +26,7 @@ func UpdateTaskBulk() {
|
||||
//imageModel := "midjourney"
|
||||
for {
|
||||
time.Sleep(time.Duration(15) * time.Second)
|
||||
common.SysLog("任务进度轮询开始")
|
||||
logger.SysLog("任务进度轮询开始")
|
||||
ctx := context.TODO()
|
||||
allTasks := model.GetAllUnFinishSyncTasks(500)
|
||||
platformTask := make(map[constant.TaskPlatform][]*model.Task)
|
||||
@@ -54,9 +55,9 @@ func UpdateTaskBulk() {
|
||||
"progress": "100%",
|
||||
})
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
|
||||
logger.LogError(ctx, fmt.Sprintf("Fix null task_id task error: %v", err))
|
||||
} else {
|
||||
common.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
|
||||
logger.LogInfo(ctx, fmt.Sprintf("Fix null task_id task success: %v", nullTaskIds))
|
||||
}
|
||||
}
|
||||
if len(taskChannelM) == 0 {
|
||||
@@ -65,7 +66,7 @@ func UpdateTaskBulk() {
|
||||
|
||||
UpdateTaskByPlatform(platform, taskChannelM, taskM)
|
||||
}
|
||||
common.SysLog("任务进度轮询完成")
|
||||
logger.SysLog("任务进度轮询完成")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,7 +78,7 @@ func UpdateTaskByPlatform(platform constant.TaskPlatform, taskChannelM map[int][
|
||||
_ = UpdateSunoTaskAll(context.Background(), taskChannelM, taskM)
|
||||
default:
|
||||
if err := UpdateVideoTaskAll(context.Background(), platform, taskChannelM, taskM); err != nil {
|
||||
common.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err))
|
||||
logger.SysLog(fmt.Sprintf("UpdateVideoTaskAll fail: %s", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -86,27 +87,27 @@ func UpdateSunoTaskAll(ctx context.Context, taskChannelM map[int][]string, taskM
|
||||
for channelId, taskIds := range taskChannelM {
|
||||
err := updateSunoTaskAll(ctx, channelId, taskIds, taskM)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error()))
|
||||
logger.LogError(ctx, fmt.Sprintf("渠道 #%d 更新异步任务失败: %d", channelId, err.Error()))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, taskM map[string]*model.Task) error {
|
||||
common.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
|
||||
logger.LogInfo(ctx, fmt.Sprintf("渠道 #%d 未完成的任务有: %d", channelId, len(taskIds)))
|
||||
if len(taskIds) == 0 {
|
||||
return nil
|
||||
}
|
||||
channel, err := model.CacheGetChannel(channelId)
|
||||
if err != nil {
|
||||
common.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
|
||||
logger.SysLog(fmt.Sprintf("CacheGetChannel: %v", err))
|
||||
err = model.TaskBulkUpdate(taskIds, map[string]any{
|
||||
"fail_reason": fmt.Sprintf("获取渠道信息失败,请联系管理员,渠道ID:%d", channelId),
|
||||
"status": "FAILURE",
|
||||
"progress": "100%",
|
||||
})
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
|
||||
logger.SysError(fmt.Sprintf("UpdateMidjourneyTask error2: %v", err))
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -118,27 +119,27 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
||||
"ids": taskIds,
|
||||
})
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Get Task Do req error: %v", err))
|
||||
logger.SysError(fmt.Sprintf("Get Task Do req error: %v", err))
|
||||
return err
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||
logger.LogError(ctx, fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||
return errors.New(fmt.Sprintf("Get Task status code: %d", resp.StatusCode))
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
common.SysError(fmt.Sprintf("Get Task parse body error: %v", err))
|
||||
logger.SysError(fmt.Sprintf("Get Task parse body error: %v", err))
|
||||
return err
|
||||
}
|
||||
var responseItems dto.TaskResponse[[]dto.SunoDataResponse]
|
||||
err = json.Unmarshal(responseBody, &responseItems)
|
||||
if err != nil {
|
||||
common.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
||||
logger.LogError(ctx, fmt.Sprintf("Get Task parse body error2: %v, body: %s", err, string(responseBody)))
|
||||
return err
|
||||
}
|
||||
if !responseItems.IsSuccess() {
|
||||
common.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %d", channelId, len(taskIds), string(responseBody)))
|
||||
logger.SysLog(fmt.Sprintf("渠道 #%d 未完成的任务有: %d, 成功获取到任务数: %d", channelId, len(taskIds), string(responseBody)))
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -154,19 +155,19 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
||||
task.StartTime = lo.If(responseItem.StartTime != 0, responseItem.StartTime).Else(task.StartTime)
|
||||
task.FinishTime = lo.If(responseItem.FinishTime != 0, responseItem.FinishTime).Else(task.FinishTime)
|
||||
if responseItem.FailReason != "" || task.Status == model.TaskStatusFailure {
|
||||
common.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
|
||||
logger.LogInfo(ctx, task.TaskID+" 构建失败,"+task.FailReason)
|
||||
task.Progress = "100%"
|
||||
//err = model.CacheUpdateUserQuota(task.UserId) ?
|
||||
if err != nil {
|
||||
common.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||
logger.LogError(ctx, "error update user quota cache: "+err.Error())
|
||||
} else {
|
||||
quota := task.Quota
|
||||
if quota != 0 {
|
||||
err = model.IncreaseUserQuota(task.UserId, quota, false)
|
||||
if err != nil {
|
||||
common.LogError(ctx, "fail to increase user quota: "+err.Error())
|
||||
logger.LogError(ctx, "fail to increase user quota: "+err.Error())
|
||||
}
|
||||
logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, common.LogQuota(quota))
|
||||
logContent := fmt.Sprintf("异步任务执行失败 %s,补偿 %s", task.TaskID, logger.LogQuota(quota))
|
||||
model.RecordLog(task.UserId, model.LogTypeSystem, logContent)
|
||||
}
|
||||
}
|
||||
@@ -178,7 +179,7 @@ func updateSunoTaskAll(ctx context.Context, channelId int, taskIds []string, tas
|
||||
|
||||
err = task.Update()
|
||||
if err != nil {
|
||||
common.SysError("UpdateMidjourneyTask task error: " + err.Error())
|
||||
logger.SysError("UpdateMidjourneyTask task error: " + err.Error())
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user