fix: preserve explicit zero values in native relay requests
This commit is contained in:
@@ -15,7 +15,7 @@ type AudioRequest struct {
|
||||
Voice string `json:"voice"`
|
||||
Instructions string `json:"instructions,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
Speed float64 `json:"speed,omitempty"`
|
||||
Speed *float64 `json:"speed,omitempty"`
|
||||
StreamFormat string `json:"stream_format,omitempty"`
|
||||
Metadata json.RawMessage `json:"metadata,omitempty"`
|
||||
}
|
||||
|
||||
@@ -197,13 +197,13 @@ type ClaudeRequest struct {
|
||||
// InferenceGeo controls Claude data residency region.
|
||||
// This field is filtered by default and can be enabled via channel setting allow_inference_geo.
|
||||
InferenceGeo string `json:"inference_geo,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
MaxTokensToSample uint `json:"max_tokens_to_sample,omitempty"`
|
||||
MaxTokens *uint `json:"max_tokens,omitempty"`
|
||||
MaxTokensToSample *uint `json:"max_tokens_to_sample,omitempty"`
|
||||
StopSequences []string `json:"stop_sequences,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK *int `json:"top_k,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Tools any `json:"tools,omitempty"`
|
||||
ContextManagement json.RawMessage `json:"context_management,omitempty"`
|
||||
OutputConfig json.RawMessage `json:"output_config,omitempty"`
|
||||
@@ -227,9 +227,13 @@ func createClaudeFileSource(data string) *types.FileSource {
|
||||
}
|
||||
|
||||
func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
maxTokens := 0
|
||||
if c.MaxTokens != nil {
|
||||
maxTokens = int(*c.MaxTokens)
|
||||
}
|
||||
var tokenCountMeta = types.TokenCountMeta{
|
||||
TokenType: types.TokenTypeTokenizer,
|
||||
MaxTokens: int(c.MaxTokens),
|
||||
MaxTokens: maxTokens,
|
||||
}
|
||||
|
||||
var texts = make([]string, 0)
|
||||
@@ -352,7 +356,10 @@ func (c *ClaudeRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
}
|
||||
|
||||
func (c *ClaudeRequest) IsStream(ctx *gin.Context) bool {
|
||||
return c.Stream
|
||||
if c.Stream == nil {
|
||||
return false
|
||||
}
|
||||
return *c.Stream
|
||||
}
|
||||
|
||||
func (c *ClaudeRequest) SetModelName(modelName string) {
|
||||
|
||||
@@ -23,13 +23,13 @@ type EmbeddingRequest struct {
|
||||
Model string `json:"model"`
|
||||
Input any `json:"input"`
|
||||
EncodingFormat string `json:"encoding_format,omitempty"`
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
Dimensions *int `json:"dimensions,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
Seed float64 `json:"seed,omitempty"`
|
||||
Seed *float64 `json:"seed,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
}
|
||||
|
||||
func (r *EmbeddingRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
|
||||
@@ -77,8 +77,8 @@ func (r *GeminiChatRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
|
||||
var maxTokens int
|
||||
|
||||
if r.GenerationConfig.MaxOutputTokens > 0 {
|
||||
maxTokens = int(r.GenerationConfig.MaxOutputTokens)
|
||||
if r.GenerationConfig.MaxOutputTokens != nil && *r.GenerationConfig.MaxOutputTokens > 0 {
|
||||
maxTokens = int(*r.GenerationConfig.MaxOutputTokens)
|
||||
}
|
||||
|
||||
var inputTexts []string
|
||||
@@ -325,21 +325,21 @@ type GeminiChatTool struct {
|
||||
|
||||
type GeminiChatGenerationConfig struct {
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK float64 `json:"topK,omitempty"`
|
||||
MaxOutputTokens uint `json:"maxOutputTokens,omitempty"`
|
||||
CandidateCount int `json:"candidateCount,omitempty"`
|
||||
TopP *float64 `json:"topP,omitempty"`
|
||||
TopK *float64 `json:"topK,omitempty"`
|
||||
MaxOutputTokens *uint `json:"maxOutputTokens,omitempty"`
|
||||
CandidateCount *int `json:"candidateCount,omitempty"`
|
||||
StopSequences []string `json:"stopSequences,omitempty"`
|
||||
ResponseMimeType string `json:"responseMimeType,omitempty"`
|
||||
ResponseSchema any `json:"responseSchema,omitempty"`
|
||||
ResponseJsonSchema json.RawMessage `json:"responseJsonSchema,omitempty"`
|
||||
PresencePenalty *float32 `json:"presencePenalty,omitempty"`
|
||||
FrequencyPenalty *float32 `json:"frequencyPenalty,omitempty"`
|
||||
ResponseLogprobs bool `json:"responseLogprobs,omitempty"`
|
||||
ResponseLogprobs *bool `json:"responseLogprobs,omitempty"`
|
||||
Logprobs *int32 `json:"logprobs,omitempty"`
|
||||
EnableEnhancedCivicAnswers *bool `json:"enableEnhancedCivicAnswers,omitempty"`
|
||||
MediaResolution MediaResolution `json:"mediaResolution,omitempty"`
|
||||
Seed int64 `json:"seed,omitempty"`
|
||||
Seed *int64 `json:"seed,omitempty"`
|
||||
ResponseModalities []string `json:"responseModalities,omitempty"`
|
||||
ThinkingConfig *GeminiThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||
SpeechConfig json.RawMessage `json:"speechConfig,omitempty"` // RawMessage to allow flexible speech config
|
||||
@@ -351,17 +351,17 @@ func (c *GeminiChatGenerationConfig) UnmarshalJSON(data []byte) error {
|
||||
type Alias GeminiChatGenerationConfig
|
||||
var aux struct {
|
||||
Alias
|
||||
TopPSnake float64 `json:"top_p,omitempty"`
|
||||
TopKSnake float64 `json:"top_k,omitempty"`
|
||||
MaxOutputTokensSnake uint `json:"max_output_tokens,omitempty"`
|
||||
CandidateCountSnake int `json:"candidate_count,omitempty"`
|
||||
TopPSnake *float64 `json:"top_p,omitempty"`
|
||||
TopKSnake *float64 `json:"top_k,omitempty"`
|
||||
MaxOutputTokensSnake *uint `json:"max_output_tokens,omitempty"`
|
||||
CandidateCountSnake *int `json:"candidate_count,omitempty"`
|
||||
StopSequencesSnake []string `json:"stop_sequences,omitempty"`
|
||||
ResponseMimeTypeSnake string `json:"response_mime_type,omitempty"`
|
||||
ResponseSchemaSnake any `json:"response_schema,omitempty"`
|
||||
ResponseJsonSchemaSnake json.RawMessage `json:"response_json_schema,omitempty"`
|
||||
PresencePenaltySnake *float32 `json:"presence_penalty,omitempty"`
|
||||
FrequencyPenaltySnake *float32 `json:"frequency_penalty,omitempty"`
|
||||
ResponseLogprobsSnake bool `json:"response_logprobs,omitempty"`
|
||||
ResponseLogprobsSnake *bool `json:"response_logprobs,omitempty"`
|
||||
EnableEnhancedCivicAnswersSnake *bool `json:"enable_enhanced_civic_answers,omitempty"`
|
||||
MediaResolutionSnake MediaResolution `json:"media_resolution,omitempty"`
|
||||
ResponseModalitiesSnake []string `json:"response_modalities,omitempty"`
|
||||
@@ -377,16 +377,16 @@ func (c *GeminiChatGenerationConfig) UnmarshalJSON(data []byte) error {
|
||||
*c = GeminiChatGenerationConfig(aux.Alias)
|
||||
|
||||
// Prioritize snake_case if present
|
||||
if aux.TopPSnake != 0 {
|
||||
if aux.TopPSnake != nil {
|
||||
c.TopP = aux.TopPSnake
|
||||
}
|
||||
if aux.TopKSnake != 0 {
|
||||
if aux.TopKSnake != nil {
|
||||
c.TopK = aux.TopKSnake
|
||||
}
|
||||
if aux.MaxOutputTokensSnake != 0 {
|
||||
if aux.MaxOutputTokensSnake != nil {
|
||||
c.MaxOutputTokens = aux.MaxOutputTokensSnake
|
||||
}
|
||||
if aux.CandidateCountSnake != 0 {
|
||||
if aux.CandidateCountSnake != nil {
|
||||
c.CandidateCount = aux.CandidateCountSnake
|
||||
}
|
||||
if len(aux.StopSequencesSnake) > 0 {
|
||||
@@ -407,7 +407,7 @@ func (c *GeminiChatGenerationConfig) UnmarshalJSON(data []byte) error {
|
||||
if aux.FrequencyPenaltySnake != nil {
|
||||
c.FrequencyPenalty = aux.FrequencyPenaltySnake
|
||||
}
|
||||
if aux.ResponseLogprobsSnake {
|
||||
if aux.ResponseLogprobsSnake != nil {
|
||||
c.ResponseLogprobs = aux.ResponseLogprobsSnake
|
||||
}
|
||||
if aux.EnableEnhancedCivicAnswersSnake != nil {
|
||||
|
||||
89
dto/gemini_generation_config_test.go
Normal file
89
dto/gemini_generation_config_test.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGeminiChatGenerationConfigPreservesExplicitZeroValuesCamelCase(t *testing.T) {
|
||||
raw := []byte(`{
|
||||
"contents":[{"role":"user","parts":[{"text":"hello"}]}],
|
||||
"generationConfig":{
|
||||
"topP":0,
|
||||
"topK":0,
|
||||
"maxOutputTokens":0,
|
||||
"candidateCount":0,
|
||||
"seed":0,
|
||||
"responseLogprobs":false
|
||||
}
|
||||
}`)
|
||||
|
||||
var req GeminiChatRequest
|
||||
require.NoError(t, common.Unmarshal(raw, &req))
|
||||
|
||||
encoded, err := common.Marshal(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var out map[string]any
|
||||
require.NoError(t, common.Unmarshal(encoded, &out))
|
||||
|
||||
generationConfig, ok := out["generationConfig"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Contains(t, generationConfig, "topP")
|
||||
assert.Contains(t, generationConfig, "topK")
|
||||
assert.Contains(t, generationConfig, "maxOutputTokens")
|
||||
assert.Contains(t, generationConfig, "candidateCount")
|
||||
assert.Contains(t, generationConfig, "seed")
|
||||
assert.Contains(t, generationConfig, "responseLogprobs")
|
||||
|
||||
assert.Equal(t, float64(0), generationConfig["topP"])
|
||||
assert.Equal(t, float64(0), generationConfig["topK"])
|
||||
assert.Equal(t, float64(0), generationConfig["maxOutputTokens"])
|
||||
assert.Equal(t, float64(0), generationConfig["candidateCount"])
|
||||
assert.Equal(t, float64(0), generationConfig["seed"])
|
||||
assert.Equal(t, false, generationConfig["responseLogprobs"])
|
||||
}
|
||||
|
||||
func TestGeminiChatGenerationConfigPreservesExplicitZeroValuesSnakeCase(t *testing.T) {
|
||||
raw := []byte(`{
|
||||
"contents":[{"role":"user","parts":[{"text":"hello"}]}],
|
||||
"generationConfig":{
|
||||
"top_p":0,
|
||||
"top_k":0,
|
||||
"max_output_tokens":0,
|
||||
"candidate_count":0,
|
||||
"seed":0,
|
||||
"response_logprobs":false
|
||||
}
|
||||
}`)
|
||||
|
||||
var req GeminiChatRequest
|
||||
require.NoError(t, common.Unmarshal(raw, &req))
|
||||
|
||||
encoded, err := common.Marshal(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
var out map[string]any
|
||||
require.NoError(t, common.Unmarshal(encoded, &out))
|
||||
|
||||
generationConfig, ok := out["generationConfig"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
|
||||
assert.Contains(t, generationConfig, "topP")
|
||||
assert.Contains(t, generationConfig, "topK")
|
||||
assert.Contains(t, generationConfig, "maxOutputTokens")
|
||||
assert.Contains(t, generationConfig, "candidateCount")
|
||||
assert.Contains(t, generationConfig, "seed")
|
||||
assert.Contains(t, generationConfig, "responseLogprobs")
|
||||
|
||||
assert.Equal(t, float64(0), generationConfig["topP"])
|
||||
assert.Equal(t, float64(0), generationConfig["topK"])
|
||||
assert.Equal(t, float64(0), generationConfig["maxOutputTokens"])
|
||||
assert.Equal(t, float64(0), generationConfig["candidateCount"])
|
||||
assert.Equal(t, float64(0), generationConfig["seed"])
|
||||
assert.Equal(t, false, generationConfig["responseLogprobs"])
|
||||
}
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
type ImageRequest struct {
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt" binding:"required"`
|
||||
N uint `json:"n,omitempty"`
|
||||
N *uint `json:"n,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Quality string `json:"quality,omitempty"`
|
||||
ResponseFormat string `json:"response_format,omitempty"`
|
||||
@@ -149,10 +149,14 @@ func (i *ImageRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
}
|
||||
|
||||
// not support token count for dalle
|
||||
n := uint(1)
|
||||
if i.N != nil {
|
||||
n = *i.N
|
||||
}
|
||||
return &types.TokenCountMeta{
|
||||
CombineText: i.Prompt,
|
||||
MaxTokens: 1584,
|
||||
ImagePriceRatio: sizeRatio * qualityRatio * float64(i.N),
|
||||
ImagePriceRatio: sizeRatio * qualityRatio * float64(n),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/QuantumNous/new-api/types"
|
||||
"github.com/samber/lo"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
@@ -31,26 +32,26 @@ type GeneralOpenAIRequest struct {
|
||||
Prompt any `json:"prompt,omitempty"`
|
||||
Prefix any `json:"prefix,omitempty"`
|
||||
Suffix any `json:"suffix,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
MaxTokens uint `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens uint `json:"max_completion_tokens,omitempty"`
|
||||
MaxTokens *uint `json:"max_tokens,omitempty"`
|
||||
MaxCompletionTokens *uint `json:"max_completion_tokens,omitempty"`
|
||||
ReasoningEffort string `json:"reasoning_effort,omitempty"`
|
||||
Verbosity json.RawMessage `json:"verbosity,omitempty"` // gpt-5
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"top_p,omitempty"`
|
||||
TopK int `json:"top_k,omitempty"`
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
TopK *int `json:"top_k,omitempty"`
|
||||
Stop any `json:"stop,omitempty"`
|
||||
N int `json:"n,omitempty"`
|
||||
N *int `json:"n,omitempty"`
|
||||
Input any `json:"input,omitempty"`
|
||||
Instruction string `json:"instruction,omitempty"`
|
||||
Size string `json:"size,omitempty"`
|
||||
Functions json.RawMessage `json:"functions,omitempty"`
|
||||
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
|
||||
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
|
||||
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
|
||||
EncodingFormat json.RawMessage `json:"encoding_format,omitempty"`
|
||||
Seed float64 `json:"seed,omitempty"`
|
||||
Seed *float64 `json:"seed,omitempty"`
|
||||
ParallelTooCalls *bool `json:"parallel_tool_calls,omitempty"`
|
||||
Tools []ToolCallRequest `json:"tools,omitempty"`
|
||||
ToolChoice any `json:"tool_choice,omitempty"`
|
||||
@@ -59,9 +60,9 @@ type GeneralOpenAIRequest struct {
|
||||
// ServiceTier specifies upstream service level and may affect billing.
|
||||
// This field is filtered by default and can be enabled via channel setting allow_service_tier.
|
||||
ServiceTier string `json:"service_tier,omitempty"`
|
||||
LogProbs bool `json:"logprobs,omitempty"`
|
||||
TopLogProbs int `json:"top_logprobs,omitempty"`
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
LogProbs *bool `json:"logprobs,omitempty"`
|
||||
TopLogProbs *int `json:"top_logprobs,omitempty"`
|
||||
Dimensions *int `json:"dimensions,omitempty"`
|
||||
Modalities json.RawMessage `json:"modalities,omitempty"`
|
||||
Audio json.RawMessage `json:"audio,omitempty"`
|
||||
// 安全标识符,用于帮助 OpenAI 检测可能违反使用政策的应用程序用户
|
||||
@@ -100,8 +101,8 @@ type GeneralOpenAIRequest struct {
|
||||
// pplx Params
|
||||
SearchDomainFilter json.RawMessage `json:"search_domain_filter,omitempty"`
|
||||
SearchRecencyFilter string `json:"search_recency_filter,omitempty"`
|
||||
ReturnImages bool `json:"return_images,omitempty"`
|
||||
ReturnRelatedQuestions bool `json:"return_related_questions,omitempty"`
|
||||
ReturnImages *bool `json:"return_images,omitempty"`
|
||||
ReturnRelatedQuestions *bool `json:"return_related_questions,omitempty"`
|
||||
SearchMode string `json:"search_mode,omitempty"`
|
||||
// Minimax
|
||||
ReasoningSplit json.RawMessage `json:"reasoning_split,omitempty"`
|
||||
@@ -140,10 +141,12 @@ func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
texts = append(texts, inputs...)
|
||||
}
|
||||
|
||||
if r.MaxCompletionTokens > r.MaxTokens {
|
||||
tokenCountMeta.MaxTokens = int(r.MaxCompletionTokens)
|
||||
maxTokens := lo.FromPtrOr(r.MaxTokens, uint(0))
|
||||
maxCompletionTokens := lo.FromPtrOr(r.MaxCompletionTokens, uint(0))
|
||||
if maxCompletionTokens > maxTokens {
|
||||
tokenCountMeta.MaxTokens = int(maxCompletionTokens)
|
||||
} else {
|
||||
tokenCountMeta.MaxTokens = int(r.MaxTokens)
|
||||
tokenCountMeta.MaxTokens = int(maxTokens)
|
||||
}
|
||||
|
||||
for _, message := range r.Messages {
|
||||
@@ -222,7 +225,7 @@ func (r *GeneralOpenAIRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
}
|
||||
|
||||
func (r *GeneralOpenAIRequest) IsStream(c *gin.Context) bool {
|
||||
return r.Stream
|
||||
return lo.FromPtrOr(r.Stream, false)
|
||||
}
|
||||
|
||||
func (r *GeneralOpenAIRequest) SetModelName(modelName string) {
|
||||
@@ -273,10 +276,11 @@ type StreamOptions struct {
|
||||
}
|
||||
|
||||
func (r *GeneralOpenAIRequest) GetMaxTokens() uint {
|
||||
if r.MaxCompletionTokens != 0 {
|
||||
return r.MaxCompletionTokens
|
||||
maxCompletionTokens := lo.FromPtrOr(r.MaxCompletionTokens, uint(0))
|
||||
if maxCompletionTokens != 0 {
|
||||
return maxCompletionTokens
|
||||
}
|
||||
return r.MaxTokens
|
||||
return lo.FromPtrOr(r.MaxTokens, uint(0))
|
||||
}
|
||||
|
||||
func (r *GeneralOpenAIRequest) ParseInput() []string {
|
||||
@@ -816,7 +820,7 @@ type OpenAIResponsesRequest struct {
|
||||
Conversation json.RawMessage `json:"conversation,omitempty"`
|
||||
ContextManagement json.RawMessage `json:"context_management,omitempty"`
|
||||
Instructions json.RawMessage `json:"instructions,omitempty"`
|
||||
MaxOutputTokens uint `json:"max_output_tokens,omitempty"`
|
||||
MaxOutputTokens *uint `json:"max_output_tokens,omitempty"`
|
||||
TopLogProbs *int `json:"top_logprobs,omitempty"`
|
||||
Metadata json.RawMessage `json:"metadata,omitempty"`
|
||||
ParallelToolCalls json.RawMessage `json:"parallel_tool_calls,omitempty"`
|
||||
@@ -833,7 +837,7 @@ type OpenAIResponsesRequest struct {
|
||||
// SafetyIdentifier carries client identity for policy abuse detection.
|
||||
// This field is filtered by default and can be enabled via channel setting allow_safety_identifier.
|
||||
SafetyIdentifier string `json:"safety_identifier,omitempty"`
|
||||
Stream bool `json:"stream,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
Text json.RawMessage `json:"text,omitempty"`
|
||||
@@ -842,7 +846,7 @@ type OpenAIResponsesRequest struct {
|
||||
TopP *float64 `json:"top_p,omitempty"`
|
||||
Truncation string `json:"truncation,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
MaxToolCalls uint `json:"max_tool_calls,omitempty"`
|
||||
MaxToolCalls *uint `json:"max_tool_calls,omitempty"`
|
||||
Prompt json.RawMessage `json:"prompt,omitempty"`
|
||||
// qwen
|
||||
EnableThinking json.RawMessage `json:"enable_thinking,omitempty"`
|
||||
@@ -905,12 +909,12 @@ func (r *OpenAIResponsesRequest) GetTokenCountMeta() *types.TokenCountMeta {
|
||||
return &types.TokenCountMeta{
|
||||
CombineText: strings.Join(texts, "\n"),
|
||||
Files: fileMeta,
|
||||
MaxTokens: int(r.MaxOutputTokens),
|
||||
MaxTokens: int(lo.FromPtrOr(r.MaxOutputTokens, uint(0))),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *OpenAIResponsesRequest) IsStream(c *gin.Context) bool {
|
||||
return r.Stream
|
||||
return lo.FromPtrOr(r.Stream, false)
|
||||
}
|
||||
|
||||
func (r *OpenAIResponsesRequest) SetModelName(modelName string) {
|
||||
|
||||
73
dto/openai_request_zero_value_test.go
Normal file
73
dto/openai_request_zero_value_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package dto
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/QuantumNous/new-api/common"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestGeneralOpenAIRequestPreserveExplicitZeroValues(t *testing.T) {
|
||||
raw := []byte(`{
|
||||
"model":"gpt-4.1",
|
||||
"stream":false,
|
||||
"max_tokens":0,
|
||||
"max_completion_tokens":0,
|
||||
"top_p":0,
|
||||
"top_k":0,
|
||||
"n":0,
|
||||
"frequency_penalty":0,
|
||||
"presence_penalty":0,
|
||||
"seed":0,
|
||||
"logprobs":false,
|
||||
"top_logprobs":0,
|
||||
"dimensions":0,
|
||||
"return_images":false,
|
||||
"return_related_questions":false
|
||||
}`)
|
||||
|
||||
var req GeneralOpenAIRequest
|
||||
err := common.Unmarshal(raw, &req)
|
||||
require.NoError(t, err)
|
||||
|
||||
encoded, err := common.Marshal(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.True(t, gjson.GetBytes(encoded, "stream").Exists())
|
||||
require.True(t, gjson.GetBytes(encoded, "max_tokens").Exists())
|
||||
require.True(t, gjson.GetBytes(encoded, "max_completion_tokens").Exists())
|
||||
require.True(t, gjson.GetBytes(encoded, "top_p").Exists())
|
||||
require.True(t, gjson.GetBytes(encoded, "top_k").Exists())
|
||||
require.True(t, gjson.GetBytes(encoded, "n").Exists())
|
||||
require.True(t, gjson.GetBytes(encoded, "frequency_penalty").Exists())
|
||||
require.True(t, gjson.GetBytes(encoded, "presence_penalty").Exists())
|
||||
require.True(t, gjson.GetBytes(encoded, "seed").Exists())
|
||||
require.True(t, gjson.GetBytes(encoded, "logprobs").Exists())
|
||||
require.True(t, gjson.GetBytes(encoded, "top_logprobs").Exists())
|
||||
require.True(t, gjson.GetBytes(encoded, "dimensions").Exists())
|
||||
require.True(t, gjson.GetBytes(encoded, "return_images").Exists())
|
||||
require.True(t, gjson.GetBytes(encoded, "return_related_questions").Exists())
|
||||
}
|
||||
|
||||
func TestOpenAIResponsesRequestPreserveExplicitZeroValues(t *testing.T) {
|
||||
raw := []byte(`{
|
||||
"model":"gpt-4.1",
|
||||
"max_output_tokens":0,
|
||||
"max_tool_calls":0,
|
||||
"stream":false,
|
||||
"top_p":0
|
||||
}`)
|
||||
|
||||
var req OpenAIResponsesRequest
|
||||
err := common.Unmarshal(raw, &req)
|
||||
require.NoError(t, err)
|
||||
|
||||
encoded, err := common.Marshal(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.True(t, gjson.GetBytes(encoded, "max_output_tokens").Exists())
|
||||
require.True(t, gjson.GetBytes(encoded, "max_tool_calls").Exists())
|
||||
require.True(t, gjson.GetBytes(encoded, "stream").Exists())
|
||||
require.True(t, gjson.GetBytes(encoded, "top_p").Exists())
|
||||
}
|
||||
@@ -12,10 +12,10 @@ type RerankRequest struct {
|
||||
Documents []any `json:"documents"`
|
||||
Query string `json:"query"`
|
||||
Model string `json:"model"`
|
||||
TopN int `json:"top_n,omitempty"`
|
||||
TopN *int `json:"top_n,omitempty"`
|
||||
ReturnDocuments *bool `json:"return_documents,omitempty"`
|
||||
MaxChunkPerDoc int `json:"max_chunk_per_doc,omitempty"`
|
||||
OverLapTokens int `json:"overlap_tokens,omitempty"`
|
||||
MaxChunkPerDoc *int `json:"max_chunk_per_doc,omitempty"`
|
||||
OverLapTokens *int `json:"overlap_tokens,omitempty"`
|
||||
}
|
||||
|
||||
func (r *RerankRequest) IsStream(c *gin.Context) bool {
|
||||
|
||||
Reference in New Issue
Block a user