perf(backend): 使用 gjson/sjson 优化热路径 JSON 处理
将 API 网关热路径中的 json.Unmarshal+json.Marshal 替换为 gjson 零拷贝查询和 sjson 精准写入: - unwrapV1InternalResponse 性能提升 22x(4009ns→182ns),内存分配减少 28.5x - unwrapGeminiResponse、extractGeminiUsage、estimateGeminiCountTokens、ParseGeminiRateLimitResetTime 改为接收 []byte 使用 gjson 提取 - ParseGatewayRequest 的 model/stream/metadata/thinking/max_tokens 改用 gjson 类型安全提取 - Handler 层(sora/openai)改用 gjson 提取字段、sjson 注入/修改字段,移除 map[string]any 中间变量 - Sora Client 响应解析改用 gjson ForEach 遍历,减少内存分配 - 新增约 100 个单元测试用例,所有改动函数覆盖率 >85% Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -18,6 +18,8 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
// OpenAIGatewayHandler handles OpenAI API gateway requests
|
// OpenAIGatewayHandler handles OpenAI API gateway requests
|
||||||
@@ -93,16 +95,9 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
|
|
||||||
setOpsRequestContext(c, "", false, body)
|
setOpsRequestContext(c, "", false, body)
|
||||||
|
|
||||||
// Parse request body to map for potential modification
|
// 使用 gjson 只读提取字段做校验,避免完整 Unmarshal
|
||||||
var reqBody map[string]any
|
reqModel := gjson.GetBytes(body, "model").String()
|
||||||
if err := json.Unmarshal(body, &reqBody); err != nil {
|
reqStream := gjson.GetBytes(body, "stream").Bool()
|
||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract model and stream
|
|
||||||
reqModel, _ := reqBody["model"].(string)
|
|
||||||
reqStream, _ := reqBody["stream"].(bool)
|
|
||||||
|
|
||||||
// 验证 model 必填
|
// 验证 model 必填
|
||||||
if reqModel == "" {
|
if reqModel == "" {
|
||||||
@@ -113,16 +108,10 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
userAgent := c.GetHeader("User-Agent")
|
userAgent := c.GetHeader("User-Agent")
|
||||||
isCodexCLI := openai.IsCodexCLIRequest(userAgent) || (h.cfg != nil && h.cfg.Gateway.ForceCodexCLI)
|
isCodexCLI := openai.IsCodexCLIRequest(userAgent) || (h.cfg != nil && h.cfg.Gateway.ForceCodexCLI)
|
||||||
if !isCodexCLI {
|
if !isCodexCLI {
|
||||||
existingInstructions, _ := reqBody["instructions"].(string)
|
existingInstructions := gjson.GetBytes(body, "instructions").String()
|
||||||
if strings.TrimSpace(existingInstructions) == "" {
|
if strings.TrimSpace(existingInstructions) == "" {
|
||||||
if instructions := strings.TrimSpace(service.GetOpenCodeInstructions()); instructions != "" {
|
if instructions := strings.TrimSpace(service.GetOpenCodeInstructions()); instructions != "" {
|
||||||
reqBody["instructions"] = instructions
|
body, _ = sjson.SetBytes(body, "instructions", instructions)
|
||||||
// Re-serialize body
|
|
||||||
body, err = json.Marshal(reqBody)
|
|
||||||
if err != nil {
|
|
||||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -132,19 +121,25 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
|
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
|
||||||
// 要求 previous_response_id,或 input 内存在带 call_id 的 tool_call/function_call,
|
// 要求 previous_response_id,或 input 内存在带 call_id 的 tool_call/function_call,
|
||||||
// 或带 id 且与 call_id 匹配的 item_reference。
|
// 或带 id 且与 call_id 匹配的 item_reference。
|
||||||
if service.HasFunctionCallOutput(reqBody) {
|
// 此路径需要遍历 input 数组做 call_id 关联检查,保留 Unmarshal
|
||||||
previousResponseID, _ := reqBody["previous_response_id"].(string)
|
if gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
|
||||||
if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) {
|
var reqBody map[string]any
|
||||||
if service.HasFunctionCallOutputMissingCallID(reqBody) {
|
if err := json.Unmarshal(body, &reqBody); err == nil {
|
||||||
log.Printf("[OpenAI Handler] function_call_output 缺少 call_id: model=%s", reqModel)
|
if service.HasFunctionCallOutput(reqBody) {
|
||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
|
previousResponseID, _ := reqBody["previous_response_id"].(string)
|
||||||
return
|
if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) {
|
||||||
}
|
if service.HasFunctionCallOutputMissingCallID(reqBody) {
|
||||||
callIDs := service.FunctionCallOutputCallIDs(reqBody)
|
log.Printf("[OpenAI Handler] function_call_output 缺少 call_id: model=%s", reqModel)
|
||||||
if !service.HasItemReferenceForCallIDs(reqBody, callIDs) {
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
|
||||||
log.Printf("[OpenAI Handler] function_call_output 缺少匹配的 item_reference: model=%s", reqModel)
|
return
|
||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
|
}
|
||||||
return
|
callIDs := service.FunctionCallOutputCallIDs(reqBody)
|
||||||
|
if !service.HasItemReferenceForCallIDs(reqBody, callIDs) {
|
||||||
|
log.Printf("[OpenAI Handler] function_call_output 缺少匹配的 item_reference: model=%s", reqModel)
|
||||||
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -207,7 +202,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Generate session hash (header first; fallback to prompt_cache_key)
|
// Generate session hash (header first; fallback to prompt_cache_key)
|
||||||
sessionHash := h.gatewayService.GenerateSessionHash(c, reqBody)
|
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
|
||||||
|
|
||||||
maxAccountSwitches := h.maxAccountSwitches
|
maxAccountSwitches := h.maxAccountSwitches
|
||||||
switchCount := 0
|
switchCount := 0
|
||||||
|
|||||||
@@ -10,6 +10,8 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestOpenAIHandleStreamingAwareError_JSONEscaping(t *testing.T) {
|
func TestOpenAIHandleStreamingAwareError_JSONEscaping(t *testing.T) {
|
||||||
@@ -102,3 +104,48 @@ func TestOpenAIHandleStreamingAwareError_NonStreaming(t *testing.T) {
|
|||||||
assert.Equal(t, "upstream_error", errorObj["type"])
|
assert.Equal(t, "upstream_error", errorObj["type"])
|
||||||
assert.Equal(t, "test error", errorObj["message"])
|
assert.Equal(t, "test error", errorObj["message"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestOpenAIHandler_GjsonExtraction 验证 gjson 从请求体中提取 model/stream 的正确性
|
||||||
|
func TestOpenAIHandler_GjsonExtraction(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
wantModel string
|
||||||
|
wantStream bool
|
||||||
|
}{
|
||||||
|
{"正常提取", `{"model":"gpt-4","stream":true,"input":"hello"}`, "gpt-4", true},
|
||||||
|
{"stream false", `{"model":"gpt-4","stream":false}`, "gpt-4", false},
|
||||||
|
{"无 stream 字段", `{"model":"gpt-4"}`, "gpt-4", false},
|
||||||
|
{"model 缺失", `{"stream":true}`, "", true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
body := []byte(tt.body)
|
||||||
|
model := gjson.GetBytes(body, "model").String()
|
||||||
|
stream := gjson.GetBytes(body, "stream").Bool()
|
||||||
|
require.Equal(t, tt.wantModel, model)
|
||||||
|
require.Equal(t, tt.wantStream, stream)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestOpenAIHandler_InstructionsInjection 验证 instructions 的 gjson/sjson 注入逻辑
|
||||||
|
func TestOpenAIHandler_InstructionsInjection(t *testing.T) {
|
||||||
|
// 测试 1:无 instructions → 注入
|
||||||
|
body := []byte(`{"model":"gpt-4"}`)
|
||||||
|
existing := gjson.GetBytes(body, "instructions").String()
|
||||||
|
require.Empty(t, existing)
|
||||||
|
newBody, err := sjson.SetBytes(body, "instructions", "test instruction")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "test instruction", gjson.GetBytes(newBody, "instructions").String())
|
||||||
|
|
||||||
|
// 测试 2:已有 instructions → 不覆盖
|
||||||
|
body2 := []byte(`{"model":"gpt-4","instructions":"existing"}`)
|
||||||
|
existing2 := gjson.GetBytes(body2, "instructions").String()
|
||||||
|
require.Equal(t, "existing", existing2)
|
||||||
|
|
||||||
|
// 测试 3:空白 instructions → 注入
|
||||||
|
body3 := []byte(`{"model":"gpt-4","instructions":" "}`)
|
||||||
|
existing3 := strings.TrimSpace(gjson.GetBytes(body3, "instructions").String())
|
||||||
|
require.Empty(t, existing3)
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -23,6 +22,8 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SoraGatewayHandler handles Sora chat completions requests
|
// SoraGatewayHandler handles Sora chat completions requests
|
||||||
@@ -105,36 +106,29 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
|
|
||||||
setOpsRequestContext(c, "", false, body)
|
setOpsRequestContext(c, "", false, body)
|
||||||
|
|
||||||
var reqBody map[string]any
|
// 使用 gjson 只读提取字段做校验,避免完整 Unmarshal
|
||||||
if err := json.Unmarshal(body, &reqBody); err != nil {
|
reqModel := gjson.GetBytes(body, "model").String()
|
||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
reqModel, _ := reqBody["model"].(string)
|
|
||||||
if reqModel == "" {
|
if reqModel == "" {
|
||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
reqMessages, _ := reqBody["messages"].([]any)
|
if !gjson.GetBytes(body, "messages").Exists() || gjson.GetBytes(body, "messages").Type != gjson.JSON {
|
||||||
if len(reqMessages) == 0 {
|
|
||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "messages is required")
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "messages is required")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
clientStream, _ := reqBody["stream"].(bool)
|
clientStream := gjson.GetBytes(body, "stream").Bool()
|
||||||
if !clientStream {
|
if !clientStream {
|
||||||
if h.streamMode == "error" {
|
if h.streamMode == "error" {
|
||||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Sora requires stream=true")
|
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Sora requires stream=true")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
reqBody["stream"] = true
|
var err error
|
||||||
updated, err := json.Marshal(reqBody)
|
body, err = sjson.SetBytes(body, "stream", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
|
h.errorResponse(c, http.StatusInternalServerError, "api_error", "Failed to process request")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
body = updated
|
|
||||||
}
|
}
|
||||||
|
|
||||||
setOpsRequestContext(c, reqModel, clientStream, body)
|
setOpsRequestContext(c, reqModel, clientStream, body)
|
||||||
@@ -193,7 +187,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionHash := generateOpenAISessionHash(c, reqBody)
|
sessionHash := generateOpenAISessionHash(c, body)
|
||||||
|
|
||||||
maxAccountSwitches := h.maxAccountSwitches
|
maxAccountSwitches := h.maxAccountSwitches
|
||||||
switchCount := 0
|
switchCount := 0
|
||||||
@@ -302,7 +296,7 @@ func (h *SoraGatewayHandler) ChatCompletions(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func generateOpenAISessionHash(c *gin.Context, reqBody map[string]any) string {
|
func generateOpenAISessionHash(c *gin.Context, body []byte) string {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
@@ -310,10 +304,8 @@ func generateOpenAISessionHash(c *gin.Context, reqBody map[string]any) string {
|
|||||||
if sessionID == "" {
|
if sessionID == "" {
|
||||||
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
|
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
|
||||||
}
|
}
|
||||||
if sessionID == "" && reqBody != nil {
|
if sessionID == "" && len(body) > 0 {
|
||||||
if v, ok := reqBody["prompt_cache_key"].(string); ok {
|
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
|
||||||
sessionID = strings.TrimSpace(v)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if sessionID == "" {
|
if sessionID == "" {
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/testutil"
|
"github.com/Wei-Shaw/sub2api/internal/testutil"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
// 编译期接口断言
|
// 编译期接口断言
|
||||||
@@ -414,3 +416,65 @@ func TestSoraGatewayHandler_ChatCompletions(t *testing.T) {
|
|||||||
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
require.NoError(t, json.Unmarshal(rec.Body.Bytes(), &resp))
|
||||||
require.NotEmpty(t, resp["media_url"])
|
require.NotEmpty(t, resp["media_url"])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestSoraHandler_StreamForcing 验证 sora handler 的 stream 强制逻辑
|
||||||
|
func TestSoraHandler_StreamForcing(t *testing.T) {
|
||||||
|
// 测试 1:stream=false 时 sjson 强制修改为 true
|
||||||
|
body := []byte(`{"model":"sora","messages":[{"role":"user","content":"test"}],"stream":false}`)
|
||||||
|
clientStream := gjson.GetBytes(body, "stream").Bool()
|
||||||
|
require.False(t, clientStream)
|
||||||
|
newBody, err := sjson.SetBytes(body, "stream", true)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, gjson.GetBytes(newBody, "stream").Bool())
|
||||||
|
|
||||||
|
// 测试 2:stream=true 时不修改
|
||||||
|
body2 := []byte(`{"model":"sora","messages":[{"role":"user","content":"test"}],"stream":true}`)
|
||||||
|
require.True(t, gjson.GetBytes(body2, "stream").Bool())
|
||||||
|
|
||||||
|
// 测试 3:无 stream 字段时 gjson 返回 false(零值)
|
||||||
|
body3 := []byte(`{"model":"sora","messages":[{"role":"user","content":"test"}]}`)
|
||||||
|
require.False(t, gjson.GetBytes(body3, "stream").Bool())
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestSoraHandler_ValidationExtraction 验证 sora handler 中 gjson 字段校验逻辑
|
||||||
|
func TestSoraHandler_ValidationExtraction(t *testing.T) {
|
||||||
|
// model 缺失
|
||||||
|
body := []byte(`{"messages":[{"role":"user","content":"test"}]}`)
|
||||||
|
model := gjson.GetBytes(body, "model").String()
|
||||||
|
require.Empty(t, model)
|
||||||
|
|
||||||
|
// messages 缺失
|
||||||
|
body2 := []byte(`{"model":"sora"}`)
|
||||||
|
require.False(t, gjson.GetBytes(body2, "messages").Exists())
|
||||||
|
|
||||||
|
// messages 不是 JSON 数组
|
||||||
|
body3 := []byte(`{"model":"sora","messages":"not array"}`)
|
||||||
|
msgResult := gjson.GetBytes(body3, "messages")
|
||||||
|
require.True(t, msgResult.Exists())
|
||||||
|
require.NotEqual(t, gjson.JSON, msgResult.Type) // string 类型,不是 JSON 数组
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestGenerateOpenAISessionHash_WithBody 验证 generateOpenAISessionHash 的 body/header 解析逻辑
|
||||||
|
func TestGenerateOpenAISessionHash_WithBody(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
// 从 body 提取 prompt_cache_key
|
||||||
|
body := []byte(`{"model":"sora","prompt_cache_key":"session-abc"}`)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest("POST", "/", nil)
|
||||||
|
|
||||||
|
hash := generateOpenAISessionHash(c, body)
|
||||||
|
require.NotEmpty(t, hash)
|
||||||
|
|
||||||
|
// 无 prompt_cache_key 且无 header → 空 hash
|
||||||
|
body2 := []byte(`{"model":"sora"}`)
|
||||||
|
hash2 := generateOpenAISessionHash(c, body2)
|
||||||
|
require.Empty(t, hash2)
|
||||||
|
|
||||||
|
// header 优先于 body
|
||||||
|
c.Request.Header.Set("session_id", "from-header")
|
||||||
|
hash3 := generateOpenAISessionHash(c, body)
|
||||||
|
require.NotEmpty(t, hash3)
|
||||||
|
require.NotEqual(t, hash, hash3) // 不同来源应产生不同 hash
|
||||||
|
}
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -981,16 +982,12 @@ func (s *AntigravityGatewayService) wrapV1InternalRequest(projectID, model strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
// unwrapV1InternalResponse 解包 v1internal 响应
|
// unwrapV1InternalResponse 解包 v1internal 响应
|
||||||
|
// 使用 gjson 零拷贝提取 response 字段,避免 Unmarshal+Marshal 双重开销
|
||||||
func (s *AntigravityGatewayService) unwrapV1InternalResponse(body []byte) ([]byte, error) {
|
func (s *AntigravityGatewayService) unwrapV1InternalResponse(body []byte) ([]byte, error) {
|
||||||
var outer map[string]any
|
result := gjson.GetBytes(body, "response")
|
||||||
if err := json.Unmarshal(body, &outer); err != nil {
|
if result.Exists() {
|
||||||
return nil, err
|
return []byte(result.Raw), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp, ok := outer["response"]; ok {
|
|
||||||
return json.Marshal(resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
return body, nil
|
return body, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2516,11 +2513,11 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 解析 usage
|
// 解析 usage
|
||||||
|
if u := extractGeminiUsage(inner); u != nil {
|
||||||
|
usage = u
|
||||||
|
}
|
||||||
var parsed map[string]any
|
var parsed map[string]any
|
||||||
if json.Unmarshal(inner, &parsed) == nil {
|
if json.Unmarshal(inner, &parsed) == nil {
|
||||||
if u := extractGeminiUsage(parsed); u != nil {
|
|
||||||
usage = u
|
|
||||||
}
|
|
||||||
// Check for MALFORMED_FUNCTION_CALL
|
// Check for MALFORMED_FUNCTION_CALL
|
||||||
if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 {
|
if candidates, ok := parsed["candidates"].([]any); ok && len(candidates) > 0 {
|
||||||
if cand, ok := candidates[0].(map[string]any); ok {
|
if cand, ok := candidates[0].(map[string]any); ok {
|
||||||
@@ -2676,7 +2673,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamToNonStreaming(c *gin.Cont
|
|||||||
last = parsed
|
last = parsed
|
||||||
|
|
||||||
// 提取 usage
|
// 提取 usage
|
||||||
if u := extractGeminiUsage(parsed); u != nil {
|
if u := extractGeminiUsage(inner); u != nil {
|
||||||
usage = u
|
usage = u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -889,3 +890,144 @@ func TestAntigravityClientWriter(t *testing.T) {
|
|||||||
require.True(t, cw.Disconnected())
|
require.True(t, cw.Disconnected())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestUnwrapV1InternalResponse 测试 unwrapV1InternalResponse 的各种输入场景
|
||||||
|
func TestUnwrapV1InternalResponse(t *testing.T) {
|
||||||
|
svc := &AntigravityGatewayService{}
|
||||||
|
|
||||||
|
// 构造 >50KB 的大型 JSON
|
||||||
|
largePadding := strings.Repeat("x", 50*1024)
|
||||||
|
largeInput := []byte(fmt.Sprintf(`{"response":{"id":"big","pad":"%s"}}`, largePadding))
|
||||||
|
largeExpected := fmt.Sprintf(`{"id":"big","pad":"%s"}`, largePadding)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input []byte
|
||||||
|
expected string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "正常 response 包装",
|
||||||
|
input: []byte(`{"response":{"id":"123","content":"hello"}}`),
|
||||||
|
expected: `{"id":"123","content":"hello"}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "无 response 透传",
|
||||||
|
input: []byte(`{"id":"456"}`),
|
||||||
|
expected: `{"id":"456"}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "空 JSON",
|
||||||
|
input: []byte(`{}`),
|
||||||
|
expected: `{}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "response 为 null",
|
||||||
|
input: []byte(`{"response":null}`),
|
||||||
|
expected: `null`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "response 为基础类型 string",
|
||||||
|
input: []byte(`{"response":"hello"}`),
|
||||||
|
expected: `"hello"`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "非法 JSON",
|
||||||
|
input: []byte(`not json`),
|
||||||
|
expected: `not json`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "嵌套 response 只解一层",
|
||||||
|
input: []byte(`{"response":{"response":{"inner":true}}}`),
|
||||||
|
expected: `{"response":{"inner":true}}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "大型 JSON >50KB",
|
||||||
|
input: largeInput,
|
||||||
|
expected: largeExpected,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := svc.unwrapV1InternalResponse(tt.input)
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, tt.expected, strings.TrimSpace(string(got)))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- unwrapV1InternalResponse benchmark 对照组 ---
|
||||||
|
|
||||||
|
// unwrapV1InternalResponseOld 旧实现:Unmarshal+Marshal 双重开销(仅用于 benchmark 对照)
|
||||||
|
func unwrapV1InternalResponseOld(body []byte) ([]byte, error) {
|
||||||
|
var outer map[string]any
|
||||||
|
if err := json.Unmarshal(body, &outer); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if resp, ok := outer["response"]; ok {
|
||||||
|
return json.Marshal(resp)
|
||||||
|
}
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkUnwrapV1Internal_Old_Small(b *testing.B) {
|
||||||
|
body := []byte(`{"response":{"candidates":[{"content":{"parts":[{"text":"hello world"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5}}}`)
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = unwrapV1InternalResponseOld(body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkUnwrapV1Internal_New_Small(b *testing.B) {
|
||||||
|
body := []byte(`{"response":{"candidates":[{"content":{"parts":[{"text":"hello world"}]}}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5}}}`)
|
||||||
|
svc := &AntigravityGatewayService{}
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = svc.unwrapV1InternalResponse(body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkUnwrapV1Internal_Old_Large(b *testing.B) {
|
||||||
|
body := generateLargeUnwrapJSON(10 * 1024) // ~10KB
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = unwrapV1InternalResponseOld(body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkUnwrapV1Internal_New_Large(b *testing.B) {
|
||||||
|
body := generateLargeUnwrapJSON(10 * 1024) // ~10KB
|
||||||
|
svc := &AntigravityGatewayService{}
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = svc.unwrapV1InternalResponse(body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateLargeUnwrapJSON 生成指定最小大小的包含 response 包装的 JSON
|
||||||
|
func generateLargeUnwrapJSON(minSize int) []byte {
|
||||||
|
parts := make([]map[string]string, 0)
|
||||||
|
current := 0
|
||||||
|
for current < minSize {
|
||||||
|
text := fmt.Sprintf("这是第 %d 段内容,用于填充 JSON 到目标大小。", len(parts)+1)
|
||||||
|
parts = append(parts, map[string]string{"text": text})
|
||||||
|
current += len(text) + 20 // 估算 JSON 编码开销
|
||||||
|
}
|
||||||
|
inner := map[string]any{
|
||||||
|
"candidates": []map[string]any{
|
||||||
|
{"content": map[string]any{"parts": parts}},
|
||||||
|
},
|
||||||
|
"usageMetadata": map[string]any{
|
||||||
|
"promptTokenCount": 100,
|
||||||
|
"candidatesTokenCount": 50,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
outer := map[string]any{"response": inner}
|
||||||
|
b, _ := json.Marshal(outer)
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SessionContext 粘性会话上下文,用于区分不同来源的请求。
|
// SessionContext 粘性会话上下文,用于区分不同来源的请求。
|
||||||
@@ -48,38 +49,58 @@ type ParsedRequest struct {
|
|||||||
// protocol 指定请求协议格式(domain.PlatformAnthropic / domain.PlatformGemini),
|
// protocol 指定请求协议格式(domain.PlatformAnthropic / domain.PlatformGemini),
|
||||||
// 不同协议使用不同的 system/messages 字段名。
|
// 不同协议使用不同的 system/messages 字段名。
|
||||||
func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
|
func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
|
||||||
var req map[string]any
|
|
||||||
if err := json.Unmarshal(body, &req); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
parsed := &ParsedRequest{
|
parsed := &ParsedRequest{
|
||||||
Body: body,
|
Body: body,
|
||||||
}
|
}
|
||||||
|
|
||||||
if rawModel, exists := req["model"]; exists {
|
// --- gjson 提取简单字段(避免完整 Unmarshal) ---
|
||||||
model, ok := rawModel.(string)
|
|
||||||
if !ok {
|
// model: 需要严格类型校验,非 string 返回错误
|
||||||
|
modelResult := gjson.GetBytes(body, "model")
|
||||||
|
if modelResult.Exists() {
|
||||||
|
if modelResult.Type != gjson.String {
|
||||||
return nil, fmt.Errorf("invalid model field type")
|
return nil, fmt.Errorf("invalid model field type")
|
||||||
}
|
}
|
||||||
parsed.Model = model
|
parsed.Model = modelResult.String()
|
||||||
}
|
}
|
||||||
if rawStream, exists := req["stream"]; exists {
|
|
||||||
stream, ok := rawStream.(bool)
|
// stream: 需要严格类型校验,非 bool 返回错误
|
||||||
if !ok {
|
streamResult := gjson.GetBytes(body, "stream")
|
||||||
|
if streamResult.Exists() {
|
||||||
|
if streamResult.Type != gjson.True && streamResult.Type != gjson.False {
|
||||||
return nil, fmt.Errorf("invalid stream field type")
|
return nil, fmt.Errorf("invalid stream field type")
|
||||||
}
|
}
|
||||||
parsed.Stream = stream
|
parsed.Stream = streamResult.Bool()
|
||||||
}
|
}
|
||||||
if metadata, ok := req["metadata"].(map[string]any); ok {
|
|
||||||
if userID, ok := metadata["user_id"].(string); ok {
|
// metadata.user_id: 直接路径提取,不需要严格类型校验
|
||||||
parsed.MetadataUserID = userID
|
parsed.MetadataUserID = gjson.GetBytes(body, "metadata.user_id").String()
|
||||||
|
|
||||||
|
// thinking.type: 直接路径提取
|
||||||
|
if gjson.GetBytes(body, "thinking.type").String() == "enabled" {
|
||||||
|
parsed.ThinkingEnabled = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// max_tokens: 仅接受整数值
|
||||||
|
maxTokensResult := gjson.GetBytes(body, "max_tokens")
|
||||||
|
if maxTokensResult.Exists() && maxTokensResult.Type == gjson.Number {
|
||||||
|
f := maxTokensResult.Float()
|
||||||
|
if !math.IsNaN(f) && !math.IsInf(f, 0) && f == math.Trunc(f) &&
|
||||||
|
f <= float64(math.MaxInt) && f >= float64(math.MinInt) {
|
||||||
|
parsed.MaxTokens = int(f)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- 保留 Unmarshal 用于 system/messages 提取 ---
|
||||||
|
// 这些字段需要作为 any/[]any 传递给下游消费者,无法用 gjson 替代
|
||||||
|
|
||||||
switch protocol {
|
switch protocol {
|
||||||
case domain.PlatformGemini:
|
case domain.PlatformGemini:
|
||||||
// Gemini 原生格式: systemInstruction.parts / contents
|
// Gemini 原生格式: systemInstruction.parts / contents
|
||||||
|
var req map[string]any
|
||||||
|
if err := json.Unmarshal(body, &req); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if sysInst, ok := req["systemInstruction"].(map[string]any); ok {
|
if sysInst, ok := req["systemInstruction"].(map[string]any); ok {
|
||||||
if parts, ok := sysInst["parts"].([]any); ok {
|
if parts, ok := sysInst["parts"].([]any); ok {
|
||||||
parsed.System = parts
|
parsed.System = parts
|
||||||
@@ -92,6 +113,10 @@ func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
|
|||||||
// Anthropic / OpenAI 格式: system / messages
|
// Anthropic / OpenAI 格式: system / messages
|
||||||
// system 字段只要存在就视为显式提供(即使为 null),
|
// system 字段只要存在就视为显式提供(即使为 null),
|
||||||
// 以避免客户端传 null 时被默认 system 误注入。
|
// 以避免客户端传 null 时被默认 system 误注入。
|
||||||
|
var req map[string]any
|
||||||
|
if err := json.Unmarshal(body, &req); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
if system, ok := req["system"]; ok {
|
if system, ok := req["system"]; ok {
|
||||||
parsed.HasSystem = true
|
parsed.HasSystem = true
|
||||||
parsed.System = system
|
parsed.System = system
|
||||||
@@ -101,20 +126,6 @@ func ParseGatewayRequest(body []byte, protocol string) (*ParsedRequest, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// thinking: {type: "enabled"}
|
|
||||||
if rawThinking, ok := req["thinking"].(map[string]any); ok {
|
|
||||||
if t, ok := rawThinking["type"].(string); ok && t == "enabled" {
|
|
||||||
parsed.ThinkingEnabled = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// max_tokens
|
|
||||||
if rawMaxTokens, exists := req["max_tokens"]; exists {
|
|
||||||
if maxTokens, ok := parseIntegralNumber(rawMaxTokens); ok {
|
|
||||||
parsed.MaxTokens = maxTokens
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return parsed, nil
|
return parsed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||||
@@ -416,3 +420,341 @@ func TestFilterSignatureSensitiveBlocksForRetry_DowngradesTools(t *testing.T) {
|
|||||||
require.Contains(t, content0["text"], "tool_use")
|
require.Contains(t, content0["text"], "tool_use")
|
||||||
require.Contains(t, content1["text"], "tool_result")
|
require.Contains(t, content1["text"], "tool_result")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ============ Group 7: ParseGatewayRequest 补充单元测试 ============
|
||||||
|
|
||||||
|
// Task 7.1 — 类型校验边界测试
|
||||||
|
func TestParseGatewayRequest_TypeValidation(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
wantErr bool
|
||||||
|
errSubstr string // 期望的错误信息子串(为空则不检查)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "model 为 int",
|
||||||
|
body: `{"model":123}`,
|
||||||
|
wantErr: true,
|
||||||
|
errSubstr: "invalid model field type",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model 为 array",
|
||||||
|
body: `{"model":[]}`,
|
||||||
|
wantErr: true,
|
||||||
|
errSubstr: "invalid model field type",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model 为 bool",
|
||||||
|
body: `{"model":true}`,
|
||||||
|
wantErr: true,
|
||||||
|
errSubstr: "invalid model field type",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model 为 null — gjson Null 类型触发类型校验错误",
|
||||||
|
body: `{"model":null}`,
|
||||||
|
wantErr: true, // gjson: Exists()=true, Type=Null != String → 返回错误
|
||||||
|
errSubstr: "invalid model field type",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stream 为 string",
|
||||||
|
body: `{"stream":"true"}`,
|
||||||
|
wantErr: true,
|
||||||
|
errSubstr: "invalid stream field type",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stream 为 int",
|
||||||
|
body: `{"stream":1}`,
|
||||||
|
wantErr: true,
|
||||||
|
errSubstr: "invalid stream field type",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "stream 为 null — gjson Null 类型触发类型校验错误",
|
||||||
|
body: `{"stream":null}`,
|
||||||
|
wantErr: true, // gjson: Exists()=true, Type=Null != True && != False → 返回错误
|
||||||
|
errSubstr: "invalid stream field type",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model 为 object",
|
||||||
|
body: `{"model":{}}`,
|
||||||
|
wantErr: true,
|
||||||
|
errSubstr: "invalid model field type",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
_, err := ParseGatewayRequest([]byte(tt.body), "")
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
if tt.errSubstr != "" {
|
||||||
|
require.Contains(t, err.Error(), tt.errSubstr)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Task 7.2 — 可选字段缺失测试
|
||||||
|
func TestParseGatewayRequest_OptionalFieldsMissing(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
wantModel string
|
||||||
|
wantStream bool
|
||||||
|
wantMetadataUID string
|
||||||
|
wantHasSystem bool
|
||||||
|
wantThinking bool
|
||||||
|
wantMaxTokens int
|
||||||
|
wantMessagesNil bool
|
||||||
|
wantMessagesLen int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "完全空 JSON — 所有字段零值",
|
||||||
|
body: `{}`,
|
||||||
|
wantModel: "",
|
||||||
|
wantStream: false,
|
||||||
|
wantMetadataUID: "",
|
||||||
|
wantHasSystem: false,
|
||||||
|
wantThinking: false,
|
||||||
|
wantMaxTokens: 0,
|
||||||
|
wantMessagesNil: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "metadata 无 user_id",
|
||||||
|
body: `{"model":"test"}`,
|
||||||
|
wantModel: "test",
|
||||||
|
wantMetadataUID: "",
|
||||||
|
wantHasSystem: false,
|
||||||
|
wantThinking: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "thinking 非 enabled(type=disabled)",
|
||||||
|
body: `{"model":"test","thinking":{"type":"disabled"}}`,
|
||||||
|
wantModel: "test",
|
||||||
|
wantThinking: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "thinking 字段缺失",
|
||||||
|
body: `{"model":"test"}`,
|
||||||
|
wantModel: "test",
|
||||||
|
wantThinking: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
parsed, err := ParseGatewayRequest([]byte(tt.body), "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, tt.wantModel, parsed.Model)
|
||||||
|
require.Equal(t, tt.wantStream, parsed.Stream)
|
||||||
|
require.Equal(t, tt.wantMetadataUID, parsed.MetadataUserID)
|
||||||
|
require.Equal(t, tt.wantHasSystem, parsed.HasSystem)
|
||||||
|
require.Equal(t, tt.wantThinking, parsed.ThinkingEnabled)
|
||||||
|
require.Equal(t, tt.wantMaxTokens, parsed.MaxTokens)
|
||||||
|
|
||||||
|
if tt.wantMessagesNil {
|
||||||
|
require.Nil(t, parsed.Messages)
|
||||||
|
}
|
||||||
|
if tt.wantMessagesLen > 0 {
|
||||||
|
require.Len(t, parsed.Messages, tt.wantMessagesLen)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Task 7.3 — Gemini 协议分支测试
|
||||||
|
// 已有测试覆盖:
|
||||||
|
// - TestParseGatewayRequest_GeminiSystemInstruction: 正常 systemInstruction+contents
|
||||||
|
// - TestParseGatewayRequest_GeminiNoContents: 缺失 contents
|
||||||
|
// - TestParseGatewayRequest_GeminiContents: 正常 contents(无 systemInstruction)
|
||||||
|
// 因此跳过。
|
||||||
|
|
||||||
|
// Task 7.4 — max_tokens 边界测试
|
||||||
|
func TestParseGatewayRequest_MaxTokensBoundary(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
wantMaxTokens int
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "正常整数",
|
||||||
|
body: `{"max_tokens":1024}`,
|
||||||
|
wantMaxTokens: 1024,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "浮点数(非整数)被忽略",
|
||||||
|
body: `{"max_tokens":10.5}`,
|
||||||
|
wantMaxTokens: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "负整数可以通过",
|
||||||
|
body: `{"max_tokens":-1}`,
|
||||||
|
wantMaxTokens: -1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "超大值不 panic",
|
||||||
|
body: `{"max_tokens":9999999999999999}`,
|
||||||
|
wantMaxTokens: 10000000000000000, // float64 精度导致 9999999999999999 → 1e16
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "null 值被忽略",
|
||||||
|
body: `{"max_tokens":null}`,
|
||||||
|
wantMaxTokens: 0, // gjson Type=Null != Number → 条件不满足,跳过
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
parsed, err := ParseGatewayRequest([]byte(tt.body), "")
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, tt.wantMaxTokens, parsed.MaxTokens)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ Task 7.5: Benchmark 测试 ============
|
||||||
|
|
||||||
|
// parseGatewayRequestOld 是基于完整 json.Unmarshal 的旧实现,用于 benchmark 对比基线。
|
||||||
|
// 核心路径:先 Unmarshal 到 map[string]any,再逐字段提取。
|
||||||
|
func parseGatewayRequestOld(body []byte, protocol string) (*ParsedRequest, error) {
|
||||||
|
parsed := &ParsedRequest{
|
||||||
|
Body: body,
|
||||||
|
}
|
||||||
|
|
||||||
|
var req map[string]any
|
||||||
|
if err := json.Unmarshal(body, &req); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// model
|
||||||
|
if raw, ok := req["model"]; ok {
|
||||||
|
s, ok := raw.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid model field type")
|
||||||
|
}
|
||||||
|
parsed.Model = s
|
||||||
|
}
|
||||||
|
|
||||||
|
// stream
|
||||||
|
if raw, ok := req["stream"]; ok {
|
||||||
|
b, ok := raw.(bool)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid stream field type")
|
||||||
|
}
|
||||||
|
parsed.Stream = b
|
||||||
|
}
|
||||||
|
|
||||||
|
// metadata.user_id
|
||||||
|
if meta, ok := req["metadata"].(map[string]any); ok {
|
||||||
|
if uid, ok := meta["user_id"].(string); ok {
|
||||||
|
parsed.MetadataUserID = uid
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// thinking.type
|
||||||
|
if thinking, ok := req["thinking"].(map[string]any); ok {
|
||||||
|
if thinkType, ok := thinking["type"].(string); ok && thinkType == "enabled" {
|
||||||
|
parsed.ThinkingEnabled = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// max_tokens
|
||||||
|
if raw, ok := req["max_tokens"]; ok {
|
||||||
|
if n, ok := parseIntegralNumber(raw); ok {
|
||||||
|
parsed.MaxTokens = n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// system / messages(按协议分支)
|
||||||
|
switch protocol {
|
||||||
|
case domain.PlatformGemini:
|
||||||
|
if sysInst, ok := req["systemInstruction"].(map[string]any); ok {
|
||||||
|
if parts, ok := sysInst["parts"].([]any); ok {
|
||||||
|
parsed.System = parts
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if contents, ok := req["contents"].([]any); ok {
|
||||||
|
parsed.Messages = contents
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if system, ok := req["system"]; ok {
|
||||||
|
parsed.HasSystem = true
|
||||||
|
parsed.System = system
|
||||||
|
}
|
||||||
|
if messages, ok := req["messages"].([]any); ok {
|
||||||
|
parsed.Messages = messages
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return parsed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildSmallJSON 构建 ~500B 的小型测试 JSON
|
||||||
|
func buildSmallJSON() []byte {
|
||||||
|
return []byte(`{"model":"claude-sonnet-4-5","stream":true,"max_tokens":4096,"metadata":{"user_id":"user-abc123"},"thinking":{"type":"enabled","budget_tokens":2048},"system":"You are a helpful assistant.","messages":[{"role":"user","content":"What is the meaning of life?"},{"role":"assistant","content":"The meaning of life is a philosophical question."},{"role":"user","content":"Can you elaborate?"}]}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildLargeJSON 构建 ~50KB 的大型测试 JSON(大量 messages)
|
||||||
|
func buildLargeJSON() []byte {
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString(`{"model":"claude-sonnet-4-5","stream":true,"max_tokens":8192,"metadata":{"user_id":"user-xyz789"},"system":[{"type":"text","text":"You are a detailed assistant.","cache_control":{"type":"ephemeral"}}],"messages":[`)
|
||||||
|
|
||||||
|
msgCount := 200
|
||||||
|
for i := 0; i < msgCount; i++ {
|
||||||
|
if i > 0 {
|
||||||
|
b.WriteByte(',')
|
||||||
|
}
|
||||||
|
if i%2 == 0 {
|
||||||
|
b.WriteString(fmt.Sprintf(`{"role":"user","content":"This is user message number %d with some extra padding text to make the message reasonably long for benchmarking purposes. Lorem ipsum dolor sit amet."}`, i))
|
||||||
|
} else {
|
||||||
|
b.WriteString(fmt.Sprintf(`{"role":"assistant","content":[{"type":"text","text":"This is assistant response number %d. I will provide a detailed answer with multiple sentences to simulate real conversation content for benchmark testing."}]}`, i))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString(`]}`)
|
||||||
|
return []byte(b.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkParseGatewayRequest_Old_Small(b *testing.B) {
|
||||||
|
data := buildSmallJSON()
|
||||||
|
b.SetBytes(int64(len(data)))
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = parseGatewayRequestOld(data, "")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkParseGatewayRequest_New_Small(b *testing.B) {
|
||||||
|
data := buildSmallJSON()
|
||||||
|
b.SetBytes(int64(len(data)))
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = ParseGatewayRequest(data, "")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkParseGatewayRequest_Old_Large(b *testing.B) {
|
||||||
|
data := buildLargeJSON()
|
||||||
|
b.SetBytes(int64(len(data)))
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = parseGatewayRequestOld(data, "")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkParseGatewayRequest_New_Large(b *testing.B) {
|
||||||
|
data := buildLargeJSON()
|
||||||
|
b.SetBytes(int64(len(data)))
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_, _ = ParseGatewayRequest(data, "")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ import (
|
|||||||
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
const geminiStickySessionTTL = time.Hour
|
const geminiStickySessionTTL = time.Hour
|
||||||
@@ -929,7 +930,8 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream stream")
|
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream stream")
|
||||||
}
|
}
|
||||||
claudeResp, usageObj2 := convertGeminiToClaudeMessage(collected, originalModel)
|
collectedBytes, _ := json.Marshal(collected)
|
||||||
|
claudeResp, usageObj2 := convertGeminiToClaudeMessage(collected, originalModel, collectedBytes)
|
||||||
c.JSON(http.StatusOK, claudeResp)
|
c.JSON(http.StatusOK, claudeResp)
|
||||||
usage = usageObj2
|
usage = usageObj2
|
||||||
if usageObj != nil && (usageObj.InputTokens > 0 || usageObj.OutputTokens > 0) {
|
if usageObj != nil && (usageObj.InputTokens > 0 || usageObj.OutputTokens > 0) {
|
||||||
@@ -1726,12 +1728,17 @@ func (s *GeminiMessagesCompatService) handleNonStreamingResponse(c *gin.Context,
|
|||||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream response")
|
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to read upstream response")
|
||||||
}
|
}
|
||||||
|
|
||||||
geminiResp, err := unwrapGeminiResponse(body)
|
unwrappedBody, err := unwrapGeminiResponse(body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
|
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
|
||||||
}
|
}
|
||||||
|
|
||||||
claudeResp, usage := convertGeminiToClaudeMessage(geminiResp, originalModel)
|
var geminiResp map[string]any
|
||||||
|
if err := json.Unmarshal(unwrappedBody, &geminiResp); err != nil {
|
||||||
|
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Failed to parse upstream response")
|
||||||
|
}
|
||||||
|
|
||||||
|
claudeResp, usage := convertGeminiToClaudeMessage(geminiResp, originalModel, unwrappedBody)
|
||||||
c.JSON(http.StatusOK, claudeResp)
|
c.JSON(http.StatusOK, claudeResp)
|
||||||
|
|
||||||
return usage, nil
|
return usage, nil
|
||||||
@@ -1804,11 +1811,16 @@ func (s *GeminiMessagesCompatService) handleStreamingResponse(c *gin.Context, re
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
geminiResp, err := unwrapGeminiResponse([]byte(payload))
|
unwrappedBytes, err := unwrapGeminiResponse([]byte(payload))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var geminiResp map[string]any
|
||||||
|
if err := json.Unmarshal(unwrappedBytes, &geminiResp); err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
if fr := extractGeminiFinishReason(geminiResp); fr != "" {
|
if fr := extractGeminiFinishReason(geminiResp); fr != "" {
|
||||||
finishReason = fr
|
finishReason = fr
|
||||||
}
|
}
|
||||||
@@ -1935,7 +1947,7 @@ func (s *GeminiMessagesCompatService) handleStreamingResponse(c *gin.Context, re
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if u := extractGeminiUsage(geminiResp); u != nil {
|
if u := extractGeminiUsage(unwrappedBytes); u != nil {
|
||||||
usage = *u
|
usage = *u
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2026,11 +2038,7 @@ func unwrapIfNeeded(isOAuth bool, raw []byte) []byte {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return raw
|
return raw
|
||||||
}
|
}
|
||||||
b, err := json.Marshal(inner)
|
return inner
|
||||||
if err != nil {
|
|
||||||
return raw
|
|
||||||
}
|
|
||||||
return b
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsage, error) {
|
func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsage, error) {
|
||||||
@@ -2054,17 +2062,20 @@ func collectGeminiSSE(body io.Reader, isOAuth bool) (map[string]any, *ClaudeUsag
|
|||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
var parsed map[string]any
|
var parsed map[string]any
|
||||||
|
var rawBytes []byte
|
||||||
if isOAuth {
|
if isOAuth {
|
||||||
inner, err := unwrapGeminiResponse([]byte(payload))
|
innerBytes, err := unwrapGeminiResponse([]byte(payload))
|
||||||
if err == nil && inner != nil {
|
if err == nil {
|
||||||
parsed = inner
|
rawBytes = innerBytes
|
||||||
|
_ = json.Unmarshal(innerBytes, &parsed)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
_ = json.Unmarshal([]byte(payload), &parsed)
|
rawBytes = []byte(payload)
|
||||||
|
_ = json.Unmarshal(rawBytes, &parsed)
|
||||||
}
|
}
|
||||||
if parsed != nil {
|
if parsed != nil {
|
||||||
last = parsed
|
last = parsed
|
||||||
if u := extractGeminiUsage(parsed); u != nil {
|
if u := extractGeminiUsage(rawBytes); u != nil {
|
||||||
usage = u
|
usage = u
|
||||||
}
|
}
|
||||||
if parts := extractGeminiParts(parsed); len(parts) > 0 {
|
if parts := extractGeminiParts(parsed); len(parts) > 0 {
|
||||||
@@ -2193,53 +2204,27 @@ func isGeminiInsufficientScope(headers http.Header, body []byte) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func estimateGeminiCountTokens(reqBody []byte) int {
|
func estimateGeminiCountTokens(reqBody []byte) int {
|
||||||
var obj map[string]any
|
total := 0
|
||||||
if err := json.Unmarshal(reqBody, &obj); err != nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
var texts []string
|
|
||||||
|
|
||||||
// systemInstruction.parts[].text
|
// systemInstruction.parts[].text
|
||||||
if si, ok := obj["systemInstruction"].(map[string]any); ok {
|
gjson.GetBytes(reqBody, "systemInstruction.parts").ForEach(func(_, part gjson.Result) bool {
|
||||||
if parts, ok := si["parts"].([]any); ok {
|
if t := strings.TrimSpace(part.Get("text").String()); t != "" {
|
||||||
for _, p := range parts {
|
total += estimateTokensForText(t)
|
||||||
if pm, ok := p.(map[string]any); ok {
|
|
||||||
if t, ok := pm["text"].(string); ok && strings.TrimSpace(t) != "" {
|
|
||||||
texts = append(texts, t)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
// contents[].parts[].text
|
// contents[].parts[].text
|
||||||
if contents, ok := obj["contents"].([]any); ok {
|
gjson.GetBytes(reqBody, "contents").ForEach(func(_, content gjson.Result) bool {
|
||||||
for _, c := range contents {
|
content.Get("parts").ForEach(func(_, part gjson.Result) bool {
|
||||||
cm, ok := c.(map[string]any)
|
if t := strings.TrimSpace(part.Get("text").String()); t != "" {
|
||||||
if !ok {
|
total += estimateTokensForText(t)
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
parts, ok := cm["parts"].([]any)
|
return true
|
||||||
if !ok {
|
})
|
||||||
continue
|
return true
|
||||||
}
|
})
|
||||||
for _, p := range parts {
|
|
||||||
pm, ok := p.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if t, ok := pm["text"].(string); ok && strings.TrimSpace(t) != "" {
|
|
||||||
texts = append(texts, t)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
total := 0
|
|
||||||
for _, t := range texts {
|
|
||||||
total += estimateTokensForText(t)
|
|
||||||
}
|
|
||||||
if total < 0 {
|
if total < 0 {
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
@@ -2293,10 +2278,11 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
|
|||||||
|
|
||||||
var parsed map[string]any
|
var parsed map[string]any
|
||||||
if isOAuth {
|
if isOAuth {
|
||||||
parsed, err = unwrapGeminiResponse(respBody)
|
unwrappedBody, uwErr := unwrapGeminiResponse(respBody)
|
||||||
if err == nil && parsed != nil {
|
if uwErr == nil {
|
||||||
respBody, _ = json.Marshal(parsed)
|
respBody = unwrappedBody
|
||||||
}
|
}
|
||||||
|
_ = json.Unmarshal(respBody, &parsed)
|
||||||
} else {
|
} else {
|
||||||
_ = json.Unmarshal(respBody, &parsed)
|
_ = json.Unmarshal(respBody, &parsed)
|
||||||
}
|
}
|
||||||
@@ -2309,10 +2295,8 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
|
|||||||
}
|
}
|
||||||
c.Data(resp.StatusCode, contentType, respBody)
|
c.Data(resp.StatusCode, contentType, respBody)
|
||||||
|
|
||||||
if parsed != nil {
|
if u := extractGeminiUsage(respBody); u != nil {
|
||||||
if u := extractGeminiUsage(parsed); u != nil {
|
return u, nil
|
||||||
return u, nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return &ClaudeUsage{}, nil
|
return &ClaudeUsage{}, nil
|
||||||
}
|
}
|
||||||
@@ -2365,23 +2349,19 @@ func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Conte
|
|||||||
var rawToWrite string
|
var rawToWrite string
|
||||||
rawToWrite = payload
|
rawToWrite = payload
|
||||||
|
|
||||||
var parsed map[string]any
|
var rawBytes []byte
|
||||||
if isOAuth {
|
if isOAuth {
|
||||||
inner, err := unwrapGeminiResponse([]byte(payload))
|
innerBytes, err := unwrapGeminiResponse([]byte(payload))
|
||||||
if err == nil && inner != nil {
|
if err == nil {
|
||||||
parsed = inner
|
rawToWrite = string(innerBytes)
|
||||||
if b, err := json.Marshal(inner); err == nil {
|
rawBytes = innerBytes
|
||||||
rawToWrite = string(b)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
_ = json.Unmarshal([]byte(payload), &parsed)
|
rawBytes = []byte(payload)
|
||||||
}
|
}
|
||||||
|
|
||||||
if parsed != nil {
|
if u := extractGeminiUsage(rawBytes); u != nil {
|
||||||
if u := extractGeminiUsage(parsed); u != nil {
|
usage = u
|
||||||
usage = u
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if firstTokenMs == nil {
|
if firstTokenMs == nil {
|
||||||
@@ -2484,19 +2464,18 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func unwrapGeminiResponse(raw []byte) (map[string]any, error) {
|
// unwrapGeminiResponse 解包 Gemini OAuth 响应中的 response 字段
|
||||||
var outer map[string]any
|
// 使用 gjson 零拷贝提取,避免完整 Unmarshal+Marshal
|
||||||
if err := json.Unmarshal(raw, &outer); err != nil {
|
func unwrapGeminiResponse(raw []byte) ([]byte, error) {
|
||||||
return nil, err
|
result := gjson.GetBytes(raw, "response")
|
||||||
|
if result.Exists() && result.Type == gjson.JSON {
|
||||||
|
return []byte(result.Raw), nil
|
||||||
}
|
}
|
||||||
if resp, ok := outer["response"].(map[string]any); ok && resp != nil {
|
return raw, nil
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
return outer, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertGeminiToClaudeMessage(geminiResp map[string]any, originalModel string) (map[string]any, *ClaudeUsage) {
|
func convertGeminiToClaudeMessage(geminiResp map[string]any, originalModel string, rawData []byte) (map[string]any, *ClaudeUsage) {
|
||||||
usage := extractGeminiUsage(geminiResp)
|
usage := extractGeminiUsage(rawData)
|
||||||
if usage == nil {
|
if usage == nil {
|
||||||
usage = &ClaudeUsage{}
|
usage = &ClaudeUsage{}
|
||||||
}
|
}
|
||||||
@@ -2560,14 +2539,14 @@ func convertGeminiToClaudeMessage(geminiResp map[string]any, originalModel strin
|
|||||||
return resp, usage
|
return resp, usage
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractGeminiUsage(geminiResp map[string]any) *ClaudeUsage {
|
func extractGeminiUsage(data []byte) *ClaudeUsage {
|
||||||
usageMeta, ok := geminiResp["usageMetadata"].(map[string]any)
|
usage := gjson.GetBytes(data, "usageMetadata")
|
||||||
if !ok || usageMeta == nil {
|
if !usage.Exists() {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
prompt, _ := asInt(usageMeta["promptTokenCount"])
|
prompt := int(usage.Get("promptTokenCount").Int())
|
||||||
cand, _ := asInt(usageMeta["candidatesTokenCount"])
|
cand := int(usage.Get("candidatesTokenCount").Int())
|
||||||
cached, _ := asInt(usageMeta["cachedContentTokenCount"])
|
cached := int(usage.Get("cachedContentTokenCount").Int())
|
||||||
// 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount,
|
// 注意:Gemini 的 promptTokenCount 包含 cachedContentTokenCount,
|
||||||
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去
|
// 但 Claude 的 input_tokens 不包含 cache_read_input_tokens,需要减去
|
||||||
return &ClaudeUsage{
|
return &ClaudeUsage{
|
||||||
@@ -2646,39 +2625,35 @@ func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Cont
|
|||||||
|
|
||||||
// ParseGeminiRateLimitResetTime 解析 Gemini 格式的 429 响应,返回重置时间的 Unix 时间戳
|
// ParseGeminiRateLimitResetTime 解析 Gemini 格式的 429 响应,返回重置时间的 Unix 时间戳
|
||||||
func ParseGeminiRateLimitResetTime(body []byte) *int64 {
|
func ParseGeminiRateLimitResetTime(body []byte) *int64 {
|
||||||
// Try to parse metadata.quotaResetDelay like "12.345s"
|
// 第一阶段:gjson 结构化提取
|
||||||
var parsed map[string]any
|
errMsg := gjson.GetBytes(body, "error.message").String()
|
||||||
if err := json.Unmarshal(body, &parsed); err == nil {
|
if looksLikeGeminiDailyQuota(errMsg) {
|
||||||
if errObj, ok := parsed["error"].(map[string]any); ok {
|
if ts := nextGeminiDailyResetUnix(); ts != nil {
|
||||||
if msg, ok := errObj["message"].(string); ok {
|
return ts
|
||||||
if looksLikeGeminiDailyQuota(msg) {
|
|
||||||
if ts := nextGeminiDailyResetUnix(); ts != nil {
|
|
||||||
return ts
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if details, ok := errObj["details"].([]any); ok {
|
|
||||||
for _, d := range details {
|
|
||||||
dm, ok := d.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if meta, ok := dm["metadata"].(map[string]any); ok {
|
|
||||||
if v, ok := meta["quotaResetDelay"].(string); ok {
|
|
||||||
if dur, err := time.ParseDuration(v); err == nil {
|
|
||||||
// Use ceil to avoid undercounting fractional seconds (e.g. 10.1s should not become 10s),
|
|
||||||
// which can affect scheduling decisions around thresholds (like 10s).
|
|
||||||
ts := time.Now().Unix() + int64(math.Ceil(dur.Seconds()))
|
|
||||||
return &ts
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Match "Please retry in Xs"
|
// 遍历 error.details 查找 quotaResetDelay
|
||||||
|
var found *int64
|
||||||
|
gjson.GetBytes(body, "error.details").ForEach(func(_, detail gjson.Result) bool {
|
||||||
|
v := detail.Get("metadata.quotaResetDelay").String()
|
||||||
|
if v == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if dur, err := time.ParseDuration(v); err == nil {
|
||||||
|
// Use ceil to avoid undercounting fractional seconds (e.g. 10.1s should not become 10s),
|
||||||
|
// which can affect scheduling decisions around thresholds (like 10s).
|
||||||
|
ts := time.Now().Unix() + int64(math.Ceil(dur.Seconds()))
|
||||||
|
found = &ts
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
if found != nil {
|
||||||
|
return found
|
||||||
|
}
|
||||||
|
|
||||||
|
// 第二阶段:regex 回退匹配 "Please retry in Xs"
|
||||||
matches := retryInRegex.FindStringSubmatch(string(body))
|
matches := retryInRegex.FindStringSubmatch(string(body))
|
||||||
if len(matches) == 2 {
|
if len(matches) == 2 {
|
||||||
if dur, err := time.ParseDuration(matches[1] + "s"); err == nil {
|
if dur, err := time.ParseDuration(matches[1] + "s"); err == nil {
|
||||||
|
|||||||
@@ -2,8 +2,12 @@ package service
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换
|
// TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换
|
||||||
@@ -203,3 +207,304 @@ func TestEnsureGeminiFunctionCallThoughtSignatures_InsertsWhenMissing(t *testing
|
|||||||
t.Fatalf("expected injected thoughtSignature %q, got: %s", geminiDummyThoughtSignature, s)
|
t.Fatalf("expected injected thoughtSignature %q, got: %s", geminiDummyThoughtSignature, s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestUnwrapGeminiResponse 测试 unwrapGeminiResponse 的各种输入场景
|
||||||
|
// 关键区别:只有 response 为 JSON 对象/数组时才解包
|
||||||
|
func TestUnwrapGeminiResponse(t *testing.T) {
|
||||||
|
// 构造 >50KB 的大型 JSON 对象
|
||||||
|
largePadding := strings.Repeat("x", 50*1024)
|
||||||
|
largeInput := []byte(fmt.Sprintf(`{"response":{"id":"big","pad":"%s"}}`, largePadding))
|
||||||
|
largeExpected := fmt.Sprintf(`{"id":"big","pad":"%s"}`, largePadding)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input []byte
|
||||||
|
expected string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "正常 response 包装(JSON 对象)",
|
||||||
|
input: []byte(`{"response":{"key":"val"}}`),
|
||||||
|
expected: `{"key":"val"}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "无包装直接返回",
|
||||||
|
input: []byte(`{"key":"val"}`),
|
||||||
|
expected: `{"key":"val"}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "空 JSON",
|
||||||
|
input: []byte(`{}`),
|
||||||
|
expected: `{}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "null response 返回原始 body",
|
||||||
|
input: []byte(`{"response":null}`),
|
||||||
|
expected: `{"response":null}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "非法 JSON 返回原始 body",
|
||||||
|
input: []byte(`not json`),
|
||||||
|
expected: `not json`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "response 为基础类型 string 返回原始 body",
|
||||||
|
input: []byte(`{"response":"hello"}`),
|
||||||
|
expected: `{"response":"hello"}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "嵌套 response 只解一层",
|
||||||
|
input: []byte(`{"response":{"response":{"inner":true}}}`),
|
||||||
|
expected: `{"response":{"inner":true}}`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "大型 JSON >50KB",
|
||||||
|
input: largeInput,
|
||||||
|
expected: largeExpected,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := unwrapGeminiResponse(tt.input)
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, tt.expected, strings.TrimSpace(string(got)))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Task 8.1 — extractGeminiUsage 测试
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestExtractGeminiUsage(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
wantNil bool
|
||||||
|
wantUsage *ClaudeUsage
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "完整 usageMetadata",
|
||||||
|
input: `{"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":50,"cachedContentTokenCount":20}}`,
|
||||||
|
wantNil: false,
|
||||||
|
wantUsage: &ClaudeUsage{
|
||||||
|
InputTokens: 80,
|
||||||
|
OutputTokens: 50,
|
||||||
|
CacheReadInputTokens: 20,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "缺失 cachedContentTokenCount",
|
||||||
|
input: `{"usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":50}}`,
|
||||||
|
wantNil: false,
|
||||||
|
wantUsage: &ClaudeUsage{
|
||||||
|
InputTokens: 100,
|
||||||
|
OutputTokens: 50,
|
||||||
|
CacheReadInputTokens: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "无 usageMetadata",
|
||||||
|
input: `{"candidates":[]}`,
|
||||||
|
wantNil: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// gjson 对 null 返回 Exists()=true,因此函数不会返回 nil,
|
||||||
|
// 而是返回全零的 ClaudeUsage。
|
||||||
|
name: "null usageMetadata — gjson Exists 为 true",
|
||||||
|
input: `{"usageMetadata":null}`,
|
||||||
|
wantNil: false,
|
||||||
|
wantUsage: &ClaudeUsage{
|
||||||
|
InputTokens: 0,
|
||||||
|
OutputTokens: 0,
|
||||||
|
CacheReadInputTokens: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "零值字段",
|
||||||
|
input: `{"usageMetadata":{"promptTokenCount":0,"candidatesTokenCount":0,"cachedContentTokenCount":0}}`,
|
||||||
|
wantNil: false,
|
||||||
|
wantUsage: &ClaudeUsage{
|
||||||
|
InputTokens: 0,
|
||||||
|
OutputTokens: 0,
|
||||||
|
CacheReadInputTokens: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := extractGeminiUsage([]byte(tt.input))
|
||||||
|
if tt.wantNil {
|
||||||
|
if got != nil {
|
||||||
|
t.Fatalf("期望返回 nil,实际返回 %+v", got)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got == nil {
|
||||||
|
t.Fatalf("期望返回非 nil,实际返回 nil")
|
||||||
|
}
|
||||||
|
if got.InputTokens != tt.wantUsage.InputTokens {
|
||||||
|
t.Errorf("InputTokens: 期望 %d,实际 %d", tt.wantUsage.InputTokens, got.InputTokens)
|
||||||
|
}
|
||||||
|
if got.OutputTokens != tt.wantUsage.OutputTokens {
|
||||||
|
t.Errorf("OutputTokens: 期望 %d,实际 %d", tt.wantUsage.OutputTokens, got.OutputTokens)
|
||||||
|
}
|
||||||
|
if got.CacheReadInputTokens != tt.wantUsage.CacheReadInputTokens {
|
||||||
|
t.Errorf("CacheReadInputTokens: 期望 %d,实际 %d", tt.wantUsage.CacheReadInputTokens, got.CacheReadInputTokens)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Task 8.2 — estimateGeminiCountTokens 测试
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestEstimateGeminiCountTokens(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
wantGt0 bool // 期望结果 > 0
|
||||||
|
wantExact *int // 如果非 nil,期望精确匹配
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "含 systemInstruction 和 contents",
|
||||||
|
input: `{
|
||||||
|
"systemInstruction":{"parts":[{"text":"You are a helpful assistant."}]},
|
||||||
|
"contents":[{"parts":[{"text":"Hello, how are you?"}]}]
|
||||||
|
}`,
|
||||||
|
wantGt0: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "仅 contents,无 systemInstruction",
|
||||||
|
input: `{
|
||||||
|
"contents":[{"parts":[{"text":"Hello, how are you?"}]}]
|
||||||
|
}`,
|
||||||
|
wantGt0: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "空 parts",
|
||||||
|
input: `{"contents":[{"parts":[]}]}`,
|
||||||
|
wantGt0: false,
|
||||||
|
wantExact: intPtr(0),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "非文本 parts(inlineData)",
|
||||||
|
input: `{"contents":[{"parts":[{"inlineData":{"mimeType":"image/png"}}]}]}`,
|
||||||
|
wantGt0: false,
|
||||||
|
wantExact: intPtr(0),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "空白文本",
|
||||||
|
input: `{"contents":[{"parts":[{"text":" "}]}]}`,
|
||||||
|
wantGt0: false,
|
||||||
|
wantExact: intPtr(0),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := estimateGeminiCountTokens([]byte(tt.input))
|
||||||
|
if tt.wantExact != nil {
|
||||||
|
if got != *tt.wantExact {
|
||||||
|
t.Errorf("期望精确值 %d,实际 %d", *tt.wantExact, got)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tt.wantGt0 && got <= 0 {
|
||||||
|
t.Errorf("期望返回 > 0,实际 %d", got)
|
||||||
|
}
|
||||||
|
if !tt.wantGt0 && got != 0 {
|
||||||
|
t.Errorf("期望返回 0,实际 %d", got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
// Task 8.3 — ParseGeminiRateLimitResetTime 测试
|
||||||
|
// ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func TestParseGeminiRateLimitResetTime(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
wantNil bool
|
||||||
|
approxDelta int64 // 预期的 (返回值 - now) 大约是多少秒
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "正常 quotaResetDelay",
|
||||||
|
input: `{"error":{"details":[{"metadata":{"quotaResetDelay":"12.345s"}}]}}`,
|
||||||
|
wantNil: false,
|
||||||
|
approxDelta: 13, // 向上取整 12.345 -> 13
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "daily quota",
|
||||||
|
input: `{"error":{"message":"quota per day exceeded"}}`,
|
||||||
|
wantNil: false,
|
||||||
|
approxDelta: -1, // 不检查精确 delta,仅检查非 nil
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "无 details 且无 regex 匹配",
|
||||||
|
input: `{"error":{"message":"rate limit"}}`,
|
||||||
|
wantNil: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "regex 回退匹配",
|
||||||
|
input: `Please retry in 30s`,
|
||||||
|
wantNil: false,
|
||||||
|
approxDelta: 30,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "完全无匹配",
|
||||||
|
input: `{"error":{"code":429}}`,
|
||||||
|
wantNil: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "非法 JSON 但 regex 回退仍工作",
|
||||||
|
input: `not json but Please retry in 10s`,
|
||||||
|
wantNil: false,
|
||||||
|
approxDelta: 10,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
now := time.Now().Unix()
|
||||||
|
got := ParseGeminiRateLimitResetTime([]byte(tt.input))
|
||||||
|
|
||||||
|
if tt.wantNil {
|
||||||
|
if got != nil {
|
||||||
|
t.Fatalf("期望返回 nil,实际返回 %d", *got)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if got == nil {
|
||||||
|
t.Fatalf("期望返回非 nil,实际返回 nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
// approxDelta == -1 表示只检查非 nil,不检查具体值(如 daily quota 场景)
|
||||||
|
if tt.approxDelta == -1 {
|
||||||
|
// 仅验证返回的时间戳在合理范围内(未来的某个时间)
|
||||||
|
if *got < now {
|
||||||
|
t.Errorf("期望返回的时间戳 >= now(%d),实际 %d", now, *got)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用 +/-2 秒容差进行范围检查
|
||||||
|
delta := *got - now
|
||||||
|
if delta < tt.approxDelta-2 || delta > tt.approxDelta+2 {
|
||||||
|
t.Errorf("期望 delta 约为 %d 秒(+/-2),实际 delta = %d 秒(返回值=%d, now=%d)",
|
||||||
|
tt.approxDelta, delta, *got, now)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -230,7 +230,7 @@ func NewOpenAIGatewayService(
|
|||||||
// 1. Header: session_id
|
// 1. Header: session_id
|
||||||
// 2. Header: conversation_id
|
// 2. Header: conversation_id
|
||||||
// 3. Body: prompt_cache_key (opencode)
|
// 3. Body: prompt_cache_key (opencode)
|
||||||
func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, reqBody map[string]any) string {
|
func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, body []byte) string {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
@@ -239,10 +239,8 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context, reqBody map[s
|
|||||||
if sessionID == "" {
|
if sessionID == "" {
|
||||||
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
|
sessionID = strings.TrimSpace(c.GetHeader("conversation_id"))
|
||||||
}
|
}
|
||||||
if sessionID == "" && reqBody != nil {
|
if sessionID == "" && len(body) > 0 {
|
||||||
if v, ok := reqBody["prompt_cache_key"].(string); ok {
|
sessionID = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
|
||||||
sessionID = strings.TrimSpace(v)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if sessionID == "" {
|
if sessionID == "" {
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -129,17 +129,19 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) {
|
|||||||
|
|
||||||
svc := &OpenAIGatewayService{}
|
svc := &OpenAIGatewayService{}
|
||||||
|
|
||||||
|
bodyWithKey := []byte(`{"prompt_cache_key":"ses_aaa"}`)
|
||||||
|
|
||||||
// 1) session_id header wins
|
// 1) session_id header wins
|
||||||
c.Request.Header.Set("session_id", "sess-123")
|
c.Request.Header.Set("session_id", "sess-123")
|
||||||
c.Request.Header.Set("conversation_id", "conv-456")
|
c.Request.Header.Set("conversation_id", "conv-456")
|
||||||
h1 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"})
|
h1 := svc.GenerateSessionHash(c, bodyWithKey)
|
||||||
if h1 == "" {
|
if h1 == "" {
|
||||||
t.Fatalf("expected non-empty hash")
|
t.Fatalf("expected non-empty hash")
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2) conversation_id used when session_id absent
|
// 2) conversation_id used when session_id absent
|
||||||
c.Request.Header.Del("session_id")
|
c.Request.Header.Del("session_id")
|
||||||
h2 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"})
|
h2 := svc.GenerateSessionHash(c, bodyWithKey)
|
||||||
if h2 == "" {
|
if h2 == "" {
|
||||||
t.Fatalf("expected non-empty hash")
|
t.Fatalf("expected non-empty hash")
|
||||||
}
|
}
|
||||||
@@ -149,7 +151,7 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) {
|
|||||||
|
|
||||||
// 3) prompt_cache_key used when both headers absent
|
// 3) prompt_cache_key used when both headers absent
|
||||||
c.Request.Header.Del("conversation_id")
|
c.Request.Header.Del("conversation_id")
|
||||||
h3 := svc.GenerateSessionHash(c, map[string]any{"prompt_cache_key": "ses_aaa"})
|
h3 := svc.GenerateSessionHash(c, bodyWithKey)
|
||||||
if h3 == "" {
|
if h3 == "" {
|
||||||
t.Fatalf("expected non-empty hash")
|
t.Fatalf("expected non-empty hash")
|
||||||
}
|
}
|
||||||
@@ -158,7 +160,7 @@ func TestOpenAIGatewayService_GenerateSessionHash_Priority(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 4) empty when no signals
|
// 4) empty when no signals
|
||||||
h4 := svc.GenerateSessionHash(c, map[string]any{})
|
h4 := svc.GenerateSessionHash(c, []byte(`{}`))
|
||||||
if h4 != "" {
|
if h4 != "" {
|
||||||
t.Fatalf("expected empty hash when no signals")
|
t.Fatalf("expected empty hash when no signals")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import (
|
|||||||
|
|
||||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
"golang.org/x/crypto/sha3"
|
"golang.org/x/crypto/sha3"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -219,12 +220,8 @@ func (c *SoraDirectClient) UploadImage(ctx context.Context, account *Account, da
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
var payload map[string]any
|
id := strings.TrimSpace(gjson.GetBytes(respBody, "id").String())
|
||||||
if err := json.Unmarshal(respBody, &payload); err != nil {
|
if id == "" {
|
||||||
return "", fmt.Errorf("parse upload response: %w", err)
|
|
||||||
}
|
|
||||||
id, _ := payload["id"].(string)
|
|
||||||
if strings.TrimSpace(id) == "" {
|
|
||||||
return "", errors.New("upload response missing id")
|
return "", errors.New("upload response missing id")
|
||||||
}
|
}
|
||||||
return id, nil
|
return id, nil
|
||||||
@@ -274,12 +271,8 @@ func (c *SoraDirectClient) CreateImageTask(ctx context.Context, account *Account
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
var resp map[string]any
|
taskID := strings.TrimSpace(gjson.GetBytes(respBody, "id").String())
|
||||||
if err := json.Unmarshal(respBody, &resp); err != nil {
|
if taskID == "" {
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
taskID, _ := resp["id"].(string)
|
|
||||||
if strings.TrimSpace(taskID) == "" {
|
|
||||||
return "", errors.New("image task response missing id")
|
return "", errors.New("image task response missing id")
|
||||||
}
|
}
|
||||||
return taskID, nil
|
return taskID, nil
|
||||||
@@ -347,12 +340,8 @@ func (c *SoraDirectClient) CreateVideoTask(ctx context.Context, account *Account
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
var resp map[string]any
|
taskID := strings.TrimSpace(gjson.GetBytes(respBody, "id").String())
|
||||||
if err := json.Unmarshal(respBody, &resp); err != nil {
|
if taskID == "" {
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
taskID, _ := resp["id"].(string)
|
|
||||||
if strings.TrimSpace(taskID) == "" {
|
|
||||||
return "", errors.New("video task response missing id")
|
return "", errors.New("video task response missing id")
|
||||||
}
|
}
|
||||||
return taskID, nil
|
return taskID, nil
|
||||||
@@ -393,41 +382,30 @@ func (c *SoraDirectClient) fetchRecentImageTask(ctx context.Context, account *Ac
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false, err
|
return nil, false, err
|
||||||
}
|
}
|
||||||
var resp map[string]any
|
var found *SoraImageTaskStatus
|
||||||
if err := json.Unmarshal(respBody, &resp); err != nil {
|
gjson.GetBytes(respBody, "task_responses").ForEach(func(_, item gjson.Result) bool {
|
||||||
return nil, false, err
|
if item.Get("id").String() != taskID {
|
||||||
}
|
return true // continue
|
||||||
taskResponses, _ := resp["task_responses"].([]any)
|
|
||||||
for _, item := range taskResponses {
|
|
||||||
taskResp, ok := item.(map[string]any)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
if id, _ := taskResp["id"].(string); id == taskID {
|
status := strings.TrimSpace(item.Get("status").String())
|
||||||
status := strings.TrimSpace(fmt.Sprintf("%v", taskResp["status"]))
|
progress := item.Get("progress_pct").Float()
|
||||||
progress := 0.0
|
var urls []string
|
||||||
if v, ok := taskResp["progress_pct"].(float64); ok {
|
item.Get("generations").ForEach(func(_, gen gjson.Result) bool {
|
||||||
progress = v
|
if u := strings.TrimSpace(gen.Get("url").String()); u != "" {
|
||||||
|
urls = append(urls, u)
|
||||||
}
|
}
|
||||||
urls := []string{}
|
return true
|
||||||
if generations, ok := taskResp["generations"].([]any); ok {
|
})
|
||||||
for _, genItem := range generations {
|
found = &SoraImageTaskStatus{
|
||||||
gen, ok := genItem.(map[string]any)
|
ID: taskID,
|
||||||
if !ok {
|
Status: status,
|
||||||
continue
|
ProgressPct: progress,
|
||||||
}
|
URLs: urls,
|
||||||
if urlStr, ok := gen["url"].(string); ok && strings.TrimSpace(urlStr) != "" {
|
|
||||||
urls = append(urls, urlStr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &SoraImageTaskStatus{
|
|
||||||
ID: taskID,
|
|
||||||
Status: status,
|
|
||||||
ProgressPct: progress,
|
|
||||||
URLs: urls,
|
|
||||||
}, true, nil
|
|
||||||
}
|
}
|
||||||
|
return false // break
|
||||||
|
})
|
||||||
|
if found != nil {
|
||||||
|
return found, true, nil
|
||||||
}
|
}
|
||||||
return &SoraImageTaskStatus{ID: taskID, Status: "processing"}, false, nil
|
return &SoraImageTaskStatus{ID: taskID, Status: "processing"}, false, nil
|
||||||
}
|
}
|
||||||
@@ -463,27 +441,28 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var pending any
|
// 搜索 pending 列表(JSON 数组)
|
||||||
if err := json.Unmarshal(respBody, &pending); err == nil {
|
pendingResult := gjson.ParseBytes(respBody)
|
||||||
if list, ok := pending.([]any); ok {
|
if pendingResult.IsArray() {
|
||||||
for _, item := range list {
|
var pendingFound *SoraVideoTaskStatus
|
||||||
task, ok := item.(map[string]any)
|
pendingResult.ForEach(func(_, task gjson.Result) bool {
|
||||||
if !ok {
|
if task.Get("id").String() != taskID {
|
||||||
continue
|
return true
|
||||||
}
|
|
||||||
if id, _ := task["id"].(string); id == taskID {
|
|
||||||
progress := 0
|
|
||||||
if v, ok := task["progress_pct"].(float64); ok {
|
|
||||||
progress = int(v * 100)
|
|
||||||
}
|
|
||||||
status := strings.TrimSpace(fmt.Sprintf("%v", task["status"]))
|
|
||||||
return &SoraVideoTaskStatus{
|
|
||||||
ID: taskID,
|
|
||||||
Status: status,
|
|
||||||
ProgressPct: progress,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
progress := 0
|
||||||
|
if v := task.Get("progress_pct"); v.Exists() {
|
||||||
|
progress = int(v.Float() * 100)
|
||||||
|
}
|
||||||
|
status := strings.TrimSpace(task.Get("status").String())
|
||||||
|
pendingFound = &SoraVideoTaskStatus{
|
||||||
|
ID: taskID,
|
||||||
|
Status: status,
|
||||||
|
ProgressPct: progress,
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
if pendingFound != nil {
|
||||||
|
return pendingFound, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -491,44 +470,42 @@ func (c *SoraDirectClient) GetVideoTask(ctx context.Context, account *Account, t
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
var draftsResp map[string]any
|
var draftFound *SoraVideoTaskStatus
|
||||||
if err := json.Unmarshal(respBody, &draftsResp); err != nil {
|
gjson.GetBytes(respBody, "items").ForEach(func(_, draft gjson.Result) bool {
|
||||||
return nil, err
|
if draft.Get("task_id").String() != taskID {
|
||||||
}
|
return true
|
||||||
items, _ := draftsResp["items"].([]any)
|
}
|
||||||
for _, item := range items {
|
kind := strings.TrimSpace(draft.Get("kind").String())
|
||||||
draft, ok := item.(map[string]any)
|
reason := strings.TrimSpace(draft.Get("reason_str").String())
|
||||||
if !ok {
|
if reason == "" {
|
||||||
continue
|
reason = strings.TrimSpace(draft.Get("markdown_reason_str").String())
|
||||||
|
}
|
||||||
|
urlStr := strings.TrimSpace(draft.Get("downloadable_url").String())
|
||||||
|
if urlStr == "" {
|
||||||
|
urlStr = strings.TrimSpace(draft.Get("url").String())
|
||||||
}
|
}
|
||||||
if id, _ := draft["task_id"].(string); id == taskID {
|
|
||||||
kind := strings.TrimSpace(fmt.Sprintf("%v", draft["kind"]))
|
|
||||||
reason := strings.TrimSpace(fmt.Sprintf("%v", draft["reason_str"]))
|
|
||||||
if reason == "" {
|
|
||||||
reason = strings.TrimSpace(fmt.Sprintf("%v", draft["markdown_reason_str"]))
|
|
||||||
}
|
|
||||||
urlStr := strings.TrimSpace(fmt.Sprintf("%v", draft["downloadable_url"]))
|
|
||||||
if urlStr == "" {
|
|
||||||
urlStr = strings.TrimSpace(fmt.Sprintf("%v", draft["url"]))
|
|
||||||
}
|
|
||||||
|
|
||||||
if kind == "sora_content_violation" || reason != "" || urlStr == "" {
|
if kind == "sora_content_violation" || reason != "" || urlStr == "" {
|
||||||
msg := reason
|
msg := reason
|
||||||
if msg == "" {
|
if msg == "" {
|
||||||
msg = "Content violates guardrails"
|
msg = "Content violates guardrails"
|
||||||
}
|
|
||||||
return &SoraVideoTaskStatus{
|
|
||||||
ID: taskID,
|
|
||||||
Status: "failed",
|
|
||||||
ErrorMsg: msg,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
return &SoraVideoTaskStatus{
|
draftFound = &SoraVideoTaskStatus{
|
||||||
|
ID: taskID,
|
||||||
|
Status: "failed",
|
||||||
|
ErrorMsg: msg,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
draftFound = &SoraVideoTaskStatus{
|
||||||
ID: taskID,
|
ID: taskID,
|
||||||
Status: "completed",
|
Status: "completed",
|
||||||
URLs: []string{urlStr},
|
URLs: []string{urlStr},
|
||||||
}, nil
|
}
|
||||||
}
|
}
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
if draftFound != nil {
|
||||||
|
return draftFound, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return &SoraVideoTaskStatus{ID: taskID, Status: "processing"}, nil
|
return &SoraVideoTaskStatus{ID: taskID, Status: "processing"}, nil
|
||||||
|
|||||||
515
backend/internal/service/sora_client_gjson_test.go
Normal file
515
backend/internal/service/sora_client_gjson_test.go
Normal file
@@ -0,0 +1,515 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ---------- 辅助解析函数(复制生产代码中的 gjson 解析逻辑,用于单元测试) ----------
|
||||||
|
|
||||||
|
// testParseUploadOrCreateTaskID 模拟 UploadImage / CreateImageTask / CreateVideoTask 中
|
||||||
|
// 用 gjson.GetBytes(respBody, "id") 提取 id 的逻辑。
|
||||||
|
func testParseUploadOrCreateTaskID(respBody []byte) (string, error) {
|
||||||
|
id := strings.TrimSpace(gjson.GetBytes(respBody, "id").String())
|
||||||
|
if id == "" {
|
||||||
|
return "", assert.AnError // 占位错误,表示 "missing id"
|
||||||
|
}
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// testParseFetchRecentImageTask 模拟 fetchRecentImageTask 中的 gjson.ForEach 解析逻辑。
|
||||||
|
func testParseFetchRecentImageTask(respBody []byte, taskID string) (*SoraImageTaskStatus, bool) {
|
||||||
|
var found *SoraImageTaskStatus
|
||||||
|
gjson.GetBytes(respBody, "task_responses").ForEach(func(_, item gjson.Result) bool {
|
||||||
|
if item.Get("id").String() != taskID {
|
||||||
|
return true // continue
|
||||||
|
}
|
||||||
|
status := strings.TrimSpace(item.Get("status").String())
|
||||||
|
progress := item.Get("progress_pct").Float()
|
||||||
|
var urls []string
|
||||||
|
item.Get("generations").ForEach(func(_, gen gjson.Result) bool {
|
||||||
|
if u := strings.TrimSpace(gen.Get("url").String()); u != "" {
|
||||||
|
urls = append(urls, u)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
found = &SoraImageTaskStatus{
|
||||||
|
ID: taskID,
|
||||||
|
Status: status,
|
||||||
|
ProgressPct: progress,
|
||||||
|
URLs: urls,
|
||||||
|
}
|
||||||
|
return false // break
|
||||||
|
})
|
||||||
|
if found != nil {
|
||||||
|
return found, true
|
||||||
|
}
|
||||||
|
return &SoraImageTaskStatus{ID: taskID, Status: "processing"}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// testParseGetVideoTaskPending 模拟 GetVideoTask 中解析 pending 列表的逻辑。
|
||||||
|
func testParseGetVideoTaskPending(respBody []byte, taskID string) (*SoraVideoTaskStatus, bool) {
|
||||||
|
pendingResult := gjson.ParseBytes(respBody)
|
||||||
|
if !pendingResult.IsArray() {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
var pendingFound *SoraVideoTaskStatus
|
||||||
|
pendingResult.ForEach(func(_, task gjson.Result) bool {
|
||||||
|
if task.Get("id").String() != taskID {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
progress := 0
|
||||||
|
if v := task.Get("progress_pct"); v.Exists() {
|
||||||
|
progress = int(v.Float() * 100)
|
||||||
|
}
|
||||||
|
status := strings.TrimSpace(task.Get("status").String())
|
||||||
|
pendingFound = &SoraVideoTaskStatus{
|
||||||
|
ID: taskID,
|
||||||
|
Status: status,
|
||||||
|
ProgressPct: progress,
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
if pendingFound != nil {
|
||||||
|
return pendingFound, true
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// testParseGetVideoTaskDrafts 模拟 GetVideoTask 中解析 drafts 列表的逻辑。
|
||||||
|
func testParseGetVideoTaskDrafts(respBody []byte, taskID string) (*SoraVideoTaskStatus, bool) {
|
||||||
|
var draftFound *SoraVideoTaskStatus
|
||||||
|
gjson.GetBytes(respBody, "items").ForEach(func(_, draft gjson.Result) bool {
|
||||||
|
if draft.Get("task_id").String() != taskID {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
kind := strings.TrimSpace(draft.Get("kind").String())
|
||||||
|
reason := strings.TrimSpace(draft.Get("reason_str").String())
|
||||||
|
if reason == "" {
|
||||||
|
reason = strings.TrimSpace(draft.Get("markdown_reason_str").String())
|
||||||
|
}
|
||||||
|
urlStr := strings.TrimSpace(draft.Get("downloadable_url").String())
|
||||||
|
if urlStr == "" {
|
||||||
|
urlStr = strings.TrimSpace(draft.Get("url").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if kind == "sora_content_violation" || reason != "" || urlStr == "" {
|
||||||
|
msg := reason
|
||||||
|
if msg == "" {
|
||||||
|
msg = "Content violates guardrails"
|
||||||
|
}
|
||||||
|
draftFound = &SoraVideoTaskStatus{
|
||||||
|
ID: taskID,
|
||||||
|
Status: "failed",
|
||||||
|
ErrorMsg: msg,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
draftFound = &SoraVideoTaskStatus{
|
||||||
|
ID: taskID,
|
||||||
|
Status: "completed",
|
||||||
|
URLs: []string{urlStr},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
if draftFound != nil {
|
||||||
|
return draftFound, true
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===================== Test 1: TestSoraParseUploadResponse =====================
|
||||||
|
|
||||||
|
func TestSoraParseUploadResponse(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
wantID string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "正常 id",
|
||||||
|
body: `{"id":"file-abc123","status":"uploaded"}`,
|
||||||
|
wantID: "file-abc123",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "空 id",
|
||||||
|
body: `{"id":"","status":"uploaded"}`,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "无 id 字段",
|
||||||
|
body: `{"status":"uploaded"}`,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "id 全为空白",
|
||||||
|
body: `{"id":" ","status":"uploaded"}`,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "id 前后有空白",
|
||||||
|
body: `{"id":" file-trimmed ","status":"uploaded"}`,
|
||||||
|
wantID: "file-trimmed",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "空 JSON 对象",
|
||||||
|
body: `{}`,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
id, err := testParseUploadOrCreateTaskID([]byte(tt.body))
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err, "应返回错误")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, tt.wantID, id)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===================== Test 2: TestSoraParseCreateTaskResponse =====================
|
||||||
|
|
||||||
|
func TestSoraParseCreateTaskResponse(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
wantID string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "正常任务 id",
|
||||||
|
body: `{"id":"task-123"}`,
|
||||||
|
wantID: "task-123",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "缺失 id",
|
||||||
|
body: `{"status":"created"}`,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "空 id",
|
||||||
|
body: `{"id":" "}`,
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "id 为数字(gjson 转字符串)",
|
||||||
|
body: `{"id":123}`,
|
||||||
|
wantID: "123",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "id 含特殊字符",
|
||||||
|
body: `{"id":"task-abc-def-456-ghi"}`,
|
||||||
|
wantID: "task-abc-def-456-ghi",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "额外字段不影响解析",
|
||||||
|
body: `{"id":"task-999","type":"image_gen","extra":"data"}`,
|
||||||
|
wantID: "task-999",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
id, err := testParseUploadOrCreateTaskID([]byte(tt.body))
|
||||||
|
if tt.wantErr {
|
||||||
|
require.Error(t, err, "应返回错误")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, tt.wantID, id)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===================== Test 3: TestSoraParseFetchRecentImageTask =====================
|
||||||
|
|
||||||
|
func TestSoraParseFetchRecentImageTask(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
taskID string
|
||||||
|
wantFound bool
|
||||||
|
wantStatus string
|
||||||
|
wantProgress float64
|
||||||
|
wantURLs []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "匹配已完成任务",
|
||||||
|
body: `{"task_responses":[{"id":"task-1","status":"completed","progress_pct":1.0,"generations":[{"url":"https://example.com/img.png"}]}]}`,
|
||||||
|
taskID: "task-1",
|
||||||
|
wantFound: true,
|
||||||
|
wantStatus: "completed",
|
||||||
|
wantProgress: 1.0,
|
||||||
|
wantURLs: []string{"https://example.com/img.png"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "匹配处理中任务",
|
||||||
|
body: `{"task_responses":[{"id":"task-2","status":"processing","progress_pct":0.5,"generations":[]}]}`,
|
||||||
|
taskID: "task-2",
|
||||||
|
wantFound: true,
|
||||||
|
wantStatus: "processing",
|
||||||
|
wantProgress: 0.5,
|
||||||
|
wantURLs: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "无匹配任务",
|
||||||
|
body: `{"task_responses":[{"id":"other","status":"completed"}]}`,
|
||||||
|
taskID: "task-1",
|
||||||
|
wantFound: false,
|
||||||
|
wantStatus: "processing",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "空 task_responses",
|
||||||
|
body: `{"task_responses":[]}`,
|
||||||
|
taskID: "task-1",
|
||||||
|
wantFound: false,
|
||||||
|
wantStatus: "processing",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "缺少 task_responses 字段",
|
||||||
|
body: `{"other":"data"}`,
|
||||||
|
taskID: "task-1",
|
||||||
|
wantFound: false,
|
||||||
|
wantStatus: "processing",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "多个任务中精准匹配",
|
||||||
|
body: `{"task_responses":[{"id":"task-a","status":"completed","progress_pct":1.0,"generations":[{"url":"https://a.com/1.png"}]},{"id":"task-b","status":"processing","progress_pct":0.3,"generations":[]},{"id":"task-c","status":"failed","progress_pct":0}]}`,
|
||||||
|
taskID: "task-b",
|
||||||
|
wantFound: true,
|
||||||
|
wantStatus: "processing",
|
||||||
|
wantProgress: 0.3,
|
||||||
|
wantURLs: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "多个 generations",
|
||||||
|
body: `{"task_responses":[{"id":"task-m","status":"completed","progress_pct":1.0,"generations":[{"url":"https://a.com/1.png"},{"url":"https://a.com/2.png"},{"url":""}]}]}`,
|
||||||
|
taskID: "task-m",
|
||||||
|
wantFound: true,
|
||||||
|
wantStatus: "completed",
|
||||||
|
wantProgress: 1.0,
|
||||||
|
wantURLs: []string{"https://a.com/1.png", "https://a.com/2.png"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
status, found := testParseFetchRecentImageTask([]byte(tt.body), tt.taskID)
|
||||||
|
require.Equal(t, tt.wantFound, found, "found 不匹配")
|
||||||
|
require.NotNil(t, status)
|
||||||
|
require.Equal(t, tt.taskID, status.ID)
|
||||||
|
require.Equal(t, tt.wantStatus, status.Status)
|
||||||
|
if tt.wantFound {
|
||||||
|
require.InDelta(t, tt.wantProgress, status.ProgressPct, 0.001, "进度不匹配")
|
||||||
|
require.Equal(t, tt.wantURLs, status.URLs)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===================== Test 4: TestSoraParseGetVideoTaskPending =====================
|
||||||
|
|
||||||
|
func TestSoraParseGetVideoTaskPending(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
taskID string
|
||||||
|
wantFound bool
|
||||||
|
wantStatus string
|
||||||
|
wantProgress int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "匹配 pending 任务",
|
||||||
|
body: `[{"id":"task-1","status":"processing","progress_pct":0.5}]`,
|
||||||
|
taskID: "task-1",
|
||||||
|
wantFound: true,
|
||||||
|
wantStatus: "processing",
|
||||||
|
wantProgress: 50,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "进度为 0",
|
||||||
|
body: `[{"id":"task-2","status":"queued","progress_pct":0}]`,
|
||||||
|
taskID: "task-2",
|
||||||
|
wantFound: true,
|
||||||
|
wantStatus: "queued",
|
||||||
|
wantProgress: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "进度为 1(100%)",
|
||||||
|
body: `[{"id":"task-3","status":"completing","progress_pct":1.0}]`,
|
||||||
|
taskID: "task-3",
|
||||||
|
wantFound: true,
|
||||||
|
wantStatus: "completing",
|
||||||
|
wantProgress: 100,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "空数组",
|
||||||
|
body: `[]`,
|
||||||
|
taskID: "task-1",
|
||||||
|
wantFound: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "无匹配 id",
|
||||||
|
body: `[{"id":"task-other","status":"processing","progress_pct":0.3}]`,
|
||||||
|
taskID: "task-1",
|
||||||
|
wantFound: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "多个任务精准匹配",
|
||||||
|
body: `[{"id":"task-a","status":"processing","progress_pct":0.2},{"id":"task-b","status":"queued","progress_pct":0},{"id":"task-c","status":"processing","progress_pct":0.8}]`,
|
||||||
|
taskID: "task-c",
|
||||||
|
wantFound: true,
|
||||||
|
wantStatus: "processing",
|
||||||
|
wantProgress: 80,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "非数组 JSON",
|
||||||
|
body: `{"id":"task-1","status":"processing"}`,
|
||||||
|
taskID: "task-1",
|
||||||
|
wantFound: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "无 progress_pct 字段",
|
||||||
|
body: `[{"id":"task-4","status":"pending"}]`,
|
||||||
|
taskID: "task-4",
|
||||||
|
wantFound: true,
|
||||||
|
wantStatus: "pending",
|
||||||
|
wantProgress: 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
status, found := testParseGetVideoTaskPending([]byte(tt.body), tt.taskID)
|
||||||
|
require.Equal(t, tt.wantFound, found, "found 不匹配")
|
||||||
|
if tt.wantFound {
|
||||||
|
require.NotNil(t, status)
|
||||||
|
require.Equal(t, tt.taskID, status.ID)
|
||||||
|
require.Equal(t, tt.wantStatus, status.Status)
|
||||||
|
require.Equal(t, tt.wantProgress, status.ProgressPct)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ===================== Test 5: TestSoraParseGetVideoTaskDrafts =====================
|
||||||
|
|
||||||
|
func TestSoraParseGetVideoTaskDrafts(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
taskID string
|
||||||
|
wantFound bool
|
||||||
|
wantStatus string
|
||||||
|
wantURLs []string
|
||||||
|
wantErr string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "正常完成的视频",
|
||||||
|
body: `{"items":[{"task_id":"task-1","kind":"video","downloadable_url":"https://example.com/video.mp4"}]}`,
|
||||||
|
taskID: "task-1",
|
||||||
|
wantFound: true,
|
||||||
|
wantStatus: "completed",
|
||||||
|
wantURLs: []string{"https://example.com/video.mp4"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "使用 url 字段回退",
|
||||||
|
body: `{"items":[{"task_id":"task-2","kind":"video","url":"https://example.com/fallback.mp4"}]}`,
|
||||||
|
taskID: "task-2",
|
||||||
|
wantFound: true,
|
||||||
|
wantStatus: "completed",
|
||||||
|
wantURLs: []string{"https://example.com/fallback.mp4"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "内容违规",
|
||||||
|
body: `{"items":[{"task_id":"task-3","kind":"sora_content_violation","reason_str":"Content policy violation"}]}`,
|
||||||
|
taskID: "task-3",
|
||||||
|
wantFound: true,
|
||||||
|
wantStatus: "failed",
|
||||||
|
wantErr: "Content policy violation",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "内容违规 - markdown_reason_str 回退",
|
||||||
|
body: `{"items":[{"task_id":"task-4","kind":"sora_content_violation","markdown_reason_str":"Markdown reason"}]}`,
|
||||||
|
taskID: "task-4",
|
||||||
|
wantFound: true,
|
||||||
|
wantStatus: "failed",
|
||||||
|
wantErr: "Markdown reason",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "内容违规 - 无 reason 使用默认消息",
|
||||||
|
body: `{"items":[{"task_id":"task-5","kind":"sora_content_violation"}]}`,
|
||||||
|
taskID: "task-5",
|
||||||
|
wantFound: true,
|
||||||
|
wantStatus: "failed",
|
||||||
|
wantErr: "Content violates guardrails",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "有 reason_str 但非 violation kind(仍判定失败)",
|
||||||
|
body: `{"items":[{"task_id":"task-6","kind":"video","reason_str":"Some error occurred"}]}`,
|
||||||
|
taskID: "task-6",
|
||||||
|
wantFound: true,
|
||||||
|
wantStatus: "failed",
|
||||||
|
wantErr: "Some error occurred",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "空 URL 判定为失败",
|
||||||
|
body: `{"items":[{"task_id":"task-7","kind":"video","downloadable_url":"","url":""}]}`,
|
||||||
|
taskID: "task-7",
|
||||||
|
wantFound: true,
|
||||||
|
wantStatus: "failed",
|
||||||
|
wantErr: "Content violates guardrails",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "无匹配 task_id",
|
||||||
|
body: `{"items":[{"task_id":"task-other","kind":"video","downloadable_url":"https://example.com/video.mp4"}]}`,
|
||||||
|
taskID: "task-1",
|
||||||
|
wantFound: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "空 items",
|
||||||
|
body: `{"items":[]}`,
|
||||||
|
taskID: "task-1",
|
||||||
|
wantFound: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "缺少 items 字段",
|
||||||
|
body: `{"other":"data"}`,
|
||||||
|
taskID: "task-1",
|
||||||
|
wantFound: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "多个 items 精准匹配",
|
||||||
|
body: `{"items":[{"task_id":"task-a","kind":"video","downloadable_url":"https://a.com/a.mp4"},{"task_id":"task-b","kind":"sora_content_violation","reason_str":"Bad content"},{"task_id":"task-c","kind":"video","downloadable_url":"https://c.com/c.mp4"}]}`,
|
||||||
|
taskID: "task-b",
|
||||||
|
wantFound: true,
|
||||||
|
wantStatus: "failed",
|
||||||
|
wantErr: "Bad content",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
status, found := testParseGetVideoTaskDrafts([]byte(tt.body), tt.taskID)
|
||||||
|
require.Equal(t, tt.wantFound, found, "found 不匹配")
|
||||||
|
if !tt.wantFound {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NotNil(t, status)
|
||||||
|
require.Equal(t, tt.taskID, status.ID)
|
||||||
|
require.Equal(t, tt.wantStatus, status.Status)
|
||||||
|
if tt.wantErr != "" {
|
||||||
|
require.Equal(t, tt.wantErr, status.ErrorMsg)
|
||||||
|
}
|
||||||
|
if tt.wantURLs != nil {
|
||||||
|
require.Equal(t, tt.wantURLs, status.URLs)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user