feat(openai): 极致优化 OAuth 链路并补齐性能守护

- 优化 /v1/responses 热路径,减少重复解析与不必要拷贝\n- 优化并发与 token 竞争路径并补齐运行指标\n- 补充 OpenAI/Ops 相关单元测试与回归用例\n- 新增灰度阈值守护与压测脚本,支撑发布验收
This commit is contained in:
yangjianbo
2026-02-12 09:41:37 +08:00
parent a88bb8684f
commit 61a2bf469a
16 changed files with 1519 additions and 135 deletions

View File

@@ -104,31 +104,24 @@ func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFo
// wrapReleaseOnDone ensures release runs at most once and still triggers on context cancellation.
// 用于避免客户端断开或上游超时导致的并发槽位泄漏。
// 修复:添加 quit channel 确保 goroutine 及时退出,避免泄露
// 优化:基于 context.AfterFunc 注册回调,避免每请求额外守护 goroutine。
func wrapReleaseOnDone(ctx context.Context, releaseFunc func()) func() {
if releaseFunc == nil {
return nil
}
var once sync.Once
quit := make(chan struct{})
var stop func() bool
release := func() {
once.Do(func() {
if stop != nil {
_ = stop()
}
releaseFunc()
close(quit) // 通知监听 goroutine 退出
})
}
go func() {
select {
case <-ctx.Done():
// Context 取消时释放资源
release()
case <-quit:
// 正常释放已完成goroutine 退出
return
}
}()
stop = context.AfterFunc(ctx, release)
return release
}
@@ -153,6 +146,32 @@ func (h *ConcurrencyHelper) DecrementAccountWaitCount(ctx context.Context, accou
h.concurrencyService.DecrementAccountWaitCount(ctx, accountID)
}
// TryAcquireUserSlot 尝试立即获取用户并发槽位。
// 返回值: (releaseFunc, acquired, error)
func (h *ConcurrencyHelper) TryAcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (func(), bool, error) {
result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency)
if err != nil {
return nil, false, err
}
if !result.Acquired {
return nil, false, nil
}
return result.ReleaseFunc, true, nil
}
// TryAcquireAccountSlot 尝试立即获取账号并发槽位。
// 返回值: (releaseFunc, acquired, error)
func (h *ConcurrencyHelper) TryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (func(), bool, error) {
result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
if err != nil {
return nil, false, err
}
if !result.Acquired {
return nil, false, nil
}
return result.ReleaseFunc, true, nil
}
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
// For streaming requests, sends ping events during the wait.
// streamStarted is updated if streaming response has begun.
@@ -160,13 +179,13 @@ func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, userID int64
ctx := c.Request.Context()
// Try to acquire immediately
result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency)
releaseFunc, acquired, err := h.TryAcquireUserSlot(ctx, userID, maxConcurrency)
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
if acquired {
return releaseFunc, nil
}
// Need to wait - handle streaming ping if needed
@@ -180,13 +199,13 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
ctx := c.Request.Context()
// Try to acquire immediately
result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
releaseFunc, acquired, err := h.TryAcquireAccountSlot(ctx, accountID, maxConcurrency)
if err != nil {
return nil, err
}
if result.Acquired {
return result.ReleaseFunc, nil
if acquired {
return releaseFunc, nil
}
// Need to wait - handle streaming ping if needed

View File

@@ -0,0 +1,114 @@
package handler
import (
"context"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/stretchr/testify/require"
)
type concurrencyCacheMock struct {
acquireUserSlotFn func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error)
acquireAccountSlotFn func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error)
releaseUserCalled int32
releaseAccountCalled int32
}
func (m *concurrencyCacheMock) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
if m.acquireAccountSlotFn != nil {
return m.acquireAccountSlotFn(ctx, accountID, maxConcurrency, requestID)
}
return false, nil
}
func (m *concurrencyCacheMock) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
atomic.AddInt32(&m.releaseAccountCalled, 1)
return nil
}
func (m *concurrencyCacheMock) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
return 0, nil
}
func (m *concurrencyCacheMock) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
return true, nil
}
func (m *concurrencyCacheMock) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
return nil
}
func (m *concurrencyCacheMock) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
return 0, nil
}
func (m *concurrencyCacheMock) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
if m.acquireUserSlotFn != nil {
return m.acquireUserSlotFn(ctx, userID, maxConcurrency, requestID)
}
return false, nil
}
func (m *concurrencyCacheMock) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
atomic.AddInt32(&m.releaseUserCalled, 1)
return nil
}
func (m *concurrencyCacheMock) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
return 0, nil
}
func (m *concurrencyCacheMock) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
return true, nil
}
func (m *concurrencyCacheMock) DecrementWaitCount(ctx context.Context, userID int64) error {
return nil
}
func (m *concurrencyCacheMock) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
return map[int64]*service.AccountLoadInfo{}, nil
}
func (m *concurrencyCacheMock) GetUsersLoadBatch(ctx context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
return map[int64]*service.UserLoadInfo{}, nil
}
func (m *concurrencyCacheMock) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
return nil
}
func TestConcurrencyHelper_TryAcquireUserSlot(t *testing.T) {
cache := &concurrencyCacheMock{
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
return true, nil
},
}
helper := NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second)
release, acquired, err := helper.TryAcquireUserSlot(context.Background(), 101, 2)
require.NoError(t, err)
require.True(t, acquired)
require.NotNil(t, release)
release()
require.Equal(t, int32(1), atomic.LoadInt32(&cache.releaseUserCalled))
}
func TestConcurrencyHelper_TryAcquireAccountSlot_NotAcquired(t *testing.T) {
cache := &concurrencyCacheMock{
acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
return false, nil
},
}
helper := NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second)
release, acquired, err := helper.TryAcquireAccountSlot(context.Background(), 201, 1)
require.NoError(t, err)
require.False(t, acquired)
require.Nil(t, release)
require.Equal(t, int32(0), atomic.LoadInt32(&cache.releaseAccountCalled))
}

View File

@@ -64,6 +64,8 @@ func NewOpenAIGatewayHandler(
// Responses handles OpenAI Responses API endpoint
// POST /openai/v1/responses
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
requestStart := time.Now()
// Get apiKey and user from context (set by ApiKeyAuth middleware)
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
if !ok {
@@ -141,6 +143,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
if gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err == nil {
c.Set(service.OpenAIParsedRequestBodyKey, reqBody)
if service.HasFunctionCallOutput(reqBody) {
previousResponseID, _ := reqBody["previous_response_id"].(string)
if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) {
@@ -171,34 +174,47 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
// Get subscription info (may be nil)
subscription, _ := middleware2.GetSubscriptionFromContext(c)
// 0. Check if wait queue is full
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
waitCounted := false
if err != nil {
log.Printf("Increment wait count failed: %v", err)
// On error, allow request to proceed
} else if !canWait {
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
if err == nil && canWait {
waitCounted = true
}
defer func() {
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
}
}()
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
routingStart := time.Now()
// 1. First acquire user concurrency slot
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
// 0. 先尝试直接抢占用户槽位(快速路径)
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(c.Request.Context(), subject.UserID, subject.Concurrency)
if err != nil {
log.Printf("User concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
// User slot acquired: no longer waiting.
waitCounted := false
if !userAcquired {
// 仅在抢槽失败时才进入等待队列,减少常态请求 Redis 写入。
maxWait := service.CalculateMaxWait(subject.Concurrency)
canWait, waitErr := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
if waitErr != nil {
log.Printf("Increment wait count failed: %v", waitErr)
// 按现有降级语义:等待计数异常时放行后续抢槽流程
} else if !canWait {
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
return
}
if waitErr == nil && canWait {
waitCounted = true
}
defer func() {
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
}
}()
userReleaseFunc, err = h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
if err != nil {
log.Printf("User concurrency acquire failed: %v", err)
h.handleConcurrencyError(c, err, "user", streamStarted)
return
}
}
// 用户槽位已获取:退出等待队列计数。
if waitCounted {
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
waitCounted = false
@@ -253,53 +269,84 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
return
}
accountWaitCounted := false
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
}
if err == nil && canWait {
accountWaitCounted = true
}
releaseWait := func() {
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
// 先快速尝试一次账号槽位,命中则跳过等待计数写入。
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
c.Request.Context(),
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
releaseWait()
log.Printf("Account concurrency quick acquire failed: %v", err)
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
// Slot acquired: no longer waiting in queue.
releaseWait()
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
if fastAcquired {
accountReleaseFunc = fastReleaseFunc
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
} else {
accountWaitCounted := false
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
if err != nil {
log.Printf("Increment account wait count failed: %v", err)
} else if !canWait {
log.Printf("Account wait queue full: account=%d", account.ID)
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later", streamStarted)
return
}
if err == nil && canWait {
accountWaitCounted = true
}
releaseWait := func() {
if accountWaitCounted {
h.concurrencyHelper.DecrementAccountWaitCount(c.Request.Context(), account.ID)
accountWaitCounted = false
}
}
accountReleaseFunc, err = h.concurrencyHelper.AcquireAccountSlotWithWaitTimeout(
c,
account.ID,
selection.WaitPlan.MaxConcurrency,
selection.WaitPlan.Timeout,
reqStream,
&streamStarted,
)
if err != nil {
log.Printf("Account concurrency acquire failed: %v", err)
releaseWait()
h.handleConcurrencyError(c, err, "account", streamStarted)
return
}
// Slot acquired: no longer waiting in queue.
releaseWait()
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
log.Printf("Bind sticky session failed: %v", err)
}
}
}
// 账号槽位/等待计数需要在超时或断开时安全回收
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
// Forward request
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
forwardStart := time.Now()
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
forwardDurationMs := time.Since(forwardStart).Milliseconds()
if accountReleaseFunc != nil {
accountReleaseFunc()
}
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
responseLatencyMs := forwardDurationMs
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
}
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
if err == nil && result != nil && result.FirstTokenMs != nil {
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
}
if err != nil {
var failoverErr *service.UpstreamFailoverError
if errors.As(err, &failoverErr) {
@@ -343,6 +390,28 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
}
}
func getContextInt64(c *gin.Context, key string) (int64, bool) {
if c == nil || key == "" {
return 0, false
}
v, ok := c.Get(key)
if !ok {
return 0, false
}
switch t := v.(type) {
case int64:
return t, true
case int:
return int64(t), true
case int32:
return int64(t), true
case float64:
return int64(t), true
default:
return 0, false
}
}
// handleConcurrencyError handles concurrency-related errors with proper 429 response
func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",

View File

@@ -507,6 +507,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
RetryCount: 0,
CreatedAt: time.Now(),
}
applyOpsLatencyFieldsFromContext(c, entry)
if apiKey != nil {
entry.APIKeyID = &apiKey.ID
@@ -618,6 +619,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
RetryCount: 0,
CreatedAt: time.Now(),
}
applyOpsLatencyFieldsFromContext(c, entry)
// Capture upstream error context set by gateway services (if present).
// This does NOT affect the client response; it enriches Ops troubleshooting data.
@@ -746,6 +748,44 @@ func extractOpsRetryRequestHeaders(c *gin.Context) *string {
return &s
}
func applyOpsLatencyFieldsFromContext(c *gin.Context, entry *service.OpsInsertErrorLogInput) {
if c == nil || entry == nil {
return
}
entry.AuthLatencyMs = getContextLatencyMs(c, service.OpsAuthLatencyMsKey)
entry.RoutingLatencyMs = getContextLatencyMs(c, service.OpsRoutingLatencyMsKey)
entry.UpstreamLatencyMs = getContextLatencyMs(c, service.OpsUpstreamLatencyMsKey)
entry.ResponseLatencyMs = getContextLatencyMs(c, service.OpsResponseLatencyMsKey)
entry.TimeToFirstTokenMs = getContextLatencyMs(c, service.OpsTimeToFirstTokenMsKey)
}
func getContextLatencyMs(c *gin.Context, key string) *int64 {
if c == nil || strings.TrimSpace(key) == "" {
return nil
}
v, ok := c.Get(key)
if !ok {
return nil
}
var ms int64
switch t := v.(type) {
case int:
ms = int64(t)
case int32:
ms = int64(t)
case int64:
ms = t
case float64:
ms = int64(t)
default:
return nil
}
if ms < 0 {
return nil
}
return &ms
}
type parsedOpsError struct {
ErrorType string
Message string

View File

@@ -55,6 +55,10 @@ INSERT INTO ops_error_logs (
upstream_error_message,
upstream_error_detail,
upstream_errors,
auth_latency_ms,
routing_latency_ms,
upstream_latency_ms,
response_latency_ms,
time_to_first_token_ms,
request_body,
request_body_truncated,
@@ -64,7 +68,7 @@ INSERT INTO ops_error_logs (
retry_count,
created_at
) VALUES (
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38
) RETURNING id`
var id int64
@@ -97,6 +101,10 @@ INSERT INTO ops_error_logs (
opsNullString(input.UpstreamErrorMessage),
opsNullString(input.UpstreamErrorDetail),
opsNullString(input.UpstreamErrorsJSON),
opsNullInt64(input.AuthLatencyMs),
opsNullInt64(input.RoutingLatencyMs),
opsNullInt64(input.UpstreamLatencyMs),
opsNullInt64(input.ResponseLatencyMs),
opsNullInt64(input.TimeToFirstTokenMs),
opsNullString(input.RequestBodyJSON),
input.RequestBodyTruncated,

View File

@@ -12,7 +12,6 @@ import (
"io"
"log"
"net/http"
"regexp"
"sort"
"strconv"
"strings"
@@ -34,11 +33,10 @@ const (
// OpenAI Platform API for API Key accounts (fallback)
openaiPlatformAPIURL = "https://api.openai.com/v1/responses"
openaiStickySessionTTL = time.Hour // 粘性会话TTL
)
// openaiSSEDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var openaiSSEDataRe = regexp.MustCompile(`^data:\s*`)
// OpenAIParsedRequestBodyKey 缓存 handler 侧已解析的请求体,避免重复解析。
OpenAIParsedRequestBodyKey = "openai_parsed_request_body"
)
// OpenAI allowed headers whitelist (for non-OAuth accounts)
var openaiAllowedHeaders = map[string]bool{
@@ -745,32 +743,37 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
startTime := time.Now()
originalBody := body
reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body)
originalModel := reqModel
// Parse request body once (avoid multiple parse/serialize cycles)
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
return nil, fmt.Errorf("parse request: %w", err)
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
passthroughEnabled := account.Type == AccountTypeOAuth && account.IsOpenAIOAuthPassthroughEnabled() && isCodexCLI
if passthroughEnabled {
// 透传分支只需要轻量提取字段,避免热路径全量 Unmarshal。
reasoningEffort := extractOpenAIReasoningEffortFromBody(body, reqModel)
return s.forwardOAuthPassthrough(ctx, c, account, originalBody, reqModel, reasoningEffort, reqStream, startTime)
}
// Extract model and stream from parsed body
reqModel, _ := reqBody["model"].(string)
reqStream, _ := reqBody["stream"].(bool)
promptCacheKey := ""
if v, ok := reqBody["prompt_cache_key"].(string); ok {
promptCacheKey = strings.TrimSpace(v)
reqBody, err := getOpenAIRequestBodyMap(c, body)
if err != nil {
return nil, err
}
if v, ok := reqBody["model"].(string); ok {
reqModel = v
originalModel = reqModel
}
if v, ok := reqBody["stream"].(bool); ok {
reqStream = v
}
if promptCacheKey == "" {
if v, ok := reqBody["prompt_cache_key"].(string); ok {
promptCacheKey = strings.TrimSpace(v)
}
}
// Track if body needs re-serialization
bodyModified := false
originalModel := reqModel
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
passthroughEnabled := account.Type == AccountTypeOAuth && account.IsOpenAIOAuthPassthroughEnabled() && isCodexCLI
if passthroughEnabled {
reasoningEffort := extractOpenAIReasoningEffort(reqBody, reqModel)
return s.forwardOAuthPassthrough(ctx, c, account, originalBody, reqModel, reasoningEffort, reqStream, startTime)
}
// 对所有请求执行模型映射(包含 Codex CLI
mappedModel := account.GetMappedModel(reqModel)
@@ -888,12 +891,12 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
}
// Capture upstream request body for ops retry of this attempt.
if c != nil {
c.Set(OpsUpstreamRequestBodyKey, string(body))
}
setOpsUpstreamRequestBody(c, body)
// Send request
upstreamStart := time.Now()
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds())
if err != nil {
// Ensure the client receives an error response (handlers assume Forward writes on non-failover errors).
safeErr := sanitizeUpstreamErrorMessage(err.Error())
@@ -1019,12 +1022,14 @@ func (s *OpenAIGatewayService) forwardOAuthPassthrough(
proxyURL = account.Proxy.URL()
}
setOpsUpstreamRequestBody(c, body)
if c != nil {
c.Set(OpsUpstreamRequestBodyKey, string(body))
c.Set("openai_passthrough", true)
}
upstreamStart := time.Now()
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds())
if err != nil {
safeErr := sanitizeUpstreamErrorMessage(err.Error())
setOpsUpstreamError(c, 0, safeErr, "")
@@ -1240,8 +1245,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
for scanner.Scan() {
line := scanner.Text()
if openaiSSEDataRe.MatchString(line) {
data := openaiSSEDataRe.ReplaceAllString(line, "")
if data, ok := extractOpenAISSEDataLine(line); ok {
if firstTokenMs == nil && strings.TrimSpace(data) != "" {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
@@ -1750,8 +1754,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
lastDataAt = time.Now()
// Extract data from SSE line (supports both "data: " and "data:" formats)
if openaiSSEDataRe.MatchString(line) {
data := openaiSSEDataRe.ReplaceAllString(line, "")
if data, ok := extractOpenAISSEDataLine(line); ok {
// Replace model in response if needed
if needModelReplace {
@@ -1827,11 +1830,27 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
}
// extractOpenAISSEDataLine 低开销提取 SSE `data:` 行内容。
// 兼容 `data: xxx` 与 `data:xxx` 两种格式。
func extractOpenAISSEDataLine(line string) (string, bool) {
if !strings.HasPrefix(line, "data:") {
return "", false
}
start := len("data:")
for start < len(line) {
if line[start] != ' ' && line[start] != ' ' {
break
}
start++
}
return line[start:], true
}
func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string {
if !openaiSSEDataRe.MatchString(line) {
data, ok := extractOpenAISSEDataLine(line)
if !ok {
return line
}
data := openaiSSEDataRe.ReplaceAllString(line, "")
if data == "" || data == "[DONE]" {
return line
}
@@ -1872,25 +1891,20 @@ func (s *OpenAIGatewayService) correctToolCallsInResponseBody(body []byte) []byt
}
func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
// Parse response.completed event for usage (OpenAI Responses format)
var event struct {
Type string `json:"type"`
Response struct {
Usage struct {
InputTokens int `json:"input_tokens"`
OutputTokens int `json:"output_tokens"`
InputTokenDetails struct {
CachedTokens int `json:"cached_tokens"`
} `json:"input_tokens_details"`
} `json:"usage"`
} `json:"response"`
if usage == nil || data == "" || data == "[DONE]" {
return
}
// 选择性解析:仅在数据中包含 completed 事件标识时才进入字段提取。
if !strings.Contains(data, `"response.completed"`) {
return
}
if gjson.Get(data, "type").String() != "response.completed" {
return
}
if json.Unmarshal([]byte(data), &event) == nil && event.Type == "response.completed" {
usage.InputTokens = event.Response.Usage.InputTokens
usage.OutputTokens = event.Response.Usage.OutputTokens
usage.CacheReadInputTokens = event.Response.Usage.InputTokenDetails.CachedTokens
}
usage.InputTokens = int(gjson.Get(data, "response.usage.input_tokens").Int())
usage.OutputTokens = int(gjson.Get(data, "response.usage.output_tokens").Int())
usage.CacheReadInputTokens = int(gjson.Get(data, "response.usage.input_tokens_details.cached_tokens").Int())
}
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
@@ -2001,10 +2015,10 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.
func extractCodexFinalResponse(body string) ([]byte, bool) {
lines := strings.Split(body, "\n")
for _, line := range lines {
if !openaiSSEDataRe.MatchString(line) {
data, ok := extractOpenAISSEDataLine(line)
if !ok {
continue
}
data := openaiSSEDataRe.ReplaceAllString(line, "")
if data == "" || data == "[DONE]" {
continue
}
@@ -2028,10 +2042,10 @@ func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage {
usage := &OpenAIUsage{}
lines := strings.Split(body, "\n")
for _, line := range lines {
if !openaiSSEDataRe.MatchString(line) {
data, ok := extractOpenAISSEDataLine(line)
if !ok {
continue
}
data := openaiSSEDataRe.ReplaceAllString(line, "")
if data == "" || data == "[DONE]" {
continue
}
@@ -2043,7 +2057,7 @@ func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage {
func (s *OpenAIGatewayService) replaceModelInSSEBody(body, fromModel, toModel string) string {
lines := strings.Split(body, "\n")
for i, line := range lines {
if !openaiSSEDataRe.MatchString(line) {
if _, ok := extractOpenAISSEDataLine(line); !ok {
continue
}
lines[i] = s.replaceModelInSSELine(line, fromModel, toModel)
@@ -2396,6 +2410,53 @@ func deriveOpenAIReasoningEffortFromModel(model string) string {
return normalizeOpenAIReasoningEffort(parts[len(parts)-1])
}
func extractOpenAIRequestMetaFromBody(body []byte) (model string, stream bool, promptCacheKey string) {
if len(body) == 0 {
return "", false, ""
}
model = strings.TrimSpace(gjson.GetBytes(body, "model").String())
stream = gjson.GetBytes(body, "stream").Bool()
promptCacheKey = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
return model, stream, promptCacheKey
}
func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *string {
reasoningEffort := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String())
if reasoningEffort == "" {
reasoningEffort = strings.TrimSpace(gjson.GetBytes(body, "reasoning_effort").String())
}
if reasoningEffort != "" {
normalized := normalizeOpenAIReasoningEffort(reasoningEffort)
if normalized == "" {
return nil
}
return &normalized
}
value := deriveOpenAIReasoningEffortFromModel(requestedModel)
if value == "" {
return nil
}
return &value
}
func getOpenAIRequestBodyMap(c *gin.Context, body []byte) (map[string]any, error) {
if c != nil {
if cached, ok := c.Get(OpenAIParsedRequestBodyKey); ok {
if reqBody, ok := cached.(map[string]any); ok && reqBody != nil {
return reqBody, nil
}
}
}
var reqBody map[string]any
if err := json.Unmarshal(body, &reqBody); err != nil {
return nil, fmt.Errorf("parse request: %w", err)
}
return reqBody, nil
}
func extractOpenAIReasoningEffort(reqBody map[string]any, requestedModel string) *string {
if value, present := getOpenAIReasoningEffortFromReqBody(reqBody); present {
if value == "" {

View File

@@ -0,0 +1,125 @@
package service
import (
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestExtractOpenAIRequestMetaFromBody(t *testing.T) {
tests := []struct {
name string
body []byte
wantModel string
wantStream bool
wantPromptKey string
}{
{
name: "完整字段",
body: []byte(`{"model":"gpt-5","stream":true,"prompt_cache_key":" ses-1 "}`),
wantModel: "gpt-5",
wantStream: true,
wantPromptKey: "ses-1",
},
{
name: "缺失可选字段",
body: []byte(`{"model":"gpt-4"}`),
wantModel: "gpt-4",
wantStream: false,
wantPromptKey: "",
},
{
name: "空请求体",
body: nil,
wantModel: "",
wantStream: false,
wantPromptKey: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
model, stream, promptKey := extractOpenAIRequestMetaFromBody(tt.body)
require.Equal(t, tt.wantModel, model)
require.Equal(t, tt.wantStream, stream)
require.Equal(t, tt.wantPromptKey, promptKey)
})
}
}
func TestExtractOpenAIReasoningEffortFromBody(t *testing.T) {
tests := []struct {
name string
body []byte
model string
wantNil bool
wantValue string
}{
{
name: "优先读取 reasoning.effort",
body: []byte(`{"reasoning":{"effort":"medium"}}`),
model: "gpt-5-high",
wantNil: false,
wantValue: "medium",
},
{
name: "兼容 reasoning_effort",
body: []byte(`{"reasoning_effort":"x-high"}`),
model: "",
wantNil: false,
wantValue: "xhigh",
},
{
name: "minimal 归一化为空",
body: []byte(`{"reasoning":{"effort":"minimal"}}`),
model: "gpt-5-high",
wantNil: true,
},
{
name: "缺失字段时从模型后缀推导",
body: []byte(`{"input":"hi"}`),
model: "gpt-5-high",
wantNil: false,
wantValue: "high",
},
{
name: "未知后缀不返回",
body: []byte(`{"input":"hi"}`),
model: "gpt-5-unknown",
wantNil: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractOpenAIReasoningEffortFromBody(tt.body, tt.model)
if tt.wantNil {
require.Nil(t, got)
return
}
require.NotNil(t, got)
require.Equal(t, tt.wantValue, *got)
})
}
}
func TestGetOpenAIRequestBodyMap_UsesContextCache(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
cached := map[string]any{"model": "cached-model", "stream": true}
c.Set(OpenAIParsedRequestBodyKey, cached)
got, err := getOpenAIRequestBodyMap(c, []byte(`{invalid-json`))
require.NoError(t, err)
require.Equal(t, cached, got)
}
func TestGetOpenAIRequestBodyMap_ParseErrorWithoutCache(t *testing.T) {
_, err := getOpenAIRequestBodyMap(nil, []byte(`{invalid-json`))
require.Error(t, err)
require.Contains(t, err.Error(), "parse request")
}

View File

@@ -1416,3 +1416,109 @@ func TestReplaceModelInResponseBody(t *testing.T) {
})
}
}
func TestExtractOpenAISSEDataLine(t *testing.T) {
tests := []struct {
name string
line string
wantData string
wantOK bool
}{
{name: "标准格式", line: `data: {"type":"x"}`, wantData: `{"type":"x"}`, wantOK: true},
{name: "无空格格式", line: `data:{"type":"x"}`, wantData: `{"type":"x"}`, wantOK: true},
{name: "纯空数据", line: `data: `, wantData: ``, wantOK: true},
{name: "非 data 行", line: `event: message`, wantData: ``, wantOK: false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, ok := extractOpenAISSEDataLine(tt.line)
require.Equal(t, tt.wantOK, ok)
require.Equal(t, tt.wantData, got)
})
}
}
func TestParseSSEUsage_SelectiveParsing(t *testing.T) {
svc := &OpenAIGatewayService{}
usage := &OpenAIUsage{InputTokens: 9, OutputTokens: 8, CacheReadInputTokens: 7}
// 非 completed 事件,不应覆盖 usage
svc.parseSSEUsage(`{"type":"response.in_progress","response":{"usage":{"input_tokens":1,"output_tokens":2}}}`, usage)
require.Equal(t, 9, usage.InputTokens)
require.Equal(t, 8, usage.OutputTokens)
require.Equal(t, 7, usage.CacheReadInputTokens)
// completed 事件,应提取 usage
svc.parseSSEUsage(`{"type":"response.completed","response":{"usage":{"input_tokens":3,"output_tokens":5,"input_tokens_details":{"cached_tokens":2}}}}`, usage)
require.Equal(t, 3, usage.InputTokens)
require.Equal(t, 5, usage.OutputTokens)
require.Equal(t, 2, usage.CacheReadInputTokens)
}
func TestExtractCodexFinalResponse_SampleReplay(t *testing.T) {
body := strings.Join([]string{
`event: message`,
`data: {"type":"response.in_progress","response":{"id":"resp_1"}}`,
`data: {"type":"response.completed","response":{"id":"resp_1","model":"gpt-4o","usage":{"input_tokens":11,"output_tokens":22,"input_tokens_details":{"cached_tokens":3}}}}`,
`data: [DONE]`,
}, "\n")
finalResp, ok := extractCodexFinalResponse(body)
require.True(t, ok)
require.Contains(t, string(finalResp), `"id":"resp_1"`)
require.Contains(t, string(finalResp), `"input_tokens":11`)
}
func TestHandleOAuthSSEToJSON_CompletedEventReturnsJSON(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
svc := &OpenAIGatewayService{cfg: &config.Config{}}
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
}
body := []byte(strings.Join([]string{
`data: {"type":"response.in_progress","response":{"id":"resp_2"}}`,
`data: {"type":"response.completed","response":{"id":"resp_2","model":"gpt-4o","usage":{"input_tokens":7,"output_tokens":9,"input_tokens_details":{"cached_tokens":1}}}}`,
`data: [DONE]`,
}, "\n"))
usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o")
require.NoError(t, err)
require.NotNil(t, usage)
require.Equal(t, 7, usage.InputTokens)
require.Equal(t, 9, usage.OutputTokens)
require.Equal(t, 1, usage.CacheReadInputTokens)
// Header 可能由上游 Content-Type 透传;关键是 body 已转换为最终 JSON 响应。
require.NotContains(t, rec.Body.String(), "event:")
require.Contains(t, rec.Body.String(), `"id":"resp_2"`)
require.NotContains(t, rec.Body.String(), "data:")
}
func TestHandleOAuthSSEToJSON_NoFinalResponseKeepsSSEBody(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
svc := &OpenAIGatewayService{cfg: &config.Config{}}
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
}
body := []byte(strings.Join([]string{
`data: {"type":"response.in_progress","response":{"id":"resp_3"}}`,
`data: [DONE]`,
}, "\n"))
usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o")
require.NoError(t, err)
require.NotNil(t, usage)
require.Equal(t, 0, usage.InputTokens)
require.Contains(t, rec.Header().Get("Content-Type"), "text/event-stream")
require.Contains(t, rec.Body.String(), `data: {"type":"response.in_progress"`)
}

View File

@@ -4,16 +4,74 @@ import (
"context"
"errors"
"log/slog"
"math/rand/v2"
"strings"
"sync/atomic"
"time"
)
const (
openAITokenRefreshSkew = 3 * time.Minute
openAITokenCacheSkew = 5 * time.Minute
openAILockWaitTime = 200 * time.Millisecond
openAITokenRefreshSkew = 3 * time.Minute
openAITokenCacheSkew = 5 * time.Minute
openAILockInitialWait = 20 * time.Millisecond
openAILockMaxWait = 120 * time.Millisecond
openAILockMaxAttempts = 5
openAILockJitterRatio = 0.2
openAILockWarnThresholdMs = 250
)
// OpenAITokenRuntimeMetrics 表示 OpenAI token 刷新与锁竞争保护指标快照。
type OpenAITokenRuntimeMetrics struct {
RefreshRequests int64
RefreshSuccess int64
RefreshFailure int64
LockAcquireFailure int64
LockContention int64
LockWaitSamples int64
LockWaitTotalMs int64
LockWaitHit int64
LockWaitMiss int64
LastObservedUnixMs int64
}
type openAITokenRuntimeMetricsStore struct {
refreshRequests atomic.Int64
refreshSuccess atomic.Int64
refreshFailure atomic.Int64
lockAcquireFailure atomic.Int64
lockContention atomic.Int64
lockWaitSamples atomic.Int64
lockWaitTotalMs atomic.Int64
lockWaitHit atomic.Int64
lockWaitMiss atomic.Int64
lastObservedUnixMs atomic.Int64
}
func (m *openAITokenRuntimeMetricsStore) snapshot() OpenAITokenRuntimeMetrics {
if m == nil {
return OpenAITokenRuntimeMetrics{}
}
return OpenAITokenRuntimeMetrics{
RefreshRequests: m.refreshRequests.Load(),
RefreshSuccess: m.refreshSuccess.Load(),
RefreshFailure: m.refreshFailure.Load(),
LockAcquireFailure: m.lockAcquireFailure.Load(),
LockContention: m.lockContention.Load(),
LockWaitSamples: m.lockWaitSamples.Load(),
LockWaitTotalMs: m.lockWaitTotalMs.Load(),
LockWaitHit: m.lockWaitHit.Load(),
LockWaitMiss: m.lockWaitMiss.Load(),
LastObservedUnixMs: m.lastObservedUnixMs.Load(),
}
}
func (m *openAITokenRuntimeMetricsStore) touchNow() {
if m == nil {
return
}
m.lastObservedUnixMs.Store(time.Now().UnixMilli())
}
// OpenAITokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
type OpenAITokenCache = GeminiTokenCache
@@ -22,6 +80,7 @@ type OpenAITokenProvider struct {
accountRepo AccountRepository
tokenCache OpenAITokenCache
openAIOAuthService *OpenAIOAuthService
metrics *openAITokenRuntimeMetricsStore
}
func NewOpenAITokenProvider(
@@ -33,11 +92,27 @@ func NewOpenAITokenProvider(
accountRepo: accountRepo,
tokenCache: tokenCache,
openAIOAuthService: openAIOAuthService,
metrics: &openAITokenRuntimeMetricsStore{},
}
}
func (p *OpenAITokenProvider) SnapshotRuntimeMetrics() OpenAITokenRuntimeMetrics {
if p == nil {
return OpenAITokenRuntimeMetrics{}
}
p.ensureMetrics()
return p.metrics.snapshot()
}
func (p *OpenAITokenProvider) ensureMetrics() {
if p != nil && p.metrics == nil {
p.metrics = &openAITokenRuntimeMetricsStore{}
}
}
// GetAccessToken 获取有效的 access_token
func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
p.ensureMetrics()
if account == nil {
return "", errors.New("account is nil")
}
@@ -64,6 +139,8 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
refreshFailed := false
if needsRefresh && p.tokenCache != nil {
p.metrics.refreshRequests.Add(1)
p.metrics.touchNow()
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if lockErr == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
@@ -82,14 +159,17 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
if p.openAIOAuthService == nil {
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
p.metrics.refreshFailure.Add(1)
refreshFailed = true // 无法刷新,标记失败
} else {
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err)
p.metrics.refreshFailure.Add(1)
refreshFailed = true // 刷新失败,标记以使用短 TTL
} else {
p.metrics.refreshSuccess.Add(1)
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
@@ -106,6 +186,8 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
}
} else if lockErr != nil {
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
p.metrics.lockAcquireFailure.Add(1)
p.metrics.touchNow()
slog.Warn("openai_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr)
// 检查 ctx 是否已取消
@@ -126,13 +208,16 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
if p.openAIOAuthService == nil {
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
p.metrics.refreshFailure.Add(1)
refreshFailed = true
} else {
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
slog.Warn("openai_token_refresh_failed_degraded", "account_id", account.ID, "error", err)
p.metrics.refreshFailure.Add(1)
refreshFailed = true
} else {
p.metrics.refreshSuccess.Add(1)
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
for k, v := range account.Credentials {
if _, exists := newCredentials[k]; !exists {
@@ -148,9 +233,14 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
}
}
} else {
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
time.Sleep(openAILockWaitTime)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
// 锁被其他 worker 持有:使用短轮询+jitter降低固定等待导致的尾延迟台阶。
p.metrics.lockContention.Add(1)
p.metrics.touchNow()
token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey)
if waitErr != nil {
return "", waitErr
}
if strings.TrimSpace(token) != "" {
slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID)
return token, nil
}
@@ -198,3 +288,64 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
return accessToken, nil
}
func (p *OpenAITokenProvider) waitForTokenAfterLockRace(ctx context.Context, cacheKey string) (string, error) {
wait := openAILockInitialWait
totalWaitMs := int64(0)
for i := 0; i < openAILockMaxAttempts; i++ {
actualWait := jitterLockWait(wait)
timer := time.NewTimer(actualWait)
select {
case <-ctx.Done():
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
return "", ctx.Err()
case <-timer.C:
}
waitMs := actualWait.Milliseconds()
if waitMs < 0 {
waitMs = 0
}
totalWaitMs += waitMs
p.metrics.lockWaitSamples.Add(1)
p.metrics.lockWaitTotalMs.Add(waitMs)
p.metrics.touchNow()
token, err := p.tokenCache.GetAccessToken(ctx, cacheKey)
if err == nil && strings.TrimSpace(token) != "" {
p.metrics.lockWaitHit.Add(1)
if totalWaitMs >= openAILockWarnThresholdMs {
slog.Warn("openai_token_lock_wait_high", "wait_ms", totalWaitMs, "attempts", i+1)
}
return token, nil
}
if wait < openAILockMaxWait {
wait *= 2
if wait > openAILockMaxWait {
wait = openAILockMaxWait
}
}
}
p.metrics.lockWaitMiss.Add(1)
if totalWaitMs >= openAILockWarnThresholdMs {
slog.Warn("openai_token_lock_wait_high", "wait_ms", totalWaitMs, "attempts", openAILockMaxAttempts)
}
return "", nil
}
func jitterLockWait(base time.Duration) time.Duration {
if base <= 0 {
return 0
}
minFactor := 1 - openAILockJitterRatio
maxFactor := 1 + openAILockJitterRatio
factor := minFactor + rand.Float64()*(maxFactor-minFactor)
return time.Duration(float64(base) * factor)
}

View File

@@ -808,3 +808,119 @@ func TestOpenAITokenProvider_Real_NilCredentials(t *testing.T) {
require.Contains(t, err.Error(), "access_token not found")
require.Empty(t, token)
}
func TestOpenAITokenProvider_Real_LockRace_PollingHitsCache(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.lockAcquired = false // 模拟锁被其他 worker 持有
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 207,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-token",
"expires_at": expiresAt,
},
}
cacheKey := OpenAITokenCacheKey(account)
go func() {
time.Sleep(5 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "winner-token"
cache.mu.Unlock()
}()
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "winner-token", token)
}
func TestOpenAITokenProvider_Real_LockRace_ContextCanceled(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.lockAcquired = false // 模拟锁被其他 worker 持有
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 208,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-token",
"expires_at": expiresAt,
},
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
provider := NewOpenAITokenProvider(nil, cache, nil)
start := time.Now()
token, err := provider.GetAccessToken(ctx, account)
require.Error(t, err)
require.ErrorIs(t, err, context.Canceled)
require.Empty(t, token)
require.Less(t, time.Since(start), 50*time.Millisecond)
}
func TestOpenAITokenProvider_RuntimeMetrics_LockWaitHitAndSnapshot(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.lockAcquired = false
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 209,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-token",
"expires_at": expiresAt,
},
}
cacheKey := OpenAITokenCacheKey(account)
go func() {
time.Sleep(10 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "winner-token"
cache.mu.Unlock()
}()
provider := NewOpenAITokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "winner-token", token)
metrics := provider.SnapshotRuntimeMetrics()
require.GreaterOrEqual(t, metrics.RefreshRequests, int64(1))
require.GreaterOrEqual(t, metrics.LockContention, int64(1))
require.GreaterOrEqual(t, metrics.LockWaitSamples, int64(1))
require.GreaterOrEqual(t, metrics.LockWaitHit, int64(1))
require.GreaterOrEqual(t, metrics.LockWaitTotalMs, int64(0))
require.GreaterOrEqual(t, metrics.LastObservedUnixMs, int64(1))
}
func TestOpenAITokenProvider_RuntimeMetrics_LockAcquireFailure(t *testing.T) {
cache := newOpenAITokenCacheStub()
cache.lockErr = errors.New("redis lock error")
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 210,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-token",
"expires_at": expiresAt,
},
}
provider := NewOpenAITokenProvider(nil, cache, nil)
_, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
metrics := provider.SnapshotRuntimeMetrics()
require.GreaterOrEqual(t, metrics.LockAcquireFailure, int64(1))
require.GreaterOrEqual(t, metrics.RefreshRequests, int64(1))
}

View File

@@ -98,6 +98,10 @@ type OpsInsertErrorLogInput struct {
// It is set by OpsService.RecordError before persisting.
UpstreamErrorsJSON *string
AuthLatencyMs *int64
RoutingLatencyMs *int64
UpstreamLatencyMs *int64
ResponseLatencyMs *int64
TimeToFirstTokenMs *int64
RequestBodyJSON *string // sanitized json string (not raw bytes)

View File

@@ -20,8 +20,30 @@ const (
// retry the specific upstream attempt (not just the client request).
// This value is sanitized+trimmed before being persisted.
OpsUpstreamRequestBodyKey = "ops_upstream_request_body"
// Optional stage latencies (milliseconds) for troubleshooting and alerting.
OpsAuthLatencyMsKey = "ops_auth_latency_ms"
OpsRoutingLatencyMsKey = "ops_routing_latency_ms"
OpsUpstreamLatencyMsKey = "ops_upstream_latency_ms"
OpsResponseLatencyMsKey = "ops_response_latency_ms"
OpsTimeToFirstTokenMsKey = "ops_time_to_first_token_ms"
)
func setOpsUpstreamRequestBody(c *gin.Context, body []byte) {
if c == nil || len(body) == 0 {
return
}
// 热路径避免 string(body) 额外分配,按需在落库前再转换。
c.Set(OpsUpstreamRequestBodyKey, body)
}
func SetOpsLatencyMs(c *gin.Context, key string, value int64) {
if c == nil || strings.TrimSpace(key) == "" || value < 0 {
return
}
c.Set(key, value)
}
func setOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage, upstreamDetail string) {
if c == nil {
return
@@ -91,8 +113,11 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) {
// stored it on the context, attach it so ops can retry this specific attempt.
if ev.UpstreamRequestBody == "" {
if v, ok := c.Get(OpsUpstreamRequestBodyKey); ok {
if s, ok := v.(string); ok {
ev.UpstreamRequestBody = strings.TrimSpace(s)
switch raw := v.(type) {
case string:
ev.UpstreamRequestBody = strings.TrimSpace(raw)
case []byte:
ev.UpstreamRequestBody = strings.TrimSpace(string(raw))
}
}
}

View File

@@ -0,0 +1,47 @@
package service
import (
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestAppendOpsUpstreamError_UsesRequestBodyBytesFromContext(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
setOpsUpstreamRequestBody(c, []byte(`{"model":"gpt-5"}`))
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Kind: "http_error",
Message: "upstream failed",
})
v, ok := c.Get(OpsUpstreamErrorsKey)
require.True(t, ok)
events, ok := v.([]*OpsUpstreamErrorEvent)
require.True(t, ok)
require.Len(t, events, 1)
require.Equal(t, `{"model":"gpt-5"}`, events[0].UpstreamRequestBody)
}
func TestAppendOpsUpstreamError_UsesRequestBodyStringFromContext(t *testing.T) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Set(OpsUpstreamRequestBodyKey, `{"model":"gpt-4"}`)
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
Kind: "request_error",
Message: "dial timeout",
})
v, ok := c.Get(OpsUpstreamErrorsKey)
require.True(t, ok)
events, ok := v.([]*OpsUpstreamErrorEvent)
require.True(t, ok)
require.Len(t, events, 1)
require.Equal(t, `{"model":"gpt-4"}`, events[0].UpstreamRequestBody)
}