refactor(task): extract billing and polling logic from controller to service layer
Restructure the task relay system for better separation of concerns: - Extract task billing into service/task_billing.go with unified settlement flow - Move task polling loop from controller to service/task_polling.go (supports Suno + video platforms) - Split RelayTask into fetch/submit paths with dedicated retry logic (taskSubmitWithRetry) - Add TaskDto, TaskResponse generics, and FetchReq to dto/task.go - Add taskcommon/helpers.go for shared task adaptor utilities - Remove controller/task_video.go (logic consolidated into service layer) - Update all task adaptors (ali, doubao, gemini, hailuo, jimeng, kling, sora, suno, vertex, vidu) - Simplify frontend task logs to use new TaskDto response format
This commit is contained in:
@@ -2,13 +2,12 @@ package vertex
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/model"
|
||||
@@ -17,6 +16,7 @@ import (
|
||||
"github.com/QuantumNous/new-api/constant"
|
||||
"github.com/QuantumNous/new-api/dto"
|
||||
"github.com/QuantumNous/new-api/relay/channel"
|
||||
taskcommon "github.com/QuantumNous/new-api/relay/channel/task/taskcommon"
|
||||
vertexcore "github.com/QuantumNous/new-api/relay/channel/vertex"
|
||||
relaycommon "github.com/QuantumNous/new-api/relay/common"
|
||||
"github.com/QuantumNous/new-api/service"
|
||||
@@ -82,7 +82,7 @@ func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycom
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
adc := &vertexcore.Credentials{}
|
||||
if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil {
|
||||
if err := common.Unmarshal([]byte(a.apiKey), adc); err != nil {
|
||||
return "", fmt.Errorf("failed to decode credentials: %w", err)
|
||||
}
|
||||
modelName := info.OriginModelName
|
||||
@@ -116,7 +116,7 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
adc := &vertexcore.Credentials{}
|
||||
if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil {
|
||||
if err := common.Unmarshal([]byte(a.apiKey), adc); err != nil {
|
||||
return fmt.Errorf("failed to decode credentials: %w", err)
|
||||
}
|
||||
|
||||
@@ -184,7 +184,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
// info.PriceData.OtherRatios["durationSeconds"] = float64(v.(int))
|
||||
// }
|
||||
|
||||
data, err := json.Marshal(body)
|
||||
data, err := common.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -205,14 +205,19 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *rela
|
||||
_ = resp.Body.Close()
|
||||
|
||||
var s submitResponse
|
||||
if err := json.Unmarshal(responseBody, &s); err != nil {
|
||||
if err := common.Unmarshal(responseBody, &s); err != nil {
|
||||
return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if strings.TrimSpace(s.Name) == "" {
|
||||
return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError)
|
||||
}
|
||||
localID := encodeLocalTaskID(s.Name)
|
||||
c.JSON(http.StatusOK, gin.H{"task_id": localID})
|
||||
localID := taskcommon.EncodeLocalTaskID(s.Name)
|
||||
ov := dto.NewOpenAIVideo()
|
||||
ov.ID = info.PublicTaskID
|
||||
ov.TaskID = info.PublicTaskID
|
||||
ov.CreatedAt = time.Now().Unix()
|
||||
ov.Model = info.OriginModelName
|
||||
c.JSON(http.StatusOK, ov)
|
||||
return localID, responseBody, nil
|
||||
}
|
||||
|
||||
@@ -225,7 +230,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
}
|
||||
upstreamName, err := decodeLocalTaskID(taskID)
|
||||
upstreamName, err := taskcommon.DecodeLocalTaskID(taskID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode task_id failed: %w", err)
|
||||
}
|
||||
@@ -245,12 +250,12 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
|
||||
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, modelName)
|
||||
}
|
||||
payload := map[string]string{"operationName": upstreamName}
|
||||
data, err := json.Marshal(payload)
|
||||
data, err := common.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
adc := &vertexcore.Credentials{}
|
||||
if err := json.Unmarshal([]byte(key), adc); err != nil {
|
||||
if err := common.Unmarshal([]byte(key), adc); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode credentials: %w", err)
|
||||
}
|
||||
token, err := vertexcore.AcquireAccessToken(*adc, proxy)
|
||||
@@ -274,7 +279,7 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any, proxy
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
var op operationResponse
|
||||
if err := json.Unmarshal(respBody, &op); err != nil {
|
||||
if err := common.Unmarshal(respBody, &op); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal operation response failed: %w", err)
|
||||
}
|
||||
ti := &relaycommon.TaskInfo{}
|
||||
@@ -338,7 +343,10 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
|
||||
upstreamName, err := decodeLocalTaskID(task.TaskID)
|
||||
// Use GetUpstreamTaskID() to get the real upstream operation name for model extraction.
|
||||
// task.TaskID is now a public task_xxxx ID, no longer a base64-encoded upstream name.
|
||||
upstreamTaskID := task.GetUpstreamTaskID()
|
||||
upstreamName, err := taskcommon.DecodeLocalTaskID(upstreamTaskID)
|
||||
if err != nil {
|
||||
upstreamName = ""
|
||||
}
|
||||
@@ -353,8 +361,8 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
|
||||
v.SetProgressStr(task.Progress)
|
||||
v.CreatedAt = task.CreatedAt
|
||||
v.CompletedAt = task.UpdatedAt
|
||||
if strings.HasPrefix(task.FailReason, "data:") && len(task.FailReason) > 0 {
|
||||
v.SetMetadata("url", task.FailReason)
|
||||
if resultURL := task.GetResultURL(); strings.HasPrefix(resultURL, "data:") && len(resultURL) > 0 {
|
||||
v.SetMetadata("url", resultURL)
|
||||
}
|
||||
|
||||
return common.Marshal(v)
|
||||
@@ -364,18 +372,6 @@ func (a *TaskAdaptor) ConvertToOpenAIVideo(task *model.Task) ([]byte, error) {
|
||||
// helpers
|
||||
// ============================
|
||||
|
||||
func encodeLocalTaskID(name string) string {
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(name))
|
||||
}
|
||||
|
||||
func decodeLocalTaskID(local string) (string, error) {
|
||||
b, err := base64.RawURLEncoding.DecodeString(local)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
var regionRe = regexp.MustCompile(`locations/([a-z0-9-]+)/`)
|
||||
|
||||
func extractRegionFromOperationName(name string) string {
|
||||
|
||||
Reference in New Issue
Block a user