1074 lines
37 KiB
Go
1074 lines
37 KiB
Go
package handler
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"net/http"
|
||
"runtime/debug"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||
|
||
coderws "github.com/coder/websocket"
|
||
"github.com/gin-gonic/gin"
|
||
"github.com/tidwall/gjson"
|
||
"go.uber.org/zap"
|
||
)
|
||
|
||
// OpenAIGatewayHandler handles OpenAI API gateway requests
|
||
type OpenAIGatewayHandler struct {
|
||
gatewayService *service.OpenAIGatewayService
|
||
billingCacheService *service.BillingCacheService
|
||
apiKeyService *service.APIKeyService
|
||
usageRecordWorkerPool *service.UsageRecordWorkerPool
|
||
errorPassthroughService *service.ErrorPassthroughService
|
||
concurrencyHelper *ConcurrencyHelper
|
||
maxAccountSwitches int
|
||
}
|
||
|
||
// NewOpenAIGatewayHandler creates a new OpenAIGatewayHandler
|
||
func NewOpenAIGatewayHandler(
|
||
gatewayService *service.OpenAIGatewayService,
|
||
concurrencyService *service.ConcurrencyService,
|
||
billingCacheService *service.BillingCacheService,
|
||
apiKeyService *service.APIKeyService,
|
||
usageRecordWorkerPool *service.UsageRecordWorkerPool,
|
||
errorPassthroughService *service.ErrorPassthroughService,
|
||
cfg *config.Config,
|
||
) *OpenAIGatewayHandler {
|
||
pingInterval := time.Duration(0)
|
||
maxAccountSwitches := 3
|
||
if cfg != nil {
|
||
pingInterval = time.Duration(cfg.Concurrency.PingInterval) * time.Second
|
||
if cfg.Gateway.MaxAccountSwitches > 0 {
|
||
maxAccountSwitches = cfg.Gateway.MaxAccountSwitches
|
||
}
|
||
}
|
||
return &OpenAIGatewayHandler{
|
||
gatewayService: gatewayService,
|
||
billingCacheService: billingCacheService,
|
||
apiKeyService: apiKeyService,
|
||
usageRecordWorkerPool: usageRecordWorkerPool,
|
||
errorPassthroughService: errorPassthroughService,
|
||
concurrencyHelper: NewConcurrencyHelper(concurrencyService, SSEPingFormatComment, pingInterval),
|
||
maxAccountSwitches: maxAccountSwitches,
|
||
}
|
||
}
|
||
|
||
// Responses handles OpenAI Responses API endpoint
|
||
// POST /openai/v1/responses
|
||
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||
// 局部兜底:确保该 handler 内部任何 panic 都不会击穿到进程级。
|
||
streamStarted := false
|
||
defer h.recoverResponsesPanic(c, &streamStarted)
|
||
setOpenAIClientTransportHTTP(c)
|
||
|
||
requestStart := time.Now()
|
||
|
||
// Get apiKey and user from context (set by ApiKeyAuth middleware)
|
||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||
if !ok {
|
||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||
return
|
||
}
|
||
|
||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||
if !ok {
|
||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||
return
|
||
}
|
||
reqLog := requestLogger(
|
||
c,
|
||
"handler.openai_gateway.responses",
|
||
zap.Int64("user_id", subject.UserID),
|
||
zap.Int64("api_key_id", apiKey.ID),
|
||
zap.Any("group_id", apiKey.GroupID),
|
||
)
|
||
if !h.ensureResponsesDependencies(c, reqLog) {
|
||
return
|
||
}
|
||
|
||
// Read request body
|
||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||
if err != nil {
|
||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||
return
|
||
}
|
||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||
return
|
||
}
|
||
|
||
if len(body) == 0 {
|
||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||
return
|
||
}
|
||
|
||
setOpsRequestContext(c, "", false, body)
|
||
|
||
// 校验请求体 JSON 合法性
|
||
if !gjson.ValidBytes(body) {
|
||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||
return
|
||
}
|
||
|
||
// 使用 gjson 只读提取字段做校验,避免完整 Unmarshal
|
||
modelResult := gjson.GetBytes(body, "model")
|
||
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
|
||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||
return
|
||
}
|
||
reqModel := modelResult.String()
|
||
|
||
streamResult := gjson.GetBytes(body, "stream")
|
||
if streamResult.Exists() && streamResult.Type != gjson.True && streamResult.Type != gjson.False {
|
||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "invalid stream field type")
|
||
return
|
||
}
|
||
reqStream := streamResult.Bool()
|
||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||
previousResponseID := strings.TrimSpace(gjson.GetBytes(body, "previous_response_id").String())
|
||
if previousResponseID != "" {
|
||
previousResponseIDKind := service.ClassifyOpenAIPreviousResponseIDKind(previousResponseID)
|
||
reqLog = reqLog.With(
|
||
zap.Bool("has_previous_response_id", true),
|
||
zap.String("previous_response_id_kind", previousResponseIDKind),
|
||
zap.Int("previous_response_id_len", len(previousResponseID)),
|
||
)
|
||
if previousResponseIDKind == service.OpenAIPreviousResponseIDKindMessageID {
|
||
reqLog.Warn("openai.request_validation_failed",
|
||
zap.String("reason", "previous_response_id_looks_like_message_id"),
|
||
)
|
||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "previous_response_id must be a response.id (resp_*), not a message id")
|
||
return
|
||
}
|
||
}
|
||
|
||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||
|
||
// 提前校验 function_call_output 是否具备可关联上下文,避免上游 400。
|
||
if !h.validateFunctionCallOutputRequest(c, body, reqLog) {
|
||
return
|
||
}
|
||
|
||
// 绑定错误透传服务,允许 service 层在非 failover 错误场景复用规则。
|
||
if h.errorPassthroughService != nil {
|
||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||
}
|
||
|
||
// Get subscription info (may be nil)
|
||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||
|
||
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
||
routingStart := time.Now()
|
||
|
||
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog)
|
||
if !acquired {
|
||
return
|
||
}
|
||
// 确保请求取消时也会释放槽位,避免长连接被动中断造成泄漏
|
||
if userReleaseFunc != nil {
|
||
defer userReleaseFunc()
|
||
}
|
||
|
||
// 2. Re-check billing eligibility after wait
|
||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||
reqLog.Info("openai.billing_eligibility_check_failed", zap.Error(err))
|
||
status, code, message := billingErrorDetails(err)
|
||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||
return
|
||
}
|
||
|
||
// Generate session hash (header first; fallback to prompt_cache_key)
|
||
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
|
||
|
||
maxAccountSwitches := h.maxAccountSwitches
|
||
switchCount := 0
|
||
failedAccountIDs := make(map[int64]struct{})
|
||
var lastFailoverErr *service.UpstreamFailoverError
|
||
|
||
for {
|
||
// Select account supporting the requested model
|
||
reqLog.Debug("openai.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
||
c.Request.Context(),
|
||
apiKey.GroupID,
|
||
previousResponseID,
|
||
sessionHash,
|
||
reqModel,
|
||
failedAccountIDs,
|
||
service.OpenAIUpstreamTransportAny,
|
||
)
|
||
if err != nil {
|
||
reqLog.Warn("openai.account_select_failed",
|
||
zap.Error(err),
|
||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||
)
|
||
if len(failedAccountIDs) == 0 {
|
||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||
return
|
||
}
|
||
if lastFailoverErr != nil {
|
||
h.handleFailoverExhausted(c, lastFailoverErr, streamStarted)
|
||
} else {
|
||
h.handleFailoverExhaustedSimple(c, 502, streamStarted)
|
||
}
|
||
return
|
||
}
|
||
if selection == nil || selection.Account == nil {
|
||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||
return
|
||
}
|
||
if previousResponseID != "" && selection != nil && selection.Account != nil {
|
||
reqLog.Debug("openai.account_selected_with_previous_response_id", zap.Int64("account_id", selection.Account.ID))
|
||
}
|
||
reqLog.Debug("openai.account_schedule_decision",
|
||
zap.String("layer", scheduleDecision.Layer),
|
||
zap.Bool("sticky_previous_hit", scheduleDecision.StickyPreviousHit),
|
||
zap.Bool("sticky_session_hit", scheduleDecision.StickySessionHit),
|
||
zap.Int("candidate_count", scheduleDecision.CandidateCount),
|
||
zap.Int("top_k", scheduleDecision.TopK),
|
||
zap.Int64("latency_ms", scheduleDecision.LatencyMs),
|
||
zap.Float64("load_skew", scheduleDecision.LoadSkew),
|
||
)
|
||
account := selection.Account
|
||
reqLog.Debug("openai.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||
|
||
accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog)
|
||
if !acquired {
|
||
return
|
||
}
|
||
|
||
// Forward request
|
||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||
forwardStart := time.Now()
|
||
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||
if accountReleaseFunc != nil {
|
||
accountReleaseFunc()
|
||
}
|
||
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||
responseLatencyMs := forwardDurationMs
|
||
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
|
||
}
|
||
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
|
||
if err == nil && result != nil && result.FirstTokenMs != nil {
|
||
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
|
||
}
|
||
if err != nil {
|
||
var failoverErr *service.UpstreamFailoverError
|
||
if errors.As(err, &failoverErr) {
|
||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||
failedAccountIDs[account.ID] = struct{}{}
|
||
lastFailoverErr = failoverErr
|
||
if switchCount >= maxAccountSwitches {
|
||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||
return
|
||
}
|
||
switchCount++
|
||
reqLog.Warn("openai.upstream_failover_switching",
|
||
zap.Int64("account_id", account.ID),
|
||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||
zap.Int("switch_count", switchCount),
|
||
zap.Int("max_switches", maxAccountSwitches),
|
||
)
|
||
continue
|
||
}
|
||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||
fields := []zap.Field{
|
||
zap.Int64("account_id", account.ID),
|
||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||
zap.Error(err),
|
||
}
|
||
if shouldLogOpenAIForwardFailureAsWarn(c, wroteFallback) {
|
||
reqLog.Warn("openai.forward_failed", fields...)
|
||
return
|
||
}
|
||
reqLog.Error("openai.forward_failed", fields...)
|
||
return
|
||
}
|
||
if result != nil {
|
||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||
} else {
|
||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
|
||
}
|
||
|
||
// 捕获请求信息(用于异步记录,避免在 goroutine 中访问 gin.Context)
|
||
userAgent := c.GetHeader("User-Agent")
|
||
clientIP := ip.GetClientIP(c)
|
||
|
||
// 使用量记录通过有界 worker 池提交,避免请求热路径创建无界 goroutine。
|
||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||
Result: result,
|
||
APIKey: apiKey,
|
||
User: apiKey.User,
|
||
Account: account,
|
||
Subscription: subscription,
|
||
UserAgent: userAgent,
|
||
IPAddress: clientIP,
|
||
APIKeyService: h.apiKeyService,
|
||
}); err != nil {
|
||
logger.L().With(
|
||
zap.String("component", "handler.openai_gateway.responses"),
|
||
zap.Int64("user_id", subject.UserID),
|
||
zap.Int64("api_key_id", apiKey.ID),
|
||
zap.Any("group_id", apiKey.GroupID),
|
||
zap.String("model", reqModel),
|
||
zap.Int64("account_id", account.ID),
|
||
).Error("openai.record_usage_failed", zap.Error(err))
|
||
}
|
||
})
|
||
reqLog.Debug("openai.request_completed",
|
||
zap.Int64("account_id", account.ID),
|
||
zap.Int("switch_count", switchCount),
|
||
)
|
||
return
|
||
}
|
||
}
|
||
|
||
func (h *OpenAIGatewayHandler) validateFunctionCallOutputRequest(c *gin.Context, body []byte, reqLog *zap.Logger) bool {
|
||
if !gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
|
||
return true
|
||
}
|
||
|
||
var reqBody map[string]any
|
||
if err := json.Unmarshal(body, &reqBody); err != nil {
|
||
// 保持原有容错语义:解析失败时跳过预校验,沿用后续上游校验结果。
|
||
return true
|
||
}
|
||
|
||
c.Set(service.OpenAIParsedRequestBodyKey, reqBody)
|
||
validation := service.ValidateFunctionCallOutputContext(reqBody)
|
||
if !validation.HasFunctionCallOutput {
|
||
return true
|
||
}
|
||
|
||
previousResponseID, _ := reqBody["previous_response_id"].(string)
|
||
if strings.TrimSpace(previousResponseID) != "" || validation.HasToolCallContext {
|
||
return true
|
||
}
|
||
|
||
if validation.HasFunctionCallOutputMissingCallID {
|
||
reqLog.Warn("openai.request_validation_failed",
|
||
zap.String("reason", "function_call_output_missing_call_id"),
|
||
)
|
||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires call_id or previous_response_id; if relying on history, ensure store=true and reuse previous_response_id")
|
||
return false
|
||
}
|
||
if validation.HasItemReferenceForAllCallIDs {
|
||
return true
|
||
}
|
||
|
||
reqLog.Warn("openai.request_validation_failed",
|
||
zap.String("reason", "function_call_output_missing_item_reference"),
|
||
)
|
||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "function_call_output requires item_reference ids matching each call_id, or previous_response_id/tool_call context; if relying on history, ensure store=true and reuse previous_response_id")
|
||
return false
|
||
}
|
||
|
||
func (h *OpenAIGatewayHandler) acquireResponsesUserSlot(
|
||
c *gin.Context,
|
||
userID int64,
|
||
userConcurrency int,
|
||
reqStream bool,
|
||
streamStarted *bool,
|
||
reqLog *zap.Logger,
|
||
) (func(), bool) {
|
||
ctx := c.Request.Context()
|
||
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, userID, userConcurrency)
|
||
if err != nil {
|
||
reqLog.Warn("openai.user_slot_acquire_failed", zap.Error(err))
|
||
h.handleConcurrencyError(c, err, "user", *streamStarted)
|
||
return nil, false
|
||
}
|
||
if userAcquired {
|
||
return wrapReleaseOnDone(ctx, userReleaseFunc), true
|
||
}
|
||
|
||
maxWait := service.CalculateMaxWait(userConcurrency)
|
||
canWait, waitErr := h.concurrencyHelper.IncrementWaitCount(ctx, userID, maxWait)
|
||
if waitErr != nil {
|
||
reqLog.Warn("openai.user_wait_counter_increment_failed", zap.Error(waitErr))
|
||
// 按现有降级语义:等待计数异常时放行后续抢槽流程
|
||
} else if !canWait {
|
||
reqLog.Info("openai.user_wait_queue_full", zap.Int("max_wait", maxWait))
|
||
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||
return nil, false
|
||
}
|
||
|
||
waitCounted := waitErr == nil && canWait
|
||
defer func() {
|
||
if waitCounted {
|
||
h.concurrencyHelper.DecrementWaitCount(ctx, userID)
|
||
}
|
||
}()
|
||
|
||
userReleaseFunc, err = h.concurrencyHelper.AcquireUserSlotWithWait(c, userID, userConcurrency, reqStream, streamStarted)
|
||
if err != nil {
|
||
reqLog.Warn("openai.user_slot_acquire_failed_after_wait", zap.Error(err))
|
||
h.handleConcurrencyError(c, err, "user", *streamStarted)
|
||
return nil, false
|
||
}
|
||
|
||
// 槽位获取成功后,立刻退出等待计数。
|
||
if waitCounted {
|
||
h.concurrencyHelper.DecrementWaitCount(ctx, userID)
|
||
waitCounted = false
|
||
}
|
||
return wrapReleaseOnDone(ctx, userReleaseFunc), true
|
||
}
|
||
|
||
func (h *OpenAIGatewayHandler) acquireResponsesAccountSlot(
|
||
c *gin.Context,
|
||
groupID *int64,
|
||
sessionHash string,
|
||
selection *service.AccountSelectionResult,
|
||
reqStream bool,
|
||
streamStarted *bool,
|
||
reqLog *zap.Logger,
|
||
) (func(), bool) {
|
||
if selection == nil || selection.Account == nil {
|
||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted)
|
||
return nil, false
|
||
}
|
||
|
||
ctx := c.Request.Context()
|
||
account := selection.Account
|
||
if selection.Acquired {
|
||
return wrapReleaseOnDone(ctx, selection.ReleaseFunc), true
|
||
}
|
||
if selection.WaitPlan == nil {
|
||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", *streamStarted)
|
||
return nil, false
|
||
}
|
||
|
||
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
|
||
ctx,
|
||
account.ID,
|
||
selection.WaitPlan.MaxConcurrency,
|
||
)
|
||
if err != nil {
|
||
reqLog.Warn("openai.account_slot_quick_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||
h.handleConcurrencyError(c, err, "account", *streamStarted)
|
||
return nil, false
|
||
}
|
||
if fastAcquired {
|
||
if err := h.gatewayService.BindStickySession(ctx, groupID, sessionHash, account.ID); err != nil {
|
||
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||
}
|
||
return wrapReleaseOnDone(ctx, fastReleaseFunc), true
|
||
}
|
||
|
||
canWait, waitErr := h.concurrencyHelper.IncrementAccountWaitCount(ctx, account.ID, selection.WaitPlan.MaxWaiting)
|
||
if waitErr != nil {
|
||
reqLog.Warn("openai.account_wait_counter_increment_failed", zap.Int64("account_id", account.ID), zap.Error(waitErr))
|
||
} else if !canWait {
|
||
reqLog.Info("openai.account_wait_queue_full",
|
||
zap.Int64("account_id", account.ID),
|
||
zap.Int("max_waiting", selection.WaitPlan.MaxWaiting),
|
||
)
|
||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", *streamStarted)
|
||
return nil, false
|
||
}
|
||
|
||
accountWaitCounted := waitErr == nil && canWait
|
||
releaseWait := func() {
|
||
if accountWaitCounted {
|
||
h.concurrencyHelper.DecrementAccountWaitCount(ctx, account.ID)
|
||
accountWaitCounted = false
|
||
}
|
||
}
|
||
defer releaseWait()
|
||
|
||
accountReleaseFunc, err := h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
|
||
c,
|
||
account.ID,
|
||
selection.WaitPlan.MaxConcurrency,
|
||
selection.WaitPlan.Timeout,
|
||
reqStream,
|
||
streamStarted,
|
||
)
|
||
if err != nil {
|
||
reqLog.Warn("openai.account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||
h.handleConcurrencyError(c, err, "account", *streamStarted)
|
||
return nil, false
|
||
}
|
||
|
||
// Slot acquired: no longer waiting in queue.
|
||
releaseWait()
|
||
if err := h.gatewayService.BindStickySession(ctx, groupID, sessionHash, account.ID); err != nil {
|
||
reqLog.Warn("openai.bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||
}
|
||
return wrapReleaseOnDone(ctx, accountReleaseFunc), true
|
||
}
|
||
|
||
// ResponsesWebSocket handles OpenAI Responses API WebSocket ingress endpoint
|
||
// GET /openai/v1/responses (Upgrade: websocket)
|
||
func (h *OpenAIGatewayHandler) ResponsesWebSocket(c *gin.Context) {
|
||
if !isOpenAIWSUpgradeRequest(c.Request) {
|
||
h.errorResponse(c, http.StatusUpgradeRequired, "invalid_request_error", "WebSocket upgrade required (Upgrade: websocket)")
|
||
return
|
||
}
|
||
setOpenAIClientTransportWS(c)
|
||
|
||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||
if !ok {
|
||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||
return
|
||
}
|
||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||
if !ok {
|
||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||
return
|
||
}
|
||
|
||
reqLog := requestLogger(
|
||
c,
|
||
"handler.openai_gateway.responses_ws",
|
||
zap.Int64("user_id", subject.UserID),
|
||
zap.Int64("api_key_id", apiKey.ID),
|
||
zap.Any("group_id", apiKey.GroupID),
|
||
zap.Bool("openai_ws_mode", true),
|
||
)
|
||
if !h.ensureResponsesDependencies(c, reqLog) {
|
||
return
|
||
}
|
||
reqLog.Info("openai.websocket_ingress_started")
|
||
clientIP := ip.GetClientIP(c)
|
||
userAgent := strings.TrimSpace(c.GetHeader("User-Agent"))
|
||
|
||
wsConn, err := coderws.Accept(c.Writer, c.Request, &coderws.AcceptOptions{
|
||
CompressionMode: coderws.CompressionContextTakeover,
|
||
})
|
||
if err != nil {
|
||
reqLog.Warn("openai.websocket_accept_failed",
|
||
zap.Error(err),
|
||
zap.String("client_ip", clientIP),
|
||
zap.String("request_user_agent", userAgent),
|
||
zap.String("upgrade_header", strings.TrimSpace(c.GetHeader("Upgrade"))),
|
||
zap.String("connection_header", strings.TrimSpace(c.GetHeader("Connection"))),
|
||
zap.String("sec_websocket_version", strings.TrimSpace(c.GetHeader("Sec-WebSocket-Version"))),
|
||
zap.Bool("has_sec_websocket_key", strings.TrimSpace(c.GetHeader("Sec-WebSocket-Key")) != ""),
|
||
)
|
||
return
|
||
}
|
||
defer func() {
|
||
_ = wsConn.CloseNow()
|
||
}()
|
||
wsConn.SetReadLimit(16 * 1024 * 1024)
|
||
|
||
ctx := c.Request.Context()
|
||
readCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||
msgType, firstMessage, err := wsConn.Read(readCtx)
|
||
cancel()
|
||
if err != nil {
|
||
closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
|
||
reqLog.Warn("openai.websocket_read_first_message_failed",
|
||
zap.Error(err),
|
||
zap.String("client_ip", clientIP),
|
||
zap.String("close_status", closeStatus),
|
||
zap.String("close_reason", closeReason),
|
||
zap.Duration("read_timeout", 30*time.Second),
|
||
)
|
||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "missing first response.create message")
|
||
return
|
||
}
|
||
if msgType != coderws.MessageText && msgType != coderws.MessageBinary {
|
||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "unsupported websocket message type")
|
||
return
|
||
}
|
||
if !gjson.ValidBytes(firstMessage) {
|
||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "invalid JSON payload")
|
||
return
|
||
}
|
||
|
||
reqModel := strings.TrimSpace(gjson.GetBytes(firstMessage, "model").String())
|
||
if reqModel == "" {
|
||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "model is required in first response.create payload")
|
||
return
|
||
}
|
||
previousResponseID := strings.TrimSpace(gjson.GetBytes(firstMessage, "previous_response_id").String())
|
||
previousResponseIDKind := service.ClassifyOpenAIPreviousResponseIDKind(previousResponseID)
|
||
if previousResponseID != "" && previousResponseIDKind == service.OpenAIPreviousResponseIDKindMessageID {
|
||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "previous_response_id must be a response.id (resp_*), not a message id")
|
||
return
|
||
}
|
||
reqLog = reqLog.With(
|
||
zap.Bool("ws_ingress", true),
|
||
zap.String("model", reqModel),
|
||
zap.Bool("has_previous_response_id", previousResponseID != ""),
|
||
zap.String("previous_response_id_kind", previousResponseIDKind),
|
||
)
|
||
setOpsRequestContext(c, reqModel, true, firstMessage)
|
||
|
||
var currentUserRelease func()
|
||
var currentAccountRelease func()
|
||
releaseTurnSlots := func() {
|
||
if currentAccountRelease != nil {
|
||
currentAccountRelease()
|
||
currentAccountRelease = nil
|
||
}
|
||
if currentUserRelease != nil {
|
||
currentUserRelease()
|
||
currentUserRelease = nil
|
||
}
|
||
}
|
||
// 必须尽早注册,确保任何 early return 都能释放已获取的并发槽位。
|
||
defer releaseTurnSlots()
|
||
|
||
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
|
||
if err != nil {
|
||
reqLog.Warn("openai.websocket_user_slot_acquire_failed", zap.Error(err))
|
||
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire user concurrency slot")
|
||
return
|
||
}
|
||
if !userAcquired {
|
||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "too many concurrent requests, please retry later")
|
||
return
|
||
}
|
||
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
|
||
|
||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||
if err := h.billingCacheService.CheckBillingEligibility(ctx, apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||
reqLog.Info("openai.websocket_billing_eligibility_check_failed", zap.Error(err))
|
||
closeOpenAIClientWS(wsConn, coderws.StatusPolicyViolation, "billing check failed")
|
||
return
|
||
}
|
||
|
||
sessionHash := h.gatewayService.GenerateSessionHashWithFallback(
|
||
c,
|
||
firstMessage,
|
||
openAIWSIngressFallbackSessionSeed(subject.UserID, apiKey.ID, apiKey.GroupID),
|
||
)
|
||
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
||
ctx,
|
||
apiKey.GroupID,
|
||
previousResponseID,
|
||
sessionHash,
|
||
reqModel,
|
||
nil,
|
||
service.OpenAIUpstreamTransportResponsesWebsocketV2,
|
||
)
|
||
if err != nil {
|
||
reqLog.Warn("openai.websocket_account_select_failed", zap.Error(err))
|
||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
|
||
return
|
||
}
|
||
if selection == nil || selection.Account == nil {
|
||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "no available account")
|
||
return
|
||
}
|
||
|
||
account := selection.Account
|
||
accountMaxConcurrency := account.Concurrency
|
||
if selection.WaitPlan != nil && selection.WaitPlan.MaxConcurrency > 0 {
|
||
accountMaxConcurrency = selection.WaitPlan.MaxConcurrency
|
||
}
|
||
accountReleaseFunc := selection.ReleaseFunc
|
||
if !selection.Acquired {
|
||
if selection.WaitPlan == nil {
|
||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
|
||
return
|
||
}
|
||
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
|
||
ctx,
|
||
account.ID,
|
||
selection.WaitPlan.MaxConcurrency,
|
||
)
|
||
if err != nil {
|
||
reqLog.Warn("openai.websocket_account_slot_acquire_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to acquire account concurrency slot")
|
||
return
|
||
}
|
||
if !fastAcquired {
|
||
closeOpenAIClientWS(wsConn, coderws.StatusTryAgainLater, "account is busy, please retry later")
|
||
return
|
||
}
|
||
accountReleaseFunc = fastReleaseFunc
|
||
}
|
||
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
|
||
if err := h.gatewayService.BindStickySession(ctx, apiKey.GroupID, sessionHash, account.ID); err != nil {
|
||
reqLog.Warn("openai.websocket_bind_sticky_session_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||
}
|
||
|
||
token, _, err := h.gatewayService.GetAccessToken(ctx, account)
|
||
if err != nil {
|
||
reqLog.Warn("openai.websocket_get_access_token_failed", zap.Int64("account_id", account.ID), zap.Error(err))
|
||
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "failed to get access token")
|
||
return
|
||
}
|
||
|
||
reqLog.Debug("openai.websocket_account_selected",
|
||
zap.Int64("account_id", account.ID),
|
||
zap.String("account_name", account.Name),
|
||
zap.String("schedule_layer", scheduleDecision.Layer),
|
||
zap.Int("candidate_count", scheduleDecision.CandidateCount),
|
||
)
|
||
|
||
hooks := &service.OpenAIWSIngressHooks{
|
||
BeforeTurn: func(turn int) error {
|
||
if turn == 1 {
|
||
return nil
|
||
}
|
||
// 防御式清理:避免异常路径下旧槽位覆盖导致泄漏。
|
||
releaseTurnSlots()
|
||
// 非首轮 turn 需要重新抢占并发槽位,避免长连接空闲占槽。
|
||
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(ctx, subject.UserID, subject.Concurrency)
|
||
if err != nil {
|
||
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire user concurrency slot", err)
|
||
}
|
||
if !userAcquired {
|
||
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "too many concurrent requests, please retry later", nil)
|
||
}
|
||
accountReleaseFunc, accountAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(ctx, account.ID, accountMaxConcurrency)
|
||
if err != nil {
|
||
if userReleaseFunc != nil {
|
||
userReleaseFunc()
|
||
}
|
||
return service.NewOpenAIWSClientCloseError(coderws.StatusInternalError, "failed to acquire account concurrency slot", err)
|
||
}
|
||
if !accountAcquired {
|
||
if userReleaseFunc != nil {
|
||
userReleaseFunc()
|
||
}
|
||
return service.NewOpenAIWSClientCloseError(coderws.StatusTryAgainLater, "account is busy, please retry later", nil)
|
||
}
|
||
currentUserRelease = wrapReleaseOnDone(ctx, userReleaseFunc)
|
||
currentAccountRelease = wrapReleaseOnDone(ctx, accountReleaseFunc)
|
||
return nil
|
||
},
|
||
AfterTurn: func(turn int, result *service.OpenAIForwardResult, turnErr error) {
|
||
releaseTurnSlots()
|
||
if turnErr != nil || result == nil {
|
||
return
|
||
}
|
||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||
h.submitUsageRecordTask(func(taskCtx context.Context) {
|
||
if err := h.gatewayService.RecordUsage(taskCtx, &service.OpenAIRecordUsageInput{
|
||
Result: result,
|
||
APIKey: apiKey,
|
||
User: apiKey.User,
|
||
Account: account,
|
||
Subscription: subscription,
|
||
UserAgent: userAgent,
|
||
IPAddress: clientIP,
|
||
APIKeyService: h.apiKeyService,
|
||
}); err != nil {
|
||
reqLog.Error("openai.websocket_record_usage_failed",
|
||
zap.Int64("account_id", account.ID),
|
||
zap.String("request_id", result.RequestID),
|
||
zap.Error(err),
|
||
)
|
||
}
|
||
})
|
||
},
|
||
}
|
||
|
||
if err := h.gatewayService.ProxyResponsesWebSocketFromClient(ctx, c, wsConn, account, token, firstMessage, hooks); err != nil {
|
||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||
closeStatus, closeReason := summarizeWSCloseErrorForLog(err)
|
||
reqLog.Warn("openai.websocket_proxy_failed",
|
||
zap.Int64("account_id", account.ID),
|
||
zap.Error(err),
|
||
zap.String("close_status", closeStatus),
|
||
zap.String("close_reason", closeReason),
|
||
)
|
||
var closeErr *service.OpenAIWSClientCloseError
|
||
if errors.As(err, &closeErr) {
|
||
closeOpenAIClientWS(wsConn, closeErr.StatusCode(), closeErr.Reason())
|
||
return
|
||
}
|
||
closeOpenAIClientWS(wsConn, coderws.StatusInternalError, "upstream websocket proxy failed")
|
||
return
|
||
}
|
||
reqLog.Info("openai.websocket_ingress_closed", zap.Int64("account_id", account.ID))
|
||
}
|
||
|
||
func (h *OpenAIGatewayHandler) recoverResponsesPanic(c *gin.Context, streamStarted *bool) {
|
||
recovered := recover()
|
||
if recovered == nil {
|
||
return
|
||
}
|
||
|
||
started := false
|
||
if streamStarted != nil {
|
||
started = *streamStarted
|
||
}
|
||
wroteFallback := h.ensureForwardErrorResponse(c, started)
|
||
requestLogger(c, "handler.openai_gateway.responses").Error(
|
||
"openai.responses_panic_recovered",
|
||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||
zap.Any("panic", recovered),
|
||
zap.ByteString("stack", debug.Stack()),
|
||
)
|
||
}
|
||
|
||
func (h *OpenAIGatewayHandler) ensureResponsesDependencies(c *gin.Context, reqLog *zap.Logger) bool {
|
||
missing := h.missingResponsesDependencies()
|
||
if len(missing) == 0 {
|
||
return true
|
||
}
|
||
|
||
if reqLog == nil {
|
||
reqLog = requestLogger(c, "handler.openai_gateway.responses")
|
||
}
|
||
reqLog.Error("openai.handler_dependencies_missing", zap.Strings("missing_dependencies", missing))
|
||
|
||
if c != nil && c.Writer != nil && !c.Writer.Written() {
|
||
c.JSON(http.StatusServiceUnavailable, gin.H{
|
||
"error": gin.H{
|
||
"type": "api_error",
|
||
"message": "Service temporarily unavailable",
|
||
},
|
||
})
|
||
}
|
||
return false
|
||
}
|
||
|
||
func (h *OpenAIGatewayHandler) missingResponsesDependencies() []string {
|
||
missing := make([]string, 0, 5)
|
||
if h == nil {
|
||
return append(missing, "handler")
|
||
}
|
||
if h.gatewayService == nil {
|
||
missing = append(missing, "gatewayService")
|
||
}
|
||
if h.billingCacheService == nil {
|
||
missing = append(missing, "billingCacheService")
|
||
}
|
||
if h.apiKeyService == nil {
|
||
missing = append(missing, "apiKeyService")
|
||
}
|
||
if h.concurrencyHelper == nil || h.concurrencyHelper.concurrencyService == nil {
|
||
missing = append(missing, "concurrencyHelper")
|
||
}
|
||
return missing
|
||
}
|
||
|
||
func getContextInt64(c *gin.Context, key string) (int64, bool) {
|
||
if c == nil || key == "" {
|
||
return 0, false
|
||
}
|
||
v, ok := c.Get(key)
|
||
if !ok {
|
||
return 0, false
|
||
}
|
||
switch t := v.(type) {
|
||
case int64:
|
||
return t, true
|
||
case int:
|
||
return int64(t), true
|
||
case int32:
|
||
return int64(t), true
|
||
case float64:
|
||
return int64(t), true
|
||
default:
|
||
return 0, false
|
||
}
|
||
}
|
||
|
||
func (h *OpenAIGatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
|
||
if task == nil {
|
||
return
|
||
}
|
||
if h.usageRecordWorkerPool != nil {
|
||
h.usageRecordWorkerPool.Submit(task)
|
||
return
|
||
}
|
||
// 回退路径:worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
|
||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||
defer cancel()
|
||
defer func() {
|
||
if recovered := recover(); recovered != nil {
|
||
logger.L().With(
|
||
zap.String("component", "handler.openai_gateway.responses"),
|
||
zap.Any("panic", recovered),
|
||
).Error("openai.usage_record_task_panic_recovered")
|
||
}
|
||
}()
|
||
task(ctx)
|
||
}
|
||
|
||
// handleConcurrencyError handles concurrency-related errors with proper 429 response
|
||
func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
|
||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
|
||
fmt.Sprintf("Concurrency limit exceeded for %s, please retry later", slotType), streamStarted)
|
||
}
|
||
|
||
func (h *OpenAIGatewayHandler) handleFailoverExhausted(c *gin.Context, failoverErr *service.UpstreamFailoverError, streamStarted bool) {
|
||
statusCode := failoverErr.StatusCode
|
||
responseBody := failoverErr.ResponseBody
|
||
|
||
// 先检查透传规则
|
||
if h.errorPassthroughService != nil && len(responseBody) > 0 {
|
||
if rule := h.errorPassthroughService.MatchRule("openai", statusCode, responseBody); rule != nil {
|
||
// 确定响应状态码
|
||
respCode := statusCode
|
||
if !rule.PassthroughCode && rule.ResponseCode != nil {
|
||
respCode = *rule.ResponseCode
|
||
}
|
||
|
||
// 确定响应消息
|
||
msg := service.ExtractUpstreamErrorMessage(responseBody)
|
||
if !rule.PassthroughBody && rule.CustomMessage != nil {
|
||
msg = *rule.CustomMessage
|
||
}
|
||
|
||
if rule.SkipMonitoring {
|
||
c.Set(service.OpsSkipPassthroughKey, true)
|
||
}
|
||
|
||
h.handleStreamingAwareError(c, respCode, "upstream_error", msg, streamStarted)
|
||
return
|
||
}
|
||
}
|
||
|
||
// 使用默认的错误映射
|
||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||
}
|
||
|
||
// handleFailoverExhaustedSimple 简化版本,用于没有响应体的情况
|
||
func (h *OpenAIGatewayHandler) handleFailoverExhaustedSimple(c *gin.Context, statusCode int, streamStarted bool) {
|
||
status, errType, errMsg := h.mapUpstreamError(statusCode)
|
||
h.handleStreamingAwareError(c, status, errType, errMsg, streamStarted)
|
||
}
|
||
|
||
func (h *OpenAIGatewayHandler) mapUpstreamError(statusCode int) (int, string, string) {
|
||
switch statusCode {
|
||
case 401:
|
||
return http.StatusBadGateway, "upstream_error", "Upstream authentication failed, please contact administrator"
|
||
case 403:
|
||
return http.StatusBadGateway, "upstream_error", "Upstream access forbidden, please contact administrator"
|
||
case 429:
|
||
return http.StatusTooManyRequests, "rate_limit_error", "Upstream rate limit exceeded, please retry later"
|
||
case 529:
|
||
return http.StatusServiceUnavailable, "upstream_error", "Upstream service overloaded, please retry later"
|
||
case 500, 502, 503, 504:
|
||
return http.StatusBadGateway, "upstream_error", "Upstream service temporarily unavailable"
|
||
default:
|
||
return http.StatusBadGateway, "upstream_error", "Upstream request failed"
|
||
}
|
||
}
|
||
|
||
// handleStreamingAwareError handles errors that may occur after streaming has started
|
||
func (h *OpenAIGatewayHandler) handleStreamingAwareError(c *gin.Context, status int, errType, message string, streamStarted bool) {
|
||
if streamStarted {
|
||
// Stream already started, send error as SSE event then close
|
||
flusher, ok := c.Writer.(http.Flusher)
|
||
if ok {
|
||
// SSE 错误事件固定 schema,使用 Quote 直拼可避免额外 Marshal 分配。
|
||
errorEvent := "event: error\ndata: " + `{"error":{"type":` + strconv.Quote(errType) + `,"message":` + strconv.Quote(message) + `}}` + "\n\n"
|
||
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
||
_ = c.Error(err)
|
||
}
|
||
flusher.Flush()
|
||
}
|
||
return
|
||
}
|
||
|
||
// Normal case: return JSON response with proper status code
|
||
h.errorResponse(c, status, errType, message)
|
||
}
|
||
|
||
// ensureForwardErrorResponse 在 Forward 返回错误但尚未写响应时补写统一错误响应。
|
||
func (h *OpenAIGatewayHandler) ensureForwardErrorResponse(c *gin.Context, streamStarted bool) bool {
|
||
if c == nil || c.Writer == nil || c.Writer.Written() {
|
||
return false
|
||
}
|
||
h.handleStreamingAwareError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed", streamStarted)
|
||
return true
|
||
}
|
||
|
||
func shouldLogOpenAIForwardFailureAsWarn(c *gin.Context, wroteFallback bool) bool {
|
||
if wroteFallback {
|
||
return false
|
||
}
|
||
if c == nil || c.Writer == nil {
|
||
return false
|
||
}
|
||
return c.Writer.Written()
|
||
}
|
||
|
||
// errorResponse returns OpenAI API format error response
|
||
func (h *OpenAIGatewayHandler) errorResponse(c *gin.Context, status int, errType, message string) {
|
||
c.JSON(status, gin.H{
|
||
"error": gin.H{
|
||
"type": errType,
|
||
"message": message,
|
||
},
|
||
})
|
||
}
|
||
|
||
func setOpenAIClientTransportHTTP(c *gin.Context) {
|
||
service.SetOpenAIClientTransport(c, service.OpenAIClientTransportHTTP)
|
||
}
|
||
|
||
func setOpenAIClientTransportWS(c *gin.Context) {
|
||
service.SetOpenAIClientTransport(c, service.OpenAIClientTransportWS)
|
||
}
|
||
|
||
func openAIWSIngressFallbackSessionSeed(userID, apiKeyID int64, groupID *int64) string {
|
||
gid := int64(0)
|
||
if groupID != nil {
|
||
gid = *groupID
|
||
}
|
||
return fmt.Sprintf("openai_ws_ingress:%d:%d:%d", gid, userID, apiKeyID)
|
||
}
|
||
|
||
func isOpenAIWSUpgradeRequest(r *http.Request) bool {
|
||
if r == nil {
|
||
return false
|
||
}
|
||
if !strings.EqualFold(strings.TrimSpace(r.Header.Get("Upgrade")), "websocket") {
|
||
return false
|
||
}
|
||
return strings.Contains(strings.ToLower(strings.TrimSpace(r.Header.Get("Connection"))), "upgrade")
|
||
}
|
||
|
||
func closeOpenAIClientWS(conn *coderws.Conn, status coderws.StatusCode, reason string) {
|
||
if conn == nil {
|
||
return
|
||
}
|
||
reason = strings.TrimSpace(reason)
|
||
if len(reason) > 120 {
|
||
reason = reason[:120]
|
||
}
|
||
_ = conn.Close(status, reason)
|
||
_ = conn.CloseNow()
|
||
}
|
||
|
||
func summarizeWSCloseErrorForLog(err error) (string, string) {
|
||
if err == nil {
|
||
return "-", "-"
|
||
}
|
||
statusCode := coderws.CloseStatus(err)
|
||
if statusCode == -1 {
|
||
return "-", "-"
|
||
}
|
||
closeStatus := fmt.Sprintf("%d(%s)", int(statusCode), statusCode.String())
|
||
closeReason := "-"
|
||
var closeErr coderws.CloseError
|
||
if errors.As(err, &closeErr) {
|
||
reason := strings.TrimSpace(closeErr.Reason)
|
||
if reason != "" {
|
||
closeReason = reason
|
||
}
|
||
}
|
||
return closeStatus, closeReason
|
||
}
|