Merge pull request #1391 from Zqysl/qingyu/fix-openai-passthrough-failover-429-529
fix(openai): fail over passthrough 429 and 529
This commit is contained in:
@@ -2544,7 +2544,11 @@ func (s *OpenAIGatewayService) forwardOpenAIPassthrough(
|
|||||||
defer func() { _ = resp.Body.Close() }()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
// 透传模式不做 failover(避免改变原始上游语义),按上游原样返回错误响应。
|
// 透传模式默认保持原样代理;但 429/529 属于网关必须兜底的
|
||||||
|
// 上游容量类错误,应先触发多账号 failover 以维持基础 SLA。
|
||||||
|
if shouldFailoverOpenAIPassthroughResponse(resp.StatusCode) {
|
||||||
|
return nil, s.handleFailoverErrorResponsePassthrough(ctx, resp, c, account, body)
|
||||||
|
}
|
||||||
return nil, s.handleErrorResponsePassthrough(ctx, resp, c, account, body)
|
return nil, s.handleErrorResponsePassthrough(ctx, resp, c, account, body)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2727,6 +2731,58 @@ func (s *OpenAIGatewayService) buildUpstreamRequestOpenAIPassthrough(
|
|||||||
return req, nil
|
return req, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shouldFailoverOpenAIPassthroughResponse(statusCode int) bool {
|
||||||
|
switch statusCode {
|
||||||
|
case http.StatusTooManyRequests, 529:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *OpenAIGatewayService) handleFailoverErrorResponsePassthrough(
|
||||||
|
ctx context.Context,
|
||||||
|
resp *http.Response,
|
||||||
|
c *gin.Context,
|
||||||
|
account *Account,
|
||||||
|
requestBody []byte,
|
||||||
|
) error {
|
||||||
|
body, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
|
|
||||||
|
upstreamMsg := strings.TrimSpace(extractUpstreamErrorMessage(body))
|
||||||
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
|
upstreamDetail := ""
|
||||||
|
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
|
||||||
|
maxBytes := s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes
|
||||||
|
if maxBytes <= 0 {
|
||||||
|
maxBytes = 2048
|
||||||
|
}
|
||||||
|
upstreamDetail = truncateString(string(body), maxBytes)
|
||||||
|
}
|
||||||
|
setOpsUpstreamError(c, resp.StatusCode, upstreamMsg, upstreamDetail)
|
||||||
|
logOpenAIInstructionsRequiredDebug(ctx, c, account, resp.StatusCode, upstreamMsg, requestBody, body)
|
||||||
|
if s.rateLimitService != nil {
|
||||||
|
_ = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||||
|
}
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Platform: account.Platform,
|
||||||
|
AccountID: account.ID,
|
||||||
|
AccountName: account.Name,
|
||||||
|
UpstreamStatusCode: resp.StatusCode,
|
||||||
|
UpstreamRequestID: resp.Header.Get("x-request-id"),
|
||||||
|
Passthrough: true,
|
||||||
|
Kind: "failover",
|
||||||
|
Message: upstreamMsg,
|
||||||
|
Detail: upstreamDetail,
|
||||||
|
UpstreamResponseBody: upstreamDetail,
|
||||||
|
})
|
||||||
|
return &UpstreamFailoverError{
|
||||||
|
StatusCode: resp.StatusCode,
|
||||||
|
ResponseBody: body,
|
||||||
|
ResponseHeaders: resp.Header.Clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) handleErrorResponsePassthrough(
|
func (s *OpenAIGatewayService) handleErrorResponsePassthrough(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
resp *http.Response,
|
resp *http.Response,
|
||||||
|
|||||||
@@ -48,6 +48,22 @@ func (u *httpUpstreamRecorder) DoWithTLS(req *http.Request, proxyURL string, acc
|
|||||||
return u.Do(req, proxyURL, accountID, accountConcurrency)
|
return u.Do(req, proxyURL, accountID, accountConcurrency)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type openAIPassthroughFailoverRepo struct {
|
||||||
|
stubOpenAIAccountRepo
|
||||||
|
rateLimitCalls []time.Time
|
||||||
|
overloadCalls []time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *openAIPassthroughFailoverRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error {
|
||||||
|
r.rateLimitCalls = append(r.rateLimitCalls, resetAt)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *openAIPassthroughFailoverRepo) SetOverloaded(_ context.Context, _ int64, until time.Time) error {
|
||||||
|
r.overloadCalls = append(r.overloadCalls, until)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
var structuredLogCaptureMu sync.Mutex
|
var structuredLogCaptureMu sync.Mutex
|
||||||
|
|
||||||
type inMemoryLogSink struct {
|
type inMemoryLogSink struct {
|
||||||
@@ -527,6 +543,8 @@ func TestOpenAIGatewayService_OAuthPassthrough_UpstreamErrorIncludesPassthroughF
|
|||||||
|
|
||||||
_, err := svc.Forward(context.Background(), c, account, originalBody)
|
_, err := svc.Forward(context.Background(), c, account, originalBody)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
|
require.True(t, c.Writer.Written(), "非 429/529 的 passthrough 错误应继续原样写回客户端")
|
||||||
|
require.Equal(t, http.StatusBadRequest, rec.Code)
|
||||||
|
|
||||||
// should append an upstream error event with passthrough=true
|
// should append an upstream error event with passthrough=true
|
||||||
v, ok := c.Get(OpsUpstreamErrorsKey)
|
v, ok := c.Get(OpsUpstreamErrorsKey)
|
||||||
@@ -535,29 +553,116 @@ func TestOpenAIGatewayService_OAuthPassthrough_UpstreamErrorIncludesPassthroughF
|
|||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
require.NotEmpty(t, arr)
|
require.NotEmpty(t, arr)
|
||||||
require.True(t, arr[len(arr)-1].Passthrough)
|
require.True(t, arr[len(arr)-1].Passthrough)
|
||||||
|
require.Equal(t, "http_error", arr[len(arr)-1].Kind)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayService_OAuthPassthrough_429PersistsRateLimit(t *testing.T) {
|
func TestOpenAIGatewayService_OpenAIPassthrough_429And529TriggerFailover(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
originalBody := []byte(`{"model":"gpt-5.2","stream":false,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`)
|
||||||
|
|
||||||
|
newAccount := func(accountType string) *Account {
|
||||||
|
account := &Account{
|
||||||
|
ID: 123,
|
||||||
|
Name: "acc",
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: accountType,
|
||||||
|
Concurrency: 1,
|
||||||
|
Extra: map[string]any{"openai_passthrough": true},
|
||||||
|
Status: StatusActive,
|
||||||
|
Schedulable: true,
|
||||||
|
RateMultiplier: f64p(1),
|
||||||
|
}
|
||||||
|
switch accountType {
|
||||||
|
case AccountTypeOAuth:
|
||||||
|
account.Credentials = map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"}
|
||||||
|
case AccountTypeAPIKey:
|
||||||
|
account.Credentials = map[string]any{"api_key": "sk-test"}
|
||||||
|
}
|
||||||
|
return account
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
accountType string
|
||||||
|
statusCode int
|
||||||
|
body string
|
||||||
|
assertRepo func(t *testing.T, repo *openAIPassthroughFailoverRepo, start time.Time)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "oauth_429_rate_limit",
|
||||||
|
accountType: AccountTypeOAuth,
|
||||||
|
statusCode: http.StatusTooManyRequests,
|
||||||
|
body: func() string {
|
||||||
|
resetAt := time.Now().Add(7 * 24 * time.Hour).Unix()
|
||||||
|
return fmt.Sprintf(`{"error":{"message":"The usage limit has been reached","type":"usage_limit_reached","resets_at":%d}}`, resetAt)
|
||||||
|
}(),
|
||||||
|
assertRepo: func(t *testing.T, repo *openAIPassthroughFailoverRepo, _ time.Time) {
|
||||||
|
require.Len(t, repo.rateLimitCalls, 1)
|
||||||
|
require.Empty(t, repo.overloadCalls)
|
||||||
|
require.True(t, time.Until(repo.rateLimitCalls[0]) > 24*time.Hour)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "oauth_529_overload",
|
||||||
|
accountType: AccountTypeOAuth,
|
||||||
|
statusCode: 529,
|
||||||
|
body: `{"error":{"message":"server overloaded","type":"server_error"}}`,
|
||||||
|
assertRepo: func(t *testing.T, repo *openAIPassthroughFailoverRepo, start time.Time) {
|
||||||
|
require.Empty(t, repo.rateLimitCalls)
|
||||||
|
require.Len(t, repo.overloadCalls, 1)
|
||||||
|
require.WithinDuration(t, start.Add(10*time.Minute), repo.overloadCalls[0], 5*time.Second)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "apikey_429_rate_limit",
|
||||||
|
accountType: AccountTypeAPIKey,
|
||||||
|
statusCode: http.StatusTooManyRequests,
|
||||||
|
body: func() string {
|
||||||
|
resetAt := time.Now().Add(7 * 24 * time.Hour).Unix()
|
||||||
|
return fmt.Sprintf(`{"error":{"message":"The usage limit has been reached","type":"usage_limit_reached","resets_at":%d}}`, resetAt)
|
||||||
|
}(),
|
||||||
|
assertRepo: func(t *testing.T, repo *openAIPassthroughFailoverRepo, _ time.Time) {
|
||||||
|
require.Len(t, repo.rateLimitCalls, 1)
|
||||||
|
require.Empty(t, repo.overloadCalls)
|
||||||
|
require.True(t, time.Until(repo.rateLimitCalls[0]) > 24*time.Hour)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "apikey_529_overload",
|
||||||
|
accountType: AccountTypeAPIKey,
|
||||||
|
statusCode: 529,
|
||||||
|
body: `{"error":{"message":"server overloaded","type":"server_error"}}`,
|
||||||
|
assertRepo: func(t *testing.T, repo *openAIPassthroughFailoverRepo, start time.Time) {
|
||||||
|
require.Empty(t, repo.rateLimitCalls)
|
||||||
|
require.Len(t, repo.overloadCalls, 1)
|
||||||
|
require.WithinDuration(t, start.Add(10*time.Minute), repo.overloadCalls[0], 5*time.Second)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
rec := httptest.NewRecorder()
|
rec := httptest.NewRecorder()
|
||||||
c, _ := gin.CreateTestContext(rec)
|
c, _ := gin.CreateTestContext(rec)
|
||||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", bytes.NewReader(nil))
|
||||||
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
|
c.Request.Header.Set("User-Agent", "codex_cli_rs/0.1.0")
|
||||||
|
|
||||||
originalBody := []byte(`{"model":"gpt-5.2","stream":false,"instructions":"local-test-instructions","input":[{"type":"text","text":"hi"}]}`)
|
|
||||||
resetAt := time.Now().Add(7 * 24 * time.Hour).Unix()
|
|
||||||
resp := &http.Response{
|
resp := &http.Response{
|
||||||
StatusCode: http.StatusTooManyRequests,
|
StatusCode: tc.statusCode,
|
||||||
Header: http.Header{
|
Header: http.Header{
|
||||||
"Content-Type": []string{"application/json"},
|
"Content-Type": []string{"application/json"},
|
||||||
"x-request-id": []string{"rid-rate-limit"},
|
"x-request-id": []string{"rid-failover"},
|
||||||
},
|
},
|
||||||
Body: io.NopCloser(strings.NewReader(fmt.Sprintf(`{"error":{"message":"The usage limit has been reached","type":"usage_limit_reached","resets_at":%d}}`, resetAt))),
|
Body: io.NopCloser(strings.NewReader(tc.body)),
|
||||||
}
|
}
|
||||||
upstream := &httpUpstreamRecorder{resp: resp}
|
upstream := &httpUpstreamRecorder{resp: resp}
|
||||||
repo := &openAIWSRateLimitSignalRepo{}
|
repo := &openAIPassthroughFailoverRepo{}
|
||||||
rateSvc := &RateLimitService{accountRepo: repo}
|
rateSvc := &RateLimitService{
|
||||||
|
accountRepo: repo,
|
||||||
|
cfg: &config.Config{
|
||||||
|
RateLimit: config.RateLimitConfig{OverloadCooldownMinutes: 10},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
svc := &OpenAIGatewayService{
|
svc := &OpenAIGatewayService{
|
||||||
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
|
cfg: &config.Config{Gateway: config.GatewayConfig{ForceCodexCLI: false}},
|
||||||
@@ -565,25 +670,28 @@ func TestOpenAIGatewayService_OAuthPassthrough_429PersistsRateLimit(t *testing.T
|
|||||||
rateLimitService: rateSvc,
|
rateLimitService: rateSvc,
|
||||||
}
|
}
|
||||||
|
|
||||||
account := &Account{
|
account := newAccount(tc.accountType)
|
||||||
ID: 123,
|
start := time.Now()
|
||||||
Name: "acc",
|
|
||||||
Platform: PlatformOpenAI,
|
|
||||||
Type: AccountTypeOAuth,
|
|
||||||
Concurrency: 1,
|
|
||||||
Credentials: map[string]any{"access_token": "oauth-token", "chatgpt_account_id": "chatgpt-acc"},
|
|
||||||
Extra: map[string]any{"openai_passthrough": true},
|
|
||||||
Status: StatusActive,
|
|
||||||
Schedulable: true,
|
|
||||||
RateMultiplier: f64p(1),
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err := svc.Forward(context.Background(), c, account, originalBody)
|
_, err := svc.Forward(context.Background(), c, account, originalBody)
|
||||||
require.Error(t, err)
|
require.Error(t, err)
|
||||||
require.Equal(t, http.StatusTooManyRequests, rec.Code)
|
|
||||||
require.Contains(t, rec.Body.String(), "usage_limit_reached")
|
var failoverErr *UpstreamFailoverError
|
||||||
require.Len(t, repo.rateLimitCalls, 1)
|
require.ErrorAs(t, err, &failoverErr)
|
||||||
require.WithinDuration(t, time.Unix(resetAt, 0), repo.rateLimitCalls[0], 2*time.Second)
|
require.Equal(t, tc.statusCode, failoverErr.StatusCode)
|
||||||
|
require.False(t, c.Writer.Written(), "429/529 passthrough 应返回 failover 错误给上层换号,而不是直接向客户端写响应")
|
||||||
|
|
||||||
|
v, ok := c.Get(OpsUpstreamErrorsKey)
|
||||||
|
require.True(t, ok)
|
||||||
|
arr, ok := v.([]*OpsUpstreamErrorEvent)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.NotEmpty(t, arr)
|
||||||
|
require.True(t, arr[len(arr)-1].Passthrough)
|
||||||
|
require.Equal(t, "failover", arr[len(arr)-1].Kind)
|
||||||
|
require.Equal(t, tc.statusCode, arr[len(arr)-1].UpstreamStatusCode)
|
||||||
|
|
||||||
|
tc.assertRepo(t, repo, start)
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *testing.T) {
|
func TestOpenAIGatewayService_OAuthPassthrough_NonCodexUAFallbackToCodexUA(t *testing.T) {
|
||||||
|
|||||||
Reference in New Issue
Block a user