feat(sync): full code sync from release
This commit is contained in:
@@ -5,17 +5,20 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
|
||||
coderws "github.com/coder/websocket"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.uber.org/zap"
|
||||
@@ -64,6 +67,11 @@ func NewOpenAIGatewayHandler(
|
||||
// Responses handles OpenAI Responses API endpoint
|
||||
// POST /openai/v1/responses
|
||||
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
// 局部兜底:确保该 handler 内部任何 panic 都不会击穿到进程级。
|
||||
streamStarted := false
|
||||
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||
setOpenAIClientTransportHTTP(c)
|
||||
|
||||
requestStart := time.Now()
|
||||
|
||||
// Get apiKey and user from context (set by ApiKeyAuth middleware)
|
||||
@@ -85,9 +93,12 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
if !h.ensureResponsesDependencies(c, reqLog) {
|
||||
return
|
||||
}
|
||||
|
||||
// Read request body
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
@@ -125,43 +136,30 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
reqStream := streamResult.Bool()
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
previousResponseID := strings.TrimSpace(gjson.GetBytes(body, "previous_response_id").String())
|
||||
if previousResponseID != "" {
|
||||
previousResponseIDKind := service.ClassifyOpenAIPreviousResponseIDKind(previousResponseID)
|
||||
reqLog = reqLog.With(
|
||||
zap.Bool("has_previous_response_id", true),
|
||||
zap.String("previous_response_id_kind", previousResponseIDKind),
|
||||
zap.Int("previous_response_id_len", len(previousResponseID)),
|
||||
)
|
||||
if previousResponseIDKind == service.OpenAIPreviousResponseIDKindMessageID {
|
||||
reqLog.Warn("openai.request_validation_failed",
|
||||
zap.String("reason", "previous_response_id_looks_like_message_id"),
|
||||
)
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "previous_response_id must be a response.id (resp_*), not a message id")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
|
||||
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
|
||||
// 要求 previous_response_id,或 input 内存在带 call_id 的 tool_call/function_call,
|
||||
// 或带 id 且与 call_id 匹配的 item_reference。
|
||||
// 此路径需要遍历 input 数组做 call_id 关联检查,保留 Unmarshal
|
||||
if gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
|
||||
var reqBody map[string]any
|
||||
if err := json.Unmarshal(body, &reqBody); err == nil {
|
||||
c.Set(service.OpenAIParsedRequestBodyKey, reqBody)
|
||||
if service.HasFunctionCallOutput(reqBody) {
|
||||
previousResponseID, _ := reqBody["previous_response_id"].(string)
|
||||
if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) {
|
||||
if service.HasFunctionCallOutputMissingCallID(reqBody) {
|
||||
reqLog.Warn("openai.request_validation_failed",
|
||||
zap.String("reason", "function_call_output_missing_call_id"),
|
||||
)
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
|
||||
return
|
||||
}
|
||||
callIDs := service.FunctionCallOutputCallIDs(reqBody)
|
||||
if !service.HasItemReferenceForCallIDs(reqBody, callIDs) {
|
||||
reqLog.Warn("openai.request_validation_failed",
|
||||
zap.String("reason", "function_call_output_missing_item_reference"),
|
||||
)
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if !h.validateFunctionCallOutputRequest(c, body, reqLog) {
|
||||
return
|
||||
}
|
||||
|
||||
// Track if we've started streaming (for error handling)
|
||||
streamStarted := false
|
||||
|
||||
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
|
||||
if h.errorPassthroughService != nil {
|
||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||
@@ -173,51 +171,11 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
||||
routingStart := time.Now()
|
||||
|
||||
// 0. 先尝试直接抢占用户槽位(快速路径)
|
||||
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(c.Request.Context(), subject.UserID, subject.Concurrency)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.user_slot_acquire_failed", zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
|
||||
waitCounted := false
|
||||
if !userAcquired {
|
||||
// 仅在抢槽失败时才进入等待队列,减少常态请求 Redis 写入。
|
||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
||||
canWait, waitErr := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||
if waitErr != nil {
|
||||
reqLog.Warn("openai.user_wait_counter_increment_failed", zap.Error(waitErr))
|
||||
// 按现有降级语义:等待计数异常时放行后续抢槽流程
|
||||
} else if !canWait {
|
||||
reqLog.Info("openai.user_wait_queue_full", zap.Int("max_wait", maxWait))
|
||||
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||
return
|
||||
}
|
||||
if waitErr == nil && canWait {
|
||||
waitCounted = true
|
||||
}
|
||||
defer func() {
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
}
|
||||
}()
|
||||
|
||||
userReleaseFunc, err = h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.user_slot_acquire_failed_after_wait", zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 用户槽位已获取:退出等待队列计数。
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||
waitCounted = false
|
||||
}
|
||||
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
|
||||
userReleaseFunc = wrapReleaseOnDone(c.Request.Context(), userReleaseFunc)
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
@@ -241,7 +199,15 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
for {
|
||||
// Select account supporting the requested model
|
||||
reqLog.Debug("openai.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||
selection, err := h.gatewayService.SelectAccountWithLoadAwareness(c.Request.Context(), apiKey.GroupID, sessionHash, reqModel, failedAccountIDs)
|
||||
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
||||
c.Request.Context(),
|
||||
apiKey.GroupID,
|
||||
previousResponseID,
|
||||
sessionHash,
|
||||
reqModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.account_select_failed",
|
||||
zap.Error(err),
|
||||
@@ -258,80 +224,30 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
return
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
if previousResponseID != "" && selection != nil && selection.Account != nil {
|
||||
reqLog.Debug("openai.account_selected_with_previous_response_id", zap.Int64("account_id", selection.Account.ID))
|
||||
}
|
||||
reqLog.Debug("openai.account_schedule_decision",
|
||||
zap.String("layer", scheduleDecision.Layer),
|
||||
zap.Bool("sticky_previous_hit", scheduleDecision.StickyPreviousHit),
|
||||
zap.Bool("sticky_session_hit", scheduleDecision.StickySessionHit),
|
||||
zap.Int("candidate_count", scheduleDecision.CandidateCount),
|
||||
zap.Int("top_k", scheduleDecision.TopK),
|
||||
zap.Int64("latency_ms", scheduleDecision.LatencyMs),
|
||||
zap.Float64("load_skew", scheduleDecision.LoadSkew),
|
||||
)
|
||||
account := selection.Account
|
||||
reqLog.Debug("openai.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
// 3. Acquire account concurrency slot
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
// 先快速尝试一次账号槽位,命中则跳过等待计数写入。
|
||||
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
|
||||
c.Request.Context(),
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.account_slot_quick_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
if fastAcquired {
|
||||
accountReleaseFunc = fastReleaseFunc
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
|
||||
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
} else {
|
||||
accountWaitCounted := false
|
||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
} else if !canWait {
|
||||
reqLog.Info("openai.account_wait_queue_full",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
|
||||
)
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
|
||||
return
|
||||
}
|
||||
if err == nil && canWait {
|
||||
accountWaitCounted = true
|
||||
}
|
||||
releaseWait := func() {
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
|
||||
accountWaitCounted = false
|
||||
}
|
||||
}
|
||||
|
||||
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
selection.WaitPlan.Timeout,
|
||||
reqStream,
|
||||
&streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
releaseWait()
|
||||
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||
return
|
||||
}
|
||||
// Slot acquired: no longer waiting in queue.
|
||||
releaseWait()
|
||||
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
|
||||
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
}
|
||||
accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
// 账号槽位/等待计数需要在超时或断开时安全回收
|
||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||
|
||||
// Forward request
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
@@ -353,6 +269,8 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
@@ -368,14 +286,25 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Error("openai.forward_failed",
|
||||
fields := []zap.Field{
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
|
||||
reqLog.Warn("openai.forward_failed", fields...)
|
||||
return
|
||||
}
|
||||
reqLog.Error("openai.forward_failed", fields...)
|
||||
return
|
||||
}
|
||||
if result != nil {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
} else {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
|
||||
}
|
||||
|
||||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
@@ -411,6 +340,525 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context, body []byte, reqLog *zap.Logger) bool {
|
||||
if !gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
|
||||
return true
|
||||
}
|
||||
|
||||
var reqBody map[string]any
|
||||
if err := json.Unmarshal(body, &reqBody); err != nil {
|
||||
// 保持原有容错语义:解析失败时跳过预校验,沿用后续上游校验结果。
|
||||
return true
|
||||
}
|
||||
|
||||
c.Set(service.OpenAIParsedRequestBodyKey, reqBody)
|
||||
validation := service.ValidateFunctionCallOutputContext(reqBody)
|
||||
if !validation.HasFunctionCallOutput {
|
||||
return true
|
||||
}
|
||||
|
||||
previousResponseID, _ := reqBody["previous_response_id"].(string)
|
||||
if strings.TrimSpace(previousResponseID) != "" || validation.HasToolCallContext {
|
||||
return true
|
||||
}
|
||||
|
||||
if validation.HasFunctionCallOutputMissingCallID {
|
||||
reqLog.Warn("openai.request_validation_failed",
|
||||
zap.String("reason", "function_call_output_missing_call_id"),
|
||||
)
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
|
||||
return false
|
||||
}
|
||||
if validation.HasItemReferenceForAllCallIDs {
|
||||
return true
|
||||
}
|
||||
|
||||
reqLog.Warn("openai.request_validation_failed",
|
||||
zap.String("reason", "function_call_output_missing_item_reference"),
|
||||
)
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) acquireResponsesUserSlot(
|
||||
c *gin.Context,
|
||||
userID int64,
|
||||
userConcurrency int,
|
||||
reqStream bool,
|
||||
streamStarted *bool,
|
||||
reqLog *zap.Logger,
|
||||
) (func(), bool) {
|
||||
ctx := c.Request.Context()
|
||||
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, userID, userConcurrency)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.user_slot_acquire_failed", zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "user", *streamStarted)
|
||||
return nil, false
|
||||
}
|
||||
if userAcquired {
|
||||
return wrapReleaseOnDone(ctx, userReleaseFunc), true
|
||||
}
|
||||
|
||||
maxWait := service.CalculateMaxWait(userConcurrency)
|
||||
canWait, waitErr := h.concurrencyHelper.IncrementWaitCount(ctx, userID, maxWait)
|
||||
if waitErr != nil {
|
||||
reqLog.Warn("openai.user_wait_counter_increment_failed", zap.Error(waitErr))
|
||||
// 按现有降级语义:等待计数异常时放行后续抢槽流程
|
||||
} else if !canWait {
|
||||
reqLog.Info("openai.user_wait_queue_full", zap.Int("max_wait", maxWait))
|
||||
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||
return nil, false
|
||||
}
|
||||
|
||||
waitCounted := waitErr == nil && canWait
|
||||
defer func() {
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(ctx, userID)
|
||||
}
|
||||
}()
|
||||
|
||||
userReleaseFunc, err = h.concurrencyHelper.AcquireUserSlotWithWait(c, userID, userConcurrency, reqStream, streamStarted)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.user_slot_acquire_failed_after_wait", zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "user", *streamStarted)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// 槽位获取成功后,立刻退出等待计数。
|
||||
if waitCounted {
|
||||
h.concurrencyHelper.DecrementWaitCount(ctx, userID)
|
||||
waitCounted = false
|
||||
}
|
||||
return wrapReleaseOnDone(ctx, userReleaseFunc), true
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) acquireResponsesAccountSlot(
|
||||
c *gin.Context,
|
||||
groupID *int64,
|
||||
sessionHash string,
|
||||
selection *service.AccountSelectionResult,
|
||||
reqStream bool,
|
||||
streamStarted *bool,
|
||||
reqLog *zap.Logger,
|
||||
) (func(), bool) {
|
||||
if selection == nil || selection.Account == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
ctx := c.Request.Context()
|
||||
account := selection.Account
|
||||
if selection.Acquired {
|
||||
return wrapReleaseOnDone(ctx, selection.ReleaseFunc), true
|
||||
}
|
||||
if selection.WaitPlan == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
|
||||
ctx,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.account_slot_quick_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "account", *streamStarted)
|
||||
return nil, false
|
||||
}
|
||||
if fastAcquired {
|
||||
if err := h.gatewayService.BindStickySession(ctx, groupID, sessionHash, account.ID); err != nil {
|
||||
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
return wrapReleaseOnDone(ctx, fastReleaseFunc), true
|
||||
}
|
||||
|
||||
canWait, waitErr := h.concurrencyHelper.IncrementAccountWaitCount(ctx, account.ID, selection.WaitPlan.MaxWaiting)
|
||||
if waitErr != nil {
|
||||
reqLog.Warn("openai.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(waitErr))
|
||||
} else if !canWait {
|
||||
reqLog.Info("openai.account_wait_queue_full",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
|
||||
)
|
||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", *streamStarted)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
accountWaitCounted := waitErr == nil && canWait
|
||||
releaseWait := func() {
|
||||
if accountWaitCounted {
|
||||
h.concurrencyHelper.DecrementAccountWaitCount(ctx, account.ID)
|
||||
accountWaitCounted = false
|
||||
}
|
||||
}
|
||||
defer releaseWait()
|
||||
|
||||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||||
c,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
selection.WaitPlan.Timeout,
|
||||
reqStream,
|
||||
streamStarted,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
h.handleConcurrencyError(c, err, "account", *streamStarted)
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Slot acquired: no longer waiting in queue.
|
||||
releaseWait()
|
||||
if err := h.gatewayService.BindStickySession(ctx, groupID, sessionHash, account.ID); err != nil {
|
||||
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
return wrapReleaseOnDone(ctx, accountReleaseFunc), true
|
||||
}
|
||||
|
||||
// ResponsesWebSocket handles OpenAI Responses API WebSocket ingress endpoint
|
||||
// GET /openai/v1/responses (Upgrade: websocket)
|
||||
func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||||
if !isOpenAIWSUpgradeRequest(c.Request) {
|
||||
h.errorResponse(c, http.StatusUpgradeRequired, "invalid_request_error", "WebSocket upgrade required (Upgrade: websocket)")
|
||||
return
|
||||
}
|
||||
setOpenAIClientTransportWS(c)
|
||||
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.openai_gateway.responses_ws",
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
zap.Bool("openai_ws_mode", true),
|
||||
)
|
||||
if !h.ensureResponsesDependencies(c, reqLog) {
|
||||
return
|
||||
}
|
||||
reqLog.Info("openai.websocket_ingress_started")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
userAgent := strings.TrimSpace(c.GetHeader("User-Agent"))
|
||||
|
||||
wsConn, err := coderws.Accept(c.Writer, c.Request, &coderws.AcceptOptions{
|
||||
CompressionMode: coderws.CompressionContextTakeover,
|
||||
})
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.websocket_accept_failed",
|
||||
zap.Error(err),
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("request_user_agent", userAgent),
|
||||
zap.String("upgrade_header", strings.TrimSpace(c.GetHeader("Upgrade"))),
|
||||
zap.String("connection_header", strings.TrimSpace(c.GetHeader("Connection"))),
|
||||
zap.String("sec_websocket_version", strings.TrimSpace(c.GetHeader("Sec-WebSocket-Version"))),
|
||||
zap.Bool("has_sec_websocket_key", strings.TrimSpace(c.GetHeader("Sec-WebSocket-Key")) != ""),
|
||||
)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = wsConn.CloseNow()
|
||||
}()
|
||||
wsConn.SetReadLimit(16 * 1024 * 1024)
|
||||
|
||||
ctx := c.Request.Context()
|
||||
readCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
msgType, firstMessage, err := wsConn.Read(readCtx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
|
||||
reqLog.Warn("openai.websocket_read_first_message_failed",
|
||||
zap.Error(err),
|
||||
zap.String("client_ip", clientIP),
|
||||
zap.String("close_status", closeStatus),
|
||||
zap.String("close_reason", closeReason),
|
||||
zap.Duration("read_timeout", 30*time.Second),
|
||||
)
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "missing first response.create message")
|
||||
return
|
||||
}
|
||||
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "unsupported websocket message type")
|
||||
return
|
||||
}
|
||||
if !gjson.ValidBytes(firstMessage) {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "invalid JSON payload")
|
||||
return
|
||||
}
|
||||
|
||||
reqModel := strings.TrimSpace(gjson.GetBytes(firstMessage, "model").String())
|
||||
if reqModel == "" {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "model is required in first response.create payload")
|
||||
return
|
||||
}
|
||||
previousResponseID := strings.TrimSpace(gjson.GetBytes(firstMessage, "previous_response_id").String())
|
||||
previousResponseIDKind := service.ClassifyOpenAIPreviousResponseIDKind(previousResponseID)
|
||||
if previousResponseID != "" && previousResponseIDKind == service.OpenAIPreviousResponseIDKindMessageID {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "previous_response_id must be a response.id (resp_*), not a message id")
|
||||
return
|
||||
}
|
||||
reqLog = reqLog.With(
|
||||
zap.Bool("ws_ingress", true),
|
||||
zap.String("model", reqModel),
|
||||
zap.Bool("has_previous_response_id", previousResponseID != ""),
|
||||
zap.String("previous_response_id_kind", previousResponseIDKind),
|
||||
)
|
||||
setOpsRequestContext(c, reqModel, true, firstMessage)
|
||||
|
||||
var currentUserRelease func()
|
||||
var currentAccountRelease func()
|
||||
releaseTurnSlots := func() {
|
||||
if currentAccountRelease != nil {
|
||||
currentAccountRelease()
|
||||
currentAccountRelease = nil
|
||||
}
|
||||
if currentUserRelease != nil {
|
||||
currentUserRelease()
|
||||
currentUserRelease = nil
|
||||
}
|
||||
}
|
||||
// 必须尽早注册,确保任何 early return 都能释放已获取的并发槽位。
|
||||
defer releaseTurnSlots()
|
||||
|
||||
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.websocket_user_slot_acquire_failed", zap.Error(err))
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire user concurrency slot")
|
||||
return
|
||||
}
|
||||
if !userAcquired {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "too many concurrent requests, please retry later")
|
||||
return
|
||||
}
|
||||
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
|
||||
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("openai.websocket_billing_eligibility_check_failed", zap.Error(err))
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "billing check failed")
|
||||
return
|
||||
}
|
||||
|
||||
sessionHash := h.gatewayService.GenerateSessionHashWithFallback(
|
||||
c,
|
||||
firstMessage,
|
||||
openAIWSIngressFallbackSessionSeed(subject.UserID, apiKey.ID, apiKey.GroupID),
|
||||
)
|
||||
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
||||
ctx,
|
||||
apiKey.GroupID,
|
||||
previousResponseID,
|
||||
sessionHash,
|
||||
reqModel,
|
||||
nil,
|
||||
service.OpenAIUpstreamTransportResponsesWebsocketV2,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.websocket_account_select_failed", zap.Error(err))
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
|
||||
return
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
|
||||
return
|
||||
}
|
||||
|
||||
account := selection.Account
|
||||
accountMaxConcurrency := account.Concurrency
|
||||
if selection.WaitPlan != nil && selection.WaitPlan.MaxConcurrency > 0 {
|
||||
accountMaxConcurrency = selection.WaitPlan.MaxConcurrency
|
||||
}
|
||||
accountReleaseFunc := selection.ReleaseFunc
|
||||
if !selection.Acquired {
|
||||
if selection.WaitPlan == nil {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
|
||||
return
|
||||
}
|
||||
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
|
||||
ctx,
|
||||
account.ID,
|
||||
selection.WaitPlan.MaxConcurrency,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.websocket_account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire account concurrency slot")
|
||||
return
|
||||
}
|
||||
if !fastAcquired {
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
|
||||
return
|
||||
}
|
||||
accountReleaseFunc = fastReleaseFunc
|
||||
}
|
||||
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
|
||||
if err := h.gatewayService.BindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil {
|
||||
reqLog.Warn("openai.websocket_bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
}
|
||||
|
||||
token, _, err := h.gatewayService.GetAccessToken(ctx, account)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token")
|
||||
return
|
||||
}
|
||||
|
||||
reqLog.Debug("openai.websocket_account_selected",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("account_name", account.Name),
|
||||
zap.String("schedule_layer", scheduleDecision.Layer),
|
||||
zap.Int("candidate_count", scheduleDecision.CandidateCount),
|
||||
)
|
||||
|
||||
hooks := &service.OpenAIWSIngressHooks{
|
||||
BeforeTurn: func(turn int) error {
|
||||
if turn == 1 {
|
||||
return nil
|
||||
}
|
||||
// 防御式清理:避免异常路径下旧槽位覆盖导致泄漏。
|
||||
releaseTurnSlots()
|
||||
// 非首轮 turn 需要重新抢占并发槽位,避免长连接空闲占槽。
|
||||
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
|
||||
if err != nil {
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire user concurrency slot", err)
|
||||
}
|
||||
if !userAcquired {
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "too many concurrent requests, please retry later", nil)
|
||||
}
|
||||
accountReleaseFunc, accountAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(ctx, account.ID, accountMaxConcurrency)
|
||||
if err != nil {
|
||||
if userReleaseFunc != nil {
|
||||
userReleaseFunc()
|
||||
}
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire account concurrency slot", err)
|
||||
}
|
||||
if !accountAcquired {
|
||||
if userReleaseFunc != nil {
|
||||
userReleaseFunc()
|
||||
}
|
||||
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "account is busy, please retry later", nil)
|
||||
}
|
||||
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
|
||||
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
|
||||
return nil
|
||||
},
|
||||
AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) {
|
||||
releaseTurnSlots()
|
||||
if turnErr != nil || result == nil {
|
||||
return
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
h.submitUsageRecordTask(func(taskCtx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
reqLog.Error("openai.websocket_record_usage_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.String("request_id", result.RequestID),
|
||||
zap.Error(err),
|
||||
)
|
||||
}
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, firstMessage, hooks); err != nil {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
|
||||
reqLog.Warn("openai.websocket_proxy_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Error(err),
|
||||
zap.String("close_status", closeStatus),
|
||||
zap.String("close_reason", closeReason),
|
||||
)
|
||||
var closeErr *service.OpenAIWSClientCloseError
|
||||
if errors.As(err, &closeErr) {
|
||||
closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason())
|
||||
return
|
||||
}
|
||||
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed")
|
||||
return
|
||||
}
|
||||
reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID))
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) recoverResponsesPanic(c *gin.Context, streamStarted *bool) {
|
||||
recovered := recover()
|
||||
if recovered == nil {
|
||||
return
|
||||
}
|
||||
|
||||
started := false
|
||||
if streamStarted != nil {
|
||||
started = *streamStarted
|
||||
}
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, started)
|
||||
requestLogger(c, "handler.openai_gateway.responses").Error(
|
||||
"openai.responses_panic_recovered",
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Any("panic", recovered),
|
||||
zap.ByteString("stack", debug.Stack()),
|
||||
)
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) ensureResponsesDependencies(c *gin.Context, reqLog *zap.Logger) bool {
|
||||
missing := h.missingResponsesDependencies()
|
||||
if len(missing) == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
if reqLog == nil {
|
||||
reqLog = requestLogger(c, "handler.openai_gateway.responses")
|
||||
}
|
||||
reqLog.Error("openai.handler_dependencies_missing", zap.Strings("missing_dependencies", missing))
|
||||
|
||||
if c != nil && c.Writer != nil && !c.Writer.Written() {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "api_error",
|
||||
"message": "Service temporarily unavailable",
|
||||
},
|
||||
})
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (h *OpenAIGatewayHandler) missingResponsesDependencies() []string {
|
||||
missing := make([]string, 0, 5)
|
||||
if h == nil {
|
||||
return append(missing, "handler")
|
||||
}
|
||||
if h.gatewayService == nil {
|
||||
missing = append(missing, "gatewayService")
|
||||
}
|
||||
if h.billingCacheService == nil {
|
||||
missing = append(missing, "billingCacheService")
|
||||
}
|
||||
if h.apiKeyService == nil {
|
||||
missing = append(missing, "apiKeyService")
|
||||
}
|
||||
if h.concurrencyHelper == nil || h.concurrencyHelper.concurrencyService == nil {
|
||||
missing = append(missing, "concurrencyHelper")
|
||||
}
|
||||
return missing
|
||||
}
|
||||
|
||||
func getContextInt64(c *gin.Context, key string) (int64, bool) {
|
||||
if c == nil || key == "" {
|
||||
return 0, false
|
||||
@@ -444,6 +892,14 @@ func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTas
|
||||
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
defer func() {
|
||||
if recovered := recover(); recovered != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.responses"),
|
||||
zap.Any("panic", recovered),
|
||||
).Error("openai.usage_record_task_panic_recovered")
|
||||
}
|
||||
}()
|
||||
task(ctx)
|
||||
}
|
||||
|
||||
@@ -515,19 +971,8 @@ func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status
|
||||
// Stream already started, send error as SSE event then close
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if ok {
|
||||
// Send error event in OpenAI SSE format with proper JSON marshaling
|
||||
errorData := map[string]any{
|
||||
"error": map[string]string{
|
||||
"type": errType,
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
jsonBytes, err := json.Marshal(errorData)
|
||||
if err != nil {
|
||||
_ = c.Error(err)
|
||||
return
|
||||
}
|
||||
errorEvent := fmt.Sprintf("event: error\ndata: %s\n\n", string(jsonBytes))
|
||||
// SSE 错误事件固定 schema,使用 Quote 直拼可避免额外 Marshal 分配。
|
||||
errorEvent := "event: error\ndata: " + `{"error":{"type":` + strconv.Quote(errType) + `,"message":` + strconv.Quote(message) + `}}` + "\n\n"
|
||||
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
||||
_ = c.Error(err)
|
||||
}
|
||||
@@ -549,6 +994,16 @@ func (h *OpenAIGatewayHandler) ensureForwardErrorResponse(c *gin.Context, stream
|
||||
return true
|
||||
}
|
||||
|
||||
func shouldLogOpenAIForwardFailureAsWarn(c *gin.Context, wroteFallback bool) bool {
|
||||
if wroteFallback {
|
||||
return false
|
||||
}
|
||||
if c == nil || c.Writer == nil {
|
||||
return false
|
||||
}
|
||||
return c.Writer.Written()
|
||||
}
|
||||
|
||||
// errorResponse returns OpenAI API format error response
|
||||
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
||||
c.JSON(status, gin.H{
|
||||
@@ -558,3 +1013,61 @@ func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func setOpenAIClientTransportHTTP(c *gin.Context) {
|
||||
service.SetOpenAIClientTransport(c, service.OpenAIClientTransportHTTP)
|
||||
}
|
||||
|
||||
func setOpenAIClientTransportWS(c *gin.Context) {
|
||||
service.SetOpenAIClientTransport(c, service.OpenAIClientTransportWS)
|
||||
}
|
||||
|
||||
func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64) string {
|
||||
gid := int64(0)
|
||||
if groupID != nil {
|
||||
gid = *groupID
|
||||
}
|
||||
return fmt.Sprintf("openai_ws_ingress:%d:%d:%d", gid, userID, apiKeyID)
|
||||
}
|
||||
|
||||
func isOpenAIWSUpgradeRequest(r *http.Request) bool {
|
||||
if r == nil {
|
||||
return false
|
||||
}
|
||||
if !strings.EqualFold(strings.TrimSpace(r.Header.Get("Upgrade")), "websocket") {
|
||||
return false
|
||||
}
|
||||
return strings.Contains(strings.ToLower(strings.TrimSpace(r.Header.Get("Connection"))), "upgrade")
|
||||
}
|
||||
|
||||
func closeOpenAIClientWS(conn *coderws.Conn, status coderws.StatusCode, reason string) {
|
||||
if conn == nil {
|
||||
return
|
||||
}
|
||||
reason = strings.TrimSpace(reason)
|
||||
if len(reason) > 120 {
|
||||
reason = reason[:120]
|
||||
}
|
||||
_ = conn.Close(status, reason)
|
||||
_ = conn.CloseNow()
|
||||
}
|
||||
|
||||
func summarizeWSCloseErrorForLog(err error) (string, string) {
|
||||
if err == nil {
|
||||
return "-", "-"
|
||||
}
|
||||
statusCode := coderws.CloseStatus(err)
|
||||
if statusCode == -1 {
|
||||
return "-", "-"
|
||||
}
|
||||
closeStatus := fmt.Sprintf("%d(%s)", int(statusCode), statusCode.String())
|
||||
closeReason := "-"
|
||||
var closeErr coderws.CloseError
|
||||
if errors.As(err, &closeErr) {
|
||||
reason := strings.TrimSpace(closeErr.Reason)
|
||||
if reason != "" {
|
||||
closeReason = reason
|
||||
}
|
||||
}
|
||||
return closeStatus, closeReason
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user