Merge branch 'main' of github.com:danding5/new-api
# Conflicts: # relay/relay_adaptor.go
This commit is contained in:
@@ -264,9 +264,8 @@ func doRequest(c *gin.Context, req *http.Request, info *common.RelayInfo) (*http
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, types.NewError(err, types.ErrorCodeDoRequestFailed, types.ErrOptionWithHideErrMsg("upstream error: do request failed"))
|
||||
}
|
||||
if resp == nil {
|
||||
return nil, errors.New("resp is nil")
|
||||
|
||||
@@ -60,7 +60,16 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
if request == nil {
|
||||
return nil, errors.New("request is nil")
|
||||
}
|
||||
// 检查是否为Nova模型
|
||||
if isNovaModel(request.Model) {
|
||||
novaReq := convertToNovaRequest(request)
|
||||
c.Set("request_model", request.Model)
|
||||
c.Set("converted_request", novaReq)
|
||||
c.Set("is_nova_model", true)
|
||||
return novaReq, nil
|
||||
}
|
||||
|
||||
// 原有的Claude模型处理逻辑
|
||||
var claudeReq *dto.ClaudeRequest
|
||||
var err error
|
||||
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(c, *request)
|
||||
@@ -69,6 +78,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
c.Set("request_model", claudeReq.Model)
|
||||
c.Set("converted_request", claudeReq)
|
||||
c.Set("is_nova_model", false)
|
||||
return claudeReq, err
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package aws
|
||||
|
||||
import "strings"
|
||||
|
||||
var awsModelIDMap = map[string]string{
|
||||
"claude-instant-1.2": "anthropic.claude-instant-v1",
|
||||
"claude-2.0": "anthropic.claude-v2",
|
||||
@@ -14,6 +16,11 @@ var awsModelIDMap = map[string]string{
|
||||
"claude-sonnet-4-20250514": "anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
"claude-opus-4-20250514": "anthropic.claude-opus-4-20250514-v1:0",
|
||||
"claude-opus-4-1-20250805": "anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
// Nova models
|
||||
"nova-micro-v1:0": "amazon.nova-micro-v1:0",
|
||||
"nova-lite-v1:0": "amazon.nova-lite-v1:0",
|
||||
"nova-pro-v1:0": "amazon.nova-pro-v1:0",
|
||||
"nova-premier-v1:0": "amazon.nova-premier-v1:0",
|
||||
}
|
||||
|
||||
var awsModelCanCrossRegionMap = map[string]map[string]bool{
|
||||
@@ -58,7 +65,27 @@ var awsModelCanCrossRegionMap = map[string]map[string]bool{
|
||||
"anthropic.claude-opus-4-1-20250805-v1:0": {
|
||||
"us": true,
|
||||
},
|
||||
}
|
||||
// Nova models - all support three major regions
|
||||
"amazon.nova-micro-v1:0": {
|
||||
"us": true,
|
||||
"eu": true,
|
||||
"apac": true,
|
||||
},
|
||||
"amazon.nova-lite-v1:0": {
|
||||
"us": true,
|
||||
"eu": true,
|
||||
"apac": true,
|
||||
},
|
||||
"amazon.nova-pro-v1:0": {
|
||||
"us": true,
|
||||
"eu": true,
|
||||
"apac": true,
|
||||
},
|
||||
"amazon.nova-premier-v1:0": {
|
||||
"us": true,
|
||||
"eu": true,
|
||||
"apac": true,
|
||||
}}
|
||||
|
||||
var awsRegionCrossModelPrefixMap = map[string]string{
|
||||
"us": "us",
|
||||
@@ -67,3 +94,8 @@ var awsRegionCrossModelPrefixMap = map[string]string{
|
||||
}
|
||||
|
||||
var ChannelName = "aws"
|
||||
|
||||
// 判断是否为Nova模型
|
||||
func isNovaModel(modelId string) bool {
|
||||
return strings.HasPrefix(modelId, "nova-")
|
||||
}
|
||||
|
||||
@@ -34,3 +34,92 @@ func copyRequest(req *dto.ClaudeRequest) *AwsClaudeRequest {
|
||||
Thinking: req.Thinking,
|
||||
}
|
||||
}
|
||||
|
||||
// NovaMessage Nova模型使用messages-v1格式
|
||||
type NovaMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content []NovaContent `json:"content"`
|
||||
}
|
||||
|
||||
type NovaContent struct {
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
type NovaRequest struct {
|
||||
SchemaVersion string `json:"schemaVersion"` // 请求版本,例如 "1.0"
|
||||
Messages []NovaMessage `json:"messages"` // 对话消息列表
|
||||
InferenceConfig *NovaInferenceConfig `json:"inferenceConfig,omitempty"` // 推理配置,可选
|
||||
}
|
||||
|
||||
type NovaInferenceConfig struct {
|
||||
MaxTokens int `json:"maxTokens,omitempty"` // 最大生成的 token 数
|
||||
Temperature float64 `json:"temperature,omitempty"` // 随机性 (默认 0.7, 范围 0-1)
|
||||
TopP float64 `json:"topP,omitempty"` // nucleus sampling (默认 0.9, 范围 0-1)
|
||||
TopK int `json:"topK,omitempty"` // 限制候选 token 数 (默认 50, 范围 0-128)
|
||||
StopSequences []string `json:"stopSequences,omitempty"` // 停止生成的序列
|
||||
}
|
||||
|
||||
// 转换OpenAI请求为Nova格式
|
||||
func convertToNovaRequest(req *dto.GeneralOpenAIRequest) *NovaRequest {
|
||||
novaMessages := make([]NovaMessage, len(req.Messages))
|
||||
for i, msg := range req.Messages {
|
||||
novaMessages[i] = NovaMessage{
|
||||
Role: msg.Role,
|
||||
Content: []NovaContent{{Text: msg.StringContent()}},
|
||||
}
|
||||
}
|
||||
|
||||
novaReq := &NovaRequest{
|
||||
SchemaVersion: "messages-v1",
|
||||
Messages: novaMessages,
|
||||
}
|
||||
|
||||
// 设置推理配置
|
||||
if req.MaxTokens != 0 || (req.Temperature != nil && *req.Temperature != 0) || req.TopP != 0 || req.TopK != 0 || req.Stop != nil {
|
||||
novaReq.InferenceConfig = &NovaInferenceConfig{}
|
||||
if req.MaxTokens != 0 {
|
||||
novaReq.InferenceConfig.MaxTokens = int(req.MaxTokens)
|
||||
}
|
||||
if req.Temperature != nil && *req.Temperature != 0 {
|
||||
novaReq.InferenceConfig.Temperature = *req.Temperature
|
||||
}
|
||||
if req.TopP != 0 {
|
||||
novaReq.InferenceConfig.TopP = req.TopP
|
||||
}
|
||||
if req.TopK != 0 {
|
||||
novaReq.InferenceConfig.TopK = req.TopK
|
||||
}
|
||||
if req.Stop != nil {
|
||||
if stopSequences := parseStopSequences(req.Stop); len(stopSequences) > 0 {
|
||||
novaReq.InferenceConfig.StopSequences = stopSequences
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return novaReq
|
||||
}
|
||||
|
||||
// parseStopSequences 解析停止序列,支持字符串或字符串数组
|
||||
func parseStopSequences(stop any) []string {
|
||||
if stop == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch v := stop.(type) {
|
||||
case string:
|
||||
if v != "" {
|
||||
return []string{v}
|
||||
}
|
||||
case []string:
|
||||
return v
|
||||
case []interface{}:
|
||||
var sequences []string
|
||||
for _, item := range v {
|
||||
if str, ok := item.(string); ok && str != "" {
|
||||
sequences = append(sequences, str)
|
||||
}
|
||||
}
|
||||
return sequences
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
@@ -93,7 +94,19 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
||||
}
|
||||
|
||||
awsModelId := awsModelID(c.GetString("request_model"))
|
||||
// 检查是否为Nova模型
|
||||
isNova, _ := c.Get("is_nova_model")
|
||||
if isNova == true {
|
||||
// Nova模型也支持跨区域
|
||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||
if canCrossRegion {
|
||||
awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
|
||||
}
|
||||
return handleNovaRequest(c, awsCli, info, awsModelId)
|
||||
}
|
||||
|
||||
// 原有的Claude处理逻辑
|
||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||
if canCrossRegion {
|
||||
@@ -209,3 +222,74 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
||||
claude.HandleStreamFinalResponse(c, info, claudeInfo, RequestModeMessage)
|
||||
return nil, claudeInfo.Usage
|
||||
}
|
||||
|
||||
// Nova模型处理函数
|
||||
func handleNovaRequest(c *gin.Context, awsCli *bedrockruntime.Client, info *relaycommon.RelayInfo, awsModelId string) (*types.NewAPIError, *dto.Usage) {
|
||||
novaReq_, ok := c.Get("converted_request")
|
||||
if !ok {
|
||||
return types.NewError(errors.New("nova request not found"), types.ErrorCodeInvalidRequest), nil
|
||||
}
|
||||
novaReq := novaReq_.(*NovaRequest)
|
||||
|
||||
// 使用InvokeModel API,但使用Nova格式的请求体
|
||||
awsReq := &bedrockruntime.InvokeModelInput{
|
||||
ModelId: aws.String(awsModelId),
|
||||
Accept: aws.String("application/json"),
|
||||
ContentType: aws.String("application/json"),
|
||||
}
|
||||
|
||||
reqBody, err := json.Marshal(novaReq)
|
||||
if err != nil {
|
||||
return types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
awsReq.Body = reqBody
|
||||
|
||||
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
|
||||
if err != nil {
|
||||
return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil
|
||||
}
|
||||
|
||||
// 解析Nova响应
|
||||
var novaResp struct {
|
||||
Output struct {
|
||||
Message struct {
|
||||
Content []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"content"`
|
||||
} `json:"message"`
|
||||
} `json:"output"`
|
||||
Usage struct {
|
||||
InputTokens int `json:"inputTokens"`
|
||||
OutputTokens int `json:"outputTokens"`
|
||||
TotalTokens int `json:"totalTokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(awsResp.Body, &novaResp); err != nil {
|
||||
return types.NewError(errors.Wrap(err, "unmarshal nova response"), types.ErrorCodeBadResponseBody), nil
|
||||
}
|
||||
|
||||
// 构造OpenAI格式响应
|
||||
response := dto.OpenAITextResponse{
|
||||
Id: helper.GetResponseID(c),
|
||||
Object: "chat.completion",
|
||||
Created: common.GetTimestamp(),
|
||||
Model: info.UpstreamModelName,
|
||||
Choices: []dto.OpenAITextResponseChoice{{
|
||||
Index: 0,
|
||||
Message: dto.Message{
|
||||
Role: "assistant",
|
||||
Content: novaResp.Output.Message.Content[0].Text,
|
||||
},
|
||||
FinishReason: "stop",
|
||||
}},
|
||||
Usage: dto.Usage{
|
||||
PromptTokens: novaResp.Usage.InputTokens,
|
||||
CompletionTokens: novaResp.Usage.OutputTokens,
|
||||
TotalTokens: novaResp.Usage.TotalTokens,
|
||||
},
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, response)
|
||||
return nil, &response.Usage
|
||||
}
|
||||
|
||||
@@ -46,32 +46,6 @@ func GeminiTextGenerationHandler(c *gin.Context, info *relaycommon.RelayInfo, re
|
||||
|
||||
usage.CompletionTokenDetails.ReasoningTokens = geminiResponse.UsageMetadata.ThoughtsTokenCount
|
||||
|
||||
if strings.HasPrefix(info.UpstreamModelName, "gemini-2.5-flash-image-preview") {
|
||||
imageOutputCounts := 0
|
||||
for _, candidate := range geminiResponse.Candidates {
|
||||
for _, part := range candidate.Content.Parts {
|
||||
if part.InlineData != nil && strings.HasPrefix(part.InlineData.MimeType, "image/") {
|
||||
imageOutputCounts++
|
||||
}
|
||||
}
|
||||
}
|
||||
if imageOutputCounts != 0 {
|
||||
usage.CompletionTokens = usage.CompletionTokens - imageOutputCounts*1290
|
||||
usage.TotalTokens = usage.TotalTokens - imageOutputCounts*1290
|
||||
c.Set("gemini_image_tokens", imageOutputCounts*1290)
|
||||
}
|
||||
}
|
||||
|
||||
// if strings.HasPrefix(info.UpstreamModelName, "gemini-2.5-flash-image-preview") {
|
||||
// for _, detail := range geminiResponse.UsageMetadata.CandidatesTokensDetails {
|
||||
// if detail.Modality == "IMAGE" {
|
||||
// usage.CompletionTokens = usage.CompletionTokens - detail.TokenCount
|
||||
// usage.TotalTokens = usage.TotalTokens - detail.TokenCount
|
||||
// c.Set("gemini_image_tokens", detail.TokenCount)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
for _, detail := range geminiResponse.UsageMetadata.PromptTokensDetails {
|
||||
if detail.Modality == "AUDIO" {
|
||||
usage.PromptTokensDetails.AudioTokens = detail.TokenCount
|
||||
@@ -162,16 +136,6 @@ func GeminiTextGenerationStreamHandler(c *gin.Context, info *relaycommon.RelayIn
|
||||
usage.PromptTokensDetails.TextTokens = detail.TokenCount
|
||||
}
|
||||
}
|
||||
|
||||
if strings.HasPrefix(info.UpstreamModelName, "gemini-2.5-flash-image-preview") {
|
||||
for _, detail := range geminiResponse.UsageMetadata.CandidatesTokensDetails {
|
||||
if detail.Modality == "IMAGE" {
|
||||
usage.CompletionTokens = usage.CompletionTokens - detail.TokenCount
|
||||
usage.TotalTokens = usage.TotalTokens - detail.TokenCount
|
||||
c.Set("gemini_image_tokens", detail.TokenCount)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 直接发送 GeminiChatResponse 响应
|
||||
|
||||
@@ -17,15 +17,15 @@ type Adaptor struct {
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertClaudeRequest(*gin.Context, *relaycommon.RelayInfo, *dto.ClaudeRequest) (any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
return nil, errors.New("submodel channel: endpoint not supported")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
return nil, errors.New("submodel channel: endpoint not supported")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
return nil, errors.New("submodel channel: endpoint not supported")
|
||||
}
|
||||
|
||||
func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
@@ -49,15 +49,15 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
return nil, errors.New("submodel channel: endpoint not supported")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
return nil, errors.New("submodel channel: endpoint not supported")
|
||||
}
|
||||
|
||||
func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
return nil, errors.New("submodel channel: endpoint not supported")
|
||||
}
|
||||
|
||||
func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) {
|
||||
@@ -79,4 +79,4 @@ func (a *Adaptor) GetModelList() []string {
|
||||
|
||||
func (a *Adaptor) GetChannelName() string {
|
||||
return ChannelName
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
@@ -89,22 +88,7 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||
// Accept only POST /v1/video/generations as "generate" action.
|
||||
action := constant.TaskActionGenerate
|
||||
info.Action = action
|
||||
|
||||
req := relaycommon.TaskSubmitReq{}
|
||||
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(req.Prompt) == "" {
|
||||
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Store into context for later usage
|
||||
c.Set("task_request", req)
|
||||
return nil
|
||||
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
|
||||
}
|
||||
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
@@ -334,11 +318,11 @@ func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*
|
||||
}
|
||||
|
||||
// Handle one-of image_urls or binary_data_base64
|
||||
if req.Image != "" {
|
||||
if strings.HasPrefix(req.Image, "http") {
|
||||
r.ImageUrls = []string{req.Image}
|
||||
if req.HasImage() {
|
||||
if strings.HasPrefix(req.Images[0], "http") {
|
||||
r.ImageUrls = req.Images
|
||||
} else {
|
||||
r.BinaryDataBase64 = []string{req.Image}
|
||||
r.BinaryDataBase64 = req.Images
|
||||
}
|
||||
}
|
||||
metadata := req.Metadata
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"one-api/common"
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
@@ -28,16 +27,6 @@ import (
|
||||
// Request / Response structures
|
||||
// ============================
|
||||
|
||||
type SubmitReq struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
type TrajectoryPoint struct {
|
||||
X int `json:"x"`
|
||||
Y int `json:"y"`
|
||||
@@ -121,23 +110,8 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
|
||||
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||
// Accept only POST /v1/video/generations as "generate" action.
|
||||
action := constant.TaskActionGenerate
|
||||
info.Action = action
|
||||
|
||||
var req SubmitReq
|
||||
if err := common.UnmarshalBodyReusable(c, &req); err != nil {
|
||||
taskErr = service.TaskErrorWrapperLocal(err, "invalid_request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(req.Prompt) == "" {
|
||||
taskErr = service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "invalid_request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Store into context for later usage
|
||||
c.Set("task_request", req)
|
||||
return nil
|
||||
// Use the standard validation method for TaskSubmitReq
|
||||
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionGenerate)
|
||||
}
|
||||
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
@@ -166,7 +140,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayIn
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("request not found in context")
|
||||
}
|
||||
req := v.(SubmitReq)
|
||||
req := v.(relaycommon.TaskSubmitReq)
|
||||
|
||||
body, err := a.convertToRequestPayload(&req)
|
||||
if err != nil {
|
||||
@@ -255,7 +229,7 @@ func (a *TaskAdaptor) GetChannelName() string {
|
||||
// helpers
|
||||
// ============================
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
|
||||
r := requestPayload{
|
||||
Prompt: req.Prompt,
|
||||
Image: req.Image,
|
||||
|
||||
355
relay/channel/task/vertex/adaptor.go
Normal file
355
relay/channel/task/vertex/adaptor.go
Normal file
@@ -0,0 +1,355 @@
|
||||
package vertex
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/model"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"one-api/constant"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
vertexcore "one-api/relay/channel/vertex"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
)
|
||||
|
||||
// ============================
|
||||
// Request / Response structures
|
||||
// ============================
|
||||
|
||||
type requestPayload struct {
|
||||
Instances []map[string]any `json:"instances"`
|
||||
Parameters map[string]any `json:"parameters,omitempty"`
|
||||
}
|
||||
|
||||
type submitResponse struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type operationVideo struct {
|
||||
MimeType string `json:"mimeType"`
|
||||
BytesBase64Encoded string `json:"bytesBase64Encoded"`
|
||||
Encoding string `json:"encoding"`
|
||||
}
|
||||
|
||||
type operationResponse struct {
|
||||
Name string `json:"name"`
|
||||
Done bool `json:"done"`
|
||||
Response struct {
|
||||
Type string `json:"@type"`
|
||||
RaiMediaFilteredCount int `json:"raiMediaFilteredCount"`
|
||||
Videos []operationVideo `json:"videos"`
|
||||
BytesBase64Encoded string `json:"bytesBase64Encoded"`
|
||||
Encoding string `json:"encoding"`
|
||||
Video string `json:"video"`
|
||||
} `json:"response"`
|
||||
Error struct {
|
||||
Message string `json:"message"`
|
||||
} `json:"error"`
|
||||
}
|
||||
|
||||
// ============================
|
||||
// Adaptor implementation
|
||||
// ============================
|
||||
|
||||
type TaskAdaptor struct {
|
||||
ChannelType int
|
||||
apiKey string
|
||||
baseURL string
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
a.ChannelType = info.ChannelType
|
||||
a.baseURL = info.ChannelBaseUrl
|
||||
a.apiKey = info.ApiKey
|
||||
}
|
||||
|
||||
// ValidateRequestAndSetAction parses body, validates fields and sets default action.
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) (taskErr *dto.TaskError) {
|
||||
// Use the standard validation method for TaskSubmitReq
|
||||
return relaycommon.ValidateBasicTaskRequest(c, info, constant.TaskActionTextGenerate)
|
||||
}
|
||||
|
||||
// BuildRequestURL constructs the upstream URL.
|
||||
func (a *TaskAdaptor) BuildRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
adc := &vertexcore.Credentials{}
|
||||
if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil {
|
||||
return "", fmt.Errorf("failed to decode credentials: %w", err)
|
||||
}
|
||||
modelName := info.OriginModelName
|
||||
if modelName == "" {
|
||||
modelName = "veo-3.0-generate-001"
|
||||
}
|
||||
|
||||
region := vertexcore.GetModelRegion(info.ApiVersion, modelName)
|
||||
if strings.TrimSpace(region) == "" {
|
||||
region = "global"
|
||||
}
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:predictLongRunning",
|
||||
adc.ProjectID,
|
||||
modelName,
|
||||
), nil
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:predictLongRunning",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
modelName,
|
||||
), nil
|
||||
}
|
||||
|
||||
// BuildRequestHeader sets required headers.
|
||||
func (a *TaskAdaptor) BuildRequestHeader(c *gin.Context, req *http.Request, info *relaycommon.RelayInfo) error {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
adc := &vertexcore.Credentials{}
|
||||
if err := json.Unmarshal([]byte(a.apiKey), adc); err != nil {
|
||||
return fmt.Errorf("failed to decode credentials: %w", err)
|
||||
}
|
||||
|
||||
token, err := vertexcore.AcquireAccessToken(*adc, "")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to acquire access token: %w", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("x-goog-user-project", adc.ProjectID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildRequestBody converts request into Vertex specific format.
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, info *relaycommon.RelayInfo) (io.Reader, error) {
|
||||
v, ok := c.Get("task_request")
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("request not found in context")
|
||||
}
|
||||
req := v.(relaycommon.TaskSubmitReq)
|
||||
|
||||
body := requestPayload{
|
||||
Instances: []map[string]any{{"prompt": req.Prompt}},
|
||||
Parameters: map[string]any{},
|
||||
}
|
||||
if req.Metadata != nil {
|
||||
if v, ok := req.Metadata["storageUri"]; ok {
|
||||
body.Parameters["storageUri"] = v
|
||||
}
|
||||
if v, ok := req.Metadata["sampleCount"]; ok {
|
||||
body.Parameters["sampleCount"] = v
|
||||
}
|
||||
}
|
||||
if _, ok := body.Parameters["sampleCount"]; !ok {
|
||||
body.Parameters["sampleCount"] = 1
|
||||
}
|
||||
|
||||
data, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return bytes.NewReader(data), nil
|
||||
}
|
||||
|
||||
// DoRequest delegates to common helper.
|
||||
func (a *TaskAdaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (*http.Response, error) {
|
||||
return channel.DoTaskApiRequest(a, c, info, requestBody)
|
||||
}
|
||||
|
||||
// DoResponse handles upstream response, returns taskID etc.
|
||||
func (a *TaskAdaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (taskID string, taskData []byte, taskErr *dto.TaskError) {
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", nil, service.TaskErrorWrapper(err, "read_response_body_failed", http.StatusInternalServerError)
|
||||
}
|
||||
_ = resp.Body.Close()
|
||||
|
||||
var s submitResponse
|
||||
if err := json.Unmarshal(responseBody, &s); err != nil {
|
||||
return "", nil, service.TaskErrorWrapper(err, "unmarshal_response_failed", http.StatusInternalServerError)
|
||||
}
|
||||
if strings.TrimSpace(s.Name) == "" {
|
||||
return "", nil, service.TaskErrorWrapper(fmt.Errorf("missing operation name"), "invalid_response", http.StatusInternalServerError)
|
||||
}
|
||||
localID := encodeLocalTaskID(s.Name)
|
||||
c.JSON(http.StatusOK, gin.H{"task_id": localID})
|
||||
return localID, responseBody, nil
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) GetModelList() []string { return []string{"veo-3.0-generate-001"} }
|
||||
func (a *TaskAdaptor) GetChannelName() string { return "vertex" }
|
||||
|
||||
// FetchTask fetch task status
|
||||
func (a *TaskAdaptor) FetchTask(baseUrl, key string, body map[string]any) (*http.Response, error) {
|
||||
taskID, ok := body["task_id"].(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("invalid task_id")
|
||||
}
|
||||
upstreamName, err := decodeLocalTaskID(taskID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode task_id failed: %w", err)
|
||||
}
|
||||
region := extractRegionFromOperationName(upstreamName)
|
||||
if region == "" {
|
||||
region = "us-central1"
|
||||
}
|
||||
project := extractProjectFromOperationName(upstreamName)
|
||||
modelName := extractModelFromOperationName(upstreamName)
|
||||
if project == "" || modelName == "" {
|
||||
return nil, fmt.Errorf("cannot extract project/model from operation name")
|
||||
}
|
||||
var url string
|
||||
if region == "global" {
|
||||
url = fmt.Sprintf("https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:fetchPredictOperation", project, modelName)
|
||||
} else {
|
||||
url = fmt.Sprintf("https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:fetchPredictOperation", region, project, region, modelName)
|
||||
}
|
||||
payload := map[string]string{"operationName": upstreamName}
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
adc := &vertexcore.Credentials{}
|
||||
if err := json.Unmarshal([]byte(key), adc); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode credentials: %w", err)
|
||||
}
|
||||
token, err := vertexcore.AcquireAccessToken(*adc, "")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to acquire access token: %w", err)
|
||||
}
|
||||
req, err := http.NewRequest(http.MethodPost, url, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("x-goog-user-project", adc.ProjectID)
|
||||
return service.GetHttpClient().Do(req)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ParseTaskResult(respBody []byte) (*relaycommon.TaskInfo, error) {
|
||||
var op operationResponse
|
||||
if err := json.Unmarshal(respBody, &op); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal operation response failed: %w", err)
|
||||
}
|
||||
ti := &relaycommon.TaskInfo{}
|
||||
if op.Error.Message != "" {
|
||||
ti.Status = model.TaskStatusFailure
|
||||
ti.Reason = op.Error.Message
|
||||
ti.Progress = "100%"
|
||||
return ti, nil
|
||||
}
|
||||
if !op.Done {
|
||||
ti.Status = model.TaskStatusInProgress
|
||||
ti.Progress = "50%"
|
||||
return ti, nil
|
||||
}
|
||||
ti.Status = model.TaskStatusSuccess
|
||||
ti.Progress = "100%"
|
||||
if len(op.Response.Videos) > 0 {
|
||||
v0 := op.Response.Videos[0]
|
||||
if v0.BytesBase64Encoded != "" {
|
||||
mime := strings.TrimSpace(v0.MimeType)
|
||||
if mime == "" {
|
||||
enc := strings.TrimSpace(v0.Encoding)
|
||||
if enc == "" {
|
||||
enc = "mp4"
|
||||
}
|
||||
if strings.Contains(enc, "/") {
|
||||
mime = enc
|
||||
} else {
|
||||
mime = "video/" + enc
|
||||
}
|
||||
}
|
||||
ti.Url = "data:" + mime + ";base64," + v0.BytesBase64Encoded
|
||||
return ti, nil
|
||||
}
|
||||
}
|
||||
if op.Response.BytesBase64Encoded != "" {
|
||||
enc := strings.TrimSpace(op.Response.Encoding)
|
||||
if enc == "" {
|
||||
enc = "mp4"
|
||||
}
|
||||
mime := enc
|
||||
if !strings.Contains(enc, "/") {
|
||||
mime = "video/" + enc
|
||||
}
|
||||
ti.Url = "data:" + mime + ";base64," + op.Response.BytesBase64Encoded
|
||||
return ti, nil
|
||||
}
|
||||
if op.Response.Video != "" { // some variants use `video` as base64
|
||||
enc := strings.TrimSpace(op.Response.Encoding)
|
||||
if enc == "" {
|
||||
enc = "mp4"
|
||||
}
|
||||
mime := enc
|
||||
if !strings.Contains(enc, "/") {
|
||||
mime = "video/" + enc
|
||||
}
|
||||
ti.Url = "data:" + mime + ";base64," + op.Response.Video
|
||||
return ti, nil
|
||||
}
|
||||
return ti, nil
|
||||
}
|
||||
|
||||
// ============================
|
||||
// helpers
|
||||
// ============================
|
||||
|
||||
func encodeLocalTaskID(name string) string {
|
||||
return base64.RawURLEncoding.EncodeToString([]byte(name))
|
||||
}
|
||||
|
||||
func decodeLocalTaskID(local string) (string, error) {
|
||||
b, err := base64.RawURLEncoding.DecodeString(local)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
var regionRe = regexp.MustCompile(`locations/([a-z0-9-]+)/`)
|
||||
|
||||
func extractRegionFromOperationName(name string) string {
|
||||
m := regionRe.FindStringSubmatch(name)
|
||||
if len(m) == 2 {
|
||||
return m[1]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var modelRe = regexp.MustCompile(`models/([^/]+)/operations/`)
|
||||
|
||||
func extractModelFromOperationName(name string) string {
|
||||
m := modelRe.FindStringSubmatch(name)
|
||||
if len(m) == 2 {
|
||||
return m[1]
|
||||
}
|
||||
idx := strings.Index(name, "models/")
|
||||
if idx >= 0 {
|
||||
s := name[idx+len("models/"):]
|
||||
if p := strings.Index(s, "/operations/"); p > 0 {
|
||||
return s[:p]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var projectRe = regexp.MustCompile(`projects/([^/]+)/locations/`)
|
||||
|
||||
func extractProjectFromOperationName(name string) string {
|
||||
m := projectRe.FindStringSubmatch(name)
|
||||
if len(m) == 2 {
|
||||
return m[1]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -23,16 +23,6 @@ import (
|
||||
// Request / Response structures
|
||||
// ============================
|
||||
|
||||
type SubmitReq struct {
|
||||
Prompt string `json:"prompt"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Mode string `json:"mode,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Duration int `json:"duration,omitempty"`
|
||||
Metadata map[string]interface{} `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
type requestPayload struct {
|
||||
Model string `json:"model"`
|
||||
Images []string `json:"images"`
|
||||
@@ -90,23 +80,8 @@ func (a *TaskAdaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) ValidateRequestAndSetAction(c *gin.Context, info *relaycommon.RelayInfo) *dto.TaskError {
|
||||
var req SubmitReq
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
return service.TaskErrorWrapper(err, "invalid_request_body", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if req.Prompt == "" {
|
||||
return service.TaskErrorWrapperLocal(fmt.Errorf("prompt is required"), "missing_prompt", http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if req.Image != "" {
|
||||
info.Action = constant.TaskActionGenerate
|
||||
} else {
|
||||
info.Action = constant.TaskActionTextGenerate
|
||||
}
|
||||
|
||||
c.Set("task_request", req)
|
||||
return nil
|
||||
// Use the unified validation method for TaskSubmitReq with image-based action determination
|
||||
return relaycommon.ValidateTaskRequestWithImageBinding(c, info)
|
||||
}
|
||||
|
||||
func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo) (io.Reader, error) {
|
||||
@@ -114,7 +89,7 @@ func (a *TaskAdaptor) BuildRequestBody(c *gin.Context, _ *relaycommon.RelayInfo)
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("request not found in context")
|
||||
}
|
||||
req := v.(SubmitReq)
|
||||
req := v.(relaycommon.TaskSubmitReq)
|
||||
|
||||
body, err := a.convertToRequestPayload(&req)
|
||||
if err != nil {
|
||||
@@ -211,7 +186,7 @@ func (a *TaskAdaptor) GetChannelName() string {
|
||||
// helpers
|
||||
// ============================
|
||||
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *SubmitReq) (*requestPayload, error) {
|
||||
func (a *TaskAdaptor) convertToRequestPayload(req *relaycommon.TaskSubmitReq) (*requestPayload, error) {
|
||||
var images []string
|
||||
if req.Image != "" {
|
||||
images = []string{req.Image}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/dto"
|
||||
"one-api/relay/channel"
|
||||
"one-api/relay/channel/claude"
|
||||
@@ -80,16 +81,64 @@ func (a *Adaptor) Init(info *relaycommon.RelayInfo) {
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
adc := &Credentials{}
|
||||
if err := json.Unmarshal([]byte(info.ApiKey), adc); err != nil {
|
||||
return "", fmt.Errorf("failed to decode credentials file: %w", err)
|
||||
}
|
||||
func (a *Adaptor) getRequestUrl(info *relaycommon.RelayInfo, modelName, suffix string) (string, error) {
|
||||
region := GetModelRegion(info.ApiVersion, info.OriginModelName)
|
||||
a.AccountCredentials = *adc
|
||||
if info.ChannelOtherSettings.VertexKeyType != dto.VertexKeyTypeAPIKey {
|
||||
adc := &Credentials{}
|
||||
if err := common.Unmarshal([]byte(info.ApiKey), adc); err != nil {
|
||||
return "", fmt.Errorf("failed to decode credentials file: %w", err)
|
||||
}
|
||||
a.AccountCredentials = *adc
|
||||
|
||||
if a.RequestMode == RequestModeLlama {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
), nil
|
||||
}
|
||||
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
|
||||
adc.ProjectID,
|
||||
modelName,
|
||||
suffix,
|
||||
), nil
|
||||
} else {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
modelName,
|
||||
suffix,
|
||||
), nil
|
||||
}
|
||||
} else {
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/publishers/google/models/%s:%s?key=%s",
|
||||
modelName,
|
||||
suffix,
|
||||
info.ApiKey,
|
||||
), nil
|
||||
} else {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/publishers/google/models/%s:%s?key=%s",
|
||||
region,
|
||||
modelName,
|
||||
suffix,
|
||||
info.ApiKey,
|
||||
), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
suffix := ""
|
||||
if a.RequestMode == RequestModeGemini {
|
||||
|
||||
if model_setting.GetGeminiSettings().ThinkingAdapterEnabled {
|
||||
// 新增逻辑:处理 -thinking-<budget> 格式
|
||||
if strings.Contains(info.UpstreamModelName, "-thinking-") {
|
||||
@@ -111,24 +160,7 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if strings.HasPrefix(info.UpstreamModelName, "imagen") {
|
||||
suffix = "predict"
|
||||
}
|
||||
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/google/models/%s:%s",
|
||||
adc.ProjectID,
|
||||
info.UpstreamModelName,
|
||||
suffix,
|
||||
), nil
|
||||
} else {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/google/models/%s:%s",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
info.UpstreamModelName,
|
||||
suffix,
|
||||
), nil
|
||||
}
|
||||
return a.getRequestUrl(info, info.UpstreamModelName, suffix)
|
||||
} else if a.RequestMode == RequestModeClaude {
|
||||
if info.IsStream {
|
||||
suffix = "streamRawPredict?alt=sse"
|
||||
@@ -139,41 +171,25 @@ func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) {
|
||||
if v, ok := claudeModelMap[info.UpstreamModelName]; ok {
|
||||
model = v
|
||||
}
|
||||
if region == "global" {
|
||||
return fmt.Sprintf(
|
||||
"https://aiplatform.googleapis.com/v1/projects/%s/locations/global/publishers/anthropic/models/%s:%s",
|
||||
adc.ProjectID,
|
||||
model,
|
||||
suffix,
|
||||
), nil
|
||||
} else {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1/projects/%s/locations/%s/publishers/anthropic/models/%s:%s",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
model,
|
||||
suffix,
|
||||
), nil
|
||||
}
|
||||
return a.getRequestUrl(info, model, suffix)
|
||||
} else if a.RequestMode == RequestModeLlama {
|
||||
return fmt.Sprintf(
|
||||
"https://%s-aiplatform.googleapis.com/v1beta1/projects/%s/locations/%s/endpoints/openapi/chat/completions",
|
||||
region,
|
||||
adc.ProjectID,
|
||||
region,
|
||||
), nil
|
||||
return a.getRequestUrl(info, "", "")
|
||||
}
|
||||
return "", errors.New("unsupported request mode")
|
||||
}
|
||||
|
||||
func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error {
|
||||
channel.SetupApiRequestHeader(info, c, req)
|
||||
accessToken, err := getAccessToken(a, info)
|
||||
if err != nil {
|
||||
return err
|
||||
if info.ChannelOtherSettings.VertexKeyType != dto.VertexKeyTypeAPIKey {
|
||||
accessToken, err := getAccessToken(a, info)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Set("Authorization", "Bearer "+accessToken)
|
||||
}
|
||||
if a.AccountCredentials.ProjectID != "" {
|
||||
req.Set("x-goog-user-project", a.AccountCredentials.ProjectID)
|
||||
}
|
||||
req.Set("Authorization", "Bearer "+accessToken)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -12,7 +12,10 @@ func GetModelRegion(other string, localModelName string) string {
|
||||
if m[localModelName] != nil {
|
||||
return m[localModelName].(string)
|
||||
} else {
|
||||
return m["default"].(string)
|
||||
if v, ok := m["default"]; ok {
|
||||
return v.(string)
|
||||
}
|
||||
return "global"
|
||||
}
|
||||
}
|
||||
return other
|
||||
|
||||
@@ -6,14 +6,15 @@ import (
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"github.com/bytedance/gopkg/cache/asynccache"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
relaycommon "one-api/relay/common"
|
||||
"one-api/service"
|
||||
"strings"
|
||||
|
||||
"github.com/bytedance/gopkg/cache/asynccache"
|
||||
"github.com/golang-jwt/jwt"
|
||||
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
@@ -137,3 +138,45 @@ func exchangeJwtForAccessToken(signedJWT string, info *relaycommon.RelayInfo) (s
|
||||
|
||||
return "", fmt.Errorf("failed to get access token: %v", result)
|
||||
}
|
||||
|
||||
func AcquireAccessToken(creds Credentials, proxy string) (string, error) {
|
||||
signedJWT, err := createSignedJWT(creds.ClientEmail, creds.PrivateKey)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create signed JWT: %w", err)
|
||||
}
|
||||
return exchangeJwtForAccessTokenWithProxy(signedJWT, proxy)
|
||||
}
|
||||
|
||||
func exchangeJwtForAccessTokenWithProxy(signedJWT string, proxy string) (string, error) {
|
||||
authURL := "https://www.googleapis.com/oauth2/v4/token"
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer")
|
||||
data.Set("assertion", signedJWT)
|
||||
|
||||
var client *http.Client
|
||||
var err error
|
||||
if proxy != "" {
|
||||
client, err = service.NewProxyHttpClient(proxy)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("new proxy http client failed: %w", err)
|
||||
}
|
||||
} else {
|
||||
client = service.GetHttpClient()
|
||||
}
|
||||
|
||||
resp, err := client.PostForm(authURL, data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if accessToken, ok := result["access_token"].(string); ok {
|
||||
return accessToken, nil
|
||||
}
|
||||
return "", fmt.Errorf("failed to get access token: %v", result)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user