fix(sora): 修复令牌刷新请求格式与流式错误转义
- 将 refresh_token 恢复请求改为表单编码并匹配 OAuth 约定 - 流式错误改为 JSON 序列化,避免消息含引号或换行导致 SSE 非法 - 补充 Sora token 恢复与 failover 流式错误透传回归测试 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"crypto/sha256"
|
"crypto/sha256"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -442,7 +443,18 @@ func (h *SoraGatewayHandler) handleStreamingAwareError(c *gin.Context, status in
|
|||||||
if streamStarted {
|
if streamStarted {
|
||||||
flusher, ok := c.Writer.(http.Flusher)
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
if ok {
|
if ok {
|
||||||
errorEvent := fmt.Sprintf(`event: error`+"\n"+`data: {"error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
|
errorData := map[string]any{
|
||||||
|
"error": map[string]string{
|
||||||
|
"type": errType,
|
||||||
|
"message": message,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
jsonBytes, err := json.Marshal(errorData)
|
||||||
|
if err != nil {
|
||||||
|
_ = c.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes))
|
||||||
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
||||||
_ = c.Error(err)
|
_ = c.Error(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -498,3 +498,84 @@ func TestGenerateOpenAISessionHash_WithBody(t *testing.T) {
|
|||||||
require.NotEmpty(t, hash3)
|
require.NotEmpty(t, hash3)
|
||||||
require.NotEqual(t, hash, hash3) // 不同来源应产生不同 hash
|
require.NotEqual(t, hash, hash3) // 不同来源应产生不同 hash
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSoraHandleStreamingAwareError_JSONEscaping(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
errType string
|
||||||
|
message string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "包含双引号",
|
||||||
|
errType: "upstream_error",
|
||||||
|
message: `upstream returned "invalid" payload`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "包含换行和制表符",
|
||||||
|
errType: "rate_limit_error",
|
||||||
|
message: "line1\nline2\ttab",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "包含反斜杠",
|
||||||
|
errType: "upstream_error",
|
||||||
|
message: `path C:\Users\test\file.txt not found`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
|
||||||
|
h := &SoraGatewayHandler{}
|
||||||
|
h.handleStreamingAwareError(c, http.StatusBadGateway, tt.errType, tt.message, true)
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
require.True(t, strings.HasPrefix(body, "event: error\n"), "应以 SSE error 事件开头")
|
||||||
|
require.True(t, strings.HasSuffix(body, "\n\n"), "应以 SSE 结束分隔符结尾")
|
||||||
|
|
||||||
|
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
|
||||||
|
require.Len(t, lines, 2, "SSE 错误事件应包含 event 行和 data 行")
|
||||||
|
require.Equal(t, "event: error", lines[0])
|
||||||
|
require.True(t, strings.HasPrefix(lines[1], "data: "), "第二行应为 data 前缀")
|
||||||
|
|
||||||
|
jsonStr := strings.TrimPrefix(lines[1], "data: ")
|
||||||
|
var parsed map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed), "data 行必须是合法 JSON")
|
||||||
|
|
||||||
|
errorObj, ok := parsed["error"].(map[string]any)
|
||||||
|
require.True(t, ok, "JSON 中应包含 error 对象")
|
||||||
|
require.Equal(t, tt.errType, errorObj["type"])
|
||||||
|
require.Equal(t, tt.message, errorObj["message"])
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSoraHandleFailoverExhausted_StreamPassesUpstreamMessage(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
|
||||||
|
h := &SoraGatewayHandler{}
|
||||||
|
resp := []byte(`{"error":{"message":"invalid \"prompt\"\nline2","code":"bad_request"}}`)
|
||||||
|
h.handleFailoverExhausted(c, http.StatusBadGateway, resp, true)
|
||||||
|
|
||||||
|
body := w.Body.String()
|
||||||
|
require.True(t, strings.HasPrefix(body, "event: error\n"))
|
||||||
|
require.True(t, strings.HasSuffix(body, "\n\n"))
|
||||||
|
|
||||||
|
lines := strings.Split(strings.TrimSuffix(body, "\n\n"), "\n")
|
||||||
|
require.Len(t, lines, 2)
|
||||||
|
jsonStr := strings.TrimPrefix(lines[1], "data: ")
|
||||||
|
|
||||||
|
var parsed map[string]any
|
||||||
|
require.NoError(t, json.Unmarshal([]byte(jsonStr), &parsed))
|
||||||
|
|
||||||
|
errorObj, ok := parsed["error"].(map[string]any)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, "upstream_error", errorObj["type"])
|
||||||
|
require.Equal(t, "invalid \"prompt\"\nline2", errorObj["message"])
|
||||||
|
}
|
||||||
|
|||||||
@@ -779,22 +779,17 @@ func (c *SoraDirectClient) exchangeRefreshToken(ctx context.Context, account *Ac
|
|||||||
}
|
}
|
||||||
tried[clientID] = struct{}{}
|
tried[clientID] = struct{}{}
|
||||||
|
|
||||||
payload := map[string]any{
|
formData := url.Values{}
|
||||||
"client_id": clientID,
|
formData.Set("client_id", clientID)
|
||||||
"grant_type": "refresh_token",
|
formData.Set("grant_type", "refresh_token")
|
||||||
"refresh_token": refreshToken,
|
formData.Set("refresh_token", refreshToken)
|
||||||
"redirect_uri": "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback",
|
formData.Set("redirect_uri", "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback")
|
||||||
}
|
|
||||||
bodyBytes, err := json.Marshal(payload)
|
|
||||||
if err != nil {
|
|
||||||
return "", "", "", err
|
|
||||||
}
|
|
||||||
headers := http.Header{}
|
headers := http.Header{}
|
||||||
headers.Set("Accept", "application/json")
|
headers.Set("Accept", "application/json")
|
||||||
headers.Set("Content-Type", "application/json")
|
headers.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
headers.Set("User-Agent", c.defaultUserAgent())
|
headers.Set("User-Agent", c.defaultUserAgent())
|
||||||
|
|
||||||
respBody, _, err := c.doRequest(ctx, account, http.MethodPost, soraOAuthTokenURL, headers, bytes.NewReader(bodyBytes), false)
|
respBody, _, err := c.doRequest(ctx, account, http.MethodPost, soraOAuthTokenURL, headers, strings.NewReader(formData.Encode()), false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
lastErr = err
|
lastErr = err
|
||||||
if c.debugEnabled() {
|
if c.debugEnabled() {
|
||||||
|
|||||||
@@ -281,6 +281,12 @@ func TestSoraDirectClient_GetAccessToken_FromRefreshToken(t *testing.T) {
|
|||||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
require.Equal(t, http.MethodPost, r.Method)
|
require.Equal(t, http.MethodPost, r.Method)
|
||||||
require.Equal(t, "/oauth/token", r.URL.Path)
|
require.Equal(t, "/oauth/token", r.URL.Path)
|
||||||
|
require.Equal(t, "application/x-www-form-urlencoded", r.Header.Get("Content-Type"))
|
||||||
|
require.NoError(t, r.ParseForm())
|
||||||
|
require.Equal(t, "refresh_token", r.FormValue("grant_type"))
|
||||||
|
require.Equal(t, "refresh-token-old", r.FormValue("refresh_token"))
|
||||||
|
require.NotEmpty(t, r.FormValue("client_id"))
|
||||||
|
require.Equal(t, "com.openai.chat://auth0.openai.com/ios/com.openai.chat/callback", r.FormValue("redirect_uri"))
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
_ = json.NewEncoder(w).Encode(map[string]any{
|
_ = json.NewEncoder(w).Encode(map[string]any{
|
||||||
"access_token": "refresh-access-token",
|
"access_token": "refresh-access-token",
|
||||||
|
|||||||
Reference in New Issue
Block a user