diff --git a/backend/internal/handler/sora_gateway_handler.go b/backend/internal/handler/sora_gateway_handler.go index 9c9f53b1..3a5ddcb0 100644 --- a/backend/internal/handler/sora_gateway_handler.go +++ b/backend/internal/handler/sora_gateway_handler.go @@ -4,6 +4,7 @@ import ( "context" "crypto/sha256" "encoding/hex" + "encoding/json" "errors" "fmt" "io" @@ -442,7 +443,18 @@ func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status in if streamStarted { flusher, ok := c.Writer.(http.Flusher) if ok { - errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message) + errorData := map[string]any{ + "error": map[string]string{ + "type": errType, + "message": message, + }, + } + jsonBytes, err := json.Marshal(errorData) + if err != nil { + _ = c.Error(err) + return + } + errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes)) if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil { _ = c.Error(err) } diff --git a/backend/internal/handler/sora_gateway_handler_test.go b/backend/internal/handler/sora_gateway_handler_test.go index 39e2eed6..d80b959c 100644 --- a/backend/internal/handler/sora_gateway_handler_test.go +++ b/backend/internal/handler/sora_gateway_handler_test.go @@ -498,3 +498,84 @@ func TestGenerateOpenAISessionHash_WithBody(t *testing.T) { require.NotEmpty(t, hash3) require.NotEqual(t, hash, hash3) // 不同来源应产生不同 hash } + +func TestSoraHandleStreamingAwareError_JSONEscaping(t *testing.T) { + tests := []struct { + name string + errType string + message string + }{ + { + name: "包含双引号", + errType: "upstream_error", + message: `upstream returned "invalid" payload`, + }, + { + name: "包含换行和制表符", + errType: "rate_limit_error", + message: "line1\nline2\ttab", + }, + { + name: "包含反斜杠", + errType: "upstream_error", + message: `path C:\Users\test\file.txt not found`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + h := &SoraGatewayHandler{} + h.handleStreamingAwareError(c, http.StatusBadGateway, tt.errType, tt.message, true) + + body := w.Body.String() + require.True(t, strings.HasPrefix(body, "event: error\n"), "应以 SSE error 事件开头") + require.True(t, strings.HasSuffix(body, "\n\n"), "应以 SSE 结束分隔符结尾") + + lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n") + require.Len(t, lines, 2, "SSE 错误事件应包含 event 行和 data 行") + require.Equal(t, "event: error", lines[0]) + require.True(t, strings.HasPrefix(lines[1], "data: "), "第二行应为 data 前缀") + + jsonStr := strings.TrimPrefix(lines[1], "data: ") + var parsed map[string]any + require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed), "data 行必须是合法 JSON") + + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok, "JSON 中应包含 error 对象") + require.Equal(t, tt.errType, errorObj["type"]) + require.Equal(t, tt.message, errorObj["message"]) + }) + } +} + +func TestSoraHandleFailoverExhausted_StreamPassesUpstreamMessage(t *testing.T) { + gin.SetMode(gin.TestMode) + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Request = httptest.NewRequest(http.MethodGet, "/", nil) + + h := &SoraGatewayHandler{} + resp := []byte(`{"error":{"message":"invalid \"prompt\"\nline2","code":"bad_request"}}`) + h.handleFailoverExhausted(c, http.StatusBadGateway, resp, true) + + body := w.Body.String() + require.True(t, strings.HasPrefix(body, "event: error\n")) + require.True(t, strings.HasSuffix(body, "\n\n")) + + lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n") + require.Len(t, lines, 2) + jsonStr := strings.TrimPrefix(lines[1], "data: ") + + var parsed map[string]any + require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed)) + + errorObj, ok := parsed["error"].(map[string]any) + require.True(t, ok) + require.Equal(t, "upstream_error", errorObj["type"]) + require.Equal(t, "invalid \"prompt\"\nline2", errorObj["message"]) +} diff --git a/backend/internal/service/sora_client.go b/backend/internal/service/sora_client.go index 38be7a04..38c1b3cc 100644 --- a/backend/internal/service/sora_client.go +++ b/backend/internal/service/sora_client.go @@ -779,22 +779,17 @@ func (c *SoraDirectClient) exchangeRefreshToken(ctx context.Context, account *Ac } tried[clientID] = struct{}{} - payload := map[string]any{ - "client_id": clientID, - "grant_type": "refresh_token", - "refresh_token": refreshToken, - "redirect_uri": "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback", - } - bodyBytes, err := json.Marshal(payload) - if err != nil { - return "", "", "", err - } + formData := url.Values{} + formData.Set("client_id", clientID) + formData.Set("grant_type", "refresh_token") + formData.Set("refresh_token", refreshToken) + formData.Set("redirect_uri", "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback") headers := http.Header{} headers.Set("Accept", "application/json") - headers.Set("Content-Type", "application/json") + headers.Set("Content-Type", "application/x-www-form-urlencoded") headers.Set("User-Agent", c.defaultUserAgent()) - respBody, _, err := c.doRequest(ctx, account, http.MethodPost, soraOAuthTokenURL, headers, bytes.NewReader(bodyBytes), false) + respBody, _, err := c.doRequest(ctx, account, http.MethodPost, soraOAuthTokenURL, headers, strings.NewReader(formData.Encode()), false) if err != nil { lastErr = err if c.debugEnabled() { diff --git a/backend/internal/service/sora_client_test.go b/backend/internal/service/sora_client_test.go index 3e88c9f9..e566f06b 100644 --- a/backend/internal/service/sora_client_test.go +++ b/backend/internal/service/sora_client_test.go @@ -281,6 +281,12 @@ func TestSoraDirectClient_GetAccessToken_FromRefreshToken(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { require.Equal(t, http.MethodPost, r.Method) require.Equal(t, "/oauth/token", r.URL.Path) + require.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type")) + require.NoError(t, r.ParseForm()) + require.Equal(t, "refresh_token", r.FormValue("grant_type")) + require.Equal(t, "refresh-token-old", r.FormValue("refresh_token")) + require.NotEmpty(t, r.FormValue("client_id")) + require.Equal(t, "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback", r.FormValue("redirect_uri")) w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(map[string]any{ "access_token": "refresh-access-token",