fix: 修复 OpenAI WS 限流状态与调度同步

This commit is contained in:
神乐
2026-03-07 23:59:39 +08:00
parent 0c1dcad429
commit 45d57018eb
3 changed files with 471 additions and 7 deletions

View File

@@ -1853,6 +1853,10 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
wsPath,
account.ProxyID != nil && account.Proxy != nil,
)
var dialErr *openAIWSDialError
if errors.As(err, &dialErr) && dialErr != nil && dialErr.StatusCode == http.StatusTooManyRequests {
s.persistOpenAIWSRateLimitSignal(ctx, account, dialErr.ResponseHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(err.Error()))
}
return nil, wrapOpenAIWSFallback(classifyOpenAIWSAcquireError(err), err)
}
defer lease.Release()
@@ -2136,6 +2140,7 @@ func (s *OpenAIGatewayService) forwardOpenAIWSV2(
if eventType == "error" {
errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)
s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), message, errCodeRaw, errTypeRaw, errMsgRaw)
errMsg := strings.TrimSpace(errMsgRaw)
if errMsg == "" {
errMsg = "Upstream websocket error"
@@ -2639,6 +2644,10 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
wsPath,
account.ProxyID != nil && account.Proxy != nil,
)
var dialErr *openAIWSDialError
if errors.As(acquireErr, &dialErr) && dialErr != nil && dialErr.StatusCode == http.StatusTooManyRequests {
s.persistOpenAIWSRateLimitSignal(ctx, account, dialErr.ResponseHeaders, nil, "rate_limit_exceeded", "rate_limit_error", strings.TrimSpace(acquireErr.Error()))
}
if errors.Is(acquireErr, errOpenAIWSPreferredConnUnavailable) {
return nil, NewOpenAIWSClientCloseError(
coderws.StatusPolicyViolation,
@@ -2777,6 +2786,7 @@ func (s *OpenAIGatewayService) ProxyResponsesWebSocketFromClient(
}
if eventType == "error" {
errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(upstreamMessage)
s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), upstreamMessage, errCodeRaw, errTypeRaw, errMsgRaw)
fallbackReason, _ := classifyOpenAIWSErrorEventFromRaw(errCodeRaw, errTypeRaw, errMsgRaw)
errCode, errType, errMessage := summarizeOpenAIWSErrorEventFieldsFromRaw(errCodeRaw, errTypeRaw, errMsgRaw)
recoverablePrevNotFound := fallbackReason == openAIWSIngressStagePreviousResponseNotFound &&
@@ -3604,6 +3614,7 @@ func (s *OpenAIGatewayService) performOpenAIWSGeneratePrewarm(
if eventType == "error" {
errCodeRaw, errTypeRaw, errMsgRaw := parseOpenAIWSErrorEventFields(message)
s.persistOpenAIWSRateLimitSignal(ctx, account, lease.HandshakeHeaders(), message, errCodeRaw, errTypeRaw, errMsgRaw)
errMsg := strings.TrimSpace(errMsgRaw)
if errMsg == "" {
errMsg = "OpenAI websocket prewarm error"
@@ -3867,6 +3878,36 @@ func classifyOpenAIWSAcquireError(err error) string {
return "acquire_conn"
}
func isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw string) bool {
code := strings.ToLower(strings.TrimSpace(codeRaw))
errType := strings.ToLower(strings.TrimSpace(errTypeRaw))
msg := strings.ToLower(strings.TrimSpace(msgRaw))
if strings.Contains(errType, "rate_limit") || strings.Contains(errType, "usage_limit") {
return true
}
if strings.Contains(code, "rate_limit") || strings.Contains(code, "usage_limit") || strings.Contains(code, "insufficient_quota") {
return true
}
if strings.Contains(msg, "usage limit") && strings.Contains(msg, "reached") {
return true
}
if strings.Contains(msg, "rate limit") && (strings.Contains(msg, "reached") || strings.Contains(msg, "exceeded")) {
return true
}
return false
}
func (s *OpenAIGatewayService) persistOpenAIWSRateLimitSignal(ctx context.Context, account *Account, headers http.Header, responseBody []byte, codeRaw, errTypeRaw, msgRaw string) {
if s == nil || s.rateLimitService == nil || account == nil || account.Platform != PlatformOpenAI {
return
}
if !isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) {
return
}
s.rateLimitService.HandleUpstreamError(ctx, account, http.StatusTooManyRequests, headers, responseBody)
}
func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (string, bool) {
code := strings.ToLower(strings.TrimSpace(codeRaw))
errType := strings.ToLower(strings.TrimSpace(errTypeRaw))
@@ -3882,6 +3923,9 @@ func classifyOpenAIWSErrorEventFromRaw(codeRaw, errTypeRaw, msgRaw string) (stri
case "previous_response_not_found":
return "previous_response_not_found", true
}
if isOpenAIWSRateLimitError(codeRaw, errTypeRaw, msgRaw) {
return "upstream_rate_limited", false
}
if strings.Contains(msg, "upgrade required") || strings.Contains(msg, "status 426") {
return "upgrade_required", true
}
@@ -3927,9 +3971,7 @@ func openAIWSErrorHTTPStatusFromRaw(codeRaw, errTypeRaw string) int {
case strings.Contains(errType, "permission"),
strings.Contains(code, "forbidden"):
return http.StatusForbidden
case strings.Contains(errType, "rate_limit"),
strings.Contains(code, "rate_limit"),
strings.Contains(code, "insufficient_quota"):
case isOpenAIWSRateLimitError(codeRaw, errTypeRaw, ""):
return http.StatusTooManyRequests
default:
return http.StatusBadGateway