feat(openai): 极致优化 OAuth 链路并补齐性能守护

- 优化 /v1/responses 热路径,减少重复解析与不必要拷贝\n- 优化并发与 token 竞争路径并补齐运行指标\n- 补充 OpenAI/Ops 相关单元测试与回归用例\n- 新增灰度阈值守护与压测脚本,支撑发布验收
This commit is contained in:
yangjianbo
2026-02-12 09:41:37 +08:00
parent a88bb8684f
commit 61a2bf469a
16 changed files with 1519 additions and 135 deletions

View File

@@ -104,31 +104,24 @@ func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFo
// wrapReleaseOnDone ensures release runs at most once and still triggers on context cancellation. // wrapReleaseOnDone ensures release runs at most once and still triggers on context cancellation.
// 用于避免客户端断开或上游超时导致的并发槽位泄漏。 // 用于避免客户端断开或上游超时导致的并发槽位泄漏。
// 修复:添加 quit channel 确保 goroutine 及时退出,避免泄露 // 优化:基于 context.AfterFunc 注册回调,避免每请求额外守护 goroutine。
func wrapReleaseOnDone(ctx context.Context, releaseFunc func()) func() { func wrapReleaseOnDone(ctx context.Context, releaseFunc func()) func() {
if releaseFunc == nil { if releaseFunc == nil {
return nil return nil
} }
var once sync.Once var once sync.Once
quit := make(chan struct{}) var stop func() bool
release := func() { release := func() {
once.Do(func() { once.Do(func() {
if stop != nil {
_ = stop()
}
releaseFunc() releaseFunc()
close(quit) // 通知监听 goroutine 退出
}) })
} }
go func() { stop = context.AfterFunc(ctx, release)
select {
case <-ctx.Done():
// Context 取消时释放资源
release()
case <-quit:
// 正常释放已完成goroutine 退出
return
}
}()
return release return release
} }
@@ -153,6 +146,32 @@ func (h *ConcurrencyHelper) DecrementAccountWaitCount(ctx context.Context, accou
h.concurrencyService.DecrementAccountWaitCount(ctx, accountID) h.concurrencyService.DecrementAccountWaitCount(ctx, accountID)
} }
// TryAcquireUserSlot 尝试立即获取用户并发槽位。
// 返回值: (releaseFunc, acquired, error)
func (h *ConcurrencyHelper) TryAcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (func(), bool, error) {
result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency)
if err != nil {
return nil, false, err
}
if !result.Acquired {
return nil, false, nil
}
return result.ReleaseFunc, true, nil
}
// TryAcquireAccountSlot 尝试立即获取账号并发槽位。
// 返回值: (releaseFunc, acquired, error)
func (h *ConcurrencyHelper) TryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (func(), bool, error) {
result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
if err != nil {
return nil, false, err
}
if !result.Acquired {
return nil, false, nil
}
return result.ReleaseFunc, true, nil
}
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary. // AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait. // For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun. // streamStarted is updated if streaming response has begun.
@@ -160,13 +179,13 @@ func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, userID int64
ctx := c.Request.Context() ctx := c.Request.Context()
// Try to acquire immediately // Try to acquire immediately
result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency) releaseFunc, acquired, err := h.TryAcquireUserSlot(ctx, userID, maxConcurrency)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if result.Acquired { if acquired {
return result.ReleaseFunc, nil return releaseFunc, nil
} }
// Need to wait - handle streaming ping if needed // Need to wait - handle streaming ping if needed
@@ -180,13 +199,13 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
ctx := c.Request.Context() ctx := c.Request.Context()
// Try to acquire immediately // Try to acquire immediately
result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency) releaseFunc, acquired, err := h.TryAcquireAccountSlot(ctx, accountID, maxConcurrency)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if result.Acquired { if acquired {
return result.ReleaseFunc, nil return releaseFunc, nil
} }
// Need to wait - handle streaming ping if needed // Need to wait - handle streaming ping if needed

View File

@@ -0,0 +1,114 @@
package handler
import (
"context"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
type concurrencyCacheMock struct {
acquireUserSlotFn func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error)
acquireAccountSlotFn func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error)
releaseUserCalled int32
releaseAccountCalled int32
}
func (m *concurrencyCacheMock) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
if m.acquireAccountSlotFn != nil {
return m.acquireAccountSlotFn(ctx, accountID, maxConcurrency, requestID)
}
return false, nil
}
func (m *concurrencyCacheMock) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
atomic.AddInt32(&m.releaseAccountCalled, 1)
return nil
}
func (m *concurrencyCacheMock) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
return 0, nil
}
func (m *concurrencyCacheMock) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
return true, nil
}
func (m *concurrencyCacheMock) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
return nil
}
func (m *concurrencyCacheMock) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
return 0, nil
}
func (m *concurrencyCacheMock) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
if m.acquireUserSlotFn != nil {
return m.acquireUserSlotFn(ctx, userID, maxConcurrency, requestID)
}
return false, nil
}
func (m *concurrencyCacheMock) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
atomic.AddInt32(&m.releaseUserCalled, 1)
return nil
}
func (m *concurrencyCacheMock) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
return 0, nil
}
func (m *concurrencyCacheMock) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
return true, nil
}
func (m *concurrencyCacheMock) DecrementWaitCount(ctx context.Context, userID int64) error {
return nil
}
func (m *concurrencyCacheMock) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
return map[int64]*service.AccountLoadInfo{}, nil
}
func (m *concurrencyCacheMock) GetUsersLoadBatch(ctx context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
return map[int64]*service.UserLoadInfo{}, nil
}
func (m *concurrencyCacheMock) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
return nil
}
func TestConcurrencyHelper_TryAcquireUserSlot(t *testing.T) {
cache := &concurrencyCacheMock{
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
},
}
helper := NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second)
release, acquired, err := helper.TryAcquireUserSlot(context.Background(), 101, 2)
require.NoError(t, err)
require.True(t, acquired)
require.NotNil(t, release)
release()
require.Equal(t, int32(1), atomic.LoadInt32(&cache.releaseUserCalled))
}
func TestConcurrencyHelper_TryAcquireAccountSlot_NotAcquired(t *testing.T) {
cache := &concurrencyCacheMock{
acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
return false, nil
},
}
helper := NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second)
release, acquired, err := helper.TryAcquireAccountSlot(context.Background(), 201, 1)
require.NoError(t, err)
require.False(t, acquired)
require.Nil(t, release)
require.Equal(t, int32(0), atomic.LoadInt32(&cache.releaseAccountCalled))
}

View File

@@ -64,6 +64,8 @@ func NewOpenAIGatewayHandler(
// Responses handles OpenAI Responses API endpoint // Responses handles OpenAI Responses API endpoint
// POST /openai/v1/responses // POST /openai/v1/responses
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) { func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
requestStart := time.Now()
// Get apiKey and user from context (set by ApiKeyAuth middleware) // Get apiKey and user from context (set by ApiKeyAuth middleware)
apiKey, ok := middleware2.GetAPIKeyFromContext(c) apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok { if !ok {
@@ -141,6 +143,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
if gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() { if gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
var reqBody map[string]any var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err == nil { if err := json.Unmarshal(body, &reqBody); err == nil {
c.Set(service.OpenAIParsedRequestBodyKey, reqBody)
if service.HasFunctionCallOutput(reqBody) { if service.HasFunctionCallOutput(reqBody) {
previousResponseID, _ := reqBody["previous_response_id"].(string) previousResponseID, _ := reqBody["previous_response_id"].(string)
if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) { if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) {
@@ -171,18 +174,30 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Get subscription info (may be nil) // Get subscription info (may be nil)
subscription, _ := middleware2.GetSubscriptionFromContext(c) subscription, _ := middleware2.GetSubscriptionFromContext(c)
// 0. Check if wait queue is full service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
maxWait := service.CalculateMaxWait(subject.Concurrency) routingStart := time.Now()
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
waitCounted := false // 0. 先尝试直接抢占用户槽位(快速路径)
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(c.Request.Context(), subject.UserID, subject.Concurrency)
if err != nil { if err != nil {
log.Printf("Increment wait count failed: %v", err) log.Printf("User concurrency acquire failed: %v", err)
// On error, allow request to proceed h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
waitCounted := false
if !userAcquired {
// 仅在抢槽失败时才进入等待队列,减少常态请求 Redis 写入。
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, waitErr := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
if waitErr != nil {
log.Printf("Increment wait count failed: %v", waitErr)
// 按现有降级语义:等待计数异常时放行后续抢槽流程
} else if !canWait { } else if !canWait {
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later") h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return return
} }
if err == nil && canWait { if waitErr == nil && canWait {
waitCounted = true waitCounted = true
} }
defer func() { defer func() {
@@ -191,14 +206,15 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
} }
}() }()
// 1. First acquire user concurrency slot userReleaseFunc, err = h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
if err != nil { if err != nil {
log.Printf("User concurrency acquire failed: %v", err) log.Printf("User concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "user", streamStarted) h.handleConcurrencyError(c, err, "user", streamStarted)
return return
} }
// User slot acquired: no longer waiting. }
// 用户槽位已获取:退出等待队列计数。
if waitCounted { if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID) h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
waitCounted = false waitCounted = false
@@ -253,6 +269,24 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted) h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return return
} }
// 先快速尝试一次账号槽位,命中则跳过等待计数写入。
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
c.Request.Context(),
account.ID,
selection.WaitPlan.MaxConcurrency,
)
if err != nil {
log.Printf("Account concurrency quick acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
if fastAcquired {
accountReleaseFunc = fastReleaseFunc
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
} else {
accountWaitCounted := false accountWaitCounted := false
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting) canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil { if err != nil {
@@ -292,14 +326,27 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
log.Printf("Bind sticky session failed: %v", err) log.Printf("Bind sticky session failed: %v", err)
} }
} }
}
// 账号槽位/等待计数需要在超时或断开时安全回收 // 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc) accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
// Forward request // Forward request
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now()
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body) result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil { if accountReleaseFunc != nil {
accountReleaseFunc() accountReleaseFunc()
} }
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
responseLatencyMs := forwardDurationMs
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
}
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
if err == nil && result != nil && result.FirstTokenMs != nil {
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
}
if err != nil { if err != nil {
var failoverErr *service.UpstreamFailoverError var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) { if errors.As(err, &failoverErr) {
@@ -343,6 +390,28 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
} }
} }
func getContextInt64(c *gin.Context, key string) (int64, bool) {
if c == nil || key == "" {
return 0, false
}
v, ok := c.Get(key)
if !ok {
return 0, false
}
switch t := v.(type) {
case int64:
return t, true
case int:
return int64(t), true
case int32:
return int64(t), true
case float64:
return int64(t), true
default:
return 0, false
}
}
// handleConcurrencyError handles concurrency-related errors with proper 429 response // handleConcurrencyError handles concurrency-related errors with proper 429 response
func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) { func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",

View File

@@ -507,6 +507,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
RetryCount: 0, RetryCount: 0,
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
applyOpsLatencyFieldsFromContext(c, entry)
if apiKey != nil { if apiKey != nil {
entry.APIKeyID = &apiKey.ID entry.APIKeyID = &apiKey.ID
@@ -618,6 +619,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
RetryCount: 0, RetryCount: 0,
CreatedAt: time.Now(), CreatedAt: time.Now(),
} }
applyOpsLatencyFieldsFromContext(c, entry)
// Capture upstream error context set by gateway services (if present). // Capture upstream error context set by gateway services (if present).
// This does NOT affect the client response; it enriches Ops troubleshooting data. // This does NOT affect the client response; it enriches Ops troubleshooting data.
@@ -746,6 +748,44 @@ func extractOpsRetryRequestHeaders(c *gin.Context) *string {
return &s return &s
} }
func applyOpsLatencyFieldsFromContext(c *gin.Context, entry *service.OpsInsertErrorLogInput) {
if c == nil || entry == nil {
return
}
entry.AuthLatencyMs = getContextLatencyMs(c, service.OpsAuthLatencyMsKey)
entry.RoutingLatencyMs = getContextLatencyMs(c, service.OpsRoutingLatencyMsKey)
entry.UpstreamLatencyMs = getContextLatencyMs(c, service.OpsUpstreamLatencyMsKey)
entry.ResponseLatencyMs = getContextLatencyMs(c, service.OpsResponseLatencyMsKey)
entry.TimeToFirstTokenMs = getContextLatencyMs(c, service.OpsTimeToFirstTokenMsKey)
}
func getContextLatencyMs(c *gin.Context, key string) *int64 {
if c == nil || strings.TrimSpace(key) == "" {
return nil
}
v, ok := c.Get(key)
if !ok {
return nil
}
var ms int64
switch t := v.(type) {
case int:
ms = int64(t)
case int32:
ms = int64(t)
case int64:
ms = t
case float64:
ms = int64(t)
default:
return nil
}
if ms < 0 {
return nil
}
return &ms
}
type parsedOpsError struct { type parsedOpsError struct {
ErrorType string ErrorType string
Message string Message string

View File

@@ -55,6 +55,10 @@ INSERT INTO ops_error_logs (
upstream_error_message, upstream_error_message,
upstream_error_detail, upstream_error_detail,
upstream_errors, upstream_errors,
auth_latency_ms,
routing_latency_ms,
upstream_latency_ms,
response_latency_ms,
time_to_first_token_ms, time_to_first_token_ms,
request_body, request_body,
request_body_truncated, request_body_truncated,
@@ -64,7 +68,7 @@ INSERT INTO ops_error_logs (
retry_count, retry_count,
created_at created_at
) VALUES ( ) VALUES (
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34 $1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38
) RETURNING id` ) RETURNING id`
var id int64 var id int64
@@ -97,6 +101,10 @@ INSERT INTO ops_error_logs (
opsNullString(input.UpstreamErrorMessage), opsNullString(input.UpstreamErrorMessage),
opsNullString(input.UpstreamErrorDetail), opsNullString(input.UpstreamErrorDetail),
opsNullString(input.UpstreamErrorsJSON), opsNullString(input.UpstreamErrorsJSON),
opsNullInt64(input.AuthLatencyMs),
opsNullInt64(input.RoutingLatencyMs),
opsNullInt64(input.UpstreamLatencyMs),
opsNullInt64(input.ResponseLatencyMs),
opsNullInt64(input.TimeToFirstTokenMs), opsNullInt64(input.TimeToFirstTokenMs),
opsNullString(input.RequestBodyJSON), opsNullString(input.RequestBodyJSON),
input.RequestBodyTruncated, input.RequestBodyTruncated,

View File

@@ -12,7 +12,6 @@ import (
"io" "io"
"log" "log"
"net/http" "net/http"
"regexp"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@@ -34,11 +33,10 @@ const (
// OpenAI Platform API for API Key accounts (fallback) // OpenAI Platform API for API Key accounts (fallback)
openaiPlatformAPIURL = "https://api.openai.com/v1/responses" openaiPlatformAPIURL = "https://api.openai.com/v1/responses"
openaiStickySessionTTL = time.Hour // 粘性会话TTL openaiStickySessionTTL = time.Hour // 粘性会话TTL
)
// openaiSSEDataRe matches SSE data lines with optional whitespace after colon. // OpenAIParsedRequestBodyKey 缓存 handler 侧已解析的请求体,避免重复解析。
// Some upstream APIs return non-standard "data:" without space (should be "data: "). OpenAIParsedRequestBodyKey = "openai_parsed_request_body"
var openaiSSEDataRe = regexp.MustCompile(`^data:\s*`) )
// OpenAI allowed headers whitelist (for non-OAuth accounts) // OpenAI allowed headers whitelist (for non-OAuth accounts)
var openaiAllowedHeaders = map[string]bool{ var openaiAllowedHeaders = map[string]bool{
@@ -745,32 +743,37 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
startTime := time.Now() startTime := time.Now()
originalBody := body originalBody := body
reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body)
originalModel := reqModel
// Parse request body once (avoid multiple parse/serialize cycles) isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
var reqBody map[string]any passthroughEnabled := account.Type == AccountTypeOAuth && account.IsOpenAIOAuthPassthroughEnabled() && isCodexCLI
if err := json.Unmarshal(body, &reqBody); err != nil { if passthroughEnabled {
return nil, fmt.Errorf("parse request: %w", err) // 透传分支只需要轻量提取字段,避免热路径全量 Unmarshal。
reasoningEffort := extractOpenAIReasoningEffortFromBody(body, reqModel)
return s.forwardOAuthPassthrough(ctx, c, account, originalBody, reqModel, reasoningEffort, reqStream, startTime)
} }
// Extract model and stream from parsed body reqBody, err := getOpenAIRequestBodyMap(c, body)
reqModel, _ := reqBody["model"].(string) if err != nil {
reqStream, _ := reqBody["stream"].(bool) return nil, err
promptCacheKey := "" }
if v, ok := reqBody["model"].(string); ok {
reqModel = v
originalModel = reqModel
}
if v, ok := reqBody["stream"].(bool); ok {
reqStream = v
}
if promptCacheKey == "" {
if v, ok := reqBody["prompt_cache_key"].(string); ok { if v, ok := reqBody["prompt_cache_key"].(string); ok {
promptCacheKey = strings.TrimSpace(v) promptCacheKey = strings.TrimSpace(v)
} }
}
// Track if body needs re-serialization // Track if body needs re-serialization
bodyModified := false bodyModified := false
originalModel := reqModel
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
passthroughEnabled := account.Type == AccountTypeOAuth && account.IsOpenAIOAuthPassthroughEnabled() && isCodexCLI
if passthroughEnabled {
reasoningEffort := extractOpenAIReasoningEffort(reqBody, reqModel)
return s.forwardOAuthPassthrough(ctx, c, account, originalBody, reqModel, reasoningEffort, reqStream, startTime)
}
// 对所有请求执行模型映射(包含 Codex CLI // 对所有请求执行模型映射(包含 Codex CLI
mappedModel := account.GetMappedModel(reqModel) mappedModel := account.GetMappedModel(reqModel)
@@ -888,12 +891,12 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
} }
// Capture upstream request body for ops retry of this attempt. // Capture upstream request body for ops retry of this attempt.
if c != nil { setOpsUpstreamRequestBody(c, body)
c.Set(OpsUpstreamRequestBodyKey, string(body))
}
// Send request // Send request
upstreamStart := time.Now()
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds())
if err != nil { if err != nil {
// Ensure the client receives an error response (handlers assume Forward writes on non-failover errors). // Ensure the client receives an error response (handlers assume Forward writes on non-failover errors).
safeErr := sanitizeUpstreamErrorMessage(err.Error()) safeErr := sanitizeUpstreamErrorMessage(err.Error())
@@ -1019,12 +1022,14 @@ func (s *OpenAIGatewayService) forwardOAuthPassthrough(
proxyURL = account.Proxy.URL() proxyURL = account.Proxy.URL()
} }
setOpsUpstreamRequestBody(c, body)
if c != nil { if c != nil {
c.Set(OpsUpstreamRequestBodyKey, string(body))
c.Set("openai_passthrough", true) c.Set("openai_passthrough", true)
} }
upstreamStart := time.Now()
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency) resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds())
if err != nil { if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error()) safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "") setOpsUpstreamError(c, 0, safeErr, "")
@@ -1240,8 +1245,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() line := scanner.Text()
if openaiSSEDataRe.MatchString(line) { if data, ok := extractOpenAISSEDataLine(line); ok {
data := openaiSSEDataRe.ReplaceAllString(line, "")
if firstTokenMs == nil && strings.TrimSpace(data) != "" { if firstTokenMs == nil && strings.TrimSpace(data) != "" {
ms := int(time.Since(startTime).Milliseconds()) ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms firstTokenMs = &ms
@@ -1750,8 +1754,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
lastDataAt = time.Now() lastDataAt = time.Now()
// Extract data from SSE line (supports both "data: " and "data:" formats) // Extract data from SSE line (supports both "data: " and "data:" formats)
if openaiSSEDataRe.MatchString(line) { if data, ok := extractOpenAISSEDataLine(line); ok {
data := openaiSSEDataRe.ReplaceAllString(line, "")
// Replace model in response if needed // Replace model in response if needed
if needModelReplace { if needModelReplace {
@@ -1827,11 +1830,27 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
} }
// extractOpenAISSEDataLine 低开销提取 SSE `data:` 行内容。
// 兼容 `data: xxx` 与 `data:xxx` 两种格式。
func extractOpenAISSEDataLine(line string) (string, bool) {
if !strings.HasPrefix(line, "data:") {
return "", false
}
start := len("data:")
for start < len(line) {
if line[start] != ' ' && line[start] != ' ' {
break
}
start++
}
return line[start:], true
}
func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string { func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string {
if !openaiSSEDataRe.MatchString(line) { data, ok := extractOpenAISSEDataLine(line)
if !ok {
return line return line
} }
data := openaiSSEDataRe.ReplaceAllString(line, "")
if data == "" || data == "[DONE]" { if data == "" || data == "[DONE]" {
return line return line
} }
@@ -1872,25 +1891,20 @@ func (s *OpenAIGatewayService) correctToolCallsInResponseBody(body []byte) []byt
} }
func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) { func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
// Parse response.completed event for usage (OpenAI Responses format) if usage == nil || data == "" || data == "[DONE]" {
var event struct { return
Type string `json:"type"` }
Response struct { // 选择性解析:仅在数据中包含 completed 事件标识时才进入字段提取。
Usage struct { if !strings.Contains(data, `"response.completed"`) {
InputTokens int `json:"input_tokens"` return
OutputTokens int `json:"output_tokens"` }
InputTokenDetails struct { if gjson.Get(data, "type").String() != "response.completed" {
CachedTokens int `json:"cached_tokens"` return
} `json:"input_tokens_details"`
} `json:"usage"`
} `json:"response"`
} }
if json.Unmarshal([]byte(data), &event) == nil && event.Type == "response.completed" { usage.InputTokens = int(gjson.Get(data, "response.usage.input_tokens").Int())
usage.InputTokens = event.Response.Usage.InputTokens usage.OutputTokens = int(gjson.Get(data, "response.usage.output_tokens").Int())
usage.OutputTokens = event.Response.Usage.OutputTokens usage.CacheReadInputTokens = int(gjson.Get(data, "response.usage.input_tokens_details.cached_tokens").Int())
usage.CacheReadInputTokens = event.Response.Usage.InputTokenDetails.CachedTokens
}
} }
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) { func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
@@ -2001,10 +2015,10 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.
func extractCodexFinalResponse(body string) ([]byte, bool) { func extractCodexFinalResponse(body string) ([]byte, bool) {
lines := strings.Split(body, "\n") lines := strings.Split(body, "\n")
for _, line := range lines { for _, line := range lines {
if !openaiSSEDataRe.MatchString(line) { data, ok := extractOpenAISSEDataLine(line)
if !ok {
continue continue
} }
data := openaiSSEDataRe.ReplaceAllString(line, "")
if data == "" || data == "[DONE]" { if data == "" || data == "[DONE]" {
continue continue
} }
@@ -2028,10 +2042,10 @@ func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage {
usage := &OpenAIUsage{} usage := &OpenAIUsage{}
lines := strings.Split(body, "\n") lines := strings.Split(body, "\n")
for _, line := range lines { for _, line := range lines {
if !openaiSSEDataRe.MatchString(line) { data, ok := extractOpenAISSEDataLine(line)
if !ok {
continue continue
} }
data := openaiSSEDataRe.ReplaceAllString(line, "")
if data == "" || data == "[DONE]" { if data == "" || data == "[DONE]" {
continue continue
} }
@@ -2043,7 +2057,7 @@ func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage {
func (s *OpenAIGatewayService) replaceModelInSSEBody(body, fromModel, toModel string) string { func (s *OpenAIGatewayService) replaceModelInSSEBody(body, fromModel, toModel string) string {
lines := strings.Split(body, "\n") lines := strings.Split(body, "\n")
for i, line := range lines { for i, line := range lines {
if !openaiSSEDataRe.MatchString(line) { if _, ok := extractOpenAISSEDataLine(line); !ok {
continue continue
} }
lines[i] = s.replaceModelInSSELine(line, fromModel, toModel) lines[i] = s.replaceModelInSSELine(line, fromModel, toModel)
@@ -2396,6 +2410,53 @@ func deriveOpenAIReasoningEffortFromModel(model string) string {
return normalizeOpenAIReasoningEffort(parts[len(parts)-1]) return normalizeOpenAIReasoningEffort(parts[len(parts)-1])
} }
func extractOpenAIRequestMetaFromBody(body []byte) (model string, stream bool, promptCacheKey string) {
if len(body) == 0 {
return "", false, ""
}
model = strings.TrimSpace(gjson.GetBytes(body, "model").String())
stream = gjson.GetBytes(body, "stream").Bool()
promptCacheKey = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
return model, stream, promptCacheKey
}
func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *string {
reasoningEffort := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String())
if reasoningEffort == "" {
reasoningEffort = strings.TrimSpace(gjson.GetBytes(body, "reasoning_effort").String())
}
if reasoningEffort != "" {
normalized := normalizeOpenAIReasoningEffort(reasoningEffort)
if normalized == "" {
return nil
}
return &normalized
}
value := deriveOpenAIReasoningEffortFromModel(requestedModel)
if value == "" {
return nil
}
return &value
}
func getOpenAIRequestBodyMap(c *gin.Context, body []byte) (map[string]any, error) {
if c != nil {
if cached, ok := c.Get(OpenAIParsedRequestBodyKey); ok {
if reqBody, ok := cached.(map[string]any); ok && reqBody != nil {
return reqBody, nil
}
}
}
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
return nil, fmt.Errorf("parse request: %w", err)
}
return reqBody, nil
}
func extractOpenAIReasoningEffort(reqBody map[string]any, requestedModel string) *string { func extractOpenAIReasoningEffort(reqBody map[string]any, requestedModel string) *string {
if value, present := getOpenAIReasoningEffortFromReqBody(reqBody); present { if value, present := getOpenAIReasoningEffortFromReqBody(reqBody); present {
if value == "" { if value == "" {

View File

@@ -0,0 +1,125 @@
package service
import (
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestExtractOpenAIRequestMetaFromBody(t *testing.T) {
tests := []struct {
name string
body []byte
wantModel string
wantStream bool
wantPromptKey string
}{
{
name: "完整字段",
body: []byte(`{"model":"gpt-5","stream":true,"prompt_cache_key":" ses-1 "}`),
wantModel: "gpt-5",
wantStream: true,
wantPromptKey: "ses-1",
},
{
name: "缺失可选字段",
body: []byte(`{"model":"gpt-4"}`),
wantModel: "gpt-4",
wantStream: false,
wantPromptKey: "",
},
{
name: "空请求体",
body: nil,
wantModel: "",
wantStream: false,
wantPromptKey: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
model, stream, promptKey := extractOpenAIRequestMetaFromBody(tt.body)
require.Equal(t, tt.wantModel, model)
require.Equal(t, tt.wantStream, stream)
require.Equal(t, tt.wantPromptKey, promptKey)
})
}
}
func TestExtractOpenAIReasoningEffortFromBody(t *testing.T) {
tests := []struct {
name string
body []byte
model string
wantNil bool
wantValue string
}{
{
name: "优先读取 reasoning.effort",
body: []byte(`{"reasoning":{"effort":"medium"}}`),
model: "gpt-5-high",
wantNil: false,
wantValue: "medium",
},
{
name: "兼容 reasoning_effort",
body: []byte(`{"reasoning_effort":"x-high"}`),
model: "",
wantNil: false,
wantValue: "xhigh",
},
{
name: "minimal 归一化为空",
body: []byte(`{"reasoning":{"effort":"minimal"}}`),
model: "gpt-5-high",
wantNil: true,
},
{
name: "缺失字段时从模型后缀推导",
body: []byte(`{"input":"hi"}`),
model: "gpt-5-high",
wantNil: false,
wantValue: "high",
},
{
name: "未知后缀不返回",
body: []byte(`{"input":"hi"}`),
model: "gpt-5-unknown",
wantNil: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractOpenAIReasoningEffortFromBody(tt.body, tt.model)
if tt.wantNil {
require.Nil(t, got)
return
}
require.NotNil(t, got)
require.Equal(t, tt.wantValue, *got)
})
}
}
func TestGetOpenAIRequestBodyMap_UsesContextCache(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
cached := map[string]any{"model": "cached-model", "stream": true}
c.Set(OpenAIParsedRequestBodyKey, cached)
got, err := getOpenAIRequestBodyMap(c, []byte(`{invalid-json`))
require.NoError(t, err)
require.Equal(t, cached, got)
}
func TestGetOpenAIRequestBodyMap_ParseErrorWithoutCache(t *testing.T) {
_, err := getOpenAIRequestBodyMap(nil, []byte(`{invalid-json`))
require.Error(t, err)
require.Contains(t, err.Error(), "parse request")
}

View File

@@ -1416,3 +1416,109 @@ func TestReplaceModelInResponseBody(t *testing.T) {
}) })
} }
} }
func TestExtractOpenAISSEDataLine(t *testing.T) {
tests := []struct {
name string
line string
wantData string
wantOK bool
}{
{name: "标准格式", line: `data: {"type":"x"}`, wantData: `{"type":"x"}`, wantOK: true},
{name: "无空格格式", line: `data:{"type":"x"}`, wantData: `{"type":"x"}`, wantOK: true},
{name: "纯空数据", line: `data: `, wantData: ``, wantOK: true},
{name: "非 data 行", line: `event: message`, wantData: ``, wantOK: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, ok := extractOpenAISSEDataLine(tt.line)
require.Equal(t, tt.wantOK, ok)
require.Equal(t, tt.wantData, got)
})
}
}
func TestParseSSEUsage_SelectiveParsing(t *testing.T) {
svc := &OpenAIGatewayService{}
usage := &OpenAIUsage{InputTokens: 9, OutputTokens: 8, CacheReadInputTokens: 7}
// 非 completed 事件,不应覆盖 usage
svc.parseSSEUsage(`{"type":"response.in_progress","response":{"usage":{"input_tokens":1,"output_tokens":2}}}`, usage)
require.Equal(t, 9, usage.InputTokens)
require.Equal(t, 8, usage.OutputTokens)
require.Equal(t, 7, usage.CacheReadInputTokens)
// completed 事件,应提取 usage
svc.parseSSEUsage(`{"type":"response.completed","response":{"usage":{"input_tokens":3,"output_tokens":5,"input_tokens_details":{"cached_tokens":2}}}}`, usage)
require.Equal(t, 3, usage.InputTokens)
require.Equal(t, 5, usage.OutputTokens)
require.Equal(t, 2, usage.CacheReadInputTokens)
}
func TestExtractCodexFinalResponse_SampleReplay(t *testing.T) {
body := strings.Join([]string{
`event: message`,
`data: {"type":"response.in_progress","response":{"id":"resp_1"}}`,
`data: {"type":"response.completed","response":{"id":"resp_1","model":"gpt-4o","usage":{"input_tokens":11,"output_tokens":22,"input_tokens_details":{"cached_tokens":3}}}}`,
`data: [DONE]`,
}, "\n")
finalResp, ok := extractCodexFinalResponse(body)
require.True(t, ok)
require.Contains(t, string(finalResp), `"id":"resp_1"`)
require.Contains(t, string(finalResp), `"input_tokens":11`)
}
func TestHandleOAuthSSEToJSON_CompletedEventReturnsJSON(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
svc := &OpenAIGatewayService{cfg: &config.Config{}}
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
}
body := []byte(strings.Join([]string{
`data: {"type":"response.in_progress","response":{"id":"resp_2"}}`,
`data: {"type":"response.completed","response":{"id":"resp_2","model":"gpt-4o","usage":{"input_tokens":7,"output_tokens":9,"input_tokens_details":{"cached_tokens":1}}}}`,
`data: [DONE]`,
}, "\n"))
usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o")
require.NoError(t, err)
require.NotNil(t, usage)
require.Equal(t, 7, usage.InputTokens)
require.Equal(t, 9, usage.OutputTokens)
require.Equal(t, 1, usage.CacheReadInputTokens)
// Header 可能由上游 Content-Type 透传;关键是 body 已转换为最终 JSON 响应。
require.NotContains(t, rec.Body.String(), "event:")
require.Contains(t, rec.Body.String(), `"id":"resp_2"`)
require.NotContains(t, rec.Body.String(), "data:")
}
func TestHandleOAuthSSEToJSON_NoFinalResponseKeepsSSEBody(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
svc := &OpenAIGatewayService{cfg: &config.Config{}}
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
}
body := []byte(strings.Join([]string{
`data: {"type":"response.in_progress","response":{"id":"resp_3"}}`,
`data: [DONE]`,
}, "\n"))
usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o")
require.NoError(t, err)
require.NotNil(t, usage)
require.Equal(t, 0, usage.InputTokens)
require.Contains(t, rec.Header().Get("Content-Type"), "text/event-stream")
require.Contains(t, rec.Body.String(), `data: {"type":"response.in_progress"`)
}

View File

@@ -4,16 +4,74 @@ import (
"context" "context"
"errors" "errors"
"log/slog" "log/slog"
"math/rand/v2"
"strings" "strings"
"sync/atomic"
"time" "time"
) )
const ( const (
openAITokenRefreshSkew = 3 * time.Minute openAITokenRefreshSkew = 3 * time.Minute
openAITokenCacheSkew = 5 * time.Minute openAITokenCacheSkew = 5 * time.Minute
openAILockWaitTime = 200 * time.Millisecond openAILockInitialWait = 20 * time.Millisecond
openAILockMaxWait = 120 * time.Millisecond
openAILockMaxAttempts = 5
openAILockJitterRatio = 0.2
openAILockWarnThresholdMs = 250
) )
// OpenAITokenRuntimeMetrics 表示 OpenAI token 刷新与锁竞争保护指标快照。
type OpenAITokenRuntimeMetrics struct {
RefreshRequests int64
RefreshSuccess int64
RefreshFailure int64
LockAcquireFailure int64
LockContention int64
LockWaitSamples int64
LockWaitTotalMs int64
LockWaitHit int64
LockWaitMiss int64
LastObservedUnixMs int64
}
type openAITokenRuntimeMetricsStore struct {
refreshRequests atomic.Int64
refreshSuccess atomic.Int64
refreshFailure atomic.Int64
lockAcquireFailure atomic.Int64
lockContention atomic.Int64
lockWaitSamples atomic.Int64
lockWaitTotalMs atomic.Int64
lockWaitHit atomic.Int64
lockWaitMiss atomic.Int64
lastObservedUnixMs atomic.Int64
}
func (m *openAITokenRuntimeMetricsStore) snapshot() OpenAITokenRuntimeMetrics {
if m == nil {
return OpenAITokenRuntimeMetrics{}
}
return OpenAITokenRuntimeMetrics{
RefreshRequests: m.refreshRequests.Load(),
RefreshSuccess: m.refreshSuccess.Load(),
RefreshFailure: m.refreshFailure.Load(),
LockAcquireFailure: m.lockAcquireFailure.Load(),
LockContention: m.lockContention.Load(),
LockWaitSamples: m.lockWaitSamples.Load(),
LockWaitTotalMs: m.lockWaitTotalMs.Load(),
LockWaitHit: m.lockWaitHit.Load(),
LockWaitMiss: m.lockWaitMiss.Load(),
LastObservedUnixMs: m.lastObservedUnixMs.Load(),
}
}
func (m *openAITokenRuntimeMetricsStore) touchNow() {
if m == nil {
return
}
m.lastObservedUnixMs.Store(time.Now().UnixMilli())
}
// OpenAITokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义) // OpenAITokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
type OpenAITokenCache = GeminiTokenCache type OpenAITokenCache = GeminiTokenCache
@@ -22,6 +80,7 @@ type OpenAITokenProvider struct {
accountRepo AccountRepository accountRepo AccountRepository
tokenCache OpenAITokenCache tokenCache OpenAITokenCache
openAIOAuthService *OpenAIOAuthService openAIOAuthService *OpenAIOAuthService
metrics *openAITokenRuntimeMetricsStore
} }
func NewOpenAITokenProvider( func NewOpenAITokenProvider(
@@ -33,11 +92,27 @@ func NewOpenAITokenProvider(
accountRepo: accountRepo, accountRepo: accountRepo,
tokenCache: tokenCache, tokenCache: tokenCache,
openAIOAuthService: openAIOAuthService, openAIOAuthService: openAIOAuthService,
metrics: &openAITokenRuntimeMetricsStore{},
}
}
func (p *OpenAITokenProvider) SnapshotRuntimeMetrics() OpenAITokenRuntimeMetrics {
if p == nil {
return OpenAITokenRuntimeMetrics{}
}
p.ensureMetrics()
return p.metrics.snapshot()
}
func (p *OpenAITokenProvider) ensureMetrics() {
if p != nil && p.metrics == nil {
p.metrics = &openAITokenRuntimeMetricsStore{}
} }
} }
// GetAccessToken 获取有效的 access_token // GetAccessToken 获取有效的 access_token
func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) { func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
p.ensureMetrics()
if account == nil { if account == nil {
return "", errors.New("account is nil") return "", errors.New("account is nil")
} }
@@ -64,6 +139,8 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
refreshFailed := false refreshFailed := false
if needsRefresh && p.tokenCache != nil { if needsRefresh && p.tokenCache != nil {
p.metrics.refreshRequests.Add(1)
p.metrics.touchNow()
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second) locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if lockErr == nil && locked { if lockErr == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }() defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
@@ -82,14 +159,17 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew { if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
if p.openAIOAuthService == nil { if p.openAIOAuthService == nil {
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID) slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
p.metrics.refreshFailure.Add(1)
refreshFailed = true // 无法刷新,标记失败 refreshFailed = true // 无法刷新,标记失败
} else { } else {
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account) tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
if err != nil { if err != nil {
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token // 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err) slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err)
p.metrics.refreshFailure.Add(1)
refreshFailed = true // 刷新失败,标记以使用短 TTL refreshFailed = true // 刷新失败,标记以使用短 TTL
} else { } else {
p.metrics.refreshSuccess.Add(1)
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo) newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials { for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists { if _, exists := newCredentials[k]; !exists {
@@ -106,6 +186,8 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
} }
} else if lockErr != nil { } else if lockErr != nil {
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时) // Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
p.metrics.lockAcquireFailure.Add(1)
p.metrics.touchNow()
slog.Warn("openai_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr) slog.Warn("openai_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr)
// 检查 ctx 是否已取消 // 检查 ctx 是否已取消
@@ -126,13 +208,16 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew { if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
if p.openAIOAuthService == nil { if p.openAIOAuthService == nil {
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID) slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
p.metrics.refreshFailure.Add(1)
refreshFailed = true refreshFailed = true
} else { } else {
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account) tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
if err != nil { if err != nil {
slog.Warn("openai_token_refresh_failed_degraded", "account_id", account.ID, "error", err) slog.Warn("openai_token_refresh_failed_degraded", "account_id", account.ID, "error", err)
p.metrics.refreshFailure.Add(1)
refreshFailed = true refreshFailed = true
} else { } else {
p.metrics.refreshSuccess.Add(1)
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo) newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials { for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists { if _, exists := newCredentials[k]; !exists {
@@ -148,9 +233,14 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
} }
} }
} else { } else {
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存 // 锁被其他 worker 持有:使用短轮询+jitter降低固定等待导致的尾延迟台阶。
time.Sleep(openAILockWaitTime) p.metrics.lockContention.Add(1)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" { p.metrics.touchNow()
token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey)
if waitErr != nil {
return "", waitErr
}
if strings.TrimSpace(token) != "" {
slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID) slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID)
return token, nil return token, nil
} }
@@ -198,3 +288,64 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
return accessToken, nil return accessToken, nil
} }
func (p *OpenAITokenProvider) waitForTokenAfterLockRace(ctx context.Context, cacheKey string) (string, error) {
wait := openAILockInitialWait
totalWaitMs := int64(0)
for i := 0; i < openAILockMaxAttempts; i++ {
actualWait := jitterLockWait(wait)
timer := time.NewTimer(actualWait)
select {
case <-ctx.Done():
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
return "", ctx.Err()
case <-timer.C:
}
waitMs := actualWait.Milliseconds()
if waitMs < 0 {
waitMs = 0
}
totalWaitMs += waitMs
p.metrics.lockWaitSamples.Add(1)
p.metrics.lockWaitTotalMs.Add(waitMs)
p.metrics.touchNow()
token, err := p.tokenCache.GetAccessToken(ctx, cacheKey)
if err == nil && strings.TrimSpace(token) != "" {
p.metrics.lockWaitHit.Add(1)
if totalWaitMs >= openAILockWarnThresholdMs {
slog.Warn("openai_token_lock_wait_high", "wait_ms", totalWaitMs, "attempts", i+1)
}
return token, nil
}
if wait < openAILockMaxWait {
wait *= 2
if wait > openAILockMaxWait {
wait = openAILockMaxWait
}
}
}
p.metrics.lockWaitMiss.Add(1)
if totalWaitMs >= openAILockWarnThresholdMs {
slog.Warn("openai_token_lock_wait_high", "wait_ms", totalWaitMs, "attempts", openAILockMaxAttempts)
}
return "", nil
}
func jitterLockWait(base time.Duration) time.Duration {
if base <= 0 {
return 0
}
minFactor := 1 - openAILockJitterRatio
maxFactor := 1 + openAILockJitterRatio
factor := minFactor + rand.Float64()*(maxFactor-minFactor)
return time.Duration(float64(base) * factor)
}

View File

@@ -808,3 +808,119 @@ func TestOpenAITokenProvider_Real_NilCredentials(t *testing.T) {
require.Contains(t, err.Error(), "access_token not found") require.Contains(t, err.Error(), "access_token not found")
require.Empty(t, token) require.Empty(t, token)
} }
func TestOpenAITokenProvider_Real_LockRace_PollingHitsCache(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.lockAcquired = false // 模拟锁被其他 worker 持有
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 207,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-token",
"expires_at": expiresAt,
},
}
cacheKey := OpenAITokenCacheKey(account)
go func() {
time.Sleep(5 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "winner-token"
cache.mu.Unlock()
}()
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "winner-token", token)
}
func TestOpenAITokenProvider_Real_LockRace_ContextCanceled(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.lockAcquired = false // 模拟锁被其他 worker 持有
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 208,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-token",
"expires_at": expiresAt,
},
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
provider := NewOpenAITokenProvider(nil, cache, nil)
start := time.Now()
token, err := provider.GetAccessToken(ctx, account)
require.Error(t, err)
require.ErrorIs(t, err, context.Canceled)
require.Empty(t, token)
require.Less(t, time.Since(start), 50*time.Millisecond)
}
func TestOpenAITokenProvider_RuntimeMetrics_LockWaitHitAndSnapshot(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.lockAcquired = false
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 209,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-token",
"expires_at": expiresAt,
},
}
cacheKey := OpenAITokenCacheKey(account)
go func() {
time.Sleep(10 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "winner-token"
cache.mu.Unlock()
}()
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "winner-token", token)
metrics := provider.SnapshotRuntimeMetrics()
require.GreaterOrEqual(t, metrics.RefreshRequests, int64(1))
require.GreaterOrEqual(t, metrics.LockContention, int64(1))
require.GreaterOrEqual(t, metrics.LockWaitSamples, int64(1))
require.GreaterOrEqual(t, metrics.LockWaitHit, int64(1))
require.GreaterOrEqual(t, metrics.LockWaitTotalMs, int64(0))
require.GreaterOrEqual(t, metrics.LastObservedUnixMs, int64(1))
}
func TestOpenAITokenProvider_RuntimeMetrics_LockAcquireFailure(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.lockErr = errors.New("redis lock error")
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 210,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-token",
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
_, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
metrics := provider.SnapshotRuntimeMetrics()
require.GreaterOrEqual(t, metrics.LockAcquireFailure, int64(1))
require.GreaterOrEqual(t, metrics.RefreshRequests, int64(1))
}

View File

@@ -98,6 +98,10 @@ type OpsInsertErrorLogInput struct {
// It is set by OpsService.RecordError before persisting. // It is set by OpsService.RecordError before persisting.
UpstreamErrorsJSON *string UpstreamErrorsJSON *string
AuthLatencyMs *int64
RoutingLatencyMs *int64
UpstreamLatencyMs *int64
ResponseLatencyMs *int64
TimeToFirstTokenMs *int64 TimeToFirstTokenMs *int64
RequestBodyJSON *string // sanitized json string (not raw bytes) RequestBodyJSON *string // sanitized json string (not raw bytes)

View File

@@ -20,8 +20,30 @@ const (
// retry the specific upstream attempt (not just the client request). // retry the specific upstream attempt (not just the client request).
// This value is sanitized+trimmed before being persisted. // This value is sanitized+trimmed before being persisted.
OpsUpstreamRequestBodyKey = "ops_upstream_request_body" OpsUpstreamRequestBodyKey = "ops_upstream_request_body"
// Optional stage latencies (milliseconds) for troubleshooting and alerting.
OpsAuthLatencyMsKey = "ops_auth_latency_ms"
OpsRoutingLatencyMsKey = "ops_routing_latency_ms"
OpsUpstreamLatencyMsKey = "ops_upstream_latency_ms"
OpsResponseLatencyMsKey = "ops_response_latency_ms"
OpsTimeToFirstTokenMsKey = "ops_time_to_first_token_ms"
) )
func setOpsUpstreamRequestBody(c *gin.Context, body []byte) {
if c == nil || len(body) == 0 {
return
}
// 热路径避免 string(body) 额外分配,按需在落库前再转换。
c.Set(OpsUpstreamRequestBodyKey, body)
}
func SetOpsLatencyMs(c *gin.Context, key string, value int64) {
if c == nil || strings.TrimSpace(key) == "" || value < 0 {
return
}
c.Set(key, value)
}
func setOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage, upstreamDetail string) { func setOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage, upstreamDetail string) {
if c == nil { if c == nil {
return return
@@ -91,8 +113,11 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) {
// stored it on the context, attach it so ops can retry this specific attempt. // stored it on the context, attach it so ops can retry this specific attempt.
if ev.UpstreamRequestBody == "" { if ev.UpstreamRequestBody == "" {
if v, ok := c.Get(OpsUpstreamRequestBodyKey); ok { if v, ok := c.Get(OpsUpstreamRequestBodyKey); ok {
if s, ok := v.(string); ok { switch raw := v.(type) {
ev.UpstreamRequestBody = strings.TrimSpace(s) case string:
ev.UpstreamRequestBody = strings.TrimSpace(raw)
case []byte:
ev.UpstreamRequestBody = strings.TrimSpace(string(raw))
} }
} }
} }

View File

@@ -0,0 +1,47 @@
package service
import (
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestAppendOpsUpstreamError_UsesRequestBodyBytesFromContext(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
setOpsUpstreamRequestBody(c, []byte(`{"model":"gpt-5"}`))
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Kind: "http_error",
Message: "upstream failed",
})
v, ok := c.Get(OpsUpstreamErrorsKey)
require.True(t, ok)
events, ok := v.([]*OpsUpstreamErrorEvent)
require.True(t, ok)
require.Len(t, events, 1)
require.Equal(t, `{"model":"gpt-5"}`, events[0].UpstreamRequestBody)
}
func TestAppendOpsUpstreamError_UsesRequestBodyStringFromContext(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Set(OpsUpstreamRequestBodyKey, `{"model":"gpt-4"}`)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Kind: "request_error",
Message: "dial timeout",
})
v, ok := c.Get(OpsUpstreamErrorsKey)
require.True(t, ok)
events, ok := v.([]*OpsUpstreamErrorEvent)
require.True(t, ok)
require.Len(t, events, 1)
require.Equal(t, `{"model":"gpt-4"}`, events[0].UpstreamRequestBody)
}

View File

@@ -0,0 +1,164 @@
#!/usr/bin/env python3
"""OpenAI OAuth 灰度发布演练脚本(本地模拟)。
该脚本会启动本地 mock Ops API调用 openai_oauth_gray_guard.py
验证以下场景:
1) A/B/C/D 四个灰度批次均通过
2) 注入异常场景触发阈值告警并返回退出码 2模拟自动回滚触发
"""
from __future__ import annotations
import json
import subprocess
import threading
from dataclasses import dataclass
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from typing import Dict, Tuple
from urllib.parse import parse_qs, urlparse
ROOT = Path(__file__).resolve().parents[2]
GUARD_SCRIPT = ROOT / "tools" / "perf" / "openai_oauth_gray_guard.py"
REPORT_PATH = ROOT / "docs" / "perf" / "openai-oauth-gray-drill-report.md"
THRESHOLDS = {
"sla_percent_min": 99.5,
"ttft_p99_ms_max": 900,
"request_error_rate_percent_max": 2.0,
"upstream_error_rate_percent_max": 2.0,
}
STAGE_SNAPSHOTS: Dict[str, Dict[str, float]] = {
"A": {"sla": 99.78, "ttft": 780, "error_rate": 1.20, "upstream_error_rate": 1.05},
"B": {"sla": 99.82, "ttft": 730, "error_rate": 1.05, "upstream_error_rate": 0.92},
"C": {"sla": 99.86, "ttft": 680, "error_rate": 0.88, "upstream_error_rate": 0.80},
"D": {"sla": 99.89, "ttft": 640, "error_rate": 0.72, "upstream_error_rate": 0.67},
"rollback": {"sla": 97.10, "ttft": 1550, "error_rate": 6.30, "upstream_error_rate": 5.60},
}
class _MockHandler(BaseHTTPRequestHandler):
def _write_json(self, payload: dict) -> None:
raw = json.dumps(payload, ensure_ascii=False).encode("utf-8")
self.send_response(200)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", str(len(raw)))
self.end_headers()
self.wfile.write(raw)
def log_message(self, format: str, *args): # noqa: A003
return
def do_GET(self): # noqa: N802
parsed = urlparse(self.path)
if parsed.path.endswith("/api/v1/admin/ops/settings/metric-thresholds"):
self._write_json({"code": 0, "message": "success", "data": THRESHOLDS})
return
if parsed.path.endswith("/api/v1/admin/ops/dashboard/overview"):
q = parse_qs(parsed.query)
stage = (q.get("group_id") or ["A"])[0]
snapshot = STAGE_SNAPSHOTS.get(stage, STAGE_SNAPSHOTS["A"])
self._write_json(
{
"code": 0,
"message": "success",
"data": {
"sla": snapshot["sla"],
"error_rate": snapshot["error_rate"],
"upstream_error_rate": snapshot["upstream_error_rate"],
"ttft": {"p99_ms": snapshot["ttft"]},
},
}
)
return
self.send_response(404)
self.end_headers()
def run_guard(base_url: str, stage: str) -> Tuple[int, str]:
cmd = [
"python",
str(GUARD_SCRIPT),
"--base-url",
base_url,
"--platform",
"openai",
"--time-range",
"30m",
"--group-id",
stage,
]
proc = subprocess.run(cmd, cwd=str(ROOT), capture_output=True, text=True)
output = (proc.stdout + "\n" + proc.stderr).strip()
return proc.returncode, output
def main() -> int:
server = HTTPServer(("127.0.0.1", 0), _MockHandler)
host, port = server.server_address
base_url = f"http://{host}:{port}"
thread = threading.Thread(target=server.serve_forever, daemon=True)
thread.start()
lines = [
"# OpenAI OAuth 灰度守护演练报告",
"",
"> 类型:本地 mock 演练(用于验证灰度守护与回滚触发机制)",
f"> 生成脚本:`tools/perf/openai_oauth_gray_drill.py`",
"",
"## 1. 灰度批次结果6.1",
"",
"| 批次 | 流量比例 | 守护脚本退出码 | 结果 |",
"|---|---:|---:|---|",
]
batch_plan = [("A", "5%"), ("B", "20%"), ("C", "50%"), ("D", "100%")]
all_pass = True
for stage, ratio in batch_plan:
code, _ = run_guard(base_url, stage)
ok = code == 0
all_pass = all_pass and ok
lines.append(f"| {stage} | {ratio} | {code} | {'通过' if ok else '失败'} |")
lines.extend([
"",
"## 2. 回滚触发演练6.2",
"",
])
rollback_code, rollback_output = run_guard(base_url, "rollback")
rollback_triggered = rollback_code == 2
lines.append(f"- 注入异常场景退出码:`{rollback_code}`")
lines.append(f"- 是否触发回滚条件:`{'' if rollback_triggered else ''}`")
lines.append("- 关键信息摘录:")
excerpt = "\n".join(rollback_output.splitlines()[:8])
lines.append("```text")
lines.append(excerpt)
lines.append("```")
lines.extend([
"",
"## 3. 验收结论6.3",
"",
f"- 批次灰度结果:`{'通过' if all_pass else '不通过'}`",
f"- 回滚触发机制:`{'通过' if rollback_triggered else '不通过'}`",
f"- 结论:`{'通过(可进入真实环境灰度)' if all_pass and rollback_triggered else '不通过(需修复后复测)'}`",
])
REPORT_PATH.parent.mkdir(parents=True, exist_ok=True)
REPORT_PATH.write_text("\n".join(lines) + "\n", encoding="utf-8")
server.shutdown()
server.server_close()
print(f"drill report generated: {REPORT_PATH}")
return 0 if all_pass and rollback_triggered else 1
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -0,0 +1,213 @@
#!/usr/bin/env python3
"""OpenAI OAuth 灰度阈值守护脚本。
用途:
- 拉取 Ops 指标阈值配置与 Dashboard Overview 实时数据
- 对比 P99 TTFT / 错误率 / SLA
- 作为 6.2 灰度守护的自动化门禁(退出码可直接用于 CI/CD
退出码:
- 0: 指标通过
- 1: 请求失败/参数错误
- 2: 指标超阈值(建议停止扩量并回滚)
"""
from __future__ import annotations
import argparse
import json
import sys
import urllib.error
import urllib.parse
import urllib.request
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
@dataclass
class GuardThresholds:
sla_percent_min: Optional[float]
ttft_p99_ms_max: Optional[float]
request_error_rate_percent_max: Optional[float]
upstream_error_rate_percent_max: Optional[float]
@dataclass
class GuardSnapshot:
sla: Optional[float]
ttft_p99_ms: Optional[float]
request_error_rate_percent: Optional[float]
upstream_error_rate_percent: Optional[float]
def build_headers(token: str) -> Dict[str, str]:
headers = {"Accept": "application/json"}
if token.strip():
headers["Authorization"] = f"Bearer {token.strip()}"
return headers
def request_json(url: str, headers: Dict[str, str]) -> Dict[str, Any]:
req = urllib.request.Request(url=url, method="GET", headers=headers)
try:
with urllib.request.urlopen(req, timeout=15) as resp:
raw = resp.read().decode("utf-8")
return json.loads(raw)
except urllib.error.HTTPError as e:
body = e.read().decode("utf-8", errors="replace")
raise RuntimeError(f"HTTP {e.code}: {body}") from e
except urllib.error.URLError as e:
raise RuntimeError(f"request failed: {e}") from e
def parse_envelope_data(payload: Dict[str, Any]) -> Dict[str, Any]:
if not isinstance(payload, dict):
raise RuntimeError("invalid response payload")
if payload.get("code") != 0:
raise RuntimeError(f"api error: code={payload.get('code')} message={payload.get('message')}")
data = payload.get("data")
if not isinstance(data, dict):
raise RuntimeError("invalid response data")
return data
def parse_thresholds(data: Dict[str, Any]) -> GuardThresholds:
return GuardThresholds(
sla_percent_min=to_float_or_none(data.get("sla_percent_min")),
ttft_p99_ms_max=to_float_or_none(data.get("ttft_p99_ms_max")),
request_error_rate_percent_max=to_float_or_none(data.get("request_error_rate_percent_max")),
upstream_error_rate_percent_max=to_float_or_none(data.get("upstream_error_rate_percent_max")),
)
def parse_snapshot(data: Dict[str, Any]) -> GuardSnapshot:
ttft = data.get("ttft") if isinstance(data.get("ttft"), dict) else {}
return GuardSnapshot(
sla=to_float_or_none(data.get("sla")),
ttft_p99_ms=to_float_or_none(ttft.get("p99_ms")),
request_error_rate_percent=to_float_or_none(data.get("error_rate")),
upstream_error_rate_percent=to_float_or_none(data.get("upstream_error_rate")),
)
def to_float_or_none(v: Any) -> Optional[float]:
if v is None:
return None
try:
return float(v)
except (TypeError, ValueError):
return None
def evaluate(snapshot: GuardSnapshot, thresholds: GuardThresholds) -> List[str]:
violations: List[str] = []
if thresholds.sla_percent_min is not None and snapshot.sla is not None:
if snapshot.sla < thresholds.sla_percent_min:
violations.append(
f"SLA 低于阈值: actual={snapshot.sla:.2f}% threshold={thresholds.sla_percent_min:.2f}%"
)
if thresholds.ttft_p99_ms_max is not None and snapshot.ttft_p99_ms is not None:
if snapshot.ttft_p99_ms > thresholds.ttft_p99_ms_max:
violations.append(
f"TTFT P99 超阈值: actual={snapshot.ttft_p99_ms:.2f}ms threshold={thresholds.ttft_p99_ms_max:.2f}ms"
)
if (
thresholds.request_error_rate_percent_max is not None
and snapshot.request_error_rate_percent is not None
and snapshot.request_error_rate_percent > thresholds.request_error_rate_percent_max
):
violations.append(
"请求错误率超阈值: "
f"actual={snapshot.request_error_rate_percent:.2f}% "
f"threshold={thresholds.request_error_rate_percent_max:.2f}%"
)
if (
thresholds.upstream_error_rate_percent_max is not None
and snapshot.upstream_error_rate_percent is not None
and snapshot.upstream_error_rate_percent > thresholds.upstream_error_rate_percent_max
):
violations.append(
"上游错误率超阈值: "
f"actual={snapshot.upstream_error_rate_percent:.2f}% "
f"threshold={thresholds.upstream_error_rate_percent_max:.2f}%"
)
return violations
def main() -> int:
parser = argparse.ArgumentParser(description="OpenAI OAuth 灰度阈值守护")
parser.add_argument("--base-url", required=True, help="服务地址,例如 http://127.0.0.1:5231")
parser.add_argument("--admin-token", default="", help="Admin JWT可选按部署策略")
parser.add_argument("--platform", default="openai", help="平台过滤,默认 openai")
parser.add_argument("--time-range", default="30m", help="时间窗口: 5m/30m/1h/6h/24h/7d/30d")
parser.add_argument("--group-id", default="", help="可选 group_id")
args = parser.parse_args()
base = args.base_url.rstrip("/")
headers = build_headers(args.admin_token)
try:
threshold_url = f"{base}/api/v1/admin/ops/settings/metric-thresholds"
thresholds_raw = request_json(threshold_url, headers)
thresholds = parse_thresholds(parse_envelope_data(thresholds_raw))
query = {"platform": args.platform, "time_range": args.time_range}
if args.group_id.strip():
query["group_id"] = args.group_id.strip()
overview_url = (
f"{base}/api/v1/admin/ops/dashboard/overview?"
+ urllib.parse.urlencode(query)
)
overview_raw = request_json(overview_url, headers)
snapshot = parse_snapshot(parse_envelope_data(overview_raw))
print("[OpenAI OAuth Gray Guard] 当前快照:")
print(
json.dumps(
{
"sla": snapshot.sla,
"ttft_p99_ms": snapshot.ttft_p99_ms,
"request_error_rate_percent": snapshot.request_error_rate_percent,
"upstream_error_rate_percent": snapshot.upstream_error_rate_percent,
},
ensure_ascii=False,
indent=2,
)
)
print("[OpenAI OAuth Gray Guard] 阈值配置:")
print(
json.dumps(
{
"sla_percent_min": thresholds.sla_percent_min,
"ttft_p99_ms_max": thresholds.ttft_p99_ms_max,
"request_error_rate_percent_max": thresholds.request_error_rate_percent_max,
"upstream_error_rate_percent_max": thresholds.upstream_error_rate_percent_max,
},
ensure_ascii=False,
indent=2,
)
)
violations = evaluate(snapshot, thresholds)
if violations:
print("[OpenAI OAuth Gray Guard] 检测到阈值违例:")
for idx, line in enumerate(violations, start=1):
print(f" {idx}. {line}")
print("[OpenAI OAuth Gray Guard] 建议:停止扩量并执行回滚。")
return 2
print("[OpenAI OAuth Gray Guard] 指标通过,可继续观察或按计划扩量。")
return 0
except Exception as exc:
print(f"[OpenAI OAuth Gray Guard] 执行失败: {exc}", file=sys.stderr)
return 1
if __name__ == "__main__":
raise SystemExit(main())

View File

@@ -0,0 +1,122 @@
import http from 'k6/http';
import { check } from 'k6';
import { Rate, Trend } from 'k6/metrics';
const baseURL = __ENV.BASE_URL || 'http://127.0.0.1:5231';
const apiKey = __ENV.API_KEY || '';
const model = __ENV.MODEL || 'gpt-5';
const timeout = __ENV.TIMEOUT || '180s';
const nonStreamRPS = Number(__ENV.NON_STREAM_RPS || 8);
const streamRPS = Number(__ENV.STREAM_RPS || 4);
const duration = __ENV.DURATION || '3m';
const preAllocatedVUs = Number(__ENV.PRE_ALLOCATED_VUS || 30);
const maxVUs = Number(__ENV.MAX_VUS || 200);
const reqDurationMs = new Trend('openai_oauth_req_duration_ms', true);
const ttftMs = new Trend('openai_oauth_ttft_ms', true);
const non2xxRate = new Rate('openai_oauth_non2xx_rate');
const streamDoneRate = new Rate('openai_oauth_stream_done_rate');
export const options = {
scenarios: {
non_stream: {
executor: 'constant-arrival-rate',
rate: nonStreamRPS,
timeUnit: '1s',
duration,
preAllocatedVUs,
maxVUs,
exec: 'runNonStream',
tags: { request_type: 'non_stream' },
},
stream: {
executor: 'constant-arrival-rate',
rate: streamRPS,
timeUnit: '1s',
duration,
preAllocatedVUs,
maxVUs,
exec: 'runStream',
tags: { request_type: 'stream' },
},
},
thresholds: {
openai_oauth_non2xx_rate: ['rate<0.01'],
openai_oauth_req_duration_ms: ['p(95)<3000', 'p(99)<6000'],
openai_oauth_ttft_ms: ['p(99)<1200'],
openai_oauth_stream_done_rate: ['rate>0.99'],
},
};
function buildHeaders() {
const headers = {
'Content-Type': 'application/json',
'User-Agent': 'codex_cli_rs/0.1.0',
};
if (apiKey) {
headers.Authorization = `Bearer ${apiKey}`;
}
return headers;
}
function buildBody(stream) {
return JSON.stringify({
model,
stream,
input: [
{
role: 'user',
content: [
{
type: 'input_text',
text: '请返回一句极短的话pong',
},
],
},
],
max_output_tokens: 32,
});
}
function recordMetrics(res, stream) {
reqDurationMs.add(res.timings.duration, { request_type: stream ? 'stream' : 'non_stream' });
ttftMs.add(res.timings.waiting, { request_type: stream ? 'stream' : 'non_stream' });
non2xxRate.add(res.status < 200 || res.status >= 300, { request_type: stream ? 'stream' : 'non_stream' });
if (stream) {
const done = !!res.body && res.body.indexOf('[DONE]') >= 0;
streamDoneRate.add(done, { request_type: 'stream' });
}
}
function postResponses(stream) {
const url = `${baseURL}/v1/responses`;
const res = http.post(url, buildBody(stream), {
headers: buildHeaders(),
timeout,
tags: { endpoint: '/v1/responses', request_type: stream ? 'stream' : 'non_stream' },
});
check(res, {
'status is 2xx': (r) => r.status >= 200 && r.status < 300,
});
recordMetrics(res, stream);
return res;
}
export function runNonStream() {
postResponses(false);
}
export function runStream() {
postResponses(true);
}
export function handleSummary(data) {
return {
stdout: `\nOpenAI OAuth /v1/responses 基线完成\n${JSON.stringify(data.metrics, null, 2)}\n`,
'docs/perf/openai-oauth-k6-summary.json': JSON.stringify(data, null, 2),
};
}