feat(openai): 极致优化 OAuth 链路并补齐性能守护
- 优化 /v1/responses 热路径,减少重复解析与不必要拷贝\n- 优化并发与 token 竞争路径并补齐运行指标\n- 补充 OpenAI/Ops 相关单元测试与回归用例\n- 新增灰度阈值守护与压测脚本,支撑发布验收
This commit is contained in:
@@ -12,7 +12,6 @@ import (
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -34,11 +33,10 @@ const (
|
||||
// OpenAI Platform API for API Key accounts (fallback)
|
||||
openaiPlatformAPIURL = "https://api.openai.com/v1/responses"
|
||||
openaiStickySessionTTL = time.Hour // 粘性会话TTL
|
||||
)
|
||||
|
||||
// openaiSSEDataRe matches SSE data lines with optional whitespace after colon.
|
||||
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
|
||||
var openaiSSEDataRe = regexp.MustCompile(`^data:\s*`)
|
||||
// OpenAIParsedRequestBodyKey 缓存 handler 侧已解析的请求体,避免重复解析。
|
||||
OpenAIParsedRequestBodyKey = "openai_parsed_request_body"
|
||||
)
|
||||
|
||||
// OpenAI allowed headers whitelist (for non-OAuth accounts)
|
||||
var openaiAllowedHeaders = map[string]bool{
|
||||
@@ -745,32 +743,37 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
startTime := time.Now()
|
||||
|
||||
originalBody := body
|
||||
reqModel, reqStream, promptCacheKey := extractOpenAIRequestMetaFromBody(body)
|
||||
originalModel := reqModel
|
||||
|
||||
// Parse request body once (avoid multiple parse/serialize cycles)
|
||||
var reqBody map[string]any
|
||||
if err := json.Unmarshal(body, &reqBody); err != nil {
|
||||
return nil, fmt.Errorf("parse request: %w", err)
|
||||
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
|
||||
passthroughEnabled := account.Type == AccountTypeOAuth && account.IsOpenAIOAuthPassthroughEnabled() && isCodexCLI
|
||||
if passthroughEnabled {
|
||||
// 透传分支只需要轻量提取字段,避免热路径全量 Unmarshal。
|
||||
reasoningEffort := extractOpenAIReasoningEffortFromBody(body, reqModel)
|
||||
return s.forwardOAuthPassthrough(ctx, c, account, originalBody, reqModel, reasoningEffort, reqStream, startTime)
|
||||
}
|
||||
|
||||
// Extract model and stream from parsed body
|
||||
reqModel, _ := reqBody["model"].(string)
|
||||
reqStream, _ := reqBody["stream"].(bool)
|
||||
promptCacheKey := ""
|
||||
if v, ok := reqBody["prompt_cache_key"].(string); ok {
|
||||
promptCacheKey = strings.TrimSpace(v)
|
||||
reqBody, err := getOpenAIRequestBodyMap(c, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if v, ok := reqBody["model"].(string); ok {
|
||||
reqModel = v
|
||||
originalModel = reqModel
|
||||
}
|
||||
if v, ok := reqBody["stream"].(bool); ok {
|
||||
reqStream = v
|
||||
}
|
||||
if promptCacheKey == "" {
|
||||
if v, ok := reqBody["prompt_cache_key"].(string); ok {
|
||||
promptCacheKey = strings.TrimSpace(v)
|
||||
}
|
||||
}
|
||||
|
||||
// Track if body needs re-serialization
|
||||
bodyModified := false
|
||||
originalModel := reqModel
|
||||
|
||||
isCodexCLI := openai.IsCodexCLIRequest(c.GetHeader("User-Agent")) || (s.cfg != nil && s.cfg.Gateway.ForceCodexCLI)
|
||||
|
||||
passthroughEnabled := account.Type == AccountTypeOAuth && account.IsOpenAIOAuthPassthroughEnabled() && isCodexCLI
|
||||
if passthroughEnabled {
|
||||
reasoningEffort := extractOpenAIReasoningEffort(reqBody, reqModel)
|
||||
return s.forwardOAuthPassthrough(ctx, c, account, originalBody, reqModel, reasoningEffort, reqStream, startTime)
|
||||
}
|
||||
|
||||
// 对所有请求执行模型映射(包含 Codex CLI)。
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
@@ -888,12 +891,12 @@ func (s *OpenAIGatewayService) Forward(ctx context.Context, c *gin.Context, acco
|
||||
}
|
||||
|
||||
// Capture upstream request body for ops retry of this attempt.
|
||||
if c != nil {
|
||||
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
||||
}
|
||||
setOpsUpstreamRequestBody(c, body)
|
||||
|
||||
// Send request
|
||||
upstreamStart := time.Now()
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds())
|
||||
if err != nil {
|
||||
// Ensure the client receives an error response (handlers assume Forward writes on non-failover errors).
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
@@ -1019,12 +1022,14 @@ func (s *OpenAIGatewayService) forwardOAuthPassthrough(
|
||||
proxyURL = account.Proxy.URL()
|
||||
}
|
||||
|
||||
setOpsUpstreamRequestBody(c, body)
|
||||
if c != nil {
|
||||
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
||||
c.Set("openai_passthrough", true)
|
||||
}
|
||||
|
||||
upstreamStart := time.Now()
|
||||
resp, err := s.httpUpstream.Do(upstreamReq, proxyURL, account.ID, account.Concurrency)
|
||||
SetOpsLatencyMs(c, OpsUpstreamLatencyMsKey, time.Since(upstreamStart).Milliseconds())
|
||||
if err != nil {
|
||||
safeErr := sanitizeUpstreamErrorMessage(err.Error())
|
||||
setOpsUpstreamError(c, 0, safeErr, "")
|
||||
@@ -1240,8 +1245,7 @@ func (s *OpenAIGatewayService) handleStreamingResponsePassthrough(
|
||||
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if openaiSSEDataRe.MatchString(line) {
|
||||
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
||||
if data, ok := extractOpenAISSEDataLine(line); ok {
|
||||
if firstTokenMs == nil && strings.TrimSpace(data) != "" {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
@@ -1750,8 +1754,7 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
lastDataAt = time.Now()
|
||||
|
||||
// Extract data from SSE line (supports both "data: " and "data:" formats)
|
||||
if openaiSSEDataRe.MatchString(line) {
|
||||
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
||||
if data, ok := extractOpenAISSEDataLine(line); ok {
|
||||
|
||||
// Replace model in response if needed
|
||||
if needModelReplace {
|
||||
@@ -1827,11 +1830,27 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
|
||||
|
||||
}
|
||||
|
||||
// extractOpenAISSEDataLine 低开销提取 SSE `data:` 行内容。
|
||||
// 兼容 `data: xxx` 与 `data:xxx` 两种格式。
|
||||
func extractOpenAISSEDataLine(line string) (string, bool) {
|
||||
if !strings.HasPrefix(line, "data:") {
|
||||
return "", false
|
||||
}
|
||||
start := len("data:")
|
||||
for start < len(line) {
|
||||
if line[start] != ' ' && line[start] != ' ' {
|
||||
break
|
||||
}
|
||||
start++
|
||||
}
|
||||
return line[start:], true
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) replaceModelInSSELine(line, fromModel, toModel string) string {
|
||||
if !openaiSSEDataRe.MatchString(line) {
|
||||
data, ok := extractOpenAISSEDataLine(line)
|
||||
if !ok {
|
||||
return line
|
||||
}
|
||||
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
||||
if data == "" || data == "[DONE]" {
|
||||
return line
|
||||
}
|
||||
@@ -1872,25 +1891,20 @@ func (s *OpenAIGatewayService) correctToolCallsInResponseBody(body []byte) []byt
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
|
||||
// Parse response.completed event for usage (OpenAI Responses format)
|
||||
var event struct {
|
||||
Type string `json:"type"`
|
||||
Response struct {
|
||||
Usage struct {
|
||||
InputTokens int `json:"input_tokens"`
|
||||
OutputTokens int `json:"output_tokens"`
|
||||
InputTokenDetails struct {
|
||||
CachedTokens int `json:"cached_tokens"`
|
||||
} `json:"input_tokens_details"`
|
||||
} `json:"usage"`
|
||||
} `json:"response"`
|
||||
if usage == nil || data == "" || data == "[DONE]" {
|
||||
return
|
||||
}
|
||||
// 选择性解析:仅在数据中包含 completed 事件标识时才进入字段提取。
|
||||
if !strings.Contains(data, `"response.completed"`) {
|
||||
return
|
||||
}
|
||||
if gjson.Get(data, "type").String() != "response.completed" {
|
||||
return
|
||||
}
|
||||
|
||||
if json.Unmarshal([]byte(data), &event) == nil && event.Type == "response.completed" {
|
||||
usage.InputTokens = event.Response.Usage.InputTokens
|
||||
usage.OutputTokens = event.Response.Usage.OutputTokens
|
||||
usage.CacheReadInputTokens = event.Response.Usage.InputTokenDetails.CachedTokens
|
||||
}
|
||||
usage.InputTokens = int(gjson.Get(data, "response.usage.input_tokens").Int())
|
||||
usage.OutputTokens = int(gjson.Get(data, "response.usage.output_tokens").Int())
|
||||
usage.CacheReadInputTokens = int(gjson.Get(data, "response.usage.input_tokens_details.cached_tokens").Int())
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
|
||||
@@ -2001,10 +2015,10 @@ func (s *OpenAIGatewayService) handleOAuthSSEToJSON(resp *http.Response, c *gin.
|
||||
func extractCodexFinalResponse(body string) ([]byte, bool) {
|
||||
lines := strings.Split(body, "\n")
|
||||
for _, line := range lines {
|
||||
if !openaiSSEDataRe.MatchString(line) {
|
||||
data, ok := extractOpenAISSEDataLine(line)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
||||
if data == "" || data == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
@@ -2028,10 +2042,10 @@ func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage {
|
||||
usage := &OpenAIUsage{}
|
||||
lines := strings.Split(body, "\n")
|
||||
for _, line := range lines {
|
||||
if !openaiSSEDataRe.MatchString(line) {
|
||||
data, ok := extractOpenAISSEDataLine(line)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
data := openaiSSEDataRe.ReplaceAllString(line, "")
|
||||
if data == "" || data == "[DONE]" {
|
||||
continue
|
||||
}
|
||||
@@ -2043,7 +2057,7 @@ func (s *OpenAIGatewayService) parseSSEUsageFromBody(body string) *OpenAIUsage {
|
||||
func (s *OpenAIGatewayService) replaceModelInSSEBody(body, fromModel, toModel string) string {
|
||||
lines := strings.Split(body, "\n")
|
||||
for i, line := range lines {
|
||||
if !openaiSSEDataRe.MatchString(line) {
|
||||
if _, ok := extractOpenAISSEDataLine(line); !ok {
|
||||
continue
|
||||
}
|
||||
lines[i] = s.replaceModelInSSELine(line, fromModel, toModel)
|
||||
@@ -2396,6 +2410,53 @@ func deriveOpenAIReasoningEffortFromModel(model string) string {
|
||||
return normalizeOpenAIReasoningEffort(parts[len(parts)-1])
|
||||
}
|
||||
|
||||
func extractOpenAIRequestMetaFromBody(body []byte) (model string, stream bool, promptCacheKey string) {
|
||||
if len(body) == 0 {
|
||||
return "", false, ""
|
||||
}
|
||||
|
||||
model = strings.TrimSpace(gjson.GetBytes(body, "model").String())
|
||||
stream = gjson.GetBytes(body, "stream").Bool()
|
||||
promptCacheKey = strings.TrimSpace(gjson.GetBytes(body, "prompt_cache_key").String())
|
||||
return model, stream, promptCacheKey
|
||||
}
|
||||
|
||||
func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *string {
|
||||
reasoningEffort := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String())
|
||||
if reasoningEffort == "" {
|
||||
reasoningEffort = strings.TrimSpace(gjson.GetBytes(body, "reasoning_effort").String())
|
||||
}
|
||||
if reasoningEffort != "" {
|
||||
normalized := normalizeOpenAIReasoningEffort(reasoningEffort)
|
||||
if normalized == "" {
|
||||
return nil
|
||||
}
|
||||
return &normalized
|
||||
}
|
||||
|
||||
value := deriveOpenAIReasoningEffortFromModel(requestedModel)
|
||||
if value == "" {
|
||||
return nil
|
||||
}
|
||||
return &value
|
||||
}
|
||||
|
||||
func getOpenAIRequestBodyMap(c *gin.Context, body []byte) (map[string]any, error) {
|
||||
if c != nil {
|
||||
if cached, ok := c.Get(OpenAIParsedRequestBodyKey); ok {
|
||||
if reqBody, ok := cached.(map[string]any); ok && reqBody != nil {
|
||||
return reqBody, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var reqBody map[string]any
|
||||
if err := json.Unmarshal(body, &reqBody); err != nil {
|
||||
return nil, fmt.Errorf("parse request: %w", err)
|
||||
}
|
||||
return reqBody, nil
|
||||
}
|
||||
|
||||
func extractOpenAIReasoningEffort(reqBody map[string]any, requestedModel string) *string {
|
||||
if value, present := getOpenAIReasoningEffortFromReqBody(reqBody); present {
|
||||
if value == "" {
|
||||
|
||||
125
backend/internal/service/openai_gateway_service_hotpath_test.go
Normal file
125
backend/internal/service/openai_gateway_service_hotpath_test.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestExtractOpenAIRequestMetaFromBody(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body []byte
|
||||
wantModel string
|
||||
wantStream bool
|
||||
wantPromptKey string
|
||||
}{
|
||||
{
|
||||
name: "完整字段",
|
||||
body: []byte(`{"model":"gpt-5","stream":true,"prompt_cache_key":" ses-1 "}`),
|
||||
wantModel: "gpt-5",
|
||||
wantStream: true,
|
||||
wantPromptKey: "ses-1",
|
||||
},
|
||||
{
|
||||
name: "缺失可选字段",
|
||||
body: []byte(`{"model":"gpt-4"}`),
|
||||
wantModel: "gpt-4",
|
||||
wantStream: false,
|
||||
wantPromptKey: "",
|
||||
},
|
||||
{
|
||||
name: "空请求体",
|
||||
body: nil,
|
||||
wantModel: "",
|
||||
wantStream: false,
|
||||
wantPromptKey: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
model, stream, promptKey := extractOpenAIRequestMetaFromBody(tt.body)
|
||||
require.Equal(t, tt.wantModel, model)
|
||||
require.Equal(t, tt.wantStream, stream)
|
||||
require.Equal(t, tt.wantPromptKey, promptKey)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractOpenAIReasoningEffortFromBody(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body []byte
|
||||
model string
|
||||
wantNil bool
|
||||
wantValue string
|
||||
}{
|
||||
{
|
||||
name: "优先读取 reasoning.effort",
|
||||
body: []byte(`{"reasoning":{"effort":"medium"}}`),
|
||||
model: "gpt-5-high",
|
||||
wantNil: false,
|
||||
wantValue: "medium",
|
||||
},
|
||||
{
|
||||
name: "兼容 reasoning_effort",
|
||||
body: []byte(`{"reasoning_effort":"x-high"}`),
|
||||
model: "",
|
||||
wantNil: false,
|
||||
wantValue: "xhigh",
|
||||
},
|
||||
{
|
||||
name: "minimal 归一化为空",
|
||||
body: []byte(`{"reasoning":{"effort":"minimal"}}`),
|
||||
model: "gpt-5-high",
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "缺失字段时从模型后缀推导",
|
||||
body: []byte(`{"input":"hi"}`),
|
||||
model: "gpt-5-high",
|
||||
wantNil: false,
|
||||
wantValue: "high",
|
||||
},
|
||||
{
|
||||
name: "未知后缀不返回",
|
||||
body: []byte(`{"input":"hi"}`),
|
||||
model: "gpt-5-unknown",
|
||||
wantNil: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := extractOpenAIReasoningEffortFromBody(tt.body, tt.model)
|
||||
if tt.wantNil {
|
||||
require.Nil(t, got)
|
||||
return
|
||||
}
|
||||
require.NotNil(t, got)
|
||||
require.Equal(t, tt.wantValue, *got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOpenAIRequestBodyMap_UsesContextCache(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
cached := map[string]any{"model": "cached-model", "stream": true}
|
||||
c.Set(OpenAIParsedRequestBodyKey, cached)
|
||||
|
||||
got, err := getOpenAIRequestBodyMap(c, []byte(`{invalid-json`))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, cached, got)
|
||||
}
|
||||
|
||||
func TestGetOpenAIRequestBodyMap_ParseErrorWithoutCache(t *testing.T) {
|
||||
_, err := getOpenAIRequestBodyMap(nil, []byte(`{invalid-json`))
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "parse request")
|
||||
}
|
||||
@@ -1416,3 +1416,109 @@ func TestReplaceModelInResponseBody(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractOpenAISSEDataLine(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
line string
|
||||
wantData string
|
||||
wantOK bool
|
||||
}{
|
||||
{name: "标准格式", line: `data: {"type":"x"}`, wantData: `{"type":"x"}`, wantOK: true},
|
||||
{name: "无空格格式", line: `data:{"type":"x"}`, wantData: `{"type":"x"}`, wantOK: true},
|
||||
{name: "纯空数据", line: `data: `, wantData: ``, wantOK: true},
|
||||
{name: "非 data 行", line: `event: message`, wantData: ``, wantOK: false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, ok := extractOpenAISSEDataLine(tt.line)
|
||||
require.Equal(t, tt.wantOK, ok)
|
||||
require.Equal(t, tt.wantData, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseSSEUsage_SelectiveParsing(t *testing.T) {
|
||||
svc := &OpenAIGatewayService{}
|
||||
usage := &OpenAIUsage{InputTokens: 9, OutputTokens: 8, CacheReadInputTokens: 7}
|
||||
|
||||
// 非 completed 事件,不应覆盖 usage
|
||||
svc.parseSSEUsage(`{"type":"response.in_progress","response":{"usage":{"input_tokens":1,"output_tokens":2}}}`, usage)
|
||||
require.Equal(t, 9, usage.InputTokens)
|
||||
require.Equal(t, 8, usage.OutputTokens)
|
||||
require.Equal(t, 7, usage.CacheReadInputTokens)
|
||||
|
||||
// completed 事件,应提取 usage
|
||||
svc.parseSSEUsage(`{"type":"response.completed","response":{"usage":{"input_tokens":3,"output_tokens":5,"input_tokens_details":{"cached_tokens":2}}}}`, usage)
|
||||
require.Equal(t, 3, usage.InputTokens)
|
||||
require.Equal(t, 5, usage.OutputTokens)
|
||||
require.Equal(t, 2, usage.CacheReadInputTokens)
|
||||
}
|
||||
|
||||
func TestExtractCodexFinalResponse_SampleReplay(t *testing.T) {
|
||||
body := strings.Join([]string{
|
||||
`event: message`,
|
||||
`data: {"type":"response.in_progress","response":{"id":"resp_1"}}`,
|
||||
`data: {"type":"response.completed","response":{"id":"resp_1","model":"gpt-4o","usage":{"input_tokens":11,"output_tokens":22,"input_tokens_details":{"cached_tokens":3}}}}`,
|
||||
`data: [DONE]`,
|
||||
}, "\n")
|
||||
|
||||
finalResp, ok := extractCodexFinalResponse(body)
|
||||
require.True(t, ok)
|
||||
require.Contains(t, string(finalResp), `"id":"resp_1"`)
|
||||
require.Contains(t, string(finalResp), `"input_tokens":11`)
|
||||
}
|
||||
|
||||
func TestHandleOAuthSSEToJSON_CompletedEventReturnsJSON(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
svc := &OpenAIGatewayService{cfg: &config.Config{}}
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||
}
|
||||
body := []byte(strings.Join([]string{
|
||||
`data: {"type":"response.in_progress","response":{"id":"resp_2"}}`,
|
||||
`data: {"type":"response.completed","response":{"id":"resp_2","model":"gpt-4o","usage":{"input_tokens":7,"output_tokens":9,"input_tokens_details":{"cached_tokens":1}}}}`,
|
||||
`data: [DONE]`,
|
||||
}, "\n"))
|
||||
|
||||
usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usage)
|
||||
require.Equal(t, 7, usage.InputTokens)
|
||||
require.Equal(t, 9, usage.OutputTokens)
|
||||
require.Equal(t, 1, usage.CacheReadInputTokens)
|
||||
// Header 可能由上游 Content-Type 透传;关键是 body 已转换为最终 JSON 响应。
|
||||
require.NotContains(t, rec.Body.String(), "event:")
|
||||
require.Contains(t, rec.Body.String(), `"id":"resp_2"`)
|
||||
require.NotContains(t, rec.Body.String(), "data:")
|
||||
}
|
||||
|
||||
func TestHandleOAuthSSEToJSON_NoFinalResponseKeepsSSEBody(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
|
||||
svc := &OpenAIGatewayService{cfg: &config.Config{}}
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"text/event-stream"}},
|
||||
}
|
||||
body := []byte(strings.Join([]string{
|
||||
`data: {"type":"response.in_progress","response":{"id":"resp_3"}}`,
|
||||
`data: [DONE]`,
|
||||
}, "\n"))
|
||||
|
||||
usage, err := svc.handleOAuthSSEToJSON(resp, c, body, "gpt-4o", "gpt-4o")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usage)
|
||||
require.Equal(t, 0, usage.InputTokens)
|
||||
require.Contains(t, rec.Header().Get("Content-Type"), "text/event-stream")
|
||||
require.Contains(t, rec.Body.String(), `data: {"type":"response.in_progress"`)
|
||||
}
|
||||
|
||||
@@ -4,16 +4,74 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"math/rand/v2"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
openAITokenRefreshSkew = 3 * time.Minute
|
||||
openAITokenCacheSkew = 5 * time.Minute
|
||||
openAILockWaitTime = 200 * time.Millisecond
|
||||
openAITokenRefreshSkew = 3 * time.Minute
|
||||
openAITokenCacheSkew = 5 * time.Minute
|
||||
openAILockInitialWait = 20 * time.Millisecond
|
||||
openAILockMaxWait = 120 * time.Millisecond
|
||||
openAILockMaxAttempts = 5
|
||||
openAILockJitterRatio = 0.2
|
||||
openAILockWarnThresholdMs = 250
|
||||
)
|
||||
|
||||
// OpenAITokenRuntimeMetrics 表示 OpenAI token 刷新与锁竞争保护指标快照。
|
||||
type OpenAITokenRuntimeMetrics struct {
|
||||
RefreshRequests int64
|
||||
RefreshSuccess int64
|
||||
RefreshFailure int64
|
||||
LockAcquireFailure int64
|
||||
LockContention int64
|
||||
LockWaitSamples int64
|
||||
LockWaitTotalMs int64
|
||||
LockWaitHit int64
|
||||
LockWaitMiss int64
|
||||
LastObservedUnixMs int64
|
||||
}
|
||||
|
||||
type openAITokenRuntimeMetricsStore struct {
|
||||
refreshRequests atomic.Int64
|
||||
refreshSuccess atomic.Int64
|
||||
refreshFailure atomic.Int64
|
||||
lockAcquireFailure atomic.Int64
|
||||
lockContention atomic.Int64
|
||||
lockWaitSamples atomic.Int64
|
||||
lockWaitTotalMs atomic.Int64
|
||||
lockWaitHit atomic.Int64
|
||||
lockWaitMiss atomic.Int64
|
||||
lastObservedUnixMs atomic.Int64
|
||||
}
|
||||
|
||||
func (m *openAITokenRuntimeMetricsStore) snapshot() OpenAITokenRuntimeMetrics {
|
||||
if m == nil {
|
||||
return OpenAITokenRuntimeMetrics{}
|
||||
}
|
||||
return OpenAITokenRuntimeMetrics{
|
||||
RefreshRequests: m.refreshRequests.Load(),
|
||||
RefreshSuccess: m.refreshSuccess.Load(),
|
||||
RefreshFailure: m.refreshFailure.Load(),
|
||||
LockAcquireFailure: m.lockAcquireFailure.Load(),
|
||||
LockContention: m.lockContention.Load(),
|
||||
LockWaitSamples: m.lockWaitSamples.Load(),
|
||||
LockWaitTotalMs: m.lockWaitTotalMs.Load(),
|
||||
LockWaitHit: m.lockWaitHit.Load(),
|
||||
LockWaitMiss: m.lockWaitMiss.Load(),
|
||||
LastObservedUnixMs: m.lastObservedUnixMs.Load(),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *openAITokenRuntimeMetricsStore) touchNow() {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
m.lastObservedUnixMs.Store(time.Now().UnixMilli())
|
||||
}
|
||||
|
||||
// OpenAITokenCache Token 缓存接口(复用 GeminiTokenCache 接口定义)
|
||||
type OpenAITokenCache = GeminiTokenCache
|
||||
|
||||
@@ -22,6 +80,7 @@ type OpenAITokenProvider struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache OpenAITokenCache
|
||||
openAIOAuthService *OpenAIOAuthService
|
||||
metrics *openAITokenRuntimeMetricsStore
|
||||
}
|
||||
|
||||
func NewOpenAITokenProvider(
|
||||
@@ -33,11 +92,27 @@ func NewOpenAITokenProvider(
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: tokenCache,
|
||||
openAIOAuthService: openAIOAuthService,
|
||||
metrics: &openAITokenRuntimeMetricsStore{},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *OpenAITokenProvider) SnapshotRuntimeMetrics() OpenAITokenRuntimeMetrics {
|
||||
if p == nil {
|
||||
return OpenAITokenRuntimeMetrics{}
|
||||
}
|
||||
p.ensureMetrics()
|
||||
return p.metrics.snapshot()
|
||||
}
|
||||
|
||||
func (p *OpenAITokenProvider) ensureMetrics() {
|
||||
if p != nil && p.metrics == nil {
|
||||
p.metrics = &openAITokenRuntimeMetricsStore{}
|
||||
}
|
||||
}
|
||||
|
||||
// GetAccessToken 获取有效的 access_token
|
||||
func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
p.ensureMetrics()
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
@@ -64,6 +139,8 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew
|
||||
refreshFailed := false
|
||||
if needsRefresh && p.tokenCache != nil {
|
||||
p.metrics.refreshRequests.Add(1)
|
||||
p.metrics.touchNow()
|
||||
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if lockErr == nil && locked {
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
@@ -82,14 +159,17 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
||||
if p.openAIOAuthService == nil {
|
||||
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
|
||||
p.metrics.refreshFailure.Add(1)
|
||||
refreshFailed = true // 无法刷新,标记失败
|
||||
} else {
|
||||
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
// 刷新失败时记录警告,但不立即返回错误,尝试使用现有 token
|
||||
slog.Warn("openai_token_refresh_failed", "account_id", account.ID, "error", err)
|
||||
p.metrics.refreshFailure.Add(1)
|
||||
refreshFailed = true // 刷新失败,标记以使用短 TTL
|
||||
} else {
|
||||
p.metrics.refreshSuccess.Add(1)
|
||||
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
@@ -106,6 +186,8 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
}
|
||||
} else if lockErr != nil {
|
||||
// Redis 错误导致无法获取锁,降级为无锁刷新(仅在 token 接近过期时)
|
||||
p.metrics.lockAcquireFailure.Add(1)
|
||||
p.metrics.touchNow()
|
||||
slog.Warn("openai_token_lock_failed_degraded_refresh", "account_id", account.ID, "error", lockErr)
|
||||
|
||||
// 检查 ctx 是否已取消
|
||||
@@ -126,13 +208,16 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= openAITokenRefreshSkew {
|
||||
if p.openAIOAuthService == nil {
|
||||
slog.Warn("openai_oauth_service_not_configured", "account_id", account.ID)
|
||||
p.metrics.refreshFailure.Add(1)
|
||||
refreshFailed = true
|
||||
} else {
|
||||
tokenInfo, err := p.openAIOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
slog.Warn("openai_token_refresh_failed_degraded", "account_id", account.ID, "error", err)
|
||||
p.metrics.refreshFailure.Add(1)
|
||||
refreshFailed = true
|
||||
} else {
|
||||
p.metrics.refreshSuccess.Add(1)
|
||||
newCredentials := p.openAIOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
for k, v := range account.Credentials {
|
||||
if _, exists := newCredentials[k]; !exists {
|
||||
@@ -148,9 +233,14 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 锁获取失败(被其他 worker 持有),等待 200ms 后重试读取缓存
|
||||
time.Sleep(openAILockWaitTime)
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
// 锁被其他 worker 持有:使用短轮询+jitter,降低固定等待导致的尾延迟台阶。
|
||||
p.metrics.lockContention.Add(1)
|
||||
p.metrics.touchNow()
|
||||
token, waitErr := p.waitForTokenAfterLockRace(ctx, cacheKey)
|
||||
if waitErr != nil {
|
||||
return "", waitErr
|
||||
}
|
||||
if strings.TrimSpace(token) != "" {
|
||||
slog.Debug("openai_token_cache_hit_after_wait", "account_id", account.ID)
|
||||
return token, nil
|
||||
}
|
||||
@@ -198,3 +288,64 @@ func (p *OpenAITokenProvider) GetAccessToken(ctx context.Context, account *Accou
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func (p *OpenAITokenProvider) waitForTokenAfterLockRace(ctx context.Context, cacheKey string) (string, error) {
|
||||
wait := openAILockInitialWait
|
||||
totalWaitMs := int64(0)
|
||||
for i := 0; i < openAILockMaxAttempts; i++ {
|
||||
actualWait := jitterLockWait(wait)
|
||||
timer := time.NewTimer(actualWait)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if !timer.Stop() {
|
||||
select {
|
||||
case <-timer.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
return "", ctx.Err()
|
||||
case <-timer.C:
|
||||
}
|
||||
|
||||
waitMs := actualWait.Milliseconds()
|
||||
if waitMs < 0 {
|
||||
waitMs = 0
|
||||
}
|
||||
totalWaitMs += waitMs
|
||||
p.metrics.lockWaitSamples.Add(1)
|
||||
p.metrics.lockWaitTotalMs.Add(waitMs)
|
||||
p.metrics.touchNow()
|
||||
|
||||
token, err := p.tokenCache.GetAccessToken(ctx, cacheKey)
|
||||
if err == nil && strings.TrimSpace(token) != "" {
|
||||
p.metrics.lockWaitHit.Add(1)
|
||||
if totalWaitMs >= openAILockWarnThresholdMs {
|
||||
slog.Warn("openai_token_lock_wait_high", "wait_ms", totalWaitMs, "attempts", i+1)
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
if wait < openAILockMaxWait {
|
||||
wait *= 2
|
||||
if wait > openAILockMaxWait {
|
||||
wait = openAILockMaxWait
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
p.metrics.lockWaitMiss.Add(1)
|
||||
if totalWaitMs >= openAILockWarnThresholdMs {
|
||||
slog.Warn("openai_token_lock_wait_high", "wait_ms", totalWaitMs, "attempts", openAILockMaxAttempts)
|
||||
}
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func jitterLockWait(base time.Duration) time.Duration {
|
||||
if base <= 0 {
|
||||
return 0
|
||||
}
|
||||
minFactor := 1 - openAILockJitterRatio
|
||||
maxFactor := 1 + openAILockJitterRatio
|
||||
factor := minFactor + rand.Float64()*(maxFactor-minFactor)
|
||||
return time.Duration(float64(base) * factor)
|
||||
}
|
||||
|
||||
@@ -808,3 +808,119 @@ func TestOpenAITokenProvider_Real_NilCredentials(t *testing.T) {
|
||||
require.Contains(t, err.Error(), "access_token not found")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_Real_LockRace_PollingHitsCache(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
cache.lockAcquired = false // 模拟锁被其他 worker 持有
|
||||
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 207,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
go func() {
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
cache.mu.Lock()
|
||||
cache.tokens[cacheKey] = "winner-token"
|
||||
cache.mu.Unlock()
|
||||
}()
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "winner-token", token)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_Real_LockRace_ContextCanceled(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
cache.lockAcquired = false // 模拟锁被其他 worker 持有
|
||||
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 208,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
start := time.Now()
|
||||
token, err := provider.GetAccessToken(ctx, account)
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
require.Empty(t, token)
|
||||
require.Less(t, time.Since(start), 50*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_RuntimeMetrics_LockWaitHitAndSnapshot(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
cache.lockAcquired = false
|
||||
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 209,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
cacheKey := OpenAITokenCacheKey(account)
|
||||
go func() {
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
cache.mu.Lock()
|
||||
cache.tokens[cacheKey] = "winner-token"
|
||||
cache.mu.Unlock()
|
||||
}()
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "winner-token", token)
|
||||
|
||||
metrics := provider.SnapshotRuntimeMetrics()
|
||||
require.GreaterOrEqual(t, metrics.RefreshRequests, int64(1))
|
||||
require.GreaterOrEqual(t, metrics.LockContention, int64(1))
|
||||
require.GreaterOrEqual(t, metrics.LockWaitSamples, int64(1))
|
||||
require.GreaterOrEqual(t, metrics.LockWaitHit, int64(1))
|
||||
require.GreaterOrEqual(t, metrics.LockWaitTotalMs, int64(0))
|
||||
require.GreaterOrEqual(t, metrics.LastObservedUnixMs, int64(1))
|
||||
}
|
||||
|
||||
func TestOpenAITokenProvider_RuntimeMetrics_LockAcquireFailure(t *testing.T) {
|
||||
cache := newOpenAITokenCacheStub()
|
||||
cache.lockErr = errors.New("redis lock error")
|
||||
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 210,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewOpenAITokenProvider(nil, cache, nil)
|
||||
_, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
metrics := provider.SnapshotRuntimeMetrics()
|
||||
require.GreaterOrEqual(t, metrics.LockAcquireFailure, int64(1))
|
||||
require.GreaterOrEqual(t, metrics.RefreshRequests, int64(1))
|
||||
}
|
||||
|
||||
@@ -98,6 +98,10 @@ type OpsInsertErrorLogInput struct {
|
||||
// It is set by OpsService.RecordError before persisting.
|
||||
UpstreamErrorsJSON *string
|
||||
|
||||
AuthLatencyMs *int64
|
||||
RoutingLatencyMs *int64
|
||||
UpstreamLatencyMs *int64
|
||||
ResponseLatencyMs *int64
|
||||
TimeToFirstTokenMs *int64
|
||||
|
||||
RequestBodyJSON *string // sanitized json string (not raw bytes)
|
||||
|
||||
@@ -20,8 +20,30 @@ const (
|
||||
// retry the specific upstream attempt (not just the client request).
|
||||
// This value is sanitized+trimmed before being persisted.
|
||||
OpsUpstreamRequestBodyKey = "ops_upstream_request_body"
|
||||
|
||||
// Optional stage latencies (milliseconds) for troubleshooting and alerting.
|
||||
OpsAuthLatencyMsKey = "ops_auth_latency_ms"
|
||||
OpsRoutingLatencyMsKey = "ops_routing_latency_ms"
|
||||
OpsUpstreamLatencyMsKey = "ops_upstream_latency_ms"
|
||||
OpsResponseLatencyMsKey = "ops_response_latency_ms"
|
||||
OpsTimeToFirstTokenMsKey = "ops_time_to_first_token_ms"
|
||||
)
|
||||
|
||||
func setOpsUpstreamRequestBody(c *gin.Context, body []byte) {
|
||||
if c == nil || len(body) == 0 {
|
||||
return
|
||||
}
|
||||
// 热路径避免 string(body) 额外分配,按需在落库前再转换。
|
||||
c.Set(OpsUpstreamRequestBodyKey, body)
|
||||
}
|
||||
|
||||
func SetOpsLatencyMs(c *gin.Context, key string, value int64) {
|
||||
if c == nil || strings.TrimSpace(key) == "" || value < 0 {
|
||||
return
|
||||
}
|
||||
c.Set(key, value)
|
||||
}
|
||||
|
||||
func setOpsUpstreamError(c *gin.Context, upstreamStatusCode int, upstreamMessage, upstreamDetail string) {
|
||||
if c == nil {
|
||||
return
|
||||
@@ -91,8 +113,11 @@ func appendOpsUpstreamError(c *gin.Context, ev OpsUpstreamErrorEvent) {
|
||||
// stored it on the context, attach it so ops can retry this specific attempt.
|
||||
if ev.UpstreamRequestBody == "" {
|
||||
if v, ok := c.Get(OpsUpstreamRequestBodyKey); ok {
|
||||
if s, ok := v.(string); ok {
|
||||
ev.UpstreamRequestBody = strings.TrimSpace(s)
|
||||
switch raw := v.(type) {
|
||||
case string:
|
||||
ev.UpstreamRequestBody = strings.TrimSpace(raw)
|
||||
case []byte:
|
||||
ev.UpstreamRequestBody = strings.TrimSpace(string(raw))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
47
backend/internal/service/ops_upstream_context_test.go
Normal file
47
backend/internal/service/ops_upstream_context_test.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAppendOpsUpstreamError_UsesRequestBodyBytesFromContext(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
setOpsUpstreamRequestBody(c, []byte(`{"model":"gpt-5"}`))
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Kind: "http_error",
|
||||
Message: "upstream failed",
|
||||
})
|
||||
|
||||
v, ok := c.Get(OpsUpstreamErrorsKey)
|
||||
require.True(t, ok)
|
||||
events, ok := v.([]*OpsUpstreamErrorEvent)
|
||||
require.True(t, ok)
|
||||
require.Len(t, events, 1)
|
||||
require.Equal(t, `{"model":"gpt-5"}`, events[0].UpstreamRequestBody)
|
||||
}
|
||||
|
||||
func TestAppendOpsUpstreamError_UsesRequestBodyStringFromContext(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
|
||||
c.Set(OpsUpstreamRequestBodyKey, `{"model":"gpt-4"}`)
|
||||
appendOpsUpstreamError(c, OpsUpstreamErrorEvent{
|
||||
Kind: "request_error",
|
||||
Message: "dial timeout",
|
||||
})
|
||||
|
||||
v, ok := c.Get(OpsUpstreamErrorsKey)
|
||||
require.True(t, ok)
|
||||
events, ok := v.([]*OpsUpstreamErrorEvent)
|
||||
require.True(t, ok)
|
||||
require.Len(t, events, 1)
|
||||
require.Equal(t, `{"model":"gpt-4"}`, events[0].UpstreamRequestBody)
|
||||
}
|
||||
Reference in New Issue
Block a user