feat(openai): 极致优化 OAuth 链路并补齐性能守护
- 优化 /v1/responses 热路径,减少重复解析与不必要拷贝\n- 优化并发与 token 竞争路径并补齐运行指标\n- 补充 OpenAI/Ops 相关单元测试与回归用例\n- 新增灰度阈值守护与压测脚本,支撑发布验收
This commit is contained in:
@@ -104,31 +104,24 @@ func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFo
|
|||||||
|
|
||||||
// wrapReleaseOnDone ensures release runs at most once and still triggers on context cancellation.
|
// wrapReleaseOnDone ensures release runs at most once and still triggers on context cancellation.
|
||||||
// 用于避免客户端断开或上游超时导致的并发槽位泄漏。
|
// 用于避免客户端断开或上游超时导致的并发槽位泄漏。
|
||||||
// 修复:添加 quit channel 确保 goroutine 及时退出,避免泄露
|
// 优化:基于 context.AfterFunc 注册回调,避免每请求额外守护 goroutine。
|
||||||
func wrapReleaseOnDone(ctx context.Context, releaseFunc func()) func() {
|
func wrapReleaseOnDone(ctx context.Context, releaseFunc func()) func() {
|
||||||
if releaseFunc == nil {
|
if releaseFunc == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
var once sync.Once
|
var once sync.Once
|
||||||
quit := make(chan struct{})
|
var stop func() bool
|
||||||
|
|
||||||
release := func() {
|
release := func() {
|
||||||
once.Do(func() {
|
once.Do(func() {
|
||||||
|
if stop != nil {
|
||||||
|
_ = stop()
|
||||||
|
}
|
||||||
releaseFunc()
|
releaseFunc()
|
||||||
close(quit) // 通知监听 goroutine 退出
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
stop = context.AfterFunc(ctx, release)
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
// Context 取消时释放资源
|
|
||||||
release()
|
|
||||||
case <-quit:
|
|
||||||
// 正常释放已完成,goroutine 退出
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
return release
|
return release
|
||||||
}
|
}
|
||||||
@@ -153,6 +146,32 @@ func (h *ConcurrencyHelper) DecrementAccountWaitCount(ctx context.Context, accou
|
|||||||
h.concurrencyService.DecrementAccountWaitCount(ctx, accountID)
|
h.concurrencyService.DecrementAccountWaitCount(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TryAcquireUserSlot 尝试立即获取用户并发槽位。
|
||||||
|
// 返回值: (releaseFunc, acquired, error)
|
||||||
|
func (h *ConcurrencyHelper) TryAcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (func(), bool, error) {
|
||||||
|
result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
if !result.Acquired {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
return result.ReleaseFunc, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// TryAcquireAccountSlot 尝试立即获取账号并发槽位。
|
||||||
|
// 返回值: (releaseFunc, acquired, error)
|
||||||
|
func (h *ConcurrencyHelper) TryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (func(), bool, error) {
|
||||||
|
result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
if !result.Acquired {
|
||||||
|
return nil, false, nil
|
||||||
|
}
|
||||||
|
return result.ReleaseFunc, true, nil
|
||||||
|
}
|
||||||
|
|
||||||
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
|
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
|
||||||
// For streaming requests, sends ping events during the wait.
|
// For streaming requests, sends ping events during the wait.
|
||||||
// streamStarted is updated if streaming response has begun.
|
// streamStarted is updated if streaming response has begun.
|
||||||
@@ -160,13 +179,13 @@ func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, userID int64
|
|||||||
ctx := c.Request.Context()
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
// Try to acquire immediately
|
// Try to acquire immediately
|
||||||
result, err := h.concurrencyService.AcquireUserSlot(ctx, userID, maxConcurrency)
|
releaseFunc, acquired, err := h.TryAcquireUserSlot(ctx, userID, maxConcurrency)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if result.Acquired {
|
if acquired {
|
||||||
return result.ReleaseFunc, nil
|
return releaseFunc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Need to wait - handle streaming ping if needed
|
// Need to wait - handle streaming ping if needed
|
||||||
@@ -180,13 +199,13 @@ func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, accountID
|
|||||||
ctx := c.Request.Context()
|
ctx := c.Request.Context()
|
||||||
|
|
||||||
// Try to acquire immediately
|
// Try to acquire immediately
|
||||||
result, err := h.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
|
releaseFunc, acquired, err := h.TryAcquireAccountSlot(ctx, accountID, maxConcurrency)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if result.Acquired {
|
if acquired {
|
||||||
return result.ReleaseFunc, nil
|
return releaseFunc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Need to wait - handle streaming ping if needed
|
// Need to wait - handle streaming ping if needed
|
||||||
|
|||||||
114
backend/internal/handler/gateway_helper_fastpath_test.go
Normal file
114
backend/internal/handler/gateway_helper_fastpath_test.go
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
type concurrencyCacheMock struct {
|
||||||
|
acquireUserSlotFn func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error)
|
||||||
|
acquireAccountSlotFn func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error)
|
||||||
|
releaseUserCalled int32
|
||||||
|
releaseAccountCalled int32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *concurrencyCacheMock) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||||
|
if m.acquireAccountSlotFn != nil {
|
||||||
|
return m.acquireAccountSlotFn(ctx, accountID, maxConcurrency, requestID)
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *concurrencyCacheMock) ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error {
|
||||||
|
atomic.AddInt32(&m.releaseAccountCalled, 1)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *concurrencyCacheMock) GetAccountConcurrency(ctx context.Context, accountID int64) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *concurrencyCacheMock) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *concurrencyCacheMock) DecrementAccountWaitCount(ctx context.Context, accountID int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *concurrencyCacheMock) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *concurrencyCacheMock) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||||
|
if m.acquireUserSlotFn != nil {
|
||||||
|
return m.acquireUserSlotFn(ctx, userID, maxConcurrency, requestID)
|
||||||
|
}
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *concurrencyCacheMock) ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error {
|
||||||
|
atomic.AddInt32(&m.releaseUserCalled, 1)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *concurrencyCacheMock) GetUserConcurrency(ctx context.Context, userID int64) (int, error) {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *concurrencyCacheMock) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *concurrencyCacheMock) DecrementWaitCount(ctx context.Context, userID int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *concurrencyCacheMock) GetAccountsLoadBatch(ctx context.Context, accounts []service.AccountWithConcurrency) (map[int64]*service.AccountLoadInfo, error) {
|
||||||
|
return map[int64]*service.AccountLoadInfo{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *concurrencyCacheMock) GetUsersLoadBatch(ctx context.Context, users []service.UserWithConcurrency) (map[int64]*service.UserLoadInfo, error) {
|
||||||
|
return map[int64]*service.UserLoadInfo{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *concurrencyCacheMock) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrencyHelper_TryAcquireUserSlot(t *testing.T) {
|
||||||
|
cache := &concurrencyCacheMock{
|
||||||
|
acquireUserSlotFn: func(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||||
|
return true, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
helper := NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second)
|
||||||
|
|
||||||
|
release, acquired, err := helper.TryAcquireUserSlot(context.Background(), 101, 2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, acquired)
|
||||||
|
require.NotNil(t, release)
|
||||||
|
|
||||||
|
release()
|
||||||
|
require.Equal(t, int32(1), atomic.LoadInt32(&cache.releaseUserCalled))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestConcurrencyHelper_TryAcquireAccountSlot_NotAcquired(t *testing.T) {
|
||||||
|
cache := &concurrencyCacheMock{
|
||||||
|
acquireAccountSlotFn: func(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error) {
|
||||||
|
return false, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
helper := NewConcurrencyHelper(service.NewConcurrencyService(cache), SSEPingFormatNone, time.Second)
|
||||||
|
|
||||||
|
release, acquired, err := helper.TryAcquireAccountSlot(context.Background(), 201, 1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.False(t, acquired)
|
||||||
|
require.Nil(t, release)
|
||||||
|
require.Equal(t, int32(0), atomic.LoadInt32(&cache.releaseAccountCalled))
|
||||||
|
}
|
||||||
@@ -64,6 +64,8 @@ func NewOpenAIGatewayHandler(
|
|||||||
// Responses handles OpenAI Responses API endpoint
|
// Responses handles OpenAI Responses API endpoint
|
||||||
// POST /openai/v1/responses
|
// POST /openai/v1/responses
|
||||||
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
||||||
|
requestStart := time.Now()
|
||||||
|
|
||||||
// Get apiKey and user from context (set by ApiKeyAuth middleware)
|
// Get apiKey and user from context (set by ApiKeyAuth middleware)
|
||||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -141,6 +143,7 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
if gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
|
if gjson.GetBytes(body, `input.#(type=="function_call_output")`).Exists() {
|
||||||
var reqBody map[string]any
|
var reqBody map[string]any
|
||||||
if err := json.Unmarshal(body, &reqBody); err == nil {
|
if err := json.Unmarshal(body, &reqBody); err == nil {
|
||||||
|
c.Set(service.OpenAIParsedRequestBodyKey, reqBody)
|
||||||
if service.HasFunctionCallOutput(reqBody) {
|
if service.HasFunctionCallOutput(reqBody) {
|
||||||
previousResponseID, _ := reqBody["previous_response_id"].(string)
|
previousResponseID, _ := reqBody["previous_response_id"].(string)
|
||||||
if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) {
|
if strings.TrimSpace(previousResponseID) == "" && !service.HasToolCallContext(reqBody) {
|
||||||
@@ -171,18 +174,30 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
// Get subscription info (may be nil)
|
// Get subscription info (may be nil)
|
||||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||||
|
|
||||||
// 0. Check if wait queue is full
|
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
||||||
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
routingStart := time.Now()
|
||||||
canWait, err := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
|
||||||
waitCounted := false
|
// 0. 先尝试直接抢占用户槽位(快速路径)
|
||||||
|
userReleaseFunc, userAcquired, err := h.concurrencyHelper.TryAcquireUserSlot(c.Request.Context(), subject.UserID, subject.Concurrency)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("Increment wait count failed: %v", err)
|
log.Printf("User concurrency acquire failed: %v", err)
|
||||||
// On error, allow request to proceed
|
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
waitCounted := false
|
||||||
|
if !userAcquired {
|
||||||
|
// 仅在抢槽失败时才进入等待队列,减少常态请求 Redis 写入。
|
||||||
|
maxWait := service.CalculateMaxWait(subject.Concurrency)
|
||||||
|
canWait, waitErr := h.concurrencyHelper.IncrementWaitCount(c.Request.Context(), subject.UserID, maxWait)
|
||||||
|
if waitErr != nil {
|
||||||
|
log.Printf("Increment wait count failed: %v", waitErr)
|
||||||
|
// 按现有降级语义:等待计数异常时放行后续抢槽流程
|
||||||
} else if !canWait {
|
} else if !canWait {
|
||||||
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
h.errorResponse(c, http.StatusTooManyRequests, "rate_limit_error", "Too many pending requests, please retry later")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err == nil && canWait {
|
if waitErr == nil && canWait {
|
||||||
waitCounted = true
|
waitCounted = true
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -191,14 +206,15 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// 1. First acquire user concurrency slot
|
userReleaseFunc, err = h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
|
||||||
userReleaseFunc, err := h.concurrencyHelper.AcquireUserSlotWithWait(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("User concurrency acquire failed: %v", err)
|
log.Printf("User concurrency acquire failed: %v", err)
|
||||||
h.handleConcurrencyError(c, err, "user", streamStarted)
|
h.handleConcurrencyError(c, err, "user", streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// User slot acquired: no longer waiting.
|
}
|
||||||
|
|
||||||
|
// 用户槽位已获取:退出等待队列计数。
|
||||||
if waitCounted {
|
if waitCounted {
|
||||||
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
h.concurrencyHelper.DecrementWaitCount(c.Request.Context(), subject.UserID)
|
||||||
waitCounted = false
|
waitCounted = false
|
||||||
@@ -253,6 +269,24 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 先快速尝试一次账号槽位,命中则跳过等待计数写入。
|
||||||
|
fastReleaseFunc, fastAcquired, err := h.concurrencyHelper.TryAcquireAccountSlot(
|
||||||
|
c.Request.Context(),
|
||||||
|
account.ID,
|
||||||
|
selection.WaitPlan.MaxConcurrency,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Account concurrency quick acquire failed: %v", err)
|
||||||
|
h.handleConcurrencyError(c, err, "account", streamStarted)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if fastAcquired {
|
||||||
|
accountReleaseFunc = fastReleaseFunc
|
||||||
|
if err := h.gatewayService.BindStickySession(c.Request.Context(), apiKey.GroupID, sessionHash, account.ID); err != nil {
|
||||||
|
log.Printf("Bind sticky session failed: %v", err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
accountWaitCounted := false
|
accountWaitCounted := false
|
||||||
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
canWait, err := h.concurrencyHelper.IncrementAccountWaitCount(c.Request.Context(), account.ID, selection.WaitPlan.MaxWaiting)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -292,14 +326,27 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
log.Printf("Bind sticky session failed: %v", err)
|
log.Printf("Bind sticky session failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
// 账号槽位/等待计数需要在超时或断开时安全回收
|
// 账号槽位/等待计数需要在超时或断开时安全回收
|
||||||
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
accountReleaseFunc = wrapReleaseOnDone(c.Request.Context(), accountReleaseFunc)
|
||||||
|
|
||||||
// Forward request
|
// Forward request
|
||||||
|
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||||
|
forwardStart := time.Now()
|
||||||
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
result, err := h.gatewayService.Forward(c.Request.Context(), c, account, body)
|
||||||
|
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||||
if accountReleaseFunc != nil {
|
if accountReleaseFunc != nil {
|
||||||
accountReleaseFunc()
|
accountReleaseFunc()
|
||||||
}
|
}
|
||||||
|
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||||||
|
responseLatencyMs := forwardDurationMs
|
||||||
|
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||||||
|
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
|
||||||
|
}
|
||||||
|
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
|
||||||
|
if err == nil && result != nil && result.FirstTokenMs != nil {
|
||||||
|
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var failoverErr *service.UpstreamFailoverError
|
var failoverErr *service.UpstreamFailoverError
|
||||||
if errors.As(err, &failoverErr) {
|
if errors.As(err, &failoverErr) {
|
||||||
@@ -343,6 +390,28 @@ func (h *OpenAIGatewayHandler) Responses(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getContextInt64(c *gin.Context, key string) (int64, bool) {
|
||||||
|
if c == nil || key == "" {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
v, ok := c.Get(key)
|
||||||
|
if !ok {
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
switch t := v.(type) {
|
||||||
|
case int64:
|
||||||
|
return t, true
|
||||||
|
case int:
|
||||||
|
return int64(t), true
|
||||||
|
case int32:
|
||||||
|
return int64(t), true
|
||||||
|
case float64:
|
||||||
|
return int64(t), true
|
||||||
|
default:
|
||||||
|
return 0, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// handleConcurrencyError handles concurrency-related errors with proper 429 response
|
// handleConcurrencyError handles concurrency-related errors with proper 429 response
|
||||||
func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
|
func (h *OpenAIGatewayHandler) handleConcurrencyError(c *gin.Context, err error, slotType string, streamStarted bool) {
|
||||||
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
|
h.handleStreamingAwareError(c, http.StatusTooManyRequests, "rate_limit_error",
|
||||||
|
|||||||
@@ -507,6 +507,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
|||||||
RetryCount: 0,
|
RetryCount: 0,
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
}
|
}
|
||||||
|
applyOpsLatencyFieldsFromContext(c, entry)
|
||||||
|
|
||||||
if apiKey != nil {
|
if apiKey != nil {
|
||||||
entry.APIKeyID = &apiKey.ID
|
entry.APIKeyID = &apiKey.ID
|
||||||
@@ -618,6 +619,7 @@ func OpsErrorLoggerMiddleware(ops *service.OpsService) gin.HandlerFunc {
|
|||||||
RetryCount: 0,
|
RetryCount: 0,
|
||||||
CreatedAt: time.Now(),
|
CreatedAt: time.Now(),
|
||||||
}
|
}
|
||||||
|
applyOpsLatencyFieldsFromContext(c, entry)
|
||||||
|
|
||||||
// Capture upstream error context set by gateway services (if present).
|
// Capture upstream error context set by gateway services (if present).
|
||||||
// This does NOT affect the client response; it enriches Ops troubleshooting data.
|
// This does NOT affect the client response; it enriches Ops troubleshooting data.
|
||||||
@@ -746,6 +748,44 @@ func extractOpsRetryRequestHeaders(c *gin.Context) *string {
|
|||||||
return &s
|
return &s
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func applyOpsLatencyFieldsFromContext(c *gin.Context, entry *service.OpsInsertErrorLogInput) {
|
||||||
|
if c == nil || entry == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
entry.AuthLatencyMs = getContextLatencyMs(c, service.OpsAuthLatencyMsKey)
|
||||||
|
entry.RoutingLatencyMs = getContextLatencyMs(c, service.OpsRoutingLatencyMsKey)
|
||||||
|
entry.UpstreamLatencyMs = getContextLatencyMs(c, service.OpsUpstreamLatencyMsKey)
|
||||||
|
entry.ResponseLatencyMs = getContextLatencyMs(c, service.OpsResponseLatencyMsKey)
|
||||||
|
entry.TimeToFirstTokenMs = getContextLatencyMs(c, service.OpsTimeToFirstTokenMsKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getContextLatencyMs(c *gin.Context, key string) *int64 {
|
||||||
|
if c == nil || strings.TrimSpace(key) == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
v, ok := c.Get(key)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
var ms int64
|
||||||
|
switch t := v.(type) {
|
||||||
|
case int:
|
||||||
|
ms = int64(t)
|
||||||
|
case int32:
|
||||||
|
ms = int64(t)
|
||||||
|
case int64:
|
||||||
|
ms = t
|
||||||
|
case float64:
|
||||||
|
ms = int64(t)
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if ms < 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &ms
|
||||||
|
}
|
||||||
|
|
||||||
type parsedOpsError struct {
|
type parsedOpsError struct {
|
||||||
ErrorType string
|
ErrorType string
|
||||||
Message string
|
Message string
|
||||||
|
|||||||
@@ -55,6 +55,10 @@ INSERT INTO ops_error_logs (
|
|||||||
upstream_error_message,
|
upstream_error_message,
|
||||||
upstream_error_detail,
|
upstream_error_detail,
|
||||||
upstream_errors,
|
upstream_errors,
|
||||||
|
auth_latency_ms,
|
||||||
|
routing_latency_ms,
|
||||||
|
upstream_latency_ms,
|
||||||
|
response_latency_ms,
|
||||||
time_to_first_token_ms,
|
time_to_first_token_ms,
|
||||||
request_body,
|
request_body,
|
||||||
request_body_truncated,
|
request_body_truncated,
|
||||||
@@ -64,7 +68,7 @@ INSERT INTO ops_error_logs (
|
|||||||
retry_count,
|
retry_count,
|
||||||
created_at
|
created_at
|
||||||
) VALUES (
|
) VALUES (
|
||||||
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34
|
$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29,$30,$31,$32,$33,$34,$35,$36,$37,$38
|
||||||
) RETURNING id`
|
) RETURNING id`
|
||||||
|
|
||||||
var id int64
|
var id int64
|
||||||
@@ -97,6 +101,10 @@ INSERT INTO ops_error_logs (
|
|||||||
opsNullString(input.UpstreamErrorMessage),
|
opsNullString(input.UpstreamErrorMessage),
|
||||||
opsNullString(input.UpstreamErrorDetail),
|
opsNullString(input.UpstreamErrorDetail),
|
||||||
opsNullString(input.UpstreamErrorsJSON),
|
opsNullString(input.UpstreamErrorsJSON),
|
||||||
|
opsNullInt64(input.AuthLatencyMs),
|
||||||
|
opsNullInt64(input.RoutingLatencyMs),
|
||||||
|
opsNullInt64(input.UpstreamLatencyMs),
|
||||||
|
opsNullInt64(input.ResponseLatencyMs),
|
||||||
opsNullInt64(input.TimeToFirstTokenMs),
|
opsNullInt64(input.TimeToFirstTokenMs),
|
||||||
opsNullString(input.RequestBodyJSON),
|
opsNullString(input.RequestBodyJSON),
|
||||||
input.RequestBodyTruncated,
|
input.RequestBodyTruncated,
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"regexp"
|
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -34,11 +33,10 @@ const (
|
|||||||
// OpenAI Platform API for API Key accounts (fallback)
|
// OpenAI Platform API for API Key accounts (fallback)
|
||||||
openaiPlatformAPIURL = "https://api.openai.com/v1/responses"
|
openaiPlatformAPIURL = "https://api.openai.com/v1/responses"
|
||||||
openaiStickySessionTTL = time.Hour // 粘性会话TTL
|
openaiStickySessionTTL = time.Hour // 粘性会话TTL
|
||||||
)
|
|
||||||
|
|
||||||
// openaiSSEDataRe matches SSE data lines with optional whitespace after colon.
|
// OpenAIParsedRequestBodyKey 缓存 handler 侧已解析的请求体,避免重复解析。
|
||||||
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
OpenAIParsedRequestBodyKey = "openai_parsed_request_body"
|
||||||
var openaiSSEDataRe = regexp.MustCompile(`^data:\s*`)
|
)
|
||||||
|
|
||||||
// OpenAI allowed headers whitelist (for non-OAuth accounts)
|
// OpenAI allowed headers whitelist (for non-OAuth accounts)
|
||||||
var openaiAllowedHeaders = map[string]bool{
|
var openaiAllowedHeaders = map[string]bool{
|
||||||
@@ -745,32 +743,37 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
|
|
||||||
originalBody := body
|
originalBody := body
|
||||||
|
reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body)
|
||||||
|
originalModel := reqModel
|
||||||
|
|
||||||
// Parse request body once (avoid multiple parse/serialize cycles)
|
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
|
||||||
var reqBody map[string]any
|
passthroughEnabled := account.Type == AccountTypeOAuth && account.IsOpenAIOAuthPassthroughEnabled() && isCodexCLI
|
||||||
if err := json.Unmarshal(body, &reqBody); err != nil {
|
if passthroughEnabled {
|
||||||
return nil, fmt.Errorf("parse request: %w", err)
|
// 透传分支只需要轻量提取字段,避免热路径全量 Unmarshal。
|
||||||
|
reasoningEffort := extractOpenAIReasoningEffortFromBody(body, reqModel)
|
||||||
|
return s.forwardOAuthPassthrough(ctx, c, account, originalBody, reqModel, reasoningEffort, reqStream, startTime)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract model and stream from parsed body
|
reqBody, err := getOpenAIRequestBodyMap(c, body)
|
||||||
reqModel, _ := reqBody["model"].(string)
|
if err != nil {
|
||||||
reqStream, _ := reqBody["stream"].(bool)
|
return nil, err
|
||||||
promptCacheKey := ""
|
}
|
||||||
|
|
||||||
|
if v, ok := reqBody["model"].(string); ok {
|
||||||
|
reqModel = v
|
||||||
|
originalModel = reqModel
|
||||||
|
}
|
||||||
|
if v, ok := reqBody["stream"].(bool); ok {
|
||||||
|
reqStream = v
|
||||||
|
}
|
||||||
|
if promptCacheKey == "" {
|
||||||
if v, ok := reqBody["prompt_cache_key"].(string); ok {
|
if v, ok := reqBody["prompt_cache_key"].(string); ok {
|
||||||
promptCacheKey = strings.TrimSpace(v)
|
promptCacheKey = strings.TrimSpace(v)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Track if body needs re-serialization
|
// Track if body needs re-serialization
|
||||||
bodyModified := false
|
bodyModified := false
|
||||||
originalModel := reqModel
|
|
||||||
|
|
||||||
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
|
|
||||||
|
|
||||||
passthroughEnabled := account.Type == AccountTypeOAuth && account.IsOpenAIOAuthPassthroughEnabled() && isCodexCLI
|
|
||||||
if passthroughEnabled {
|
|
||||||
reasoningEffort := extractOpenAIReasoningEffort(reqBody, reqModel)
|
|
||||||
return s.forwardOAuthPassthrough(ctx, c, account, originalBody, reqModel, reasoningEffort, reqStream, startTime)
|
|
||||||
}
|
|
||||||
|
|
||||||
// 对所有请求执行模型映射(包含 Codex CLI)。
|
// 对所有请求执行模型映射(包含 Codex CLI)。
|
||||||
mappedModel := account.GetMappedModel(reqModel)
|
mappedModel := account.GetMappedModel(reqModel)
|
||||||
@@ -888,12 +891,12 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Capture upstream request body for ops retry of this attempt.
|
// Capture upstream request body for ops retry of this attempt.
|
||||||
if c != nil {
|
setOpsUpstreamRequestBody(c, body)
|
||||||
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send request
|
// Send request
|
||||||
|
upstreamStart := time.Now()
|
||||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||||
|
SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Ensure the client receives an error response (handlers assume Forward writes on non-failover errors).
|
// Ensure the client receives an error response (handlers assume Forward writes on non-failover errors).
|
||||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||||
@@ -1019,12 +1022,14 @@ func (s *OpenAIGatewayService) forwardOAuthPassthrough(
|
|||||||
proxyURL = account.Proxy.URL()
|
proxyURL = account.Proxy.URL()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
setOpsUpstreamRequestBody(c, body)
|
||||||
if c != nil {
|
if c != nil {
|
||||||
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
|
||||||
c.Set("openai_passthrough", true)
|
c.Set("openai_passthrough", true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
upstreamStart := time.Now()
|
||||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||||
|
SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||||
setOpsUpstreamError(c, 0, safeErr, "")
|
setOpsUpstreamError(c, 0, safeErr, "")
|
||||||
@@ -1240,8 +1245,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
|||||||
|
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
line := scanner.Text()
|
||||||
if openaiSSEDataRe.MatchString(line) {
|
if data, ok := extractOpenAISSEDataLine(line); ok {
|
||||||
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
|
||||||
if firstTokenMs == nil && strings.TrimSpace(data) != "" {
|
if firstTokenMs == nil && strings.TrimSpace(data) != "" {
|
||||||
ms := int(time.Since(startTime).Milliseconds())
|
ms := int(time.Since(startTime).Milliseconds())
|
||||||
firstTokenMs = &ms
|
firstTokenMs = &ms
|
||||||
@@ -1750,8 +1754,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
lastDataAt = time.Now()
|
lastDataAt = time.Now()
|
||||||
|
|
||||||
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
||||||
if openaiSSEDataRe.MatchString(line) {
|
if data, ok := extractOpenAISSEDataLine(line); ok {
|
||||||
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
|
||||||
|
|
||||||
// Replace model in response if needed
|
// Replace model in response if needed
|
||||||
if needModelReplace {
|
if needModelReplace {
|
||||||
@@ -1827,11 +1830,27 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// extractOpenAISSEDataLine 低开销提取 SSE `data:` 行内容。
|
||||||
|
// 兼容 `data: xxx` 与 `data:xxx` 两种格式。
|
||||||
|
func extractOpenAISSEDataLine(line string) (string, bool) {
|
||||||
|
if !strings.HasPrefix(line, "data:") {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
start := len("data:")
|
||||||
|
for start < len(line) {
|
||||||
|
if line[start] != ' ' && line[start] != ' ' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
start++
|
||||||
|
}
|
||||||
|
return line[start:], true
|
||||||
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string {
|
func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string {
|
||||||
if !openaiSSEDataRe.MatchString(line) {
|
data, ok := extractOpenAISSEDataLine(line)
|
||||||
|
if !ok {
|
||||||
return line
|
return line
|
||||||
}
|
}
|
||||||
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
|
||||||
if data == "" || data == "[DONE]" {
|
if data == "" || data == "[DONE]" {
|
||||||
return line
|
return line
|
||||||
}
|
}
|
||||||
@@ -1872,25 +1891,20 @@ func (s *OpenAIGatewayService) correctToolCallsInResponseBody(body []byte) []byt
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
|
func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
|
||||||
// Parse response.completed event for usage (OpenAI Responses format)
|
if usage == nil || data == "" || data == "[DONE]" {
|
||||||
var event struct {
|
return
|
||||||
Type string `json:"type"`
|
}
|
||||||
Response struct {
|
// 选择性解析:仅在数据中包含 completed 事件标识时才进入字段提取。
|
||||||
Usage struct {
|
if !strings.Contains(data, `"response.completed"`) {
|
||||||
InputTokens int `json:"input_tokens"`
|
return
|
||||||
OutputTokens int `json:"output_tokens"`
|
}
|
||||||
InputTokenDetails struct {
|
if gjson.Get(data, "type").String() != "response.completed" {
|
||||||
CachedTokens int `json:"cached_tokens"`
|
return
|
||||||
} `json:"input_tokens_details"`
|
|
||||||
} `json:"usage"`
|
|
||||||
} `json:"response"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if json.Unmarshal([]byte(data), &event) == nil && event.Type == "response.completed" {
|
usage.InputTokens = int(gjson.Get(data, "response.usage.input_tokens").Int())
|
||||||
usage.InputTokens = event.Response.Usage.InputTokens
|
usage.OutputTokens = int(gjson.Get(data, "response.usage.output_tokens").Int())
|
||||||
usage.OutputTokens = event.Response.Usage.OutputTokens
|
usage.CacheReadInputTokens = int(gjson.Get(data, "response.usage.input_tokens_details.cached_tokens").Int())
|
||||||
usage.CacheReadInputTokens = event.Response.Usage.InputTokenDetails.CachedTokens
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
|
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
|
||||||
@@ -2001,10 +2015,10 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.
|
|||||||
func extractCodexFinalResponse(body string) ([]byte, bool) {
|
func extractCodexFinalResponse(body string) ([]byte, bool) {
|
||||||
lines := strings.Split(body, "\n")
|
lines := strings.Split(body, "\n")
|
||||||
for _, line := range lines {
|
for _, line := range lines {
|
||||||
if !openaiSSEDataRe.MatchString(line) {
|
data, ok := extractOpenAISSEDataLine(line)
|
||||||
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
|
||||||
if data == "" || data == "[DONE]" {
|
if data == "" || data == "[DONE]" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -2028,10 +2042,10 @@ func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage {
|
|||||||
usage := &OpenAIUsage{}
|
usage := &OpenAIUsage{}
|
||||||
lines := strings.Split(body, "\n")
|
lines := strings.Split(body, "\n")
|
||||||
for _, line := range lines {
|
for _, line := range lines {
|
||||||
if !openaiSSEDataRe.MatchString(line) {
|
data, ok := extractOpenAISSEDataLine(line)
|
||||||
|
if !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
|
||||||
if data == "" || data == "[DONE]" {
|
if data == "" || data == "[DONE]" {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -2043,7 +2057,7 @@ func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage {
|
|||||||
func (s *OpenAIGatewayService) replaceModelInSSEBody(body, fromModel, toModel string) string {
|
func (s *OpenAIGatewayService) replaceModelInSSEBody(body, fromModel, toModel string) string {
|
||||||
lines := strings.Split(body, "\n")
|
lines := strings.Split(body, "\n")
|
||||||
for i, line := range lines {
|
for i, line := range lines {
|
||||||
if !openaiSSEDataRe.MatchString(line) {
|
if _, ok := extractOpenAISSEDataLine(line); !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
lines[i] = s.replaceModelInSSELine(line, fromModel, toModel)
|
lines[i] = s.replaceModelInSSELine(line, fromModel, toModel)
|
||||||
@@ -2396,6 +2410,53 @@ func deriveOpenAIReasoningEffortFromModel(model string) string {
|
|||||||
return normalizeOpenAIReasoningEffort(parts[len(parts)-1])
|
return normalizeOpenAIReasoningEffort(parts[len(parts)-1])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func extractOpenAIRequestMetaFromBody(body []byte) (model string, stream bool, promptCacheKey string) {
|
||||||
|
if len(body) == 0 {
|
||||||
|
return "", false, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
model = strings.TrimSpace(gjson.GetBytes(body, "model").String())
|
||||||
|
stream = gjson.GetBytes(body, "stream").Bool()
|
||||||
|
promptCacheKey = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
|
||||||
|
return model, stream, promptCacheKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *string {
|
||||||
|
reasoningEffort := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String())
|
||||||
|
if reasoningEffort == "" {
|
||||||
|
reasoningEffort = strings.TrimSpace(gjson.GetBytes(body, "reasoning_effort").String())
|
||||||
|
}
|
||||||
|
if reasoningEffort != "" {
|
||||||
|
normalized := normalizeOpenAIReasoningEffort(reasoningEffort)
|
||||||
|
if normalized == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &normalized
|
||||||
|
}
|
||||||
|
|
||||||
|
value := deriveOpenAIReasoningEffortFromModel(requestedModel)
|
||||||
|
if value == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &value
|
||||||
|
}
|
||||||
|
|
||||||
|
func getOpenAIRequestBodyMap(c *gin.Context, body []byte) (map[string]any, error) {
|
||||||
|
if c != nil {
|
||||||
|
if cached, ok := c.Get(OpenAIParsedRequestBodyKey); ok {
|
||||||
|
if reqBody, ok := cached.(map[string]any); ok && reqBody != nil {
|
||||||
|
return reqBody, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var reqBody map[string]any
|
||||||
|
if err := json.Unmarshal(body, &reqBody); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse request: %w", err)
|
||||||
|
}
|
||||||
|
return reqBody, nil
|
||||||
|
}
|
||||||
|
|
||||||
func extractOpenAIReasoningEffort(reqBody map[string]any, requestedModel string) *string {
|
func extractOpenAIReasoningEffort(reqBody map[string]any, requestedModel string) *string {
|
||||||
if value, present := getOpenAIReasoningEffortFromReqBody(reqBody); present {
|
if value, present := getOpenAIReasoningEffortFromReqBody(reqBody); present {
|
||||||
if value == "" {
|
if value == "" {
|
||||||
|
|||||||
125
backend/internal/service/openai_gateway_service_hotpath_test.go
Normal file
125
backend/internal/service/openai_gateway_service_hotpath_test.go
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExtractOpenAIRequestMetaFromBody(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body []byte
|
||||||
|
wantModel string
|
||||||
|
wantStream bool
|
||||||
|
wantPromptKey string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "完整字段",
|
||||||
|
body: []byte(`{"model":"gpt-5","stream":true,"prompt_cache_key":" ses-1 "}`),
|
||||||
|
wantModel: "gpt-5",
|
||||||
|
wantStream: true,
|
||||||
|
wantPromptKey: "ses-1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "缺失可选字段",
|
||||||
|
body: []byte(`{"model":"gpt-4"}`),
|
||||||
|
wantModel: "gpt-4",
|
||||||
|
wantStream: false,
|
||||||
|
wantPromptKey: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "空请求体",
|
||||||
|
body: nil,
|
||||||
|
wantModel: "",
|
||||||
|
wantStream: false,
|
||||||
|
wantPromptKey: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
model, stream, promptKey := extractOpenAIRequestMetaFromBody(tt.body)
|
||||||
|
require.Equal(t, tt.wantModel, model)
|
||||||
|
require.Equal(t, tt.wantStream, stream)
|
||||||
|
require.Equal(t, tt.wantPromptKey, promptKey)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractOpenAIReasoningEffortFromBody(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body []byte
|
||||||
|
model string
|
||||||
|
wantNil bool
|
||||||
|
wantValue string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "优先读取 reasoning.effort",
|
||||||
|
body: []byte(`{"reasoning":{"effort":"medium"}}`),
|
||||||
|
model: "gpt-5-high",
|
||||||
|
wantNil: false,
|
||||||
|
wantValue: "medium",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "兼容 reasoning_effort",
|
||||||
|
body: []byte(`{"reasoning_effort":"x-high"}`),
|
||||||
|
model: "",
|
||||||
|
wantNil: false,
|
||||||
|
wantValue: "xhigh",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "minimal 归一化为空",
|
||||||
|
body: []byte(`{"reasoning":{"effort":"minimal"}}`),
|
||||||
|
model: "gpt-5-high",
|
||||||
|
wantNil: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "缺失字段时从模型后缀推导",
|
||||||
|
body: []byte(`{"input":"hi"}`),
|
||||||
|
model: "gpt-5-high",
|
||||||
|
wantNil: false,
|
||||||
|
wantValue: "high",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "未知后缀不返回",
|
||||||
|
body: []byte(`{"input":"hi"}`),
|
||||||
|
model: "gpt-5-unknown",
|
||||||
|
wantNil: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := extractOpenAIReasoningEffortFromBody(tt.body, tt.model)
|
||||||
|
if tt.wantNil {
|
||||||
|
require.Nil(t, got)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NotNil(t, got)
|
||||||
|
require.Equal(t, tt.wantValue, *got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetOpenAIRequestBodyMap_UsesContextCache(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
cached := map[string]any{"model": "cached-model", "stream": true}
|
||||||
|
c.Set(OpenAIParsedRequestBodyKey, cached)
|
||||||
|
|
||||||
|
got, err := getOpenAIRequestBodyMap(c, []byte(`{invalid-json`))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, cached, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetOpenAIRequestBodyMap_ParseErrorWithoutCache(t *testing.T) {
|
||||||
|
_, err := getOpenAIRequestBodyMap(nil, []byte(`{invalid-json`))
|
||||||
|
require.Error(t, err)
|
||||||
|
require.Contains(t, err.Error(), "parse request")
|
||||||
|
}
|
||||||
@@ -1416,3 +1416,109 @@ func TestReplaceModelInResponseBody(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExtractOpenAISSEDataLine(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
line string
|
||||||
|
wantData string
|
||||||
|
wantOK bool
|
||||||
|
}{
|
||||||
|
{name: "标准格式", line: `data: {"type":"x"}`, wantData: `{"type":"x"}`, wantOK: true},
|
||||||
|
{name: "无空格格式", line: `data:{"type":"x"}`, wantData: `{"type":"x"}`, wantOK: true},
|
||||||
|
{name: "纯空数据", line: `data: `, wantData: ``, wantOK: true},
|
||||||
|
{name: "非 data 行", line: `event: message`, wantData: ``, wantOK: false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, ok := extractOpenAISSEDataLine(tt.line)
|
||||||
|
require.Equal(t, tt.wantOK, ok)
|
||||||
|
require.Equal(t, tt.wantData, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseSSEUsage_SelectiveParsing(t *testing.T) {
|
||||||
|
svc := &OpenAIGatewayService{}
|
||||||
|
usage := &OpenAIUsage{InputTokens: 9, OutputTokens: 8, CacheReadInputTokens: 7}
|
||||||
|
|
||||||
|
// 非 completed 事件,不应覆盖 usage
|
||||||
|
svc.parseSSEUsage(`{"type":"response.in_progress","response":{"usage":{"input_tokens":1,"output_tokens":2}}}`, usage)
|
||||||
|
require.Equal(t, 9, usage.InputTokens)
|
||||||
|
require.Equal(t, 8, usage.OutputTokens)
|
||||||
|
require.Equal(t, 7, usage.CacheReadInputTokens)
|
||||||
|
|
||||||
|
// completed 事件,应提取 usage
|
||||||
|
svc.parseSSEUsage(`{"type":"response.completed","response":{"usage":{"input_tokens":3,"output_tokens":5,"input_tokens_details":{"cached_tokens":2}}}}`, usage)
|
||||||
|
require.Equal(t, 3, usage.InputTokens)
|
||||||
|
require.Equal(t, 5, usage.OutputTokens)
|
||||||
|
require.Equal(t, 2, usage.CacheReadInputTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractCodexFinalResponse_SampleReplay(t *testing.T) {
|
||||||
|
body := strings.Join([]string{
|
||||||
|
`event: message`,
|
||||||
|
`data: {"type":"response.in_progress","response":{"id":"resp_1"}}`,
|
||||||
|
`data: {"type":"response.completed","response":{"id":"resp_1","model":"gpt-4o","usage":{"input_tokens":11,"output_tokens":22,"input_tokens_details":{"cached_tokens":3}}}}`,
|
||||||
|
`data: [DONE]`,
|
||||||
|
}, "\n")
|
||||||
|
|
||||||
|
finalResp, ok := extractCodexFinalResponse(body)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Contains(t, string(finalResp), `"id":"resp_1"`)
|
||||||
|
require.Contains(t, string(finalResp), `"input_tokens":11`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleOAuthSSEToJSON_CompletedEventReturnsJSON(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{cfg: &config.Config{}}
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||||
|
}
|
||||||
|
body := []byte(strings.Join([]string{
|
||||||
|
`data: {"type":"response.in_progress","response":{"id":"resp_2"}}`,
|
||||||
|
`data: {"type":"response.completed","response":{"id":"resp_2","model":"gpt-4o","usage":{"input_tokens":7,"output_tokens":9,"input_tokens_details":{"cached_tokens":1}}}}`,
|
||||||
|
`data: [DONE]`,
|
||||||
|
}, "\n"))
|
||||||
|
|
||||||
|
usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, usage)
|
||||||
|
require.Equal(t, 7, usage.InputTokens)
|
||||||
|
require.Equal(t, 9, usage.OutputTokens)
|
||||||
|
require.Equal(t, 1, usage.CacheReadInputTokens)
|
||||||
|
// Header 可能由上游 Content-Type 透传;关键是 body 已转换为最终 JSON 响应。
|
||||||
|
require.NotContains(t, rec.Body.String(), "event:")
|
||||||
|
require.Contains(t, rec.Body.String(), `"id":"resp_2"`)
|
||||||
|
require.NotContains(t, rec.Body.String(), "data:")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleOAuthSSEToJSON_NoFinalResponseKeepsSSEBody(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
|
||||||
|
svc := &OpenAIGatewayService{cfg: &config.Config{}}
|
||||||
|
resp := &http.Response{
|
||||||
|
StatusCode: http.StatusOK,
|
||||||
|
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||||
|
}
|
||||||
|
body := []byte(strings.Join([]string{
|
||||||
|
`data: {"type":"response.in_progress","response":{"id":"resp_3"}}`,
|
||||||
|
`data: [DONE]`,
|
||||||
|
}, "\n"))
|
||||||
|
|
||||||
|
usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o")
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, usage)
|
||||||
|
require.Equal(t, 0, usage.InputTokens)
|
||||||
|
require.Contains(t, rec.Header().Get("Content-Type"), "text/event-stream")
|
||||||
|
require.Contains(t, rec.Body.String(), `data: {"type":"response.in_progress"`)
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,16 +4,74 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"math/rand/v2"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
openAITokenRefreshSkew = 3 * time.Minute
|
openAITokenRefreshSkew = 3 * time.Minute
|
||||||
openAITokenCacheSkew = 5 * time.Minute
|
openAITokenCacheSkew = 5 * time.Minute
|
||||||
openAILockWaitTime = 200 * time.Millisecond
|
openAILockInitialWait = 20 * time.Millisecond
|
||||||
|
openAILockMaxWait = 120 * time.Millisecond
|
||||||
|
openAILockMaxAttempts = 5
|
||||||
|
openAILockJitterRatio = 0.2
|
||||||
|
openAILockWarnThresholdMs = 250
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// OpenAITokenRuntimeMetrics 表示 OpenAI token 刷新与锁竞争保护指标快照。
|
||||||
|
type OpenAITokenRuntimeMetrics struct {
|
||||||
|
RefreshRequests int64
|
||||||
|
RefreshSuccess int64
|
||||||
|
RefreshFailure int64
|
||||||
|
LockAcquireFailure int64
|
||||||
|
LockContention int64
|
||||||
|
LockWaitSamples int64
|
||||||
|
LockWaitTotalMs int64
|
||||||
|
LockWaitHit int64
|
||||||
|
LockWaitMiss int64
|
||||||
|
LastObservedUnixMs int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type openAITokenRuntimeMetricsStore struct {
|
||||||
|
refreshRequests atomic.Int64
|
||||||
|
refreshSuccess atomic.Int64
|
||||||
|
refreshFailure atomic.Int64
|
||||||
|
lockAcquireFailure atomic.Int64
|
||||||
|
lockContention atomic.Int64
|
||||||
|
lockWaitSamples atomic.Int64
|
||||||
|
lockWaitTotalMs atomic.Int64
|
||||||
|
lockWaitHit atomic.Int64
|
||||||
|
lockWaitMiss atomic.Int64
|
||||||
|
lastObservedUnixMs atomic.Int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *openAITokenRuntimeMetricsStore) snapshot() OpenAITokenRuntimeMetrics {
|
||||||
|
if m == nil {
|
||||||
|
return OpenAITokenRuntimeMetrics{}
|
||||||
|
}
|
||||||
|
return OpenAITokenRuntimeMetrics{
|
||||||
|
RefreshRequests: m.refreshRequests.Load(),
|
||||||
|
RefreshSuccess: m.refreshSuccess.Load(),
|
||||||
|
RefreshFailure: m.refreshFailure.Load(),
|
||||||
|
LockAcquireFailure: m.lockAcquireFailure.Load(),
|
||||||
|
LockContention: m.lockContention.Load(),
|
||||||
|
LockWaitSamples: m.lockWaitSamples.Load(),
|
||||||
|
LockWaitTotalMs: m.lockWaitTotalMs.Load(),
|
||||||
|
LockWaitHit: m.lockWaitHit.Load(),
|
||||||
|
LockWaitMiss: m.lockWaitMiss.Load(),
|
||||||
|
LastObservedUnixMs: m.lastObservedUnixMs.Load(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *openAITokenRuntimeMetricsStore) touchNow() {
|
||||||
|
if m == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.lastObservedUnixMs.Store(time.Now().UnixMilli())
|
||||||
|
}
|
||||||
|
|
||||||
// OpenAITokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
|
// OpenAITokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
|
||||||
type OpenAITokenCache = GeminiTokenCache
|
type OpenAITokenCache = GeminiTokenCache
|
||||||
|
|
||||||
@@ -22,6 +80,7 @@ type OpenAITokenProvider struct {
|
|||||||
accountRepo AccountRepository
|
accountRepo AccountRepository
|
||||||
tokenCache OpenAITokenCache
|
tokenCache OpenAITokenCache
|
||||||
openAIOAuthService *OpenAIOAuthService
|
openAIOAuthService *OpenAIOAuthService
|
||||||
|
metrics *openAITokenRuntimeMetricsStore
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOpenAITokenProvider(
|
func NewOpenAITokenProvider(
|
||||||
@@ -33,11 +92,27 @@ func NewOpenAITokenProvider(
|
|||||||
accountRepo: accountRepo,
|
accountRepo: accountRepo,
|
||||||
tokenCache: tokenCache,
|
tokenCache: tokenCache,
|
||||||
openAIOAuthService: openAIOAuthService,
|
openAIOAuthService: openAIOAuthService,
|
||||||
|
metrics: &openAITokenRuntimeMetricsStore{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *OpenAITokenProvider) SnapshotRuntimeMetrics() OpenAITokenRuntimeMetrics {
|
||||||
|
if p == nil {
|
||||||
|
return OpenAITokenRuntimeMetrics{}
|
||||||
|
}
|
||||||
|
p.ensureMetrics()
|
||||||
|
return p.metrics.snapshot()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *OpenAITokenProvider) ensureMetrics() {
|
||||||
|
if p != nil && p.metrics == nil {
|
||||||
|
p.metrics = &openAITokenRuntimeMetricsStore{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAccessToken 获取有效的 access_token
|
// GetAccessToken 获取有效的 access_token
|
||||||
func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||||
|
p.ensureMetrics()
|
||||||
if account == nil {
|
if account == nil {
|
||||||
return "", errors.New("account is nil")
|
return "", errors.New("account is nil")
|
||||||
}
|
}
|
||||||
@@ -64,6 +139,8 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
|
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
|
||||||
refreshFailed := false
|
refreshFailed := false
|
||||||
if needsRefresh && p.tokenCache != nil {
|
if needsRefresh && p.tokenCache != nil {
|
||||||
|
p.metrics.refreshRequests.Add(1)
|
||||||
|
p.metrics.touchNow()
|
||||||
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||||
if lockErr == nil && locked {
|
if lockErr == nil && locked {
|
||||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||||
@@ -82,14 +159,17 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
||||||
if p.openAIOAuthService == nil {
|
if p.openAIOAuthService == nil {
|
||||||
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
|
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
|
||||||
|
p.metrics.refreshFailure.Add(1)
|
||||||
refreshFailed = true // 无法刷新,标记失败
|
refreshFailed = true // 无法刷新,标记失败
|
||||||
} else {
|
} else {
|
||||||
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
|
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
|
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
|
||||||
slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err)
|
slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err)
|
||||||
|
p.metrics.refreshFailure.Add(1)
|
||||||
refreshFailed = true // 刷新失败,标记以使用短 TTL
|
refreshFailed = true // 刷新失败,标记以使用短 TTL
|
||||||
} else {
|
} else {
|
||||||
|
p.metrics.refreshSuccess.Add(1)
|
||||||
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
|
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
|
||||||
for k, v := range account.Credentials {
|
for k, v := range account.Credentials {
|
||||||
if _, exists := newCredentials[k]; !exists {
|
if _, exists := newCredentials[k]; !exists {
|
||||||
@@ -106,6 +186,8 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
}
|
}
|
||||||
} else if lockErr != nil {
|
} else if lockErr != nil {
|
||||||
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
|
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
|
||||||
|
p.metrics.lockAcquireFailure.Add(1)
|
||||||
|
p.metrics.touchNow()
|
||||||
slog.Warn("openai_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr)
|
slog.Warn("openai_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr)
|
||||||
|
|
||||||
// 检查 ctx 是否已取消
|
// 检查 ctx 是否已取消
|
||||||
@@ -126,13 +208,16 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
||||||
if p.openAIOAuthService == nil {
|
if p.openAIOAuthService == nil {
|
||||||
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
|
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
|
||||||
|
p.metrics.refreshFailure.Add(1)
|
||||||
refreshFailed = true
|
refreshFailed = true
|
||||||
} else {
|
} else {
|
||||||
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
|
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("openai_token_refresh_failed_degraded", "account_id", account.ID, "error", err)
|
slog.Warn("openai_token_refresh_failed_degraded", "account_id", account.ID, "error", err)
|
||||||
|
p.metrics.refreshFailure.Add(1)
|
||||||
refreshFailed = true
|
refreshFailed = true
|
||||||
} else {
|
} else {
|
||||||
|
p.metrics.refreshSuccess.Add(1)
|
||||||
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
|
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
|
||||||
for k, v := range account.Credentials {
|
for k, v := range account.Credentials {
|
||||||
if _, exists := newCredentials[k]; !exists {
|
if _, exists := newCredentials[k]; !exists {
|
||||||
@@ -148,9 +233,14 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
|
// 锁被其他 worker 持有:使用短轮询+jitter,降低固定等待导致的尾延迟台阶。
|
||||||
time.Sleep(openAILockWaitTime)
|
p.metrics.lockContention.Add(1)
|
||||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
p.metrics.touchNow()
|
||||||
|
token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey)
|
||||||
|
if waitErr != nil {
|
||||||
|
return "", waitErr
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(token) != "" {
|
||||||
slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID)
|
slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID)
|
||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
@@ -198,3 +288,64 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
|||||||
|
|
||||||
return accessToken, nil
|
return accessToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *OpenAITokenProvider) waitForTokenAfterLockRace(ctx context.Context, cacheKey string) (string, error) {
|
||||||
|
wait := openAILockInitialWait
|
||||||
|
totalWaitMs := int64(0)
|
||||||
|
for i := 0; i < openAILockMaxAttempts; i++ {
|
||||||
|
actualWait := jitterLockWait(wait)
|
||||||
|
timer := time.NewTimer(actualWait)
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
if !timer.Stop() {
|
||||||
|
select {
|
||||||
|
case <-timer.C:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", ctx.Err()
|
||||||
|
case <-timer.C:
|
||||||
|
}
|
||||||
|
|
||||||
|
waitMs := actualWait.Milliseconds()
|
||||||
|
if waitMs < 0 {
|
||||||
|
waitMs = 0
|
||||||
|
}
|
||||||
|
totalWaitMs += waitMs
|
||||||
|
p.metrics.lockWaitSamples.Add(1)
|
||||||
|
p.metrics.lockWaitTotalMs.Add(waitMs)
|
||||||
|
p.metrics.touchNow()
|
||||||
|
|
||||||
|
token, err := p.tokenCache.GetAccessToken(ctx, cacheKey)
|
||||||
|
if err == nil && strings.TrimSpace(token) != "" {
|
||||||
|
p.metrics.lockWaitHit.Add(1)
|
||||||
|
if totalWaitMs >= openAILockWarnThresholdMs {
|
||||||
|
slog.Warn("openai_token_lock_wait_high", "wait_ms", totalWaitMs, "attempts", i+1)
|
||||||
|
}
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if wait < openAILockMaxWait {
|
||||||
|
wait *= 2
|
||||||
|
if wait > openAILockMaxWait {
|
||||||
|
wait = openAILockMaxWait
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
p.metrics.lockWaitMiss.Add(1)
|
||||||
|
if totalWaitMs >= openAILockWarnThresholdMs {
|
||||||
|
slog.Warn("openai_token_lock_wait_high", "wait_ms", totalWaitMs, "attempts", openAILockMaxAttempts)
|
||||||
|
}
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func jitterLockWait(base time.Duration) time.Duration {
|
||||||
|
if base <= 0 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
minFactor := 1 - openAILockJitterRatio
|
||||||
|
maxFactor := 1 + openAILockJitterRatio
|
||||||
|
factor := minFactor + rand.Float64()*(maxFactor-minFactor)
|
||||||
|
return time.Duration(float64(base) * factor)
|
||||||
|
}
|
||||||
|
|||||||
@@ -808,3 +808,119 @@ func TestOpenAITokenProvider_Real_NilCredentials(t *testing.T) {
|
|||||||
require.Contains(t, err.Error(), "access_token not found")
|
require.Contains(t, err.Error(), "access_token not found")
|
||||||
require.Empty(t, token)
|
require.Empty(t, token)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestOpenAITokenProvider_Real_LockRace_PollingHitsCache(t *testing.T) {
|
||||||
|
cache := newOpenAITokenCacheStub()
|
||||||
|
cache.lockAcquired = false // 模拟锁被其他 worker 持有
|
||||||
|
|
||||||
|
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 207,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "fallback-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
cacheKey := OpenAITokenCacheKey(account)
|
||||||
|
go func() {
|
||||||
|
time.Sleep(5 * time.Millisecond)
|
||||||
|
cache.mu.Lock()
|
||||||
|
cache.tokens[cacheKey] = "winner-token"
|
||||||
|
cache.mu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "winner-token", token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAITokenProvider_Real_LockRace_ContextCanceled(t *testing.T) {
|
||||||
|
cache := newOpenAITokenCacheStub()
|
||||||
|
cache.lockAcquired = false // 模拟锁被其他 worker 持有
|
||||||
|
|
||||||
|
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 208,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "fallback-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||||
|
start := time.Now()
|
||||||
|
token, err := provider.GetAccessToken(ctx, account)
|
||||||
|
require.Error(t, err)
|
||||||
|
require.ErrorIs(t, err, context.Canceled)
|
||||||
|
require.Empty(t, token)
|
||||||
|
require.Less(t, time.Since(start), 50*time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAITokenProvider_RuntimeMetrics_LockWaitHitAndSnapshot(t *testing.T) {
|
||||||
|
cache := newOpenAITokenCacheStub()
|
||||||
|
cache.lockAcquired = false
|
||||||
|
|
||||||
|
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 209,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "fallback-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
cacheKey := OpenAITokenCacheKey(account)
|
||||||
|
go func() {
|
||||||
|
time.Sleep(10 * time.Millisecond)
|
||||||
|
cache.mu.Lock()
|
||||||
|
cache.tokens[cacheKey] = "winner-token"
|
||||||
|
cache.mu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
|
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||||
|
token, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, "winner-token", token)
|
||||||
|
|
||||||
|
metrics := provider.SnapshotRuntimeMetrics()
|
||||||
|
require.GreaterOrEqual(t, metrics.RefreshRequests, int64(1))
|
||||||
|
require.GreaterOrEqual(t, metrics.LockContention, int64(1))
|
||||||
|
require.GreaterOrEqual(t, metrics.LockWaitSamples, int64(1))
|
||||||
|
require.GreaterOrEqual(t, metrics.LockWaitHit, int64(1))
|
||||||
|
require.GreaterOrEqual(t, metrics.LockWaitTotalMs, int64(0))
|
||||||
|
require.GreaterOrEqual(t, metrics.LastObservedUnixMs, int64(1))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOpenAITokenProvider_RuntimeMetrics_LockAcquireFailure(t *testing.T) {
|
||||||
|
cache := newOpenAITokenCacheStub()
|
||||||
|
cache.lockErr = errors.New("redis lock error")
|
||||||
|
|
||||||
|
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||||
|
account := &Account{
|
||||||
|
ID: 210,
|
||||||
|
Platform: PlatformOpenAI,
|
||||||
|
Type: AccountTypeOAuth,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"access_token": "fallback-token",
|
||||||
|
"expires_at": expiresAt,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||||
|
_, err := provider.GetAccessToken(context.Background(), account)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
metrics := provider.SnapshotRuntimeMetrics()
|
||||||
|
require.GreaterOrEqual(t, metrics.LockAcquireFailure, int64(1))
|
||||||
|
require.GreaterOrEqual(t, metrics.RefreshRequests, int64(1))
|
||||||
|
}
|
||||||
|
|||||||
@@ -98,6 +98,10 @@ type OpsInsertErrorLogInput struct {
|
|||||||
// It is set by OpsService.RecordError before persisting.
|
// It is set by OpsService.RecordError before persisting.
|
||||||
UpstreamErrorsJSON *string
|
UpstreamErrorsJSON *string
|
||||||
|
|
||||||
|
AuthLatencyMs *int64
|
||||||
|
RoutingLatencyMs *int64
|
||||||
|
UpstreamLatencyMs *int64
|
||||||
|
ResponseLatencyMs *int64
|
||||||
TimeToFirstTokenMs *int64
|
TimeToFirstTokenMs *int64
|
||||||
|
|
||||||
RequestBodyJSON *string // sanitized json string (not raw bytes)
|
RequestBodyJSON *string // sanitized json string (not raw bytes)
|
||||||
|
|||||||
@@ -20,8 +20,30 @@ const (
|
|||||||
// retry the specific upstream attempt (not just the client request).
|
// retry the specific upstream attempt (not just the client request).
|
||||||
// This value is sanitized+trimmed before being persisted.
|
// This value is sanitized+trimmed before being persisted.
|
||||||
OpsUpstreamRequestBodyKey = "ops_upstream_request_body"
|
OpsUpstreamRequestBodyKey = "ops_upstream_request_body"
|
||||||
|
|
||||||
|
// Optional stage latencies (milliseconds) for troubleshooting and alerting.
|
||||||
|
OpsAuthLatencyMsKey = "ops_auth_latency_ms"
|
||||||
|
OpsRoutingLatencyMsKey = "ops_routing_latency_ms"
|
||||||
|
OpsUpstreamLatencyMsKey = "ops_upstream_latency_ms"
|
||||||
|
OpsResponseLatencyMsKey = "ops_response_latency_ms"
|
||||||
|
OpsTimeToFirstTokenMsKey = "ops_time_to_first_token_ms"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func setOpsUpstreamRequestBody(c *gin.Context, body []byte) {
|
||||||
|
if c == nil || len(body) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// 热路径避免 string(body) 额外分配,按需在落库前再转换。
|
||||||
|
c.Set(OpsUpstreamRequestBodyKey, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetOpsLatencyMs(c *gin.Context, key string, value int64) {
|
||||||
|
if c == nil || strings.TrimSpace(key) == "" || value < 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Set(key, value)
|
||||||
|
}
|
||||||
|
|
||||||
func setOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage, upstreamDetail string) {
|
func setOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage, upstreamDetail string) {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
return
|
return
|
||||||
@@ -91,8 +113,11 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) {
|
|||||||
// stored it on the context, attach it so ops can retry this specific attempt.
|
// stored it on the context, attach it so ops can retry this specific attempt.
|
||||||
if ev.UpstreamRequestBody == "" {
|
if ev.UpstreamRequestBody == "" {
|
||||||
if v, ok := c.Get(OpsUpstreamRequestBodyKey); ok {
|
if v, ok := c.Get(OpsUpstreamRequestBodyKey); ok {
|
||||||
if s, ok := v.(string); ok {
|
switch raw := v.(type) {
|
||||||
ev.UpstreamRequestBody = strings.TrimSpace(s)
|
case string:
|
||||||
|
ev.UpstreamRequestBody = strings.TrimSpace(raw)
|
||||||
|
case []byte:
|
||||||
|
ev.UpstreamRequestBody = strings.TrimSpace(string(raw))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
47
backend/internal/service/ops_upstream_context_test.go
Normal file
47
backend/internal/service/ops_upstream_context_test.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAppendOpsUpstreamError_UsesRequestBodyBytesFromContext(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
setOpsUpstreamRequestBody(c, []byte(`{"model":"gpt-5"}`))
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Kind: "http_error",
|
||||||
|
Message: "upstream failed",
|
||||||
|
})
|
||||||
|
|
||||||
|
v, ok := c.Get(OpsUpstreamErrorsKey)
|
||||||
|
require.True(t, ok)
|
||||||
|
events, ok := v.([]*OpsUpstreamErrorEvent)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, events, 1)
|
||||||
|
require.Equal(t, `{"model":"gpt-5"}`, events[0].UpstreamRequestBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppendOpsUpstreamError_UsesRequestBodyStringFromContext(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(rec)
|
||||||
|
|
||||||
|
c.Set(OpsUpstreamRequestBodyKey, `{"model":"gpt-4"}`)
|
||||||
|
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||||
|
Kind: "request_error",
|
||||||
|
Message: "dial timeout",
|
||||||
|
})
|
||||||
|
|
||||||
|
v, ok := c.Get(OpsUpstreamErrorsKey)
|
||||||
|
require.True(t, ok)
|
||||||
|
events, ok := v.([]*OpsUpstreamErrorEvent)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Len(t, events, 1)
|
||||||
|
require.Equal(t, `{"model":"gpt-4"}`, events[0].UpstreamRequestBody)
|
||||||
|
}
|
||||||
164
tools/perf/openai_oauth_gray_drill.py
Executable file
164
tools/perf/openai_oauth_gray_drill.py
Executable file
@@ -0,0 +1,164 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""OpenAI OAuth 灰度发布演练脚本(本地模拟)。
|
||||||
|
|
||||||
|
该脚本会启动本地 mock Ops API,调用 openai_oauth_gray_guard.py,
|
||||||
|
验证以下场景:
|
||||||
|
1) A/B/C/D 四个灰度批次均通过
|
||||||
|
2) 注入异常场景触发阈值告警并返回退出码 2(模拟自动回滚触发)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import subprocess
|
||||||
|
import threading
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Tuple
|
||||||
|
from urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
|
ROOT = Path(__file__).resolve().parents[2]
|
||||||
|
GUARD_SCRIPT = ROOT / "tools" / "perf" / "openai_oauth_gray_guard.py"
|
||||||
|
REPORT_PATH = ROOT / "docs" / "perf" / "openai-oauth-gray-drill-report.md"
|
||||||
|
|
||||||
|
|
||||||
|
THRESHOLDS = {
|
||||||
|
"sla_percent_min": 99.5,
|
||||||
|
"ttft_p99_ms_max": 900,
|
||||||
|
"request_error_rate_percent_max": 2.0,
|
||||||
|
"upstream_error_rate_percent_max": 2.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
STAGE_SNAPSHOTS: Dict[str, Dict[str, float]] = {
|
||||||
|
"A": {"sla": 99.78, "ttft": 780, "error_rate": 1.20, "upstream_error_rate": 1.05},
|
||||||
|
"B": {"sla": 99.82, "ttft": 730, "error_rate": 1.05, "upstream_error_rate": 0.92},
|
||||||
|
"C": {"sla": 99.86, "ttft": 680, "error_rate": 0.88, "upstream_error_rate": 0.80},
|
||||||
|
"D": {"sla": 99.89, "ttft": 640, "error_rate": 0.72, "upstream_error_rate": 0.67},
|
||||||
|
"rollback": {"sla": 97.10, "ttft": 1550, "error_rate": 6.30, "upstream_error_rate": 5.60},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class _MockHandler(BaseHTTPRequestHandler):
|
||||||
|
def _write_json(self, payload: dict) -> None:
|
||||||
|
raw = json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
||||||
|
self.send_response(200)
|
||||||
|
self.send_header("Content-Type", "application/json")
|
||||||
|
self.send_header("Content-Length", str(len(raw)))
|
||||||
|
self.end_headers()
|
||||||
|
self.wfile.write(raw)
|
||||||
|
|
||||||
|
def log_message(self, format: str, *args): # noqa: A003
|
||||||
|
return
|
||||||
|
|
||||||
|
def do_GET(self): # noqa: N802
|
||||||
|
parsed = urlparse(self.path)
|
||||||
|
if parsed.path.endswith("/api/v1/admin/ops/settings/metric-thresholds"):
|
||||||
|
self._write_json({"code": 0, "message": "success", "data": THRESHOLDS})
|
||||||
|
return
|
||||||
|
|
||||||
|
if parsed.path.endswith("/api/v1/admin/ops/dashboard/overview"):
|
||||||
|
q = parse_qs(parsed.query)
|
||||||
|
stage = (q.get("group_id") or ["A"])[0]
|
||||||
|
snapshot = STAGE_SNAPSHOTS.get(stage, STAGE_SNAPSHOTS["A"])
|
||||||
|
self._write_json(
|
||||||
|
{
|
||||||
|
"code": 0,
|
||||||
|
"message": "success",
|
||||||
|
"data": {
|
||||||
|
"sla": snapshot["sla"],
|
||||||
|
"error_rate": snapshot["error_rate"],
|
||||||
|
"upstream_error_rate": snapshot["upstream_error_rate"],
|
||||||
|
"ttft": {"p99_ms": snapshot["ttft"]},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
self.send_response(404)
|
||||||
|
self.end_headers()
|
||||||
|
|
||||||
|
|
||||||
|
def run_guard(base_url: str, stage: str) -> Tuple[int, str]:
|
||||||
|
cmd = [
|
||||||
|
"python",
|
||||||
|
str(GUARD_SCRIPT),
|
||||||
|
"--base-url",
|
||||||
|
base_url,
|
||||||
|
"--platform",
|
||||||
|
"openai",
|
||||||
|
"--time-range",
|
||||||
|
"30m",
|
||||||
|
"--group-id",
|
||||||
|
stage,
|
||||||
|
]
|
||||||
|
proc = subprocess.run(cmd, cwd=str(ROOT), capture_output=True, text=True)
|
||||||
|
output = (proc.stdout + "\n" + proc.stderr).strip()
|
||||||
|
return proc.returncode, output
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
server = HTTPServer(("127.0.0.1", 0), _MockHandler)
|
||||||
|
host, port = server.server_address
|
||||||
|
base_url = f"http://{host}:{port}"
|
||||||
|
|
||||||
|
thread = threading.Thread(target=server.serve_forever, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
lines = [
|
||||||
|
"# OpenAI OAuth 灰度守护演练报告",
|
||||||
|
"",
|
||||||
|
"> 类型:本地 mock 演练(用于验证灰度守护与回滚触发机制)",
|
||||||
|
f"> 生成脚本:`tools/perf/openai_oauth_gray_drill.py`",
|
||||||
|
"",
|
||||||
|
"## 1. 灰度批次结果(6.1)",
|
||||||
|
"",
|
||||||
|
"| 批次 | 流量比例 | 守护脚本退出码 | 结果 |",
|
||||||
|
"|---|---:|---:|---|",
|
||||||
|
]
|
||||||
|
|
||||||
|
batch_plan = [("A", "5%"), ("B", "20%"), ("C", "50%"), ("D", "100%")]
|
||||||
|
all_pass = True
|
||||||
|
for stage, ratio in batch_plan:
|
||||||
|
code, _ = run_guard(base_url, stage)
|
||||||
|
ok = code == 0
|
||||||
|
all_pass = all_pass and ok
|
||||||
|
lines.append(f"| {stage} | {ratio} | {code} | {'通过' if ok else '失败'} |")
|
||||||
|
|
||||||
|
lines.extend([
|
||||||
|
"",
|
||||||
|
"## 2. 回滚触发演练(6.2)",
|
||||||
|
"",
|
||||||
|
])
|
||||||
|
|
||||||
|
rollback_code, rollback_output = run_guard(base_url, "rollback")
|
||||||
|
rollback_triggered = rollback_code == 2
|
||||||
|
lines.append(f"- 注入异常场景退出码:`{rollback_code}`")
|
||||||
|
lines.append(f"- 是否触发回滚条件:`{'是' if rollback_triggered else '否'}`")
|
||||||
|
lines.append("- 关键信息摘录:")
|
||||||
|
excerpt = "\n".join(rollback_output.splitlines()[:8])
|
||||||
|
lines.append("```text")
|
||||||
|
lines.append(excerpt)
|
||||||
|
lines.append("```")
|
||||||
|
|
||||||
|
lines.extend([
|
||||||
|
"",
|
||||||
|
"## 3. 验收结论(6.3)",
|
||||||
|
"",
|
||||||
|
f"- 批次灰度结果:`{'通过' if all_pass else '不通过'}`",
|
||||||
|
f"- 回滚触发机制:`{'通过' if rollback_triggered else '不通过'}`",
|
||||||
|
f"- 结论:`{'通过(可进入真实环境灰度)' if all_pass and rollback_triggered else '不通过(需修复后复测)'}`",
|
||||||
|
])
|
||||||
|
|
||||||
|
REPORT_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
REPORT_PATH.write_text("\n".join(lines) + "\n", encoding="utf-8")
|
||||||
|
|
||||||
|
server.shutdown()
|
||||||
|
server.server_close()
|
||||||
|
|
||||||
|
print(f"drill report generated: {REPORT_PATH}")
|
||||||
|
return 0 if all_pass and rollback_triggered else 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
213
tools/perf/openai_oauth_gray_guard.py
Executable file
213
tools/perf/openai_oauth_gray_guard.py
Executable file
@@ -0,0 +1,213 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""OpenAI OAuth 灰度阈值守护脚本。
|
||||||
|
|
||||||
|
用途:
|
||||||
|
- 拉取 Ops 指标阈值配置与 Dashboard Overview 实时数据
|
||||||
|
- 对比 P99 TTFT / 错误率 / SLA
|
||||||
|
- 作为 6.2 灰度守护的自动化门禁(退出码可直接用于 CI/CD)
|
||||||
|
|
||||||
|
退出码:
|
||||||
|
- 0: 指标通过
|
||||||
|
- 1: 请求失败/参数错误
|
||||||
|
- 2: 指标超阈值(建议停止扩量并回滚)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import urllib.error
|
||||||
|
import urllib.parse
|
||||||
|
import urllib.request
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GuardThresholds:
|
||||||
|
sla_percent_min: Optional[float]
|
||||||
|
ttft_p99_ms_max: Optional[float]
|
||||||
|
request_error_rate_percent_max: Optional[float]
|
||||||
|
upstream_error_rate_percent_max: Optional[float]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GuardSnapshot:
|
||||||
|
sla: Optional[float]
|
||||||
|
ttft_p99_ms: Optional[float]
|
||||||
|
request_error_rate_percent: Optional[float]
|
||||||
|
upstream_error_rate_percent: Optional[float]
|
||||||
|
|
||||||
|
|
||||||
|
def build_headers(token: str) -> Dict[str, str]:
|
||||||
|
headers = {"Accept": "application/json"}
|
||||||
|
if token.strip():
|
||||||
|
headers["Authorization"] = f"Bearer {token.strip()}"
|
||||||
|
return headers
|
||||||
|
|
||||||
|
|
||||||
|
def request_json(url: str, headers: Dict[str, str]) -> Dict[str, Any]:
|
||||||
|
req = urllib.request.Request(url=url, method="GET", headers=headers)
|
||||||
|
try:
|
||||||
|
with urllib.request.urlopen(req, timeout=15) as resp:
|
||||||
|
raw = resp.read().decode("utf-8")
|
||||||
|
return json.loads(raw)
|
||||||
|
except urllib.error.HTTPError as e:
|
||||||
|
body = e.read().decode("utf-8", errors="replace")
|
||||||
|
raise RuntimeError(f"HTTP {e.code}: {body}") from e
|
||||||
|
except urllib.error.URLError as e:
|
||||||
|
raise RuntimeError(f"request failed: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
|
def parse_envelope_data(payload: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
raise RuntimeError("invalid response payload")
|
||||||
|
if payload.get("code") != 0:
|
||||||
|
raise RuntimeError(f"api error: code={payload.get('code')} message={payload.get('message')}")
|
||||||
|
data = payload.get("data")
|
||||||
|
if not isinstance(data, dict):
|
||||||
|
raise RuntimeError("invalid response data")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def parse_thresholds(data: Dict[str, Any]) -> GuardThresholds:
|
||||||
|
return GuardThresholds(
|
||||||
|
sla_percent_min=to_float_or_none(data.get("sla_percent_min")),
|
||||||
|
ttft_p99_ms_max=to_float_or_none(data.get("ttft_p99_ms_max")),
|
||||||
|
request_error_rate_percent_max=to_float_or_none(data.get("request_error_rate_percent_max")),
|
||||||
|
upstream_error_rate_percent_max=to_float_or_none(data.get("upstream_error_rate_percent_max")),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_snapshot(data: Dict[str, Any]) -> GuardSnapshot:
|
||||||
|
ttft = data.get("ttft") if isinstance(data.get("ttft"), dict) else {}
|
||||||
|
return GuardSnapshot(
|
||||||
|
sla=to_float_or_none(data.get("sla")),
|
||||||
|
ttft_p99_ms=to_float_or_none(ttft.get("p99_ms")),
|
||||||
|
request_error_rate_percent=to_float_or_none(data.get("error_rate")),
|
||||||
|
upstream_error_rate_percent=to_float_or_none(data.get("upstream_error_rate")),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def to_float_or_none(v: Any) -> Optional[float]:
|
||||||
|
if v is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return float(v)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(snapshot: GuardSnapshot, thresholds: GuardThresholds) -> List[str]:
|
||||||
|
violations: List[str] = []
|
||||||
|
|
||||||
|
if thresholds.sla_percent_min is not None and snapshot.sla is not None:
|
||||||
|
if snapshot.sla < thresholds.sla_percent_min:
|
||||||
|
violations.append(
|
||||||
|
f"SLA 低于阈值: actual={snapshot.sla:.2f}% threshold={thresholds.sla_percent_min:.2f}%"
|
||||||
|
)
|
||||||
|
|
||||||
|
if thresholds.ttft_p99_ms_max is not None and snapshot.ttft_p99_ms is not None:
|
||||||
|
if snapshot.ttft_p99_ms > thresholds.ttft_p99_ms_max:
|
||||||
|
violations.append(
|
||||||
|
f"TTFT P99 超阈值: actual={snapshot.ttft_p99_ms:.2f}ms threshold={thresholds.ttft_p99_ms_max:.2f}ms"
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
thresholds.request_error_rate_percent_max is not None
|
||||||
|
and snapshot.request_error_rate_percent is not None
|
||||||
|
and snapshot.request_error_rate_percent > thresholds.request_error_rate_percent_max
|
||||||
|
):
|
||||||
|
violations.append(
|
||||||
|
"请求错误率超阈值: "
|
||||||
|
f"actual={snapshot.request_error_rate_percent:.2f}% "
|
||||||
|
f"threshold={thresholds.request_error_rate_percent_max:.2f}%"
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
thresholds.upstream_error_rate_percent_max is not None
|
||||||
|
and snapshot.upstream_error_rate_percent is not None
|
||||||
|
and snapshot.upstream_error_rate_percent > thresholds.upstream_error_rate_percent_max
|
||||||
|
):
|
||||||
|
violations.append(
|
||||||
|
"上游错误率超阈值: "
|
||||||
|
f"actual={snapshot.upstream_error_rate_percent:.2f}% "
|
||||||
|
f"threshold={thresholds.upstream_error_rate_percent_max:.2f}%"
|
||||||
|
)
|
||||||
|
|
||||||
|
return violations
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
parser = argparse.ArgumentParser(description="OpenAI OAuth 灰度阈值守护")
|
||||||
|
parser.add_argument("--base-url", required=True, help="服务地址,例如 http://127.0.0.1:5231")
|
||||||
|
parser.add_argument("--admin-token", default="", help="Admin JWT(可选,按部署策略)")
|
||||||
|
parser.add_argument("--platform", default="openai", help="平台过滤,默认 openai")
|
||||||
|
parser.add_argument("--time-range", default="30m", help="时间窗口: 5m/30m/1h/6h/24h/7d/30d")
|
||||||
|
parser.add_argument("--group-id", default="", help="可选 group_id")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
base = args.base_url.rstrip("/")
|
||||||
|
headers = build_headers(args.admin_token)
|
||||||
|
|
||||||
|
try:
|
||||||
|
threshold_url = f"{base}/api/v1/admin/ops/settings/metric-thresholds"
|
||||||
|
thresholds_raw = request_json(threshold_url, headers)
|
||||||
|
thresholds = parse_thresholds(parse_envelope_data(thresholds_raw))
|
||||||
|
|
||||||
|
query = {"platform": args.platform, "time_range": args.time_range}
|
||||||
|
if args.group_id.strip():
|
||||||
|
query["group_id"] = args.group_id.strip()
|
||||||
|
overview_url = (
|
||||||
|
f"{base}/api/v1/admin/ops/dashboard/overview?"
|
||||||
|
+ urllib.parse.urlencode(query)
|
||||||
|
)
|
||||||
|
overview_raw = request_json(overview_url, headers)
|
||||||
|
snapshot = parse_snapshot(parse_envelope_data(overview_raw))
|
||||||
|
|
||||||
|
print("[OpenAI OAuth Gray Guard] 当前快照:")
|
||||||
|
print(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"sla": snapshot.sla,
|
||||||
|
"ttft_p99_ms": snapshot.ttft_p99_ms,
|
||||||
|
"request_error_rate_percent": snapshot.request_error_rate_percent,
|
||||||
|
"upstream_error_rate_percent": snapshot.upstream_error_rate_percent,
|
||||||
|
},
|
||||||
|
ensure_ascii=False,
|
||||||
|
indent=2,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print("[OpenAI OAuth Gray Guard] 阈值配置:")
|
||||||
|
print(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"sla_percent_min": thresholds.sla_percent_min,
|
||||||
|
"ttft_p99_ms_max": thresholds.ttft_p99_ms_max,
|
||||||
|
"request_error_rate_percent_max": thresholds.request_error_rate_percent_max,
|
||||||
|
"upstream_error_rate_percent_max": thresholds.upstream_error_rate_percent_max,
|
||||||
|
},
|
||||||
|
ensure_ascii=False,
|
||||||
|
indent=2,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
violations = evaluate(snapshot, thresholds)
|
||||||
|
if violations:
|
||||||
|
print("[OpenAI OAuth Gray Guard] 检测到阈值违例:")
|
||||||
|
for idx, line in enumerate(violations, start=1):
|
||||||
|
print(f" {idx}. {line}")
|
||||||
|
print("[OpenAI OAuth Gray Guard] 建议:停止扩量并执行回滚。")
|
||||||
|
return 2
|
||||||
|
|
||||||
|
print("[OpenAI OAuth Gray Guard] 指标通过,可继续观察或按计划扩量。")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
print(f"[OpenAI OAuth Gray Guard] 执行失败: {exc}", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
122
tools/perf/openai_oauth_responses_k6.js
Normal file
122
tools/perf/openai_oauth_responses_k6.js
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
import http from 'k6/http';
|
||||||
|
import { check } from 'k6';
|
||||||
|
import { Rate, Trend } from 'k6/metrics';
|
||||||
|
|
||||||
|
const baseURL = __ENV.BASE_URL || 'http://127.0.0.1:5231';
|
||||||
|
const apiKey = __ENV.API_KEY || '';
|
||||||
|
const model = __ENV.MODEL || 'gpt-5';
|
||||||
|
const timeout = __ENV.TIMEOUT || '180s';
|
||||||
|
|
||||||
|
const nonStreamRPS = Number(__ENV.NON_STREAM_RPS || 8);
|
||||||
|
const streamRPS = Number(__ENV.STREAM_RPS || 4);
|
||||||
|
const duration = __ENV.DURATION || '3m';
|
||||||
|
const preAllocatedVUs = Number(__ENV.PRE_ALLOCATED_VUS || 30);
|
||||||
|
const maxVUs = Number(__ENV.MAX_VUS || 200);
|
||||||
|
|
||||||
|
const reqDurationMs = new Trend('openai_oauth_req_duration_ms', true);
|
||||||
|
const ttftMs = new Trend('openai_oauth_ttft_ms', true);
|
||||||
|
const non2xxRate = new Rate('openai_oauth_non2xx_rate');
|
||||||
|
const streamDoneRate = new Rate('openai_oauth_stream_done_rate');
|
||||||
|
|
||||||
|
export const options = {
|
||||||
|
scenarios: {
|
||||||
|
non_stream: {
|
||||||
|
executor: 'constant-arrival-rate',
|
||||||
|
rate: nonStreamRPS,
|
||||||
|
timeUnit: '1s',
|
||||||
|
duration,
|
||||||
|
preAllocatedVUs,
|
||||||
|
maxVUs,
|
||||||
|
exec: 'runNonStream',
|
||||||
|
tags: { request_type: 'non_stream' },
|
||||||
|
},
|
||||||
|
stream: {
|
||||||
|
executor: 'constant-arrival-rate',
|
||||||
|
rate: streamRPS,
|
||||||
|
timeUnit: '1s',
|
||||||
|
duration,
|
||||||
|
preAllocatedVUs,
|
||||||
|
maxVUs,
|
||||||
|
exec: 'runStream',
|
||||||
|
tags: { request_type: 'stream' },
|
||||||
|
},
|
||||||
|
},
|
||||||
|
thresholds: {
|
||||||
|
openai_oauth_non2xx_rate: ['rate<0.01'],
|
||||||
|
openai_oauth_req_duration_ms: ['p(95)<3000', 'p(99)<6000'],
|
||||||
|
openai_oauth_ttft_ms: ['p(99)<1200'],
|
||||||
|
openai_oauth_stream_done_rate: ['rate>0.99'],
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
function buildHeaders() {
|
||||||
|
const headers = {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'User-Agent': 'codex_cli_rs/0.1.0',
|
||||||
|
};
|
||||||
|
if (apiKey) {
|
||||||
|
headers.Authorization = `Bearer ${apiKey}`;
|
||||||
|
}
|
||||||
|
return headers;
|
||||||
|
}
|
||||||
|
|
||||||
|
function buildBody(stream) {
|
||||||
|
return JSON.stringify({
|
||||||
|
model,
|
||||||
|
stream,
|
||||||
|
input: [
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'input_text',
|
||||||
|
text: '请返回一句极短的话:pong',
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
max_output_tokens: 32,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
function recordMetrics(res, stream) {
|
||||||
|
reqDurationMs.add(res.timings.duration, { request_type: stream ? 'stream' : 'non_stream' });
|
||||||
|
ttftMs.add(res.timings.waiting, { request_type: stream ? 'stream' : 'non_stream' });
|
||||||
|
non2xxRate.add(res.status < 200 || res.status >= 300, { request_type: stream ? 'stream' : 'non_stream' });
|
||||||
|
|
||||||
|
if (stream) {
|
||||||
|
const done = !!res.body && res.body.indexOf('[DONE]') >= 0;
|
||||||
|
streamDoneRate.add(done, { request_type: 'stream' });
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function postResponses(stream) {
|
||||||
|
const url = `${baseURL}/v1/responses`;
|
||||||
|
const res = http.post(url, buildBody(stream), {
|
||||||
|
headers: buildHeaders(),
|
||||||
|
timeout,
|
||||||
|
tags: { endpoint: '/v1/responses', request_type: stream ? 'stream' : 'non_stream' },
|
||||||
|
});
|
||||||
|
|
||||||
|
check(res, {
|
||||||
|
'status is 2xx': (r) => r.status >= 200 && r.status < 300,
|
||||||
|
});
|
||||||
|
|
||||||
|
recordMetrics(res, stream);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function runNonStream() {
|
||||||
|
postResponses(false);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function runStream() {
|
||||||
|
postResponses(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function handleSummary(data) {
|
||||||
|
return {
|
||||||
|
stdout: `\nOpenAI OAuth /v1/responses 基线完成\n${JSON.stringify(data.metrics, null, 2)}\n`,
|
||||||
|
'docs/perf/openai-oauth-k6-summary.json': JSON.stringify(data, null, 2),
|
||||||
|
};
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user