From 61a2bf469a29d733a7f2f52c290fb3ec55eeb5a9 Mon Sep 17 00:00:00 2001 From: yangjianbo Date: Thu, 12 Feb 2026 09:41:37 +0800 Subject: [PATCH] =?UTF-8?q?feat(openai):=20=E6=9E=81=E8=87=B4=E4=BC=98?= =?UTF-8?q?=E5=8C=96=20OAuth=20=E9=93=BE=E8=B7=AF=E5=B9=B6=E8=A1=A5?= =?UTF-8?q?=E9=BD=90=E6=80=A7=E8=83=BD=E5=AE=88=E6=8A=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 优化 /v1/responses 热路径,减少重复解析与不必要拷贝\n- 优化并发与 token 竞争路径并补齐运行指标\n- 补充 OpenAI/Ops 相关单元测试与回归用例\n- 新增灰度阈值守护与压测脚本,支撑发布验收 --- backend/internal/handler/gateway_helper.go | 57 +++-- .../handler/gateway_helper_fastpath_test.go | 114 ++++++++++ .../handler/openai_gateway_handler.go | 171 +++++++++----- backend/internal/handler/ops_error_logger.go | 40 ++++ backend/internal/repository/ops_repo.go | 10 +- .../service/openai_gateway_service.go | 173 +++++++++----- .../openai_gateway_service_hotpath_test.go | 125 ++++++++++ .../service/openai_gateway_service_test.go | 106 +++++++++ .../internal/service/openai_token_provider.go | 163 +++++++++++++- .../service/openai_token_provider_test.go | 116 ++++++++++ backend/internal/service/ops_port.go | 4 + .../internal/service/ops_upstream_context.go | 29 ++- .../service/ops_upstream_context_test.go | 47 ++++ tools/perf/openai_oauth_gray_drill.py | 164 ++++++++++++++ tools/perf/openai_oauth_gray_guard.py | 213 ++++++++++++++++++ tools/perf/openai_oauth_responses_k6.js | 122 ++++++++++ 16 files changed, 1519 insertions(+), 135 deletions(-) create mode 100644 backend/internal/handler/gateway_helper_fastpath_test.go create mode 100644 backend/internal/service/openai_gateway_service_hotpath_test.go create mode 100644 backend/internal/service/ops_upstream_context_test.go create mode 100755 tools/perf/openai_oauth_gray_drill.py create mode 100755 tools/perf/openai_oauth_gray_guard.py create mode 100644 tools/perf/openai_oauth_responses_k6.js diff --git a/backend/internal/handler/gateway_helper.go b/backend/internal/handler/gateway_helper.go index 94698691..c4edf53b 100644 --- a/backend/internal/handler/gateway_helper.go +++ b/backend/internal/handler/gateway_helper.go @@ -104,31 +104,24 @@ func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFo // wrapReleaseOnDone ensures release runs at most once and still triggers on context cancellation. // 用于避免客户端断开或上游超时导致的并发槽位泄漏。 -// 修复:添加 quit channel 确保 goroutine 及时退出,避免泄露 +// 优化:基于 context.AfterFunc 注册回调,避免每请求额外守护 goroutine。 func wrapReleaseOnDone(ctx context.Context, releaseFunc func()) func() { if releaseFunc == nil { return nil } var once sync.Once - quit := make(chan struct{}) + var stop func() bool release := func() { once.Do(func() { + if stop != nil { + _ = stop() + } releaseFunc() - close(quit) // 通知监听 goroutine 退出 }) } - go func() { - select { - case <-ctx.Done(): - // Context 取消时释放资源 - release() - case <-quit: - // 正常释放已完成,goroutine 退出 - return - } - }() + stop = context.AfterFunc(ctx, release) return release } @@ -153,6 +146,32 @@ func (h *ConcurrencyHelper) DecrementAccountWaitCount(ctx context.Context, accou h.concurrencyService.DecrementAccountWaitCount(ctx, accountID) } +// TryAcquireUserSlot 尝试立即获取用户并发槽位。 +// 返回值: (releaseFunc, acquired, error) +func (h *ConcurrencyHelper) TryAcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (func(), bool, error) { + result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency) + if err != nil { + return nil, false, err + } + if !result.Acquired { + return nil, false, nil + } + return result.ReleaseFunc, true, nil +} + +// TryAcquireAccountSlot 尝试立即获取账号并发槽位。 +// 返回值: (releaseFunc, acquired, error) +func (h *ConcurrencyHelper) TryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (func(), bool, error) { + result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) + if err != nil { + return nil, false, err + } + if !result.Acquired { + return nil, false, nil + } + return result.ReleaseFunc, true, nil +} + // AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary. // For streaming requests, sends ping events during the wait. // streamStarted is updated if streaming response has begun. @@ -160,13 +179,13 @@ func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, userID int64 ctx := c.Request.Context() // Try to acquire immediately - result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency) + releaseFunc, acquired, err := h.TryAcquireUserSlot(ctx, userID, maxConcurrency) if err != nil { return nil, err } - if result.Acquired { - return result.ReleaseFunc, nil + if acquired { + return releaseFunc, nil } // Need to wait - handle streaming ping if needed @@ -180,13 +199,13 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID ctx := c.Request.Context() // Try to acquire immediately - result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) + releaseFunc, acquired, err := h.TryAcquireAccountSlot(ctx, accountID, maxConcurrency) if err != nil { return nil, err } - if result.Acquired { - return result.ReleaseFunc, nil + if acquired { + return releaseFunc, nil } // Need to wait - handle streaming ping if needed diff --git a/backend/internal/handler/gateway_helper_fastpath_test.go b/backend/internal/handler/gateway_helper_fastpath_test.go new file mode 100644 index 00000000..3e6c376b --- /dev/null +++ b/backend/internal/handler/gateway_helper_fastpath_test.go @@ -0,0 +1,114 @@ +package handler + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/Wei-Shaw/sub2api/internal/service" + "github.com/stretchr/testify/require" +) + +type concurrencyCacheMock struct { + acquireUserSlotFn func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) + acquireAccountSlotFn func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) + releaseUserCalled int32 + releaseAccountCalled int32 +} + +func (m *concurrencyCacheMock) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + if m.acquireAccountSlotFn != nil { + return m.acquireAccountSlotFn(ctx, accountID, maxConcurrency, requestID) + } + return false, nil +} + +func (m *concurrencyCacheMock) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error { + atomic.AddInt32(&m.releaseAccountCalled, 1) + return nil +} + +func (m *concurrencyCacheMock) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) { + return 0, nil +} + +func (m *concurrencyCacheMock) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) { + return true, nil +} + +func (m *concurrencyCacheMock) DecrementAccountWaitCount(ctx context.Context, accountID int64) error { + return nil +} + +func (m *concurrencyCacheMock) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) { + return 0, nil +} + +func (m *concurrencyCacheMock) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { + if m.acquireUserSlotFn != nil { + return m.acquireUserSlotFn(ctx, userID, maxConcurrency, requestID) + } + return false, nil +} + +func (m *concurrencyCacheMock) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error { + atomic.AddInt32(&m.releaseUserCalled, 1) + return nil +} + +func (m *concurrencyCacheMock) GetUserConcurrency(ctx context.Context, userID int64) (int, error) { + return 0, nil +} + +func (m *concurrencyCacheMock) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) { + return true, nil +} + +func (m *concurrencyCacheMock) DecrementWaitCount(ctx context.Context, userID int64) error { + return nil +} + +func (m *concurrencyCacheMock) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) { + return map[int64]*service.AccountLoadInfo{}, nil +} + +func (m *concurrencyCacheMock) GetUsersLoadBatch(ctx context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) { + return map[int64]*service.UserLoadInfo{}, nil +} + +func (m *concurrencyCacheMock) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error { + return nil +} + +func TestConcurrencyHelper_TryAcquireUserSlot(t *testing.T) { + cache := &concurrencyCacheMock{ + acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) { + return true, nil + }, + } + helper := NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second) + + release, acquired, err := helper.TryAcquireUserSlot(context.Background(), 101, 2) + require.NoError(t, err) + require.True(t, acquired) + require.NotNil(t, release) + + release() + require.Equal(t, int32(1), atomic.LoadInt32(&cache.releaseUserCalled)) +} + +func TestConcurrencyHelper_TryAcquireAccountSlot_NotAcquired(t *testing.T) { + cache := &concurrencyCacheMock{ + acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) { + return false, nil + }, + } + helper := NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second) + + release, acquired, err := helper.TryAcquireAccountSlot(context.Background(), 201, 1) + require.NoError(t, err) + require.False(t, acquired) + require.Nil(t, release) + require.Equal(t, int32(0), atomic.LoadInt32(&cache.releaseAccountCalled)) +} diff --git a/backend/internal/handler/openai_gateway_handler.go b/backend/internal/handler/openai_gateway_handler.go index a4c25284..75c758da 100644 --- a/backend/internal/handler/openai_gateway_handler.go +++ b/backend/internal/handler/openai_gateway_handler.go @@ -64,6 +64,8 @@ func NewOpenAIGatewayHandler( // Responses handles OpenAI Responses API endpoint // POST /openai/v1/responses func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { + requestStart := time.Now() + // Get apiKey and user from context (set by ApiKeyAuth middleware) apiKey, ok := middleware2.GetAPIKeyFromContext(c) if !ok { @@ -141,6 +143,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { if gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() { var reqBody map[string]any if err := json.Unmarshal(body, &reqBody); err == nil { + c.Set(service.OpenAIParsedRequestBodyKey, reqBody) if service.HasFunctionCallOutput(reqBody) { previousResponseID, _ := reqBody["previous_response_id"].(string) if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) { @@ -171,34 +174,47 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { // Get subscription info (may be nil) subscription, _ := middleware2.GetSubscriptionFromContext(c) - // 0. Check if wait queue is full - maxWait := service.CalculateMaxWait(subject.Concurrency) - canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) - waitCounted := false - if err != nil { - log.Printf("Increment wait count failed: %v", err) - // On error, allow request to proceed - } else if !canWait { - h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") - return - } - if err == nil && canWait { - waitCounted = true - } - defer func() { - if waitCounted { - h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) - } - }() + service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds()) + routingStart := time.Now() - // 1. First acquire user concurrency slot - userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted) + // 0. 先尝试直接抢占用户槽位(快速路径) + userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(c.Request.Context(), subject.UserID, subject.Concurrency) if err != nil { log.Printf("User concurrency acquire failed: %v", err) h.handleConcurrencyError(c, err, "user", streamStarted) return } - // User slot acquired: no longer waiting. + + waitCounted := false + if !userAcquired { + // 仅在抢槽失败时才进入等待队列,减少常态请求 Redis 写入。 + maxWait := service.CalculateMaxWait(subject.Concurrency) + canWait, waitErr := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait) + if waitErr != nil { + log.Printf("Increment wait count failed: %v", waitErr) + // 按现有降级语义:等待计数异常时放行后续抢槽流程 + } else if !canWait { + h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") + return + } + if waitErr == nil && canWait { + waitCounted = true + } + defer func() { + if waitCounted { + h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) + } + }() + + userReleaseFunc, err = h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted) + if err != nil { + log.Printf("User concurrency acquire failed: %v", err) + h.handleConcurrencyError(c, err, "user", streamStarted) + return + } + } + + // 用户槽位已获取:退出等待队列计数。 if waitCounted { h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) waitCounted = false @@ -253,53 +269,84 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) return } - accountWaitCounted := false - canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) - if err != nil { - log.Printf("Increment account wait count failed: %v", err) - } else if !canWait { - log.Printf("Account wait queue full: account=%d", account.ID) - h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) - return - } - if err == nil && canWait { - accountWaitCounted = true - } - releaseWait := func() { - if accountWaitCounted { - h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) - accountWaitCounted = false - } - } - accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( - c, + // 先快速尝试一次账号槽位,命中则跳过等待计数写入。 + fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot( + c.Request.Context(), account.ID, selection.WaitPlan.MaxConcurrency, - selection.WaitPlan.Timeout, - reqStream, - &streamStarted, ) if err != nil { - log.Printf("Account concurrency acquire failed: %v", err) - releaseWait() + log.Printf("Account concurrency quick acquire failed: %v", err) h.handleConcurrencyError(c, err, "account", streamStarted) return } - // Slot acquired: no longer waiting in queue. - releaseWait() - if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil { - log.Printf("Bind sticky session failed: %v", err) + if fastAcquired { + accountReleaseFunc = fastReleaseFunc + if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil { + log.Printf("Bind sticky session failed: %v", err) + } + } else { + accountWaitCounted := false + canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) + if err != nil { + log.Printf("Increment account wait count failed: %v", err) + } else if !canWait { + log.Printf("Account wait queue full: account=%d", account.ID) + h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted) + return + } + if err == nil && canWait { + accountWaitCounted = true + } + releaseWait := func() { + if accountWaitCounted { + h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID) + accountWaitCounted = false + } + } + + accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout( + c, + account.ID, + selection.WaitPlan.MaxConcurrency, + selection.WaitPlan.Timeout, + reqStream, + &streamStarted, + ) + if err != nil { + log.Printf("Account concurrency acquire failed: %v", err) + releaseWait() + h.handleConcurrencyError(c, err, "account", streamStarted) + return + } + // Slot acquired: no longer waiting in queue. + releaseWait() + if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil { + log.Printf("Bind sticky session failed: %v", err) + } } } // 账号槽位/等待计数需要在超时或断开时安全回收 accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) // Forward request + service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds()) + forwardStart := time.Now() result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body) + forwardDurationMs := time.Since(forwardStart).Milliseconds() if accountReleaseFunc != nil { accountReleaseFunc() } + upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey) + responseLatencyMs := forwardDurationMs + if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs { + responseLatencyMs = forwardDurationMs - upstreamLatencyMs + } + service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs) + if err == nil && result != nil && result.FirstTokenMs != nil { + service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs)) + } if err != nil { var failoverErr *service.UpstreamFailoverError if errors.As(err, &failoverErr) { @@ -343,6 +390,28 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { } } +func getContextInt64(c *gin.Context, key string) (int64, bool) { + if c == nil || key == "" { + return 0, false + } + v, ok := c.Get(key) + if !ok { + return 0, false + } + switch t := v.(type) { + case int64: + return t, true + case int: + return int64(t), true + case int32: + return int64(t), true + case float64: + return int64(t), true + default: + return 0, false + } +} + // handleConcurrencyError handles concurrency-related errors with proper 429 response func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) { h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", diff --git a/backend/internal/handler/ops_error_logger.go b/backend/internal/handler/ops_error_logger.go index 36ffde63..697078a1 100644 --- a/backend/internal/handler/ops_error_logger.go +++ b/backend/internal/handler/ops_error_logger.go @@ -507,6 +507,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { RetryCount: 0, CreatedAt: time.Now(), } + applyOpsLatencyFieldsFromContext(c, entry) if apiKey != nil { entry.APIKeyID = &apiKey.ID @@ -618,6 +619,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc { RetryCount: 0, CreatedAt: time.Now(), } + applyOpsLatencyFieldsFromContext(c, entry) // Capture upstream error context set by gateway services (if present). // This does NOT affect the client response; it enriches Ops troubleshooting data. @@ -746,6 +748,44 @@ func extractOpsRetryRequestHeaders(c *gin.Context) *string { return &s } +func applyOpsLatencyFieldsFromContext(c *gin.Context, entry *service.OpsInsertErrorLogInput) { + if c == nil || entry == nil { + return + } + entry.AuthLatencyMs = getContextLatencyMs(c, service.OpsAuthLatencyMsKey) + entry.RoutingLatencyMs = getContextLatencyMs(c, service.OpsRoutingLatencyMsKey) + entry.UpstreamLatencyMs = getContextLatencyMs(c, service.OpsUpstreamLatencyMsKey) + entry.ResponseLatencyMs = getContextLatencyMs(c, service.OpsResponseLatencyMsKey) + entry.TimeToFirstTokenMs = getContextLatencyMs(c, service.OpsTimeToFirstTokenMsKey) +} + +func getContextLatencyMs(c *gin.Context, key string) *int64 { + if c == nil || strings.TrimSpace(key) == "" { + return nil + } + v, ok := c.Get(key) + if !ok { + return nil + } + var ms int64 + switch t := v.(type) { + case int: + ms = int64(t) + case int32: + ms = int64(t) + case int64: + ms = t + case float64: + ms = int64(t) + default: + return nil + } + if ms < 0 { + return nil + } + return &ms +} + type parsedOpsError struct { ErrorType string Message string diff --git a/backend/internal/repository/ops_repo.go b/backend/internal/repository/ops_repo.go index b04154b7..8f2c30c0 100644 --- a/backend/internal/repository/ops_repo.go +++ b/backend/internal/repository/ops_repo.go @@ -55,6 +55,10 @@ INSERT INTO ops_error_logs ( upstream_error_message, upstream_error_detail, upstream_errors, + auth_latency_ms, + routing_latency_ms, + upstream_latency_ms, + response_latency_ms, time_to_first_token_ms, request_body, request_body_truncated, @@ -64,7 +68,7 @@ INSERT INTO ops_error_logs ( retry_count, created_at ) VALUES ( - $1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34 + $1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38 ) RETURNING id` var id int64 @@ -97,6 +101,10 @@ INSERT INTO ops_error_logs ( opsNullString(input.UpstreamErrorMessage), opsNullString(input.UpstreamErrorDetail), opsNullString(input.UpstreamErrorsJSON), + opsNullInt64(input.AuthLatencyMs), + opsNullInt64(input.RoutingLatencyMs), + opsNullInt64(input.UpstreamLatencyMs), + opsNullInt64(input.ResponseLatencyMs), opsNullInt64(input.TimeToFirstTokenMs), opsNullString(input.RequestBodyJSON), input.RequestBodyTruncated, diff --git a/backend/internal/service/openai_gateway_service.go b/backend/internal/service/openai_gateway_service.go index 76746d2b..3ff20978 100644 --- a/backend/internal/service/openai_gateway_service.go +++ b/backend/internal/service/openai_gateway_service.go @@ -12,7 +12,6 @@ import ( "io" "log" "net/http" - "regexp" "sort" "strconv" "strings" @@ -34,11 +33,10 @@ const ( // OpenAI Platform API for API Key accounts (fallback) openaiPlatformAPIURL = "https://api.openai.com/v1/responses" openaiStickySessionTTL = time.Hour // 粘性会话TTL -) -// openaiSSEDataRe matches SSE data lines with optional whitespace after colon. -// Some upstream APIs return non-standard "data:" without space (should be "data: "). -var openaiSSEDataRe = regexp.MustCompile(`^data:\s*`) + // OpenAIParsedRequestBodyKey 缓存 handler 侧已解析的请求体,避免重复解析。 + OpenAIParsedRequestBodyKey = "openai_parsed_request_body" +) // OpenAI allowed headers whitelist (for non-OAuth accounts) var openaiAllowedHeaders = map[string]bool{ @@ -745,32 +743,37 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco startTime := time.Now() originalBody := body + reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body) + originalModel := reqModel - // Parse request body once (avoid multiple parse/serialize cycles) - var reqBody map[string]any - if err := json.Unmarshal(body, &reqBody); err != nil { - return nil, fmt.Errorf("parse request: %w", err) + isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) + passthroughEnabled := account.Type == AccountTypeOAuth && account.IsOpenAIOAuthPassthroughEnabled() && isCodexCLI + if passthroughEnabled { + // 透传分支只需要轻量提取字段,避免热路径全量 Unmarshal。 + reasoningEffort := extractOpenAIReasoningEffortFromBody(body, reqModel) + return s.forwardOAuthPassthrough(ctx, c, account, originalBody, reqModel, reasoningEffort, reqStream, startTime) } - // Extract model and stream from parsed body - reqModel, _ := reqBody["model"].(string) - reqStream, _ := reqBody["stream"].(bool) - promptCacheKey := "" - if v, ok := reqBody["prompt_cache_key"].(string); ok { - promptCacheKey = strings.TrimSpace(v) + reqBody, err := getOpenAIRequestBodyMap(c, body) + if err != nil { + return nil, err + } + + if v, ok := reqBody["model"].(string); ok { + reqModel = v + originalModel = reqModel + } + if v, ok := reqBody["stream"].(bool); ok { + reqStream = v + } + if promptCacheKey == "" { + if v, ok := reqBody["prompt_cache_key"].(string); ok { + promptCacheKey = strings.TrimSpace(v) + } } // Track if body needs re-serialization bodyModified := false - originalModel := reqModel - - isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI) - - passthroughEnabled := account.Type == AccountTypeOAuth && account.IsOpenAIOAuthPassthroughEnabled() && isCodexCLI - if passthroughEnabled { - reasoningEffort := extractOpenAIReasoningEffort(reqBody, reqModel) - return s.forwardOAuthPassthrough(ctx, c, account, originalBody, reqModel, reasoningEffort, reqStream, startTime) - } // 对所有请求执行模型映射(包含 Codex CLI)。 mappedModel := account.GetMappedModel(reqModel) @@ -888,12 +891,12 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco } // Capture upstream request body for ops retry of this attempt. - if c != nil { - c.Set(OpsUpstreamRequestBodyKey, string(body)) - } + setOpsUpstreamRequestBody(c, body) // Send request + upstreamStart := time.Now() resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds()) if err != nil { // Ensure the client receives an error response (handlers assume Forward writes on non-failover errors). safeErr := sanitizeUpstreamErrorMessage(err.Error()) @@ -1019,12 +1022,14 @@ func (s *OpenAIGatewayService) forwardOAuthPassthrough( proxyURL = account.Proxy.URL() } + setOpsUpstreamRequestBody(c, body) if c != nil { - c.Set(OpsUpstreamRequestBodyKey, string(body)) c.Set("openai_passthrough", true) } + upstreamStart := time.Now() resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) + SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds()) if err != nil { safeErr := sanitizeUpstreamErrorMessage(err.Error()) setOpsUpstreamError(c, 0, safeErr, "") @@ -1240,8 +1245,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough( for scanner.Scan() { line := scanner.Text() - if openaiSSEDataRe.MatchString(line) { - data := openaiSSEDataRe.ReplaceAllString(line, "") + if data, ok := extractOpenAISSEDataLine(line); ok { if firstTokenMs == nil && strings.TrimSpace(data) != "" { ms := int(time.Since(startTime).Milliseconds()) firstTokenMs = &ms @@ -1750,8 +1754,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp lastDataAt = time.Now() // Extract data from SSE line (supports both "data: " and "data:" formats) - if openaiSSEDataRe.MatchString(line) { - data := openaiSSEDataRe.ReplaceAllString(line, "") + if data, ok := extractOpenAISSEDataLine(line); ok { // Replace model in response if needed if needModelReplace { @@ -1827,11 +1830,27 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp } +// extractOpenAISSEDataLine 低开销提取 SSE `data:` 行内容。 +// 兼容 `data: xxx` 与 `data:xxx` 两种格式。 +func extractOpenAISSEDataLine(line string) (string, bool) { + if !strings.HasPrefix(line, "data:") { + return "", false + } + start := len("data:") + for start < len(line) { + if line[start] != ' ' && line[start] != ' ' { + break + } + start++ + } + return line[start:], true +} + func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string { - if !openaiSSEDataRe.MatchString(line) { + data, ok := extractOpenAISSEDataLine(line) + if !ok { return line } - data := openaiSSEDataRe.ReplaceAllString(line, "") if data == "" || data == "[DONE]" { return line } @@ -1872,25 +1891,20 @@ func (s *OpenAIGatewayService) correctToolCallsInResponseBody(body []byte) []byt } func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) { - // Parse response.completed event for usage (OpenAI Responses format) - var event struct { - Type string `json:"type"` - Response struct { - Usage struct { - InputTokens int `json:"input_tokens"` - OutputTokens int `json:"output_tokens"` - InputTokenDetails struct { - CachedTokens int `json:"cached_tokens"` - } `json:"input_tokens_details"` - } `json:"usage"` - } `json:"response"` + if usage == nil || data == "" || data == "[DONE]" { + return + } + // 选择性解析:仅在数据中包含 completed 事件标识时才进入字段提取。 + if !strings.Contains(data, `"response.completed"`) { + return + } + if gjson.Get(data, "type").String() != "response.completed" { + return } - if json.Unmarshal([]byte(data), &event) == nil && event.Type == "response.completed" { - usage.InputTokens = event.Response.Usage.InputTokens - usage.OutputTokens = event.Response.Usage.OutputTokens - usage.CacheReadInputTokens = event.Response.Usage.InputTokenDetails.CachedTokens - } + usage.InputTokens = int(gjson.Get(data, "response.usage.input_tokens").Int()) + usage.OutputTokens = int(gjson.Get(data, "response.usage.output_tokens").Int()) + usage.CacheReadInputTokens = int(gjson.Get(data, "response.usage.input_tokens_details.cached_tokens").Int()) } func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) { @@ -2001,10 +2015,10 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin. func extractCodexFinalResponse(body string) ([]byte, bool) { lines := strings.Split(body, "\n") for _, line := range lines { - if !openaiSSEDataRe.MatchString(line) { + data, ok := extractOpenAISSEDataLine(line) + if !ok { continue } - data := openaiSSEDataRe.ReplaceAllString(line, "") if data == "" || data == "[DONE]" { continue } @@ -2028,10 +2042,10 @@ func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage { usage := &OpenAIUsage{} lines := strings.Split(body, "\n") for _, line := range lines { - if !openaiSSEDataRe.MatchString(line) { + data, ok := extractOpenAISSEDataLine(line) + if !ok { continue } - data := openaiSSEDataRe.ReplaceAllString(line, "") if data == "" || data == "[DONE]" { continue } @@ -2043,7 +2057,7 @@ func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage { func (s *OpenAIGatewayService) replaceModelInSSEBody(body, fromModel, toModel string) string { lines := strings.Split(body, "\n") for i, line := range lines { - if !openaiSSEDataRe.MatchString(line) { + if _, ok := extractOpenAISSEDataLine(line); !ok { continue } lines[i] = s.replaceModelInSSELine(line, fromModel, toModel) @@ -2396,6 +2410,53 @@ func deriveOpenAIReasoningEffortFromModel(model string) string { return normalizeOpenAIReasoningEffort(parts[len(parts)-1]) } +func extractOpenAIRequestMetaFromBody(body []byte) (model string, stream bool, promptCacheKey string) { + if len(body) == 0 { + return "", false, "" + } + + model = strings.TrimSpace(gjson.GetBytes(body, "model").String()) + stream = gjson.GetBytes(body, "stream").Bool() + promptCacheKey = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String()) + return model, stream, promptCacheKey +} + +func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *string { + reasoningEffort := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String()) + if reasoningEffort == "" { + reasoningEffort = strings.TrimSpace(gjson.GetBytes(body, "reasoning_effort").String()) + } + if reasoningEffort != "" { + normalized := normalizeOpenAIReasoningEffort(reasoningEffort) + if normalized == "" { + return nil + } + return &normalized + } + + value := deriveOpenAIReasoningEffortFromModel(requestedModel) + if value == "" { + return nil + } + return &value +} + +func getOpenAIRequestBodyMap(c *gin.Context, body []byte) (map[string]any, error) { + if c != nil { + if cached, ok := c.Get(OpenAIParsedRequestBodyKey); ok { + if reqBody, ok := cached.(map[string]any); ok && reqBody != nil { + return reqBody, nil + } + } + } + + var reqBody map[string]any + if err := json.Unmarshal(body, &reqBody); err != nil { + return nil, fmt.Errorf("parse request: %w", err) + } + return reqBody, nil +} + func extractOpenAIReasoningEffort(reqBody map[string]any, requestedModel string) *string { if value, present := getOpenAIReasoningEffortFromReqBody(reqBody); present { if value == "" { diff --git a/backend/internal/service/openai_gateway_service_hotpath_test.go b/backend/internal/service/openai_gateway_service_hotpath_test.go new file mode 100644 index 00000000..6b11831f --- /dev/null +++ b/backend/internal/service/openai_gateway_service_hotpath_test.go @@ -0,0 +1,125 @@ +package service + +import ( + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestExtractOpenAIRequestMetaFromBody(t *testing.T) { + tests := []struct { + name string + body []byte + wantModel string + wantStream bool + wantPromptKey string + }{ + { + name: "完整字段", + body: []byte(`{"model":"gpt-5","stream":true,"prompt_cache_key":" ses-1 "}`), + wantModel: "gpt-5", + wantStream: true, + wantPromptKey: "ses-1", + }, + { + name: "缺失可选字段", + body: []byte(`{"model":"gpt-4"}`), + wantModel: "gpt-4", + wantStream: false, + wantPromptKey: "", + }, + { + name: "空请求体", + body: nil, + wantModel: "", + wantStream: false, + wantPromptKey: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + model, stream, promptKey := extractOpenAIRequestMetaFromBody(tt.body) + require.Equal(t, tt.wantModel, model) + require.Equal(t, tt.wantStream, stream) + require.Equal(t, tt.wantPromptKey, promptKey) + }) + } +} + +func TestExtractOpenAIReasoningEffortFromBody(t *testing.T) { + tests := []struct { + name string + body []byte + model string + wantNil bool + wantValue string + }{ + { + name: "优先读取 reasoning.effort", + body: []byte(`{"reasoning":{"effort":"medium"}}`), + model: "gpt-5-high", + wantNil: false, + wantValue: "medium", + }, + { + name: "兼容 reasoning_effort", + body: []byte(`{"reasoning_effort":"x-high"}`), + model: "", + wantNil: false, + wantValue: "xhigh", + }, + { + name: "minimal 归一化为空", + body: []byte(`{"reasoning":{"effort":"minimal"}}`), + model: "gpt-5-high", + wantNil: true, + }, + { + name: "缺失字段时从模型后缀推导", + body: []byte(`{"input":"hi"}`), + model: "gpt-5-high", + wantNil: false, + wantValue: "high", + }, + { + name: "未知后缀不返回", + body: []byte(`{"input":"hi"}`), + model: "gpt-5-unknown", + wantNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := extractOpenAIReasoningEffortFromBody(tt.body, tt.model) + if tt.wantNil { + require.Nil(t, got) + return + } + require.NotNil(t, got) + require.Equal(t, tt.wantValue, *got) + }) + } +} + +func TestGetOpenAIRequestBodyMap_UsesContextCache(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + cached := map[string]any{"model": "cached-model", "stream": true} + c.Set(OpenAIParsedRequestBodyKey, cached) + + got, err := getOpenAIRequestBodyMap(c, []byte(`{invalid-json`)) + require.NoError(t, err) + require.Equal(t, cached, got) +} + +func TestGetOpenAIRequestBodyMap_ParseErrorWithoutCache(t *testing.T) { + _, err := getOpenAIRequestBodyMap(nil, []byte(`{invalid-json`)) + require.Error(t, err) + require.Contains(t, err.Error(), "parse request") +} diff --git a/backend/internal/service/openai_gateway_service_test.go b/backend/internal/service/openai_gateway_service_test.go index 165c235c..226648e4 100644 --- a/backend/internal/service/openai_gateway_service_test.go +++ b/backend/internal/service/openai_gateway_service_test.go @@ -1416,3 +1416,109 @@ func TestReplaceModelInResponseBody(t *testing.T) { }) } } + +func TestExtractOpenAISSEDataLine(t *testing.T) { + tests := []struct { + name string + line string + wantData string + wantOK bool + }{ + {name: "标准格式", line: `data: {"type":"x"}`, wantData: `{"type":"x"}`, wantOK: true}, + {name: "无空格格式", line: `data:{"type":"x"}`, wantData: `{"type":"x"}`, wantOK: true}, + {name: "纯空数据", line: `data: `, wantData: ``, wantOK: true}, + {name: "非 data 行", line: `event: message`, wantData: ``, wantOK: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := extractOpenAISSEDataLine(tt.line) + require.Equal(t, tt.wantOK, ok) + require.Equal(t, tt.wantData, got) + }) + } +} + +func TestParseSSEUsage_SelectiveParsing(t *testing.T) { + svc := &OpenAIGatewayService{} + usage := &OpenAIUsage{InputTokens: 9, OutputTokens: 8, CacheReadInputTokens: 7} + + // 非 completed 事件,不应覆盖 usage + svc.parseSSEUsage(`{"type":"response.in_progress","response":{"usage":{"input_tokens":1,"output_tokens":2}}}`, usage) + require.Equal(t, 9, usage.InputTokens) + require.Equal(t, 8, usage.OutputTokens) + require.Equal(t, 7, usage.CacheReadInputTokens) + + // completed 事件,应提取 usage + svc.parseSSEUsage(`{"type":"response.completed","response":{"usage":{"input_tokens":3,"output_tokens":5,"input_tokens_details":{"cached_tokens":2}}}}`, usage) + require.Equal(t, 3, usage.InputTokens) + require.Equal(t, 5, usage.OutputTokens) + require.Equal(t, 2, usage.CacheReadInputTokens) +} + +func TestExtractCodexFinalResponse_SampleReplay(t *testing.T) { + body := strings.Join([]string{ + `event: message`, + `data: {"type":"response.in_progress","response":{"id":"resp_1"}}`, + `data: {"type":"response.completed","response":{"id":"resp_1","model":"gpt-4o","usage":{"input_tokens":11,"output_tokens":22,"input_tokens_details":{"cached_tokens":3}}}}`, + `data: [DONE]`, + }, "\n") + + finalResp, ok := extractCodexFinalResponse(body) + require.True(t, ok) + require.Contains(t, string(finalResp), `"id":"resp_1"`) + require.Contains(t, string(finalResp), `"input_tokens":11`) +} + +func TestHandleOAuthSSEToJSON_CompletedEventReturnsJSON(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + svc := &OpenAIGatewayService{cfg: &config.Config{}} + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + } + body := []byte(strings.Join([]string{ + `data: {"type":"response.in_progress","response":{"id":"resp_2"}}`, + `data: {"type":"response.completed","response":{"id":"resp_2","model":"gpt-4o","usage":{"input_tokens":7,"output_tokens":9,"input_tokens_details":{"cached_tokens":1}}}}`, + `data: [DONE]`, + }, "\n")) + + usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o") + require.NoError(t, err) + require.NotNil(t, usage) + require.Equal(t, 7, usage.InputTokens) + require.Equal(t, 9, usage.OutputTokens) + require.Equal(t, 1, usage.CacheReadInputTokens) + // Header 可能由上游 Content-Type 透传;关键是 body 已转换为最终 JSON 响应。 + require.NotContains(t, rec.Body.String(), "event:") + require.Contains(t, rec.Body.String(), `"id":"resp_2"`) + require.NotContains(t, rec.Body.String(), "data:") +} + +func TestHandleOAuthSSEToJSON_NoFinalResponseKeepsSSEBody(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + c.Request = httptest.NewRequest(http.MethodPost, "/", nil) + + svc := &OpenAIGatewayService{cfg: &config.Config{}} + resp := &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"text/event-stream"}}, + } + body := []byte(strings.Join([]string{ + `data: {"type":"response.in_progress","response":{"id":"resp_3"}}`, + `data: [DONE]`, + }, "\n")) + + usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o") + require.NoError(t, err) + require.NotNil(t, usage) + require.Equal(t, 0, usage.InputTokens) + require.Contains(t, rec.Header().Get("Content-Type"), "text/event-stream") + require.Contains(t, rec.Body.String(), `data: {"type":"response.in_progress"`) +} diff --git a/backend/internal/service/openai_token_provider.go b/backend/internal/service/openai_token_provider.go index 026d9061..3842f0a4 100644 --- a/backend/internal/service/openai_token_provider.go +++ b/backend/internal/service/openai_token_provider.go @@ -4,16 +4,74 @@ import ( "context" "errors" "log/slog" + "math/rand/v2" "strings" + "sync/atomic" "time" ) const ( - openAITokenRefreshSkew = 3 * time.Minute - openAITokenCacheSkew = 5 * time.Minute - openAILockWaitTime = 200 * time.Millisecond + openAITokenRefreshSkew = 3 * time.Minute + openAITokenCacheSkew = 5 * time.Minute + openAILockInitialWait = 20 * time.Millisecond + openAILockMaxWait = 120 * time.Millisecond + openAILockMaxAttempts = 5 + openAILockJitterRatio = 0.2 + openAILockWarnThresholdMs = 250 ) +// OpenAITokenRuntimeMetrics 表示 OpenAI token 刷新与锁竞争保护指标快照。 +type OpenAITokenRuntimeMetrics struct { + RefreshRequests int64 + RefreshSuccess int64 + RefreshFailure int64 + LockAcquireFailure int64 + LockContention int64 + LockWaitSamples int64 + LockWaitTotalMs int64 + LockWaitHit int64 + LockWaitMiss int64 + LastObservedUnixMs int64 +} + +type openAITokenRuntimeMetricsStore struct { + refreshRequests atomic.Int64 + refreshSuccess atomic.Int64 + refreshFailure atomic.Int64 + lockAcquireFailure atomic.Int64 + lockContention atomic.Int64 + lockWaitSamples atomic.Int64 + lockWaitTotalMs atomic.Int64 + lockWaitHit atomic.Int64 + lockWaitMiss atomic.Int64 + lastObservedUnixMs atomic.Int64 +} + +func (m *openAITokenRuntimeMetricsStore) snapshot() OpenAITokenRuntimeMetrics { + if m == nil { + return OpenAITokenRuntimeMetrics{} + } + return OpenAITokenRuntimeMetrics{ + RefreshRequests: m.refreshRequests.Load(), + RefreshSuccess: m.refreshSuccess.Load(), + RefreshFailure: m.refreshFailure.Load(), + LockAcquireFailure: m.lockAcquireFailure.Load(), + LockContention: m.lockContention.Load(), + LockWaitSamples: m.lockWaitSamples.Load(), + LockWaitTotalMs: m.lockWaitTotalMs.Load(), + LockWaitHit: m.lockWaitHit.Load(), + LockWaitMiss: m.lockWaitMiss.Load(), + LastObservedUnixMs: m.lastObservedUnixMs.Load(), + } +} + +func (m *openAITokenRuntimeMetricsStore) touchNow() { + if m == nil { + return + } + m.lastObservedUnixMs.Store(time.Now().UnixMilli()) +} + // OpenAITokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义) type OpenAITokenCache = GeminiTokenCache @@ -22,6 +80,7 @@ type OpenAITokenProvider struct { accountRepo AccountRepository tokenCache OpenAITokenCache openAIOAuthService *OpenAIOAuthService + metrics *openAITokenRuntimeMetricsStore } func NewOpenAITokenProvider( @@ -33,11 +92,27 @@ func NewOpenAITokenProvider( accountRepo: accountRepo, tokenCache: tokenCache, openAIOAuthService: openAIOAuthService, + metrics: &openAITokenRuntimeMetricsStore{}, + } +} + +func (p *OpenAITokenProvider) SnapshotRuntimeMetrics() OpenAITokenRuntimeMetrics { + if p == nil { + return OpenAITokenRuntimeMetrics{} + } + p.ensureMetrics() + return p.metrics.snapshot() +} + +func (p *OpenAITokenProvider) ensureMetrics() { + if p != nil && p.metrics == nil { + p.metrics = &openAITokenRuntimeMetricsStore{} } } // GetAccessToken 获取有效的 access_token func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) { + p.ensureMetrics() if account == nil { return "", errors.New("account is nil") } @@ -64,6 +139,8 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew refreshFailed := false if needsRefresh && p.tokenCache != nil { + p.metrics.refreshRequests.Add(1) + p.metrics.touchNow() locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) if lockErr == nil && locked { defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() @@ -82,14 +159,17 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew { if p.openAIOAuthService == nil { slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID) + p.metrics.refreshFailure.Add(1) refreshFailed = true // 无法刷新,标记失败 } else { tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account) if err != nil { // 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err) + p.metrics.refreshFailure.Add(1) refreshFailed = true // 刷新失败,标记以使用短 TTL } else { + p.metrics.refreshSuccess.Add(1) newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo) for k, v := range account.Credentials { if _, exists := newCredentials[k]; !exists { @@ -106,6 +186,8 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou } } else if lockErr != nil { // Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时) + p.metrics.lockAcquireFailure.Add(1) + p.metrics.touchNow() slog.Warn("openai_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr) // 检查 ctx 是否已取消 @@ -126,13 +208,16 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew { if p.openAIOAuthService == nil { slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID) + p.metrics.refreshFailure.Add(1) refreshFailed = true } else { tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account) if err != nil { slog.Warn("openai_token_refresh_failed_degraded", "account_id", account.ID, "error", err) + p.metrics.refreshFailure.Add(1) refreshFailed = true } else { + p.metrics.refreshSuccess.Add(1) newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo) for k, v := range account.Credentials { if _, exists := newCredentials[k]; !exists { @@ -148,9 +233,14 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou } } } else { - // 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存 - time.Sleep(openAILockWaitTime) - if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { + // 锁被其他 worker 持有:使用短轮询+jitter,降低固定等待导致的尾延迟台阶。 + p.metrics.lockContention.Add(1) + p.metrics.touchNow() + token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey) + if waitErr != nil { + return "", waitErr + } + if strings.TrimSpace(token) != "" { slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID) return token, nil } @@ -198,3 +288,64 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou return accessToken, nil } + +func (p *OpenAITokenProvider) waitForTokenAfterLockRace(ctx context.Context, cacheKey string) (string, error) { + wait := openAILockInitialWait + totalWaitMs := int64(0) + for i := 0; i < openAILockMaxAttempts; i++ { + actualWait := jitterLockWait(wait) + timer := time.NewTimer(actualWait) + select { + case <-ctx.Done(): + if !timer.Stop() { + select { + case <-timer.C: + default: + } + } + return "", ctx.Err() + case <-timer.C: + } + + waitMs := actualWait.Milliseconds() + if waitMs < 0 { + waitMs = 0 + } + totalWaitMs += waitMs + p.metrics.lockWaitSamples.Add(1) + p.metrics.lockWaitTotalMs.Add(waitMs) + p.metrics.touchNow() + + token, err := p.tokenCache.GetAccessToken(ctx, cacheKey) + if err == nil && strings.TrimSpace(token) != "" { + p.metrics.lockWaitHit.Add(1) + if totalWaitMs >= openAILockWarnThresholdMs { + slog.Warn("openai_token_lock_wait_high", "wait_ms", totalWaitMs, "attempts", i+1) + } + return token, nil + } + + if wait < openAILockMaxWait { + wait *= 2 + if wait > openAILockMaxWait { + wait = openAILockMaxWait + } + } + } + + p.metrics.lockWaitMiss.Add(1) + if totalWaitMs >= openAILockWarnThresholdMs { + slog.Warn("openai_token_lock_wait_high", "wait_ms", totalWaitMs, "attempts", openAILockMaxAttempts) + } + return "", nil +} + +func jitterLockWait(base time.Duration) time.Duration { + if base <= 0 { + return 0 + } + minFactor := 1 - openAILockJitterRatio + maxFactor := 1 + openAILockJitterRatio + factor := minFactor + rand.Float64()*(maxFactor-minFactor) + return time.Duration(float64(base) * factor) +} diff --git a/backend/internal/service/openai_token_provider_test.go b/backend/internal/service/openai_token_provider_test.go index 3c649a7e..1cd92367 100644 --- a/backend/internal/service/openai_token_provider_test.go +++ b/backend/internal/service/openai_token_provider_test.go @@ -808,3 +808,119 @@ func TestOpenAITokenProvider_Real_NilCredentials(t *testing.T) { require.Contains(t, err.Error(), "access_token not found") require.Empty(t, token) } + +func TestOpenAITokenProvider_Real_LockRace_PollingHitsCache(t *testing.T) { + cache := newOpenAITokenCacheStub() + cache.lockAcquired = false // 模拟锁被其他 worker 持有 + + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 207, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "fallback-token", + "expires_at": expiresAt, + }, + } + + cacheKey := OpenAITokenCacheKey(account) + go func() { + time.Sleep(5 * time.Millisecond) + cache.mu.Lock() + cache.tokens[cacheKey] = "winner-token" + cache.mu.Unlock() + }() + + provider := NewOpenAITokenProvider(nil, cache, nil) + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "winner-token", token) +} + +func TestOpenAITokenProvider_Real_LockRace_ContextCanceled(t *testing.T) { + cache := newOpenAITokenCacheStub() + cache.lockAcquired = false // 模拟锁被其他 worker 持有 + + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 208, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "fallback-token", + "expires_at": expiresAt, + }, + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + provider := NewOpenAITokenProvider(nil, cache, nil) + start := time.Now() + token, err := provider.GetAccessToken(ctx, account) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) + require.Empty(t, token) + require.Less(t, time.Since(start), 50*time.Millisecond) +} + +func TestOpenAITokenProvider_RuntimeMetrics_LockWaitHitAndSnapshot(t *testing.T) { + cache := newOpenAITokenCacheStub() + cache.lockAcquired = false + + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 209, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "fallback-token", + "expires_at": expiresAt, + }, + } + cacheKey := OpenAITokenCacheKey(account) + go func() { + time.Sleep(10 * time.Millisecond) + cache.mu.Lock() + cache.tokens[cacheKey] = "winner-token" + cache.mu.Unlock() + }() + + provider := NewOpenAITokenProvider(nil, cache, nil) + token, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + require.Equal(t, "winner-token", token) + + metrics := provider.SnapshotRuntimeMetrics() + require.GreaterOrEqual(t, metrics.RefreshRequests, int64(1)) + require.GreaterOrEqual(t, metrics.LockContention, int64(1)) + require.GreaterOrEqual(t, metrics.LockWaitSamples, int64(1)) + require.GreaterOrEqual(t, metrics.LockWaitHit, int64(1)) + require.GreaterOrEqual(t, metrics.LockWaitTotalMs, int64(0)) + require.GreaterOrEqual(t, metrics.LastObservedUnixMs, int64(1)) +} + +func TestOpenAITokenProvider_RuntimeMetrics_LockAcquireFailure(t *testing.T) { + cache := newOpenAITokenCacheStub() + cache.lockErr = errors.New("redis lock error") + + expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339) + account := &Account{ + ID: 210, + Platform: PlatformOpenAI, + Type: AccountTypeOAuth, + Credentials: map[string]any{ + "access_token": "fallback-token", + "expires_at": expiresAt, + }, + } + + provider := NewOpenAITokenProvider(nil, cache, nil) + _, err := provider.GetAccessToken(context.Background(), account) + require.NoError(t, err) + + metrics := provider.SnapshotRuntimeMetrics() + require.GreaterOrEqual(t, metrics.LockAcquireFailure, int64(1)) + require.GreaterOrEqual(t, metrics.RefreshRequests, int64(1)) +} diff --git a/backend/internal/service/ops_port.go b/backend/internal/service/ops_port.go index 347b06b5..bbef4ceb 100644 --- a/backend/internal/service/ops_port.go +++ b/backend/internal/service/ops_port.go @@ -98,6 +98,10 @@ type OpsInsertErrorLogInput struct { // It is set by OpsService.RecordError before persisting. UpstreamErrorsJSON *string + AuthLatencyMs *int64 + RoutingLatencyMs *int64 + UpstreamLatencyMs *int64 + ResponseLatencyMs *int64 TimeToFirstTokenMs *int64 RequestBodyJSON *string // sanitized json string (not raw bytes) diff --git a/backend/internal/service/ops_upstream_context.go b/backend/internal/service/ops_upstream_context.go index 979b57cd..d33730b7 100644 --- a/backend/internal/service/ops_upstream_context.go +++ b/backend/internal/service/ops_upstream_context.go @@ -20,8 +20,30 @@ const ( // retry the specific upstream attempt (not just the client request). // This value is sanitized+trimmed before being persisted. OpsUpstreamRequestBodyKey = "ops_upstream_request_body" + + // Optional stage latencies (milliseconds) for troubleshooting and alerting. + OpsAuthLatencyMsKey = "ops_auth_latency_ms" + OpsRoutingLatencyMsKey = "ops_routing_latency_ms" + OpsUpstreamLatencyMsKey = "ops_upstream_latency_ms" + OpsResponseLatencyMsKey = "ops_response_latency_ms" + OpsTimeToFirstTokenMsKey = "ops_time_to_first_token_ms" ) +func setOpsUpstreamRequestBody(c *gin.Context, body []byte) { + if c == nil || len(body) == 0 { + return + } + // 热路径避免 string(body) 额外分配,按需在落库前再转换。 + c.Set(OpsUpstreamRequestBodyKey, body) +} + +func SetOpsLatencyMs(c *gin.Context, key string, value int64) { + if c == nil || strings.TrimSpace(key) == "" || value < 0 { + return + } + c.Set(key, value) +} + func setOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage, upstreamDetail string) { if c == nil { return @@ -91,8 +113,11 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) { // stored it on the context, attach it so ops can retry this specific attempt. if ev.UpstreamRequestBody == "" { if v, ok := c.Get(OpsUpstreamRequestBodyKey); ok { - if s, ok := v.(string); ok { - ev.UpstreamRequestBody = strings.TrimSpace(s) + switch raw := v.(type) { + case string: + ev.UpstreamRequestBody = strings.TrimSpace(raw) + case []byte: + ev.UpstreamRequestBody = strings.TrimSpace(string(raw)) } } } diff --git a/backend/internal/service/ops_upstream_context_test.go b/backend/internal/service/ops_upstream_context_test.go new file mode 100644 index 00000000..50ceaa0e --- /dev/null +++ b/backend/internal/service/ops_upstream_context_test.go @@ -0,0 +1,47 @@ +package service + +import ( + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "github.com/stretchr/testify/require" +) + +func TestAppendOpsUpstreamError_UsesRequestBodyBytesFromContext(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + setOpsUpstreamRequestBody(c, []byte(`{"model":"gpt-5"}`)) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Kind: "http_error", + Message: "upstream failed", + }) + + v, ok := c.Get(OpsUpstreamErrorsKey) + require.True(t, ok) + events, ok := v.([]*OpsUpstreamErrorEvent) + require.True(t, ok) + require.Len(t, events, 1) + require.Equal(t, `{"model":"gpt-5"}`, events[0].UpstreamRequestBody) +} + +func TestAppendOpsUpstreamError_UsesRequestBodyStringFromContext(t *testing.T) { + gin.SetMode(gin.TestMode) + rec := httptest.NewRecorder() + c, _ := gin.CreateTestContext(rec) + + c.Set(OpsUpstreamRequestBodyKey, `{"model":"gpt-4"}`) + appendOpsUpstreamError(c, OpsUpstreamErrorEvent{ + Kind: "request_error", + Message: "dial timeout", + }) + + v, ok := c.Get(OpsUpstreamErrorsKey) + require.True(t, ok) + events, ok := v.([]*OpsUpstreamErrorEvent) + require.True(t, ok) + require.Len(t, events, 1) + require.Equal(t, `{"model":"gpt-4"}`, events[0].UpstreamRequestBody) +} diff --git a/tools/perf/openai_oauth_gray_drill.py b/tools/perf/openai_oauth_gray_drill.py new file mode 100755 index 00000000..0daa3f08 --- /dev/null +++ b/tools/perf/openai_oauth_gray_drill.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 +"""OpenAI OAuth 灰度发布演练脚本(本地模拟)。 + +该脚本会启动本地 mock Ops API,调用 openai_oauth_gray_guard.py, +验证以下场景: +1) A/B/C/D 四个灰度批次均通过 +2) 注入异常场景触发阈值告警并返回退出码 2(模拟自动回滚触发) +""" + +from __future__ import annotations + +import json +import subprocess +import threading +from dataclasses import dataclass +from http.server import BaseHTTPRequestHandler, HTTPServer +from pathlib import Path +from typing import Dict, Tuple +from urllib.parse import parse_qs, urlparse + +ROOT = Path(__file__).resolve().parents[2] +GUARD_SCRIPT = ROOT / "tools" / "perf" / "openai_oauth_gray_guard.py" +REPORT_PATH = ROOT / "docs" / "perf" / "openai-oauth-gray-drill-report.md" + + +THRESHOLDS = { + "sla_percent_min": 99.5, + "ttft_p99_ms_max": 900, + "request_error_rate_percent_max": 2.0, + "upstream_error_rate_percent_max": 2.0, +} + +STAGE_SNAPSHOTS: Dict[str, Dict[str, float]] = { + "A": {"sla": 99.78, "ttft": 780, "error_rate": 1.20, "upstream_error_rate": 1.05}, + "B": {"sla": 99.82, "ttft": 730, "error_rate": 1.05, "upstream_error_rate": 0.92}, + "C": {"sla": 99.86, "ttft": 680, "error_rate": 0.88, "upstream_error_rate": 0.80}, + "D": {"sla": 99.89, "ttft": 640, "error_rate": 0.72, "upstream_error_rate": 0.67}, + "rollback": {"sla": 97.10, "ttft": 1550, "error_rate": 6.30, "upstream_error_rate": 5.60}, +} + + +class _MockHandler(BaseHTTPRequestHandler): + def _write_json(self, payload: dict) -> None: + raw = json.dumps(payload, ensure_ascii=False).encode("utf-8") + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(raw))) + self.end_headers() + self.wfile.write(raw) + + def log_message(self, format: str, *args): # noqa: A003 + return + + def do_GET(self): # noqa: N802 + parsed = urlparse(self.path) + if parsed.path.endswith("/api/v1/admin/ops/settings/metric-thresholds"): + self._write_json({"code": 0, "message": "success", "data": THRESHOLDS}) + return + + if parsed.path.endswith("/api/v1/admin/ops/dashboard/overview"): + q = parse_qs(parsed.query) + stage = (q.get("group_id") or ["A"])[0] + snapshot = STAGE_SNAPSHOTS.get(stage, STAGE_SNAPSHOTS["A"]) + self._write_json( + { + "code": 0, + "message": "success", + "data": { + "sla": snapshot["sla"], + "error_rate": snapshot["error_rate"], + "upstream_error_rate": snapshot["upstream_error_rate"], + "ttft": {"p99_ms": snapshot["ttft"]}, + }, + } + ) + return + + self.send_response(404) + self.end_headers() + + +def run_guard(base_url: str, stage: str) -> Tuple[int, str]: + cmd = [ + "python", + str(GUARD_SCRIPT), + "--base-url", + base_url, + "--platform", + "openai", + "--time-range", + "30m", + "--group-id", + stage, + ] + proc = subprocess.run(cmd, cwd=str(ROOT), capture_output=True, text=True) + output = (proc.stdout + "\n" + proc.stderr).strip() + return proc.returncode, output + + +def main() -> int: + server = HTTPServer(("127.0.0.1", 0), _MockHandler) + host, port = server.server_address + base_url = f"http://{host}:{port}" + + thread = threading.Thread(target=server.serve_forever, daemon=True) + thread.start() + + lines = [ + "# OpenAI OAuth 灰度守护演练报告", + "", + "> 类型:本地 mock 演练(用于验证灰度守护与回滚触发机制)", + f"> 生成脚本:`tools/perf/openai_oauth_gray_drill.py`", + "", + "## 1. 灰度批次结果(6.1)", + "", + "| 批次 | 流量比例 | 守护脚本退出码 | 结果 |", + "|---|---:|---:|---|", + ] + + batch_plan = [("A", "5%"), ("B", "20%"), ("C", "50%"), ("D", "100%")] + all_pass = True + for stage, ratio in batch_plan: + code, _ = run_guard(base_url, stage) + ok = code == 0 + all_pass = all_pass and ok + lines.append(f"| {stage} | {ratio} | {code} | {'通过' if ok else '失败'} |") + + lines.extend([ + "", + "## 2. 回滚触发演练(6.2)", + "", + ]) + + rollback_code, rollback_output = run_guard(base_url, "rollback") + rollback_triggered = rollback_code == 2 + lines.append(f"- 注入异常场景退出码:`{rollback_code}`") + lines.append(f"- 是否触发回滚条件:`{'是' if rollback_triggered else '否'}`") + lines.append("- 关键信息摘录:") + excerpt = "\n".join(rollback_output.splitlines()[:8]) + lines.append("```text") + lines.append(excerpt) + lines.append("```") + + lines.extend([ + "", + "## 3. 验收结论(6.3)", + "", + f"- 批次灰度结果:`{'通过' if all_pass else '不通过'}`", + f"- 回滚触发机制:`{'通过' if rollback_triggered else '不通过'}`", + f"- 结论:`{'通过(可进入真实环境灰度)' if all_pass and rollback_triggered else '不通过(需修复后复测)'}`", + ]) + + REPORT_PATH.parent.mkdir(parents=True, exist_ok=True) + REPORT_PATH.write_text("\n".join(lines) + "\n", encoding="utf-8") + + server.shutdown() + server.server_close() + + print(f"drill report generated: {REPORT_PATH}") + return 0 if all_pass and rollback_triggered else 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tools/perf/openai_oauth_gray_guard.py b/tools/perf/openai_oauth_gray_guard.py new file mode 100755 index 00000000..a71a9ad2 --- /dev/null +++ b/tools/perf/openai_oauth_gray_guard.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python3 +"""OpenAI OAuth 灰度阈值守护脚本。 + +用途: +- 拉取 Ops 指标阈值配置与 Dashboard Overview 实时数据 +- 对比 P99 TTFT / 错误率 / SLA +- 作为 6.2 灰度守护的自动化门禁(退出码可直接用于 CI/CD) + +退出码: +- 0: 指标通过 +- 1: 请求失败/参数错误 +- 2: 指标超阈值(建议停止扩量并回滚) +""" + +from __future__ import annotations + +import argparse +import json +import sys +import urllib.error +import urllib.parse +import urllib.request +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + + +@dataclass +class GuardThresholds: + sla_percent_min: Optional[float] + ttft_p99_ms_max: Optional[float] + request_error_rate_percent_max: Optional[float] + upstream_error_rate_percent_max: Optional[float] + + +@dataclass +class GuardSnapshot: + sla: Optional[float] + ttft_p99_ms: Optional[float] + request_error_rate_percent: Optional[float] + upstream_error_rate_percent: Optional[float] + + +def build_headers(token: str) -> Dict[str, str]: + headers = {"Accept": "application/json"} + if token.strip(): + headers["Authorization"] = f"Bearer {token.strip()}" + return headers + + +def request_json(url: str, headers: Dict[str, str]) -> Dict[str, Any]: + req = urllib.request.Request(url=url, method="GET", headers=headers) + try: + with urllib.request.urlopen(req, timeout=15) as resp: + raw = resp.read().decode("utf-8") + return json.loads(raw) + except urllib.error.HTTPError as e: + body = e.read().decode("utf-8", errors="replace") + raise RuntimeError(f"HTTP {e.code}: {body}") from e + except urllib.error.URLError as e: + raise RuntimeError(f"request failed: {e}") from e + + +def parse_envelope_data(payload: Dict[str, Any]) -> Dict[str, Any]: + if not isinstance(payload, dict): + raise RuntimeError("invalid response payload") + if payload.get("code") != 0: + raise RuntimeError(f"api error: code={payload.get('code')} message={payload.get('message')}") + data = payload.get("data") + if not isinstance(data, dict): + raise RuntimeError("invalid response data") + return data + + +def parse_thresholds(data: Dict[str, Any]) -> GuardThresholds: + return GuardThresholds( + sla_percent_min=to_float_or_none(data.get("sla_percent_min")), + ttft_p99_ms_max=to_float_or_none(data.get("ttft_p99_ms_max")), + request_error_rate_percent_max=to_float_or_none(data.get("request_error_rate_percent_max")), + upstream_error_rate_percent_max=to_float_or_none(data.get("upstream_error_rate_percent_max")), + ) + + +def parse_snapshot(data: Dict[str, Any]) -> GuardSnapshot: + ttft = data.get("ttft") if isinstance(data.get("ttft"), dict) else {} + return GuardSnapshot( + sla=to_float_or_none(data.get("sla")), + ttft_p99_ms=to_float_or_none(ttft.get("p99_ms")), + request_error_rate_percent=to_float_or_none(data.get("error_rate")), + upstream_error_rate_percent=to_float_or_none(data.get("upstream_error_rate")), + ) + + +def to_float_or_none(v: Any) -> Optional[float]: + if v is None: + return None + try: + return float(v) + except (TypeError, ValueError): + return None + + +def evaluate(snapshot: GuardSnapshot, thresholds: GuardThresholds) -> List[str]: + violations: List[str] = [] + + if thresholds.sla_percent_min is not None and snapshot.sla is not None: + if snapshot.sla < thresholds.sla_percent_min: + violations.append( + f"SLA 低于阈值: actual={snapshot.sla:.2f}% threshold={thresholds.sla_percent_min:.2f}%" + ) + + if thresholds.ttft_p99_ms_max is not None and snapshot.ttft_p99_ms is not None: + if snapshot.ttft_p99_ms > thresholds.ttft_p99_ms_max: + violations.append( + f"TTFT P99 超阈值: actual={snapshot.ttft_p99_ms:.2f}ms threshold={thresholds.ttft_p99_ms_max:.2f}ms" + ) + + if ( + thresholds.request_error_rate_percent_max is not None + and snapshot.request_error_rate_percent is not None + and snapshot.request_error_rate_percent > thresholds.request_error_rate_percent_max + ): + violations.append( + "请求错误率超阈值: " + f"actual={snapshot.request_error_rate_percent:.2f}% " + f"threshold={thresholds.request_error_rate_percent_max:.2f}%" + ) + + if ( + thresholds.upstream_error_rate_percent_max is not None + and snapshot.upstream_error_rate_percent is not None + and snapshot.upstream_error_rate_percent > thresholds.upstream_error_rate_percent_max + ): + violations.append( + "上游错误率超阈值: " + f"actual={snapshot.upstream_error_rate_percent:.2f}% " + f"threshold={thresholds.upstream_error_rate_percent_max:.2f}%" + ) + + return violations + + +def main() -> int: + parser = argparse.ArgumentParser(description="OpenAI OAuth 灰度阈值守护") + parser.add_argument("--base-url", required=True, help="服务地址,例如 http://127.0.0.1:5231") + parser.add_argument("--admin-token", default="", help="Admin JWT(可选,按部署策略)") + parser.add_argument("--platform", default="openai", help="平台过滤,默认 openai") + parser.add_argument("--time-range", default="30m", help="时间窗口: 5m/30m/1h/6h/24h/7d/30d") + parser.add_argument("--group-id", default="", help="可选 group_id") + args = parser.parse_args() + + base = args.base_url.rstrip("/") + headers = build_headers(args.admin_token) + + try: + threshold_url = f"{base}/api/v1/admin/ops/settings/metric-thresholds" + thresholds_raw = request_json(threshold_url, headers) + thresholds = parse_thresholds(parse_envelope_data(thresholds_raw)) + + query = {"platform": args.platform, "time_range": args.time_range} + if args.group_id.strip(): + query["group_id"] = args.group_id.strip() + overview_url = ( + f"{base}/api/v1/admin/ops/dashboard/overview?" + + urllib.parse.urlencode(query) + ) + overview_raw = request_json(overview_url, headers) + snapshot = parse_snapshot(parse_envelope_data(overview_raw)) + + print("[OpenAI OAuth Gray Guard] 当前快照:") + print( + json.dumps( + { + "sla": snapshot.sla, + "ttft_p99_ms": snapshot.ttft_p99_ms, + "request_error_rate_percent": snapshot.request_error_rate_percent, + "upstream_error_rate_percent": snapshot.upstream_error_rate_percent, + }, + ensure_ascii=False, + indent=2, + ) + ) + print("[OpenAI OAuth Gray Guard] 阈值配置:") + print( + json.dumps( + { + "sla_percent_min": thresholds.sla_percent_min, + "ttft_p99_ms_max": thresholds.ttft_p99_ms_max, + "request_error_rate_percent_max": thresholds.request_error_rate_percent_max, + "upstream_error_rate_percent_max": thresholds.upstream_error_rate_percent_max, + }, + ensure_ascii=False, + indent=2, + ) + ) + + violations = evaluate(snapshot, thresholds) + if violations: + print("[OpenAI OAuth Gray Guard] 检测到阈值违例:") + for idx, line in enumerate(violations, start=1): + print(f" {idx}. {line}") + print("[OpenAI OAuth Gray Guard] 建议:停止扩量并执行回滚。") + return 2 + + print("[OpenAI OAuth Gray Guard] 指标通过,可继续观察或按计划扩量。") + return 0 + + except Exception as exc: + print(f"[OpenAI OAuth Gray Guard] 执行失败: {exc}", file=sys.stderr) + return 1 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tools/perf/openai_oauth_responses_k6.js b/tools/perf/openai_oauth_responses_k6.js new file mode 100644 index 00000000..30e8ac04 --- /dev/null +++ b/tools/perf/openai_oauth_responses_k6.js @@ -0,0 +1,122 @@ +import http from 'k6/http'; +import { check } from 'k6'; +import { Rate, Trend } from 'k6/metrics'; + +const baseURL = __ENV.BASE_URL || 'http://127.0.0.1:5231'; +const apiKey = __ENV.API_KEY || ''; +const model = __ENV.MODEL || 'gpt-5'; +const timeout = __ENV.TIMEOUT || '180s'; + +const nonStreamRPS = Number(__ENV.NON_STREAM_RPS || 8); +const streamRPS = Number(__ENV.STREAM_RPS || 4); +const duration = __ENV.DURATION || '3m'; +const preAllocatedVUs = Number(__ENV.PRE_ALLOCATED_VUS || 30); +const maxVUs = Number(__ENV.MAX_VUS || 200); + +const reqDurationMs = new Trend('openai_oauth_req_duration_ms', true); +const ttftMs = new Trend('openai_oauth_ttft_ms', true); +const non2xxRate = new Rate('openai_oauth_non2xx_rate'); +const streamDoneRate = new Rate('openai_oauth_stream_done_rate'); + +export const options = { + scenarios: { + non_stream: { + executor: 'constant-arrival-rate', + rate: nonStreamRPS, + timeUnit: '1s', + duration, + preAllocatedVUs, + maxVUs, + exec: 'runNonStream', + tags: { request_type: 'non_stream' }, + }, + stream: { + executor: 'constant-arrival-rate', + rate: streamRPS, + timeUnit: '1s', + duration, + preAllocatedVUs, + maxVUs, + exec: 'runStream', + tags: { request_type: 'stream' }, + }, + }, + thresholds: { + openai_oauth_non2xx_rate: ['rate<0.01'], + openai_oauth_req_duration_ms: ['p(95)<3000', 'p(99)<6000'], + openai_oauth_ttft_ms: ['p(99)<1200'], + openai_oauth_stream_done_rate: ['rate>0.99'], + }, +}; + +function buildHeaders() { + const headers = { + 'Content-Type': 'application/json', + 'User-Agent': 'codex_cli_rs/0.1.0', + }; + if (apiKey) { + headers.Authorization = `Bearer ${apiKey}`; + } + return headers; +} + +function buildBody(stream) { + return JSON.stringify({ + model, + stream, + input: [ + { + role: 'user', + content: [ + { + type: 'input_text', + text: '请返回一句极短的话:pong', + }, + ], + }, + ], + max_output_tokens: 32, + }); +} + +function recordMetrics(res, stream) { + reqDurationMs.add(res.timings.duration, { request_type: stream ? 'stream' : 'non_stream' }); + ttftMs.add(res.timings.waiting, { request_type: stream ? 'stream' : 'non_stream' }); + non2xxRate.add(res.status < 200 || res.status >= 300, { request_type: stream ? 'stream' : 'non_stream' }); + + if (stream) { + const done = !!res.body && res.body.indexOf('[DONE]') >= 0; + streamDoneRate.add(done, { request_type: 'stream' }); + } +} + +function postResponses(stream) { + const url = `${baseURL}/v1/responses`; + const res = http.post(url, buildBody(stream), { + headers: buildHeaders(), + timeout, + tags: { endpoint: '/v1/responses', request_type: stream ? 'stream' : 'non_stream' }, + }); + + check(res, { + 'status is 2xx': (r) => r.status >= 200 && r.status < 300, + }); + + recordMetrics(res, stream); + return res; +} + +export function runNonStream() { + postResponses(false); +} + +export function runStream() { + postResponses(true); +} + +export function handleSummary(data) { + return { + stdout: `\nOpenAI OAuth /v1/responses 基线完成\n${JSON.stringify(data.metrics, null, 2)}\n`, + 'docs/perf/openai-oauth-k6-summary.json': JSON.stringify(data, null, 2), + }; +}