feat(sync): full code sync from release
This commit is contained in:
251
backend/internal/service/openai_ws_fallback_test.go
Normal file
251
backend/internal/service/openai_ws_fallback_test.go
Normal file
@@ -0,0 +1,251 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
coderws "github.com/coder/websocket"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClassifyOpenAIWSAcquireError(t *testing.T) {
|
||||
t.Run("dial_426_upgrade_required", func(t *testing.T) {
|
||||
err := &openAIWSDialError{StatusCode: 426, Err: errors.New("upgrade required")}
|
||||
require.Equal(t, "upgrade_required", classifyOpenAIWSAcquireError(err))
|
||||
})
|
||||
|
||||
t.Run("queue_full", func(t *testing.T) {
|
||||
require.Equal(t, "conn_queue_full", classifyOpenAIWSAcquireError(errOpenAIWSConnQueueFull))
|
||||
})
|
||||
|
||||
t.Run("preferred_conn_unavailable", func(t *testing.T) {
|
||||
require.Equal(t, "preferred_conn_unavailable", classifyOpenAIWSAcquireError(errOpenAIWSPreferredConnUnavailable))
|
||||
})
|
||||
|
||||
t.Run("acquire_timeout", func(t *testing.T) {
|
||||
require.Equal(t, "acquire_timeout", classifyOpenAIWSAcquireError(context.DeadlineExceeded))
|
||||
})
|
||||
|
||||
t.Run("auth_failed_401", func(t *testing.T) {
|
||||
err := &openAIWSDialError{StatusCode: 401, Err: errors.New("unauthorized")}
|
||||
require.Equal(t, "auth_failed", classifyOpenAIWSAcquireError(err))
|
||||
})
|
||||
|
||||
t.Run("upstream_rate_limited", func(t *testing.T) {
|
||||
err := &openAIWSDialError{StatusCode: 429, Err: errors.New("rate limited")}
|
||||
require.Equal(t, "upstream_rate_limited", classifyOpenAIWSAcquireError(err))
|
||||
})
|
||||
|
||||
t.Run("upstream_5xx", func(t *testing.T) {
|
||||
err := &openAIWSDialError{StatusCode: 502, Err: errors.New("bad gateway")}
|
||||
require.Equal(t, "upstream_5xx", classifyOpenAIWSAcquireError(err))
|
||||
})
|
||||
|
||||
t.Run("dial_failed_other_status", func(t *testing.T) {
|
||||
err := &openAIWSDialError{StatusCode: 418, Err: errors.New("teapot")}
|
||||
require.Equal(t, "dial_failed", classifyOpenAIWSAcquireError(err))
|
||||
})
|
||||
|
||||
t.Run("other", func(t *testing.T) {
|
||||
require.Equal(t, "acquire_conn", classifyOpenAIWSAcquireError(errors.New("x")))
|
||||
})
|
||||
|
||||
t.Run("nil", func(t *testing.T) {
|
||||
require.Equal(t, "acquire_conn", classifyOpenAIWSAcquireError(nil))
|
||||
})
|
||||
}
|
||||
|
||||
func TestClassifyOpenAIWSDialError(t *testing.T) {
|
||||
t.Run("handshake_not_finished", func(t *testing.T) {
|
||||
err := &openAIWSDialError{
|
||||
StatusCode: http.StatusBadGateway,
|
||||
Err: errors.New("WebSocket protocol error: Handshake not finished"),
|
||||
}
|
||||
require.Equal(t, "handshake_not_finished", classifyOpenAIWSDialError(err))
|
||||
})
|
||||
|
||||
t.Run("context_deadline", func(t *testing.T) {
|
||||
err := &openAIWSDialError{
|
||||
StatusCode: 0,
|
||||
Err: context.DeadlineExceeded,
|
||||
}
|
||||
require.Equal(t, "ctx_deadline_exceeded", classifyOpenAIWSDialError(err))
|
||||
})
|
||||
}
|
||||
|
||||
func TestSummarizeOpenAIWSDialError(t *testing.T) {
|
||||
err := &openAIWSDialError{
|
||||
StatusCode: http.StatusBadGateway,
|
||||
ResponseHeaders: http.Header{
|
||||
"Server": []string{"cloudflare"},
|
||||
"Via": []string{"1.1 example"},
|
||||
"Cf-Ray": []string{"abcd1234"},
|
||||
"X-Request-Id": []string{"req_123"},
|
||||
},
|
||||
Err: errors.New("WebSocket protocol error: Handshake not finished"),
|
||||
}
|
||||
|
||||
status, class, closeStatus, closeReason, server, via, cfRay, reqID := summarizeOpenAIWSDialError(err)
|
||||
require.Equal(t, http.StatusBadGateway, status)
|
||||
require.Equal(t, "handshake_not_finished", class)
|
||||
require.Equal(t, "-", closeStatus)
|
||||
require.Equal(t, "-", closeReason)
|
||||
require.Equal(t, "cloudflare", server)
|
||||
require.Equal(t, "1.1 example", via)
|
||||
require.Equal(t, "abcd1234", cfRay)
|
||||
require.Equal(t, "req_123", reqID)
|
||||
}
|
||||
|
||||
func TestClassifyOpenAIWSErrorEvent(t *testing.T) {
|
||||
reason, recoverable := classifyOpenAIWSErrorEvent([]byte(`{"type":"error","error":{"code":"upgrade_required","message":"Upgrade required"}}`))
|
||||
require.Equal(t, "upgrade_required", reason)
|
||||
require.True(t, recoverable)
|
||||
|
||||
reason, recoverable = classifyOpenAIWSErrorEvent([]byte(`{"type":"error","error":{"code":"previous_response_not_found","message":"not found"}}`))
|
||||
require.Equal(t, "previous_response_not_found", reason)
|
||||
require.True(t, recoverable)
|
||||
}
|
||||
|
||||
func TestClassifyOpenAIWSReconnectReason(t *testing.T) {
|
||||
reason, retryable := classifyOpenAIWSReconnectReason(wrapOpenAIWSFallback("policy_violation", errors.New("policy")))
|
||||
require.Equal(t, "policy_violation", reason)
|
||||
require.False(t, retryable)
|
||||
|
||||
reason, retryable = classifyOpenAIWSReconnectReason(wrapOpenAIWSFallback("read_event", errors.New("io")))
|
||||
require.Equal(t, "read_event", reason)
|
||||
require.True(t, retryable)
|
||||
}
|
||||
|
||||
func TestOpenAIWSErrorHTTPStatus(t *testing.T) {
|
||||
require.Equal(t, http.StatusBadRequest, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"invalid_request_error","code":"invalid_request","message":"invalid input"}}`)))
|
||||
require.Equal(t, http.StatusUnauthorized, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"authentication_error","code":"invalid_api_key","message":"auth failed"}}`)))
|
||||
require.Equal(t, http.StatusForbidden, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"permission_error","code":"forbidden","message":"forbidden"}}`)))
|
||||
require.Equal(t, http.StatusTooManyRequests, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"rate_limit_error","code":"rate_limit_exceeded","message":"rate limited"}}`)))
|
||||
require.Equal(t, http.StatusBadGateway, openAIWSErrorHTTPStatus([]byte(`{"type":"error","error":{"type":"server_error","code":"server_error","message":"server"}}`)))
|
||||
}
|
||||
|
||||
func TestResolveOpenAIWSFallbackErrorResponse(t *testing.T) {
|
||||
t.Run("previous_response_not_found", func(t *testing.T) {
|
||||
statusCode, errType, clientMessage, upstreamMessage, ok := resolveOpenAIWSFallbackErrorResponse(
|
||||
wrapOpenAIWSFallback("previous_response_not_found", errors.New("previous response not found")),
|
||||
)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, http.StatusBadRequest, statusCode)
|
||||
require.Equal(t, "invalid_request_error", errType)
|
||||
require.Equal(t, "previous response not found", clientMessage)
|
||||
require.Equal(t, "previous response not found", upstreamMessage)
|
||||
})
|
||||
|
||||
t.Run("auth_failed_uses_dial_status", func(t *testing.T) {
|
||||
statusCode, errType, clientMessage, upstreamMessage, ok := resolveOpenAIWSFallbackErrorResponse(
|
||||
wrapOpenAIWSFallback("auth_failed", &openAIWSDialError{
|
||||
StatusCode: http.StatusForbidden,
|
||||
Err: errors.New("forbidden"),
|
||||
}),
|
||||
)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, http.StatusForbidden, statusCode)
|
||||
require.Equal(t, "upstream_error", errType)
|
||||
require.Equal(t, "forbidden", clientMessage)
|
||||
require.Equal(t, "forbidden", upstreamMessage)
|
||||
})
|
||||
|
||||
t.Run("non_fallback_error_not_resolved", func(t *testing.T) {
|
||||
_, _, _, _, ok := resolveOpenAIWSFallbackErrorResponse(errors.New("plain error"))
|
||||
require.False(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
func TestOpenAIWSFallbackCooling(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{cfg: &config.Config{}}
|
||||
svc.cfg.Gateway.OpenAIWS.FallbackCooldownSeconds = 1
|
||||
|
||||
require.False(t, svc.isOpenAIWSFallbackCooling(1))
|
||||
svc.markOpenAIWSFallbackCooling(1, "upgrade_required")
|
||||
require.True(t, svc.isOpenAIWSFallbackCooling(1))
|
||||
|
||||
svc.clearOpenAIWSFallbackCooling(1)
|
||||
require.False(t, svc.isOpenAIWSFallbackCooling(1))
|
||||
|
||||
svc.markOpenAIWSFallbackCooling(2, "x")
|
||||
time.Sleep(1200 * time.Millisecond)
|
||||
require.False(t, svc.isOpenAIWSFallbackCooling(2))
|
||||
}
|
||||
|
||||
func TestOpenAIWSRetryBackoff(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{cfg: &config.Config{}}
|
||||
svc.cfg.Gateway.OpenAIWS.RetryBackoffInitialMS = 100
|
||||
svc.cfg.Gateway.OpenAIWS.RetryBackoffMaxMS = 400
|
||||
svc.cfg.Gateway.OpenAIWS.RetryJitterRatio = 0
|
||||
|
||||
require.Equal(t, time.Duration(100)*time.Millisecond, svc.openAIWSRetryBackoff(1))
|
||||
require.Equal(t, time.Duration(200)*time.Millisecond, svc.openAIWSRetryBackoff(2))
|
||||
require.Equal(t, time.Duration(400)*time.Millisecond, svc.openAIWSRetryBackoff(3))
|
||||
require.Equal(t, time.Duration(400)*time.Millisecond, svc.openAIWSRetryBackoff(4))
|
||||
}
|
||||
|
||||
func TestOpenAIWSRetryTotalBudget(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{cfg: &config.Config{}}
|
||||
svc.cfg.Gateway.OpenAIWS.RetryTotalBudgetMS = 1200
|
||||
require.Equal(t, 1200*time.Millisecond, svc.openAIWSRetryTotalBudget())
|
||||
|
||||
svc.cfg.Gateway.OpenAIWS.RetryTotalBudgetMS = 0
|
||||
require.Equal(t, time.Duration(0), svc.openAIWSRetryTotalBudget())
|
||||
}
|
||||
|
||||
func TestClassifyOpenAIWSReadFallbackReason(t *testing.T) {
|
||||
require.Equal(t, "policy_violation", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusPolicyViolation}))
|
||||
require.Equal(t, "message_too_big", classifyOpenAIWSReadFallbackReason(coderws.CloseError{Code: coderws.StatusMessageTooBig}))
|
||||
require.Equal(t, "read_event", classifyOpenAIWSReadFallbackReason(errors.New("io")))
|
||||
}
|
||||
|
||||
func TestOpenAIWSStoreDisabledConnMode(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{cfg: &config.Config{}}
|
||||
svc.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = true
|
||||
require.Equal(t, openAIWSStoreDisabledConnModeStrict, svc.openAIWSStoreDisabledConnMode())
|
||||
|
||||
svc.cfg.Gateway.OpenAIWS.StoreDisabledConnMode = "adaptive"
|
||||
require.Equal(t, openAIWSStoreDisabledConnModeAdaptive, svc.openAIWSStoreDisabledConnMode())
|
||||
|
||||
svc.cfg.Gateway.OpenAIWS.StoreDisabledConnMode = ""
|
||||
svc.cfg.Gateway.OpenAIWS.StoreDisabledForceNewConn = false
|
||||
require.Equal(t, openAIWSStoreDisabledConnModeOff, svc.openAIWSStoreDisabledConnMode())
|
||||
}
|
||||
|
||||
func TestShouldForceNewConnOnStoreDisabled(t *testing.T) {
|
||||
require.True(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeStrict, ""))
|
||||
require.False(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeOff, "policy_violation"))
|
||||
|
||||
require.True(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeAdaptive, "policy_violation"))
|
||||
require.True(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeAdaptive, "prewarm_message_too_big"))
|
||||
require.False(t, shouldForceNewConnOnStoreDisabled(openAIWSStoreDisabledConnModeAdaptive, "read_event"))
|
||||
}
|
||||
|
||||
func TestOpenAIWSRetryMetricsSnapshot(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{}
|
||||
svc.recordOpenAIWSRetryAttempt(150 * time.Millisecond)
|
||||
svc.recordOpenAIWSRetryAttempt(0)
|
||||
svc.recordOpenAIWSRetryExhausted()
|
||||
svc.recordOpenAIWSNonRetryableFastFallback()
|
||||
|
||||
snapshot := svc.SnapshotOpenAIWSRetryMetrics()
|
||||
require.Equal(t, int64(2), snapshot.RetryAttemptsTotal)
|
||||
require.Equal(t, int64(150), snapshot.RetryBackoffMsTotal)
|
||||
require.Equal(t, int64(1), snapshot.RetryExhaustedTotal)
|
||||
require.Equal(t, int64(1), snapshot.NonRetryableFastFallbackTotal)
|
||||
}
|
||||
|
||||
func TestShouldLogOpenAIWSPayloadSchema(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{cfg: &config.Config{}}
|
||||
|
||||
svc.cfg.Gateway.OpenAIWS.PayloadLogSampleRate = 0
|
||||
require.True(t, svc.shouldLogOpenAIWSPayloadSchema(1), "首次尝试应始终记录 payload_schema")
|
||||
require.False(t, svc.shouldLogOpenAIWSPayloadSchema(2))
|
||||
|
||||
svc.cfg.Gateway.OpenAIWS.PayloadLogSampleRate = 1
|
||||
require.True(t, svc.shouldLogOpenAIWSPayloadSchema(2))
|
||||
}
|
||||
Reference in New Issue
Block a user