fix(sora): 修复令牌刷新请求格式与流式错误转义

- 将 refresh_token 恢复请求改为表单编码并匹配 OAuth 约定
- 流式错误改为 JSON 序列化,避免消息含引号或换行导致 SSE 非法
- 补充 Sora token 恢复与 failover 流式错误透传回归测试

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
yangjianbo
2026-02-19 08:23:00 +08:00
parent 900cce20a1
commit 5d2219d299
4 changed files with 107 additions and 13 deletions

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"crypto/sha256" "crypto/sha256"
"encoding/hex" "encoding/hex"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -442,7 +443,18 @@ func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status in
if streamStarted { if streamStarted {
flusher, ok := c.Writer.(http.Flusher) flusher, ok := c.Writer.(http.Flusher)
if ok { if ok {
errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message) errorData := map[string]any{
"error": map[string]string{
"type": errType,
"message": message,
},
}
jsonBytes, err := json.Marshal(errorData)
if err != nil {
_ = c.Error(err)
return
}
errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes))
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil { if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
_ = c.Error(err) _ = c.Error(err)
} }

View File

@@ -498,3 +498,84 @@ func TestGenerateOpenAISessionHash_WithBody(t *testing.T) {
require.NotEmpty(t, hash3) require.NotEmpty(t, hash3)
require.NotEqual(t, hash, hash3) // 不同来源应产生不同 hash require.NotEqual(t, hash, hash3) // 不同来源应产生不同 hash
} }
func TestSoraHandleStreamingAwareError_JSONEscaping(t *testing.T) {
tests := []struct {
name string
errType string
message string
}{
{
name: "包含双引号",
errType: "upstream_error",
message: `upstream returned "invalid" payload`,
},
{
name: "包含换行和制表符",
errType: "rate_limit_error",
message: "line1\nline2\ttab",
},
{
name: "包含反斜杠",
errType: "upstream_error",
message: `path C:\Users\test\file.txt not found`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
h := &SoraGatewayHandler{}
h.handleStreamingAwareError(c, http.StatusBadGateway, tt.errType, tt.message, true)
body := w.Body.String()
require.True(t, strings.HasPrefix(body, "event: error\n"), "应以 SSE error 事件开头")
require.True(t, strings.HasSuffix(body, "\n\n"), "应以 SSE 结束分隔符结尾")
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
require.Len(t, lines, 2, "SSE 错误事件应包含 event 行和 data 行")
require.Equal(t, "event: error", lines[0])
require.True(t, strings.HasPrefix(lines[1], "data: "), "第二行应为 data 前缀")
jsonStr := strings.TrimPrefix(lines[1], "data: ")
var parsed map[string]any
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed), "data 行必须是合法 JSON")
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok, "JSON 中应包含 error 对象")
require.Equal(t, tt.errType, errorObj["type"])
require.Equal(t, tt.message, errorObj["message"])
})
}
}
func TestSoraHandleFailoverExhausted_StreamPassesUpstreamMessage(t *testing.T) {
gin.SetMode(gin.TestMode)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
h := &SoraGatewayHandler{}
resp := []byte(`{"error":{"message":"invalid \"prompt\"\nline2","code":"bad_request"}}`)
h.handleFailoverExhausted(c, http.StatusBadGateway, resp, true)
body := w.Body.String()
require.True(t, strings.HasPrefix(body, "event: error\n"))
require.True(t, strings.HasSuffix(body, "\n\n"))
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
require.Len(t, lines, 2)
jsonStr := strings.TrimPrefix(lines[1], "data: ")
var parsed map[string]any
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
errorObj, ok := parsed["error"].(map[string]any)
require.True(t, ok)
require.Equal(t, "upstream_error", errorObj["type"])
require.Equal(t, "invalid \"prompt\"\nline2", errorObj["message"])
}

View File

@@ -779,22 +779,17 @@ func (c *SoraDirectClient) exchangeRefreshToken(ctx context.Context, account *Ac
} }
tried[clientID] = struct{}{} tried[clientID] = struct{}{}
payload := map[string]any{ formData := url.Values{}
"client_id": clientID, formData.Set("client_id", clientID)
"grant_type": "refresh_token", formData.Set("grant_type", "refresh_token")
"refresh_token": refreshToken, formData.Set("refresh_token", refreshToken)
"redirect_uri": "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback", formData.Set("redirect_uri", "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback")
}
bodyBytes, err := json.Marshal(payload)
if err != nil {
return "", "", "", err
}
headers := http.Header{} headers := http.Header{}
headers.Set("Accept", "application/json") headers.Set("Accept", "application/json")
headers.Set("Content-Type", "application/json") headers.Set("Content-Type", "application/x-www-form-urlencoded")
headers.Set("User-Agent", c.defaultUserAgent()) headers.Set("User-Agent", c.defaultUserAgent())
respBody, _, err := c.doRequest(ctx, account, http.MethodPost, soraOAuthTokenURL, headers, bytes.NewReader(bodyBytes), false) respBody, _, err := c.doRequest(ctx, account, http.MethodPost, soraOAuthTokenURL, headers, strings.NewReader(formData.Encode()), false)
if err != nil { if err != nil {
lastErr = err lastErr = err
if c.debugEnabled() { if c.debugEnabled() {

View File

@@ -281,6 +281,12 @@ func TestSoraDirectClient_GetAccessToken_FromRefreshToken(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.Equal(t, http.MethodPost, r.Method) require.Equal(t, http.MethodPost, r.Method)
require.Equal(t, "/oauth/token", r.URL.Path) require.Equal(t, "/oauth/token", r.URL.Path)
require.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type"))
require.NoError(t, r.ParseForm())
require.Equal(t, "refresh_token", r.FormValue("grant_type"))
require.Equal(t, "refresh-token-old", r.FormValue("refresh_token"))
require.NotEmpty(t, r.FormValue("client_id"))
require.Equal(t, "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback", r.FormValue("redirect_uri"))
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(map[string]any{ _ = json.NewEncoder(w).Encode(map[string]any{
"access_token": "refresh-access-token", "access_token": "refresh-access-token",