refactor: replace DeepCopy with Copy for request handling consistency

This commit is contained in:
CaIon
2025-08-28 14:57:47 +08:00
parent 7e698f658a
commit c21219fcff
12 changed files with 58 additions and 50 deletions

View File

@@ -6,14 +6,21 @@ import (
"github.com/jinzhu/copier" "github.com/jinzhu/copier"
) )
func DeepCopy[T any](src *T) (*T, error) { func Copy[T any](src *T, deepCopy bool) (*T, error) {
if src == nil { if src == nil {
return nil, fmt.Errorf("copy source cannot be nil") return nil, fmt.Errorf("copy source cannot be nil")
} }
var dst T var dst T
if deepCopy {
err := copier.CopyWithOption(&dst, src, copier.Option{DeepCopy: true, IgnoreEmpty: true}) err := copier.CopyWithOption(&dst, src, copier.Option{DeepCopy: true, IgnoreEmpty: true})
if err != nil { if err != nil {
return nil, err return nil, err
} }
} else {
err := copier.Copy(&dst, src)
if err != nil {
return nil, err
}
}
return &dst, nil return &dst, nil
} }

View File

@@ -2,11 +2,12 @@ package dto
import ( import (
"encoding/json" "encoding/json"
"github.com/gin-gonic/gin"
"one-api/common" "one-api/common"
"one-api/logger" "one-api/logger"
"one-api/types" "one-api/types"
"strings" "strings"
"github.com/gin-gonic/gin"
) )
type GeminiChatRequest struct { type GeminiChatRequest struct {

View File

@@ -265,7 +265,7 @@ type Message struct {
Reasoning string `json:"reasoning,omitempty"` Reasoning string `json:"reasoning,omitempty"`
ToolCalls json.RawMessage `json:"tool_calls,omitempty"` ToolCalls json.RawMessage `json:"tool_calls,omitempty"`
ToolCallId string `json:"tool_call_id,omitempty"` ToolCallId string `json:"tool_call_id,omitempty"`
parsedContent []MediaContent parsedContent *[]MediaContent
//parsedStringContent *string //parsedStringContent *string
} }
@@ -441,7 +441,7 @@ func (m *Message) SetStringContent(content string) {
func (m *Message) SetMediaContent(content []MediaContent) { func (m *Message) SetMediaContent(content []MediaContent) {
m.Content = content m.Content = content
m.parsedContent = content m.parsedContent = &content
} }
func (m *Message) IsStringContent() bool { func (m *Message) IsStringContent() bool {
@@ -456,8 +456,8 @@ func (m *Message) ParseContent() []MediaContent {
if m.Content == nil { if m.Content == nil {
return nil return nil
} }
if len(m.parsedContent) > 0 { if m.parsedContent != nil && len(*m.parsedContent) > 0 {
return m.parsedContent return *m.parsedContent
} }
var contentList []MediaContent var contentList []MediaContent
@@ -468,7 +468,7 @@ func (m *Message) ParseContent() []MediaContent {
Type: ContentTypeText, Type: ContentTypeText,
Text: content, Text: content,
}} }}
m.parsedContent = contentList m.parsedContent = &contentList
return contentList return contentList
} }
@@ -580,7 +580,7 @@ func (m *Message) ParseContent() []MediaContent {
} }
if len(contentList) > 0 { if len(contentList) > 0 {
m.parsedContent = contentList m.parsedContent = &contentList
} }
return contentList return contentList
} }
@@ -767,7 +767,7 @@ type WebSearchOptions struct {
// https://platform.openai.com/docs/api-reference/responses/create // https://platform.openai.com/docs/api-reference/responses/create
type OpenAIResponsesRequest struct { type OpenAIResponsesRequest struct {
Model string `json:"model"` Model string `json:"model"`
Input json.RawMessage `json:"input,omitempty"` Input *json.RawMessage `json:"input,omitempty"`
Include json.RawMessage `json:"include,omitempty"` Include json.RawMessage `json:"include,omitempty"`
Instructions json.RawMessage `json:"instructions,omitempty"` Instructions json.RawMessage `json:"instructions,omitempty"`
MaxOutputTokens uint `json:"max_output_tokens,omitempty"` MaxOutputTokens uint `json:"max_output_tokens,omitempty"`
@@ -781,7 +781,7 @@ type OpenAIResponsesRequest struct {
Temperature float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
Text json.RawMessage `json:"text,omitempty"` Text json.RawMessage `json:"text,omitempty"`
ToolChoice json.RawMessage `json:"tool_choice,omitempty"` ToolChoice json.RawMessage `json:"tool_choice,omitempty"`
Tools json.RawMessage `json:"tools,omitempty"` // 需要处理的参数很少MCP 参数太多不确定,所以用 map Tools *json.RawMessage `json:"tools,omitempty"` // 需要处理的参数很少MCP 参数太多不确定,所以用 map
TopP float64 `json:"top_p,omitempty"` TopP float64 `json:"top_p,omitempty"`
Truncation string `json:"truncation,omitempty"` Truncation string `json:"truncation,omitempty"`
User string `json:"user,omitempty"` User string `json:"user,omitempty"`
@@ -837,8 +837,8 @@ func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta {
texts = append(texts, string(r.Prompt)) texts = append(texts, string(r.Prompt))
} }
if len(r.Tools) > 0 { if r.Tools != nil && len(*r.Tools) > 0 {
texts = append(texts, string(r.Tools)) texts = append(texts, string(*r.Tools))
} }
return &types.TokenCountMeta{ return &types.TokenCountMeta{
@@ -859,9 +859,9 @@ func (r *OpenAIResponsesRequest) SetModelName(modelName string) {
} }
func (r *OpenAIResponsesRequest) GetToolsMap() []map[string]any { func (r *OpenAIResponsesRequest) GetToolsMap() []map[string]any {
var toolsMap []map[string]any var toolsMap = make([]map[string]any, 0)
if len(r.Tools) > 0 { if r.Tools != nil && len(*r.Tools) > 0 {
_ = common.Unmarshal(r.Tools, &toolsMap) _ = common.Unmarshal(*r.Tools, &toolsMap)
} }
return toolsMap return toolsMap
} }
@@ -896,17 +896,17 @@ func (r *OpenAIResponsesRequest) ParseInput() []MediaInput {
// inputs = append(inputs, MediaInput{Type: "input_text", Text: str}) // inputs = append(inputs, MediaInput{Type: "input_text", Text: str})
// return inputs // return inputs
// } // }
if common.GetJsonType(r.Input) == "string" { if common.GetJsonType(*r.Input) == "string" {
var str string var str string
_ = common.Unmarshal(r.Input, &str) _ = common.Unmarshal(*r.Input, &str)
inputs = append(inputs, MediaInput{Type: "input_text", Text: str}) inputs = append(inputs, MediaInput{Type: "input_text", Text: str})
return inputs return inputs
} }
// Try array of parts // Try array of parts
if common.GetJsonType(r.Input) == "array" { if common.GetJsonType(*r.Input) == "array" {
var array []any var array []any
_ = common.Unmarshal(r.Input, &array) _ = common.Unmarshal(*r.Input, &array)
for _, itemAny := range array { for _, itemAny := range array {
// Already parsed MediaInput // Already parsed MediaInput
if media, ok := itemAny.(MediaInput); ok { if media, ok := itemAny.(MediaInput); ok {

View File

@@ -22,7 +22,7 @@ func AudioHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
return types.NewError(errors.New("invalid request type"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) return types.NewError(errors.New("invalid request type"), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
} }
request, err := common.DeepCopy(audioReq) request, err := common.Copy(audioReq, false)
if err != nil { if err != nil {
return types.NewError(fmt.Errorf("failed to copy request to AudioRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) return types.NewError(fmt.Errorf("failed to copy request to AudioRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
} }

View File

@@ -27,7 +27,7 @@ func ClaudeHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected *dto.ClaudeRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected *dto.ClaudeRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
} }
request, err := common.DeepCopy(claudeReq) request, err := common.Copy(claudeReq, false)
if err != nil { if err != nil {
return types.NewError(fmt.Errorf("failed to copy request to ClaudeRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) return types.NewError(fmt.Errorf("failed to copy request to ClaudeRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
} }

View File

@@ -313,7 +313,7 @@ func GenRelayInfoResponses(c *gin.Context, request *dto.OpenAIResponsesRequest)
info.ResponsesUsageInfo = &ResponsesUsageInfo{ info.ResponsesUsageInfo = &ResponsesUsageInfo{
BuiltInTools: make(map[string]*BuildInToolInfo), BuiltInTools: make(map[string]*BuildInToolInfo),
} }
if len(request.Tools) > 0 { if request.Tools != nil && len(*request.Tools) > 0 {
for _, tool := range request.GetToolsMap() { for _, tool := range request.GetToolsMap() {
toolType := common.Interface2String(tool["type"]) toolType := common.Interface2String(tool["type"])
info.ResponsesUsageInfo.BuiltInTools[toolType] = &BuildInToolInfo{ info.ResponsesUsageInfo.BuiltInTools[toolType] = &BuildInToolInfo{

View File

@@ -32,7 +32,7 @@ func TextHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *types
return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.GeneralOpenAIRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.GeneralOpenAIRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
} }
request, err := common.DeepCopy(textReq) request, err := common.Copy(textReq, false)
if err != nil { if err != nil {
return types.NewError(fmt.Errorf("failed to copy request to GeneralOpenAIRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) return types.NewError(fmt.Errorf("failed to copy request to GeneralOpenAIRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
} }

View File

@@ -23,7 +23,7 @@ func EmbeddingHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected *dto.EmbeddingRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected *dto.EmbeddingRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
} }
request, err := common.DeepCopy(embeddingReq) request, err := common.Copy(embeddingReq, false)
if err != nil { if err != nil {
return types.NewError(fmt.Errorf("failed to copy request to EmbeddingRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) return types.NewError(fmt.Errorf("failed to copy request to EmbeddingRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
} }

View File

@@ -58,7 +58,7 @@ func GeminiHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected *dto.GeminiChatRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected *dto.GeminiChatRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
} }
request, err := common.DeepCopy(geminiReq) request, err := common.Copy(geminiReq, false)
if err != nil { if err != nil {
return types.NewError(fmt.Errorf("failed to copy request to GeminiChatRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) return types.NewError(fmt.Errorf("failed to copy request to GeminiChatRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
} }

View File

@@ -26,7 +26,7 @@ func ImageHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *type
return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.ImageRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.ImageRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
} }
request, err := common.DeepCopy(imageReq) request, err := common.Copy(imageReq, false)
if err != nil { if err != nil {
return types.NewError(fmt.Errorf("failed to copy request to ImageRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) return types.NewError(fmt.Errorf("failed to copy request to ImageRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
} }

View File

@@ -24,7 +24,7 @@ func RerankHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *typ
return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.RerankRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.RerankRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
} }
request, err := common.DeepCopy(rerankReq) request, err := common.Copy(rerankReq, false)
if err != nil { if err != nil {
return types.NewError(fmt.Errorf("failed to copy request to ImageRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) return types.NewError(fmt.Errorf("failed to copy request to ImageRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
} }

View File

@@ -25,7 +25,7 @@ func ResponsesHelper(c *gin.Context, info *relaycommon.RelayInfo) (newAPIError *
return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.OpenAIResponsesRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry()) return types.NewErrorWithStatusCode(fmt.Errorf("invalid request type, expected dto.OpenAIResponsesRequest, got %T", info.Request), types.ErrorCodeInvalidRequest, http.StatusBadRequest, types.ErrOptionWithSkipRetry())
} }
request, err := common.DeepCopy(responsesReq) request, err := common.Copy(responsesReq, false)
if err != nil { if err != nil {
return types.NewError(fmt.Errorf("failed to copy request to GeneralOpenAIRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry()) return types.NewError(fmt.Errorf("failed to copy request to GeneralOpenAIRequest: %w", err), types.ErrorCodeInvalidRequest, types.ErrOptionWithSkipRetry())
} }