feat: support amazon nova
This commit is contained in:
@@ -60,7 +60,16 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
|||||||
if request == nil {
|
if request == nil {
|
||||||
return nil, errors.New("request is nil")
|
return nil, errors.New("request is nil")
|
||||||
}
|
}
|
||||||
|
// 检查是否为Nova模型
|
||||||
|
if isNovaModel(request.Model) {
|
||||||
|
novaReq := convertToNovaRequest(request)
|
||||||
|
c.Set("request_model", request.Model)
|
||||||
|
c.Set("converted_request", novaReq)
|
||||||
|
c.Set("is_nova_model", true)
|
||||||
|
return novaReq, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 原有的Claude模型处理逻辑
|
||||||
var claudeReq *dto.ClaudeRequest
|
var claudeReq *dto.ClaudeRequest
|
||||||
var err error
|
var err error
|
||||||
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(c, *request)
|
claudeReq, err = claude.RequestOpenAI2ClaudeMessage(c, *request)
|
||||||
@@ -69,6 +78,7 @@ func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayIn
|
|||||||
}
|
}
|
||||||
c.Set("request_model", claudeReq.Model)
|
c.Set("request_model", claudeReq.Model)
|
||||||
c.Set("converted_request", claudeReq)
|
c.Set("converted_request", claudeReq)
|
||||||
|
c.Set("is_nova_model", false)
|
||||||
return claudeReq, err
|
return claudeReq, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
package aws
|
package aws
|
||||||
|
|
||||||
|
import "strings"
|
||||||
|
|
||||||
var awsModelIDMap = map[string]string{
|
var awsModelIDMap = map[string]string{
|
||||||
"claude-instant-1.2": "anthropic.claude-instant-v1",
|
"claude-instant-1.2": "anthropic.claude-instant-v1",
|
||||||
"claude-2.0": "anthropic.claude-v2",
|
"claude-2.0": "anthropic.claude-v2",
|
||||||
@@ -14,6 +16,10 @@ var awsModelIDMap = map[string]string{
|
|||||||
"claude-sonnet-4-20250514": "anthropic.claude-sonnet-4-20250514-v1:0",
|
"claude-sonnet-4-20250514": "anthropic.claude-sonnet-4-20250514-v1:0",
|
||||||
"claude-opus-4-20250514": "anthropic.claude-opus-4-20250514-v1:0",
|
"claude-opus-4-20250514": "anthropic.claude-opus-4-20250514-v1:0",
|
||||||
"claude-opus-4-1-20250805": "anthropic.claude-opus-4-1-20250805-v1:0",
|
"claude-opus-4-1-20250805": "anthropic.claude-opus-4-1-20250805-v1:0",
|
||||||
|
// Nova models
|
||||||
|
"amazon.nova-micro-v1:0": "us.amazon.nova-micro-v1:0",
|
||||||
|
"amazon.nova-lite-v1:0": "us.amazon.nova-lite-v1:0",
|
||||||
|
"amazon.nova-pro-v1:0": "us.amazon.nova-pro-v1:0",
|
||||||
}
|
}
|
||||||
|
|
||||||
var awsModelCanCrossRegionMap = map[string]map[string]bool{
|
var awsModelCanCrossRegionMap = map[string]map[string]bool{
|
||||||
@@ -67,3 +73,8 @@ var awsRegionCrossModelPrefixMap = map[string]string{
|
|||||||
}
|
}
|
||||||
|
|
||||||
var ChannelName = "aws"
|
var ChannelName = "aws"
|
||||||
|
|
||||||
|
// 判断是否为Nova模型
|
||||||
|
func isNovaModel(modelId string) bool {
|
||||||
|
return strings.HasPrefix(modelId, "amazon.nova-")
|
||||||
|
}
|
||||||
|
|||||||
@@ -34,3 +34,56 @@ func copyRequest(req *dto.ClaudeRequest) *AwsClaudeRequest {
|
|||||||
Thinking: req.Thinking,
|
Thinking: req.Thinking,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Nova模型使用messages-v1格式
|
||||||
|
type NovaMessage struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content []NovaContent `json:"content"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type NovaContent struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type NovaRequest struct {
|
||||||
|
SchemaVersion string `json:"schemaVersion"`
|
||||||
|
Messages []NovaMessage `json:"messages"`
|
||||||
|
InferenceConfig NovaInferenceConfig `json:"inferenceConfig,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type NovaInferenceConfig struct {
|
||||||
|
MaxTokens int `json:"maxTokens,omitempty"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
TopP float64 `json:"topP,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换OpenAI请求为Nova格式
|
||||||
|
func convertToNovaRequest(req *dto.GeneralOpenAIRequest) *NovaRequest {
|
||||||
|
novaMessages := make([]NovaMessage, len(req.Messages))
|
||||||
|
for i, msg := range req.Messages {
|
||||||
|
novaMessages[i] = NovaMessage{
|
||||||
|
Role: msg.Role,
|
||||||
|
Content: []NovaContent{{Text: msg.StringContent()}},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
novaReq := &NovaRequest{
|
||||||
|
SchemaVersion: "messages-v1",
|
||||||
|
Messages: novaMessages,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置推理配置
|
||||||
|
if req.MaxTokens != 0 || (req.Temperature != nil && *req.Temperature != 0) || req.TopP != 0 {
|
||||||
|
if req.MaxTokens != 0 {
|
||||||
|
novaReq.InferenceConfig.MaxTokens = int(req.MaxTokens)
|
||||||
|
}
|
||||||
|
if req.Temperature != nil && *req.Temperature != 0 {
|
||||||
|
novaReq.InferenceConfig.Temperature = *req.Temperature
|
||||||
|
}
|
||||||
|
if req.TopP != 0 {
|
||||||
|
novaReq.InferenceConfig.TopP = req.TopP
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return novaReq
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package aws
|
package aws
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"one-api/common"
|
"one-api/common"
|
||||||
@@ -93,7 +94,13 @@ func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*
|
|||||||
}
|
}
|
||||||
|
|
||||||
awsModelId := awsModelID(c.GetString("request_model"))
|
awsModelId := awsModelID(c.GetString("request_model"))
|
||||||
|
// 检查是否为Nova模型
|
||||||
|
isNova, _ := c.Get("is_nova_model")
|
||||||
|
if isNova == true {
|
||||||
|
return handleNovaRequest(c, awsCli, info, awsModelId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 原有的Claude处理逻辑
|
||||||
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
|
||||||
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
|
||||||
if canCrossRegion {
|
if canCrossRegion {
|
||||||
@@ -209,3 +216,74 @@ func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.Rel
|
|||||||
claude.HandleStreamFinalResponse(c, info, claudeInfo, RequestModeMessage)
|
claude.HandleStreamFinalResponse(c, info, claudeInfo, RequestModeMessage)
|
||||||
return nil, claudeInfo.Usage
|
return nil, claudeInfo.Usage
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Nova模型处理函数
|
||||||
|
func handleNovaRequest(c *gin.Context, awsCli *bedrockruntime.Client, info *relaycommon.RelayInfo, awsModelId string) (*types.NewAPIError, *dto.Usage) {
|
||||||
|
novaReq_, ok := c.Get("converted_request")
|
||||||
|
if !ok {
|
||||||
|
return types.NewError(errors.New("nova request not found"), types.ErrorCodeInvalidRequest), nil
|
||||||
|
}
|
||||||
|
novaReq := novaReq_.(*NovaRequest)
|
||||||
|
|
||||||
|
// 使用InvokeModel API,但使用Nova格式的请求体
|
||||||
|
awsReq := &bedrockruntime.InvokeModelInput{
|
||||||
|
ModelId: aws.String(awsModelId),
|
||||||
|
Accept: aws.String("application/json"),
|
||||||
|
ContentType: aws.String("application/json"),
|
||||||
|
}
|
||||||
|
|
||||||
|
reqBody, err := json.Marshal(novaReq)
|
||||||
|
if err != nil {
|
||||||
|
return types.NewError(errors.Wrap(err, "marshal nova request"), types.ErrorCodeBadResponseBody), nil
|
||||||
|
}
|
||||||
|
awsReq.Body = reqBody
|
||||||
|
|
||||||
|
awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
|
||||||
|
if err != nil {
|
||||||
|
return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析Nova响应
|
||||||
|
var novaResp struct {
|
||||||
|
Output struct {
|
||||||
|
Message struct {
|
||||||
|
Content []struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
} `json:"content"`
|
||||||
|
} `json:"message"`
|
||||||
|
} `json:"output"`
|
||||||
|
Usage struct {
|
||||||
|
InputTokens int `json:"inputTokens"`
|
||||||
|
OutputTokens int `json:"outputTokens"`
|
||||||
|
TotalTokens int `json:"totalTokens"`
|
||||||
|
} `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(awsResp.Body, &novaResp); err != nil {
|
||||||
|
return types.NewError(errors.Wrap(err, "unmarshal nova response"), types.ErrorCodeBadResponseBody), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构造OpenAI格式响应
|
||||||
|
response := dto.OpenAITextResponse{
|
||||||
|
Id: helper.GetResponseID(c),
|
||||||
|
Object: "chat.completion",
|
||||||
|
Created: common.GetTimestamp(),
|
||||||
|
Model: info.UpstreamModelName,
|
||||||
|
Choices: []dto.OpenAITextResponseChoice{{
|
||||||
|
Index: 0,
|
||||||
|
Message: dto.Message{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: novaResp.Output.Message.Content[0].Text,
|
||||||
|
},
|
||||||
|
FinishReason: "stop",
|
||||||
|
}},
|
||||||
|
Usage: dto.Usage{
|
||||||
|
PromptTokens: novaResp.Usage.InputTokens,
|
||||||
|
CompletionTokens: novaResp.Usage.OutputTokens,
|
||||||
|
TotalTokens: novaResp.Usage.TotalTokens,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(http.StatusOK, response)
|
||||||
|
return nil, &response.Usage
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user