From 30763deb965de7a135c53d85d8bdbfaf1828537d Mon Sep 17 00:00:00 2001 From: Seefs Date: Sat, 13 Sep 2025 16:26:14 +0800 Subject: [PATCH] fix veo3 adapter --- relay/channel/task/vertex/adaptor.go | 85 ++++++++++++++++------------ 1 file changed, 48 insertions(+), 37 deletions(-) diff --git a/relay/channel/task/vertex/adaptor.go b/relay/channel/task/vertex/adaptor.go index d2ab826d..4a236b2f 100644 --- a/relay/channel/task/vertex/adaptor.go +++ b/relay/channel/task/vertex/adaptor.go @@ -7,12 +7,12 @@ import ( "fmt" "io" "net/http" + "one-api/model" "regexp" "strings" "github.com/gin-gonic/gin" - "one-api/common" "one-api/constant" "one-api/dto" "one-api/relay/channel" @@ -21,6 +21,10 @@ import ( "one-api/service" ) +// ============================ +// Request / Response structures +// ============================ + type requestPayload struct { Instances []map[string]any `json:"instances"` Parameters map[string]any `json:"parameters,omitempty"` @@ -52,33 +56,35 @@ type operationResponse struct { } `json:"error"` } -type TaskAdaptor struct{} +// ============================ +// Adaptor implementation +// ============================ -func (a *TaskAdaptor) Init(info *relaycommon.TaskRelayInfo) {} - -func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.TaskRelayInfo) (taskErr *dto.TaskError) { - info.Action = constant.TaskActionTextGenerate - - req := relaycommon.TaskSubmitReq{} - if err := common.UnmarshalBodyReusable(c, &req); err != nil { - return service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest) - } - if strings.TrimSpace(req.Prompt) == "" { - return service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest) - } - c.Set("task_request", req) - return nil +type TaskAdaptor struct { + ChannelType int + apiKey string + baseURL string } -func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, error) { +func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) { + a.ChannelType = info.ChannelType + a.baseURL = info.ChannelBaseUrl + a.apiKey = info.ApiKey +} + +// ValidateRequestAndSetAction parses body, validates fields and sets default action. +func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) { + // Use the standard validation method for TaskSubmitReq + return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionTextGenerate) +} + +// BuildRequestURL constructs the upstream URL. +func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) { adc := &vertexcore.Credentials{} - if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil { + if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil { return "", fmt.Errorf("failed to decode credentials: %w", err) } modelName := info.OriginModelName - if v, ok := getRequestModelFromContext(info); ok { - modelName = v - } if modelName == "" { modelName = "veo-3.0-generate-001" } @@ -103,16 +109,17 @@ func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.TaskRelayInfo) (string, ), nil } -func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.TaskRelayInfo) error { +// BuildRequestHeader sets required headers. +func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error { req.Header.Set("Content-Type", "application/json") req.Header.Set("Accept", "application/json") adc := &vertexcore.Credentials{} - if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil { + if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil { return fmt.Errorf("failed to decode credentials: %w", err) } - token, err := vertexcore.AcquireAccessToken(*adc, info.ChannelSetting.Proxy) + token, err := vertexcore.AcquireAccessToken(*adc, "") if err != nil { return fmt.Errorf("failed to acquire access token: %w", err) } @@ -121,7 +128,8 @@ func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info return nil } -func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.TaskRelayInfo) (io.Reader, error) { +// BuildRequestBody converts request into Vertex specific format. +func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) { v, ok := c.Get("task_request") if !ok { return nil, fmt.Errorf("request not found in context") @@ -151,11 +159,13 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.TaskRelayI return bytes.NewReader(data), nil } -func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.TaskRelayInfo, requestBody io.Reader) (*http.Response, error) { +// DoRequest delegates to common helper. +func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) { return channel.DoTaskApiRequest(a, c, info, requestBody) } -func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relaycommon.TaskRelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { +// DoResponse handles upstream response, returns taskID etc. +func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) { responseBody, err := io.ReadAll(resp.Body) if err != nil { return "", nil, service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError) @@ -177,6 +187,7 @@ func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, _ *relayco func (a *TaskAdaptor) GetModelList() []string { return []string{"veo-3.0-generate-001"} } func (a *TaskAdaptor) GetChannelName() string { return "vertex" } +// FetchTask fetch task status func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) { taskID, ok := body["task_id"].(string) if !ok { @@ -191,15 +202,15 @@ func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http region = "us-central1" } project := extractProjectFromOperationName(upstreamName) - model := extractModelFromOperationName(upstreamName) - if project == "" || model == "" { + modelName := extractModelFromOperationName(upstreamName) + if project == "" || modelName == "" { return nil, fmt.Errorf("cannot extract project/model from operation name") } var url string if region == "global" { - url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:fetchPredictOperation", project, model) + url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:fetchPredictOperation", project, modelName) } else { - url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, model) + url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, modelName) } payload := map[string]string{"operationName": upstreamName} data, err := json.Marshal(payload) @@ -232,17 +243,17 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e } ti := &relaycommon.TaskInfo{} if op.Error.Message != "" { - ti.Status = "FAILURE" + ti.Status = model.TaskStatusFailure ti.Reason = op.Error.Message ti.Progress = "100%" return ti, nil } if !op.Done { - ti.Status = "IN_PROGRESS" + ti.Status = model.TaskStatusInProgress ti.Progress = "50%" return ti, nil } - ti.Status = "SUCCESS" + ti.Status = model.TaskStatusSuccess ti.Progress = "100%" if len(op.Response.Videos) > 0 { v0 := op.Response.Videos[0] @@ -290,9 +301,9 @@ func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, e return ti, nil } -func getRequestModelFromContext(info *relaycommon.TaskRelayInfo) (string, bool) { - return info.OriginModelName, info.OriginModelName != "" -} +// ============================ +// helpers +// ============================ func encodeLocalTaskID(name string) string { return base64.RawURLEncoding.EncodeToString([]byte(name))