feat(upstream): passthrough all client headers instead of manual header setting
Replace manual header setting (Content-Type, anthropic-version, anthropic-beta) with full client header passthrough in ForwardUpstream/ForwardUpstreamGemini. Only authentication headers (Authorization, x-api-key) are overridden with upstream account credentials. Hop-by-hop headers are excluded. Add unit tests covering header passthrough, auth override, and hop-by-hop filtering.
This commit is contained in:
@@ -47,6 +47,21 @@ const (
|
|||||||
googleRPCReasonRateLimitExceeded = "RATE_LIMIT_EXCEEDED"
|
googleRPCReasonRateLimitExceeded = "RATE_LIMIT_EXCEEDED"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// upstreamHopByHopHeaders 透传请求头时需要排除的 hop-by-hop 头
|
||||||
|
var upstreamHopByHopHeaders = map[string]bool{
|
||||||
|
"connection": true,
|
||||||
|
"keep-alive": true,
|
||||||
|
"proxy-authenticate": true,
|
||||||
|
"proxy-authorization": true,
|
||||||
|
"proxy-connection": true,
|
||||||
|
"te": true,
|
||||||
|
"trailer": true,
|
||||||
|
"transfer-encoding": true,
|
||||||
|
"upgrade": true,
|
||||||
|
"host": true,
|
||||||
|
"content-length": true,
|
||||||
|
}
|
||||||
|
|
||||||
// antigravityPassthroughErrorMessages 透传给客户端的错误消息白名单(小写)
|
// antigravityPassthroughErrorMessages 透传给客户端的错误消息白名单(小写)
|
||||||
// 匹配时使用 strings.Contains,无需完全匹配
|
// 匹配时使用 strings.Contains,无需完全匹配
|
||||||
var antigravityPassthroughErrorMessages = []string{
|
var antigravityPassthroughErrorMessages = []string{
|
||||||
@@ -3456,10 +3471,6 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
|
|||||||
if mappedModel == "" {
|
if mappedModel == "" {
|
||||||
return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model))
|
return nil, s.writeClaudeError(c, http.StatusForbidden, "permission_error", fmt.Sprintf("model %s not in whitelist", claudeReq.Model))
|
||||||
}
|
}
|
||||||
loadModel := mappedModel
|
|
||||||
thinkingEnabled := claudeReq.Thinking != nil && claudeReq.Thinking.Type == "enabled"
|
|
||||||
mappedModel = applyThinkingModelSuffix(mappedModel, thinkingEnabled)
|
|
||||||
quotaScope, _ := resolveAntigravityQuotaScope(originalModel)
|
|
||||||
|
|
||||||
// 代理 URL
|
// 代理 URL
|
||||||
proxyURL := ""
|
proxyURL := ""
|
||||||
@@ -3469,98 +3480,38 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
|
|||||||
|
|
||||||
// 统计模型调用次数
|
// 统计模型调用次数
|
||||||
if s.cache != nil {
|
if s.cache != nil {
|
||||||
_, _ = s.cache.IncrModelCallCount(ctx, account.ID, loadModel)
|
_, _ = s.cache.IncrModelCallCount(ctx, account.ID, mappedModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
apiURL := baseURL + "/antigravity/v1/messages"
|
apiURL := baseURL + "/antigravity/v1/messages"
|
||||||
log.Printf("%s upstream_forward url=%s model=%s", prefix, apiURL, mappedModel)
|
log.Printf("%s upstream_forward url=%s model=%s", prefix, apiURL, mappedModel)
|
||||||
|
|
||||||
// 预检查:模型级限流
|
// 构建请求:body 原样透传
|
||||||
if remaining := account.GetRateLimitRemainingTimeWithContext(ctx, originalModel); remaining > 0 {
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body))
|
||||||
if remaining < antigravityRateLimitThreshold {
|
if err != nil {
|
||||||
select {
|
return nil, s.writeClaudeError(c, http.StatusInternalServerError, "api_error", "Failed to build request")
|
||||||
case <-ctx.Done():
|
}
|
||||||
return nil, ctx.Err()
|
// 透传客户端所有请求头(排除 hop-by-hop 和认证头)
|
||||||
case <-time.After(remaining):
|
for key, values := range c.Request.Header {
|
||||||
}
|
if upstreamHopByHopHeaders[strings.ToLower(key)] {
|
||||||
} else {
|
continue
|
||||||
return nil, &UpstreamFailoverError{
|
}
|
||||||
StatusCode: http.StatusServiceUnavailable,
|
for _, v := range values {
|
||||||
ForceCacheBilling: isStickySession,
|
req.Header.Add(key, v)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// 覆盖认证头
|
||||||
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
|
req.Header.Set("x-api-key", apiKey)
|
||||||
|
|
||||||
// 重试循环
|
if c != nil && len(body) > 0 {
|
||||||
var resp *http.Response
|
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
||||||
var lastErr error
|
|
||||||
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil, ctx.Err()
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body))
|
|
||||||
if err != nil {
|
|
||||||
return nil, s.writeClaudeError(c, http.StatusInternalServerError, "api_error", "Failed to build request")
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
|
||||||
req.Header.Set("x-api-key", apiKey)
|
|
||||||
|
|
||||||
// 透传 anthropic headers
|
|
||||||
if v := c.GetHeader("anthropic-version"); v != "" {
|
|
||||||
req.Header.Set("anthropic-version", v)
|
|
||||||
} else {
|
|
||||||
req.Header.Set("anthropic-version", "2023-06-01")
|
|
||||||
}
|
|
||||||
if v := c.GetHeader("anthropic-beta"); v != "" {
|
|
||||||
req.Header.Set("anthropic-beta", v)
|
|
||||||
}
|
|
||||||
|
|
||||||
if c != nil && len(body) > 0 {
|
|
||||||
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err = s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
|
||||||
if err != nil {
|
|
||||||
lastErr = err
|
|
||||||
if attempt < antigravityMaxRetries {
|
|
||||||
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
|
|
||||||
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
|
|
||||||
return nil, ctx.Err()
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", "Upstream request failed after retries")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 429/503 重试
|
|
||||||
if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable {
|
|
||||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
|
||||||
_ = resp.Body.Close()
|
|
||||||
|
|
||||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession)
|
|
||||||
|
|
||||||
if attempt < antigravityMaxRetries {
|
|
||||||
log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200))
|
|
||||||
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
|
|
||||||
return nil, ctx.Err()
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, &UpstreamFailoverError{
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
ForceCacheBilling: isStickySession,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
break // 成功或非限流错误,跳出重试
|
|
||||||
}
|
}
|
||||||
if resp == nil {
|
|
||||||
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", fmt.Sprintf("upstream request failed: %v", lastErr))
|
// 单次发送,不重试
|
||||||
|
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||||
|
if err != nil {
|
||||||
|
return nil, s.writeClaudeError(c, http.StatusBadGateway, "upstream_error", fmt.Sprintf("Upstream request failed: %v", err))
|
||||||
}
|
}
|
||||||
defer func() { _ = resp.Body.Close() }()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
@@ -3568,44 +3519,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
|
|||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
|
|
||||||
// signature 重试
|
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, "", 0, "", isStickySession)
|
||||||
if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) {
|
|
||||||
log.Printf("%s upstream signature error, retrying with thinking stripped", prefix)
|
|
||||||
retryClaudeReq := claudeReq
|
|
||||||
retryClaudeReq.Messages = append([]antigravity.ClaudeMessage(nil), claudeReq.Messages...)
|
|
||||||
if stripped, stripErr := stripThinkingFromClaudeRequest(&retryClaudeReq); stripErr == nil && stripped {
|
|
||||||
retryBody, _ := json.Marshal(&retryClaudeReq)
|
|
||||||
retryReq, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(retryBody))
|
|
||||||
if err == nil {
|
|
||||||
retryReq.Header.Set("Content-Type", "application/json")
|
|
||||||
retryReq.Header.Set("Authorization", "Bearer "+apiKey)
|
|
||||||
retryReq.Header.Set("x-api-key", apiKey)
|
|
||||||
retryReq.Header.Set("anthropic-version", "2023-06-01")
|
|
||||||
if v := c.GetHeader("anthropic-beta"); v != "" {
|
|
||||||
retryReq.Header.Set("anthropic-beta", v)
|
|
||||||
}
|
|
||||||
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
|
||||||
if retryErr == nil && retryResp != nil && retryResp.StatusCode < 400 {
|
|
||||||
resp = retryResp
|
|
||||||
goto upstreamClaudeSuccess
|
|
||||||
}
|
|
||||||
if retryResp != nil {
|
|
||||||
_ = retryResp.Body.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// prompt too long
|
|
||||||
if resp.StatusCode == http.StatusBadRequest && isPromptTooLongError(respBody) {
|
|
||||||
return nil, &PromptTooLongError{
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
RequestID: resp.Header.Get("x-request-id"),
|
|
||||||
Body: respBody,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession)
|
|
||||||
|
|
||||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode, ResponseBody: respBody}
|
||||||
@@ -3614,7 +3528,7 @@ func (s *AntigravityGatewayService) ForwardUpstream(ctx context.Context, c *gin.
|
|||||||
return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody)
|
return nil, s.writeMappedClaudeError(c, account, resp.StatusCode, resp.Header.Get("x-request-id"), respBody)
|
||||||
}
|
}
|
||||||
|
|
||||||
upstreamClaudeSuccess:
|
// 成功响应
|
||||||
requestID := resp.Header.Get("x-request-id")
|
requestID := resp.Header.Get("x-request-id")
|
||||||
if requestID != "" {
|
if requestID != "" {
|
||||||
c.Header("x-request-id", requestID)
|
c.Header("x-request-id", requestID)
|
||||||
@@ -3674,7 +3588,6 @@ func (s *AntigravityGatewayService) ForwardUpstreamGemini(ctx context.Context, c
|
|||||||
if len(body) == 0 {
|
if len(body) == 0 {
|
||||||
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty")
|
return nil, s.writeGoogleError(c, http.StatusBadRequest, "Request body is empty")
|
||||||
}
|
}
|
||||||
quotaScope, _ := resolveAntigravityQuotaScope(originalModel)
|
|
||||||
|
|
||||||
imageSize := s.extractImageSize(body)
|
imageSize := s.extractImageSize(body)
|
||||||
|
|
||||||
@@ -3712,143 +3625,52 @@ func (s *AntigravityGatewayService) ForwardUpstreamGemini(ctx context.Context, c
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 构建 upstream URL: base_url + /antigravity/v1beta/models/MODEL:ACTION
|
// 构建 upstream URL: base_url + /antigravity/v1beta/models/MODEL:ACTION
|
||||||
upstreamAction := action
|
apiURL := fmt.Sprintf("%s/antigravity/v1beta/models/%s:%s", baseURL, mappedModel, action)
|
||||||
if action == "generateContent" && !stream {
|
if stream || action == "streamGenerateContent" {
|
||||||
// 非流式也用 streamGenerateContent,与 OAuth 路径行为一致
|
|
||||||
upstreamAction = action
|
|
||||||
}
|
|
||||||
apiURL := fmt.Sprintf("%s/antigravity/v1beta/models/%s:%s", baseURL, mappedModel, upstreamAction)
|
|
||||||
if stream || upstreamAction == "streamGenerateContent" {
|
|
||||||
apiURL += "?alt=sse"
|
apiURL += "?alt=sse"
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("%s upstream_forward_gemini url=%s model=%s action=%s", prefix, apiURL, mappedModel, upstreamAction)
|
log.Printf("%s upstream_forward_gemini url=%s model=%s action=%s", prefix, apiURL, mappedModel, action)
|
||||||
|
|
||||||
// 预检查:模型级限流
|
// 构建请求:body 原样透传
|
||||||
if remaining := account.GetRateLimitRemainingTimeWithContext(ctx, originalModel); remaining > 0 {
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body))
|
||||||
if remaining < antigravityRateLimitThreshold {
|
if err != nil {
|
||||||
select {
|
return nil, s.writeGoogleError(c, http.StatusInternalServerError, "Failed to build request")
|
||||||
case <-ctx.Done():
|
}
|
||||||
return nil, ctx.Err()
|
// 透传客户端所有请求头(排除 hop-by-hop 和认证头)
|
||||||
case <-time.After(remaining):
|
for key, values := range c.Request.Header {
|
||||||
}
|
if upstreamHopByHopHeaders[strings.ToLower(key)] {
|
||||||
} else {
|
continue
|
||||||
return nil, &UpstreamFailoverError{
|
}
|
||||||
StatusCode: http.StatusServiceUnavailable,
|
for _, v := range values {
|
||||||
ForceCacheBilling: isStickySession,
|
req.Header.Add(key, v)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// 覆盖认证头
|
||||||
|
req.Header.Set("Authorization", "Bearer "+apiKey)
|
||||||
|
|
||||||
// 重试循环
|
if c != nil && len(body) > 0 {
|
||||||
var resp *http.Response
|
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
||||||
var lastErr error
|
|
||||||
for attempt := 1; attempt <= antigravityMaxRetries; attempt++ {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil, ctx.Err()
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, apiURL, bytes.NewReader(body))
|
|
||||||
if err != nil {
|
|
||||||
return nil, s.writeGoogleError(c, http.StatusInternalServerError, "Failed to build request")
|
|
||||||
}
|
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("Authorization", "Bearer "+apiKey)
|
|
||||||
|
|
||||||
if c != nil && len(body) > 0 {
|
|
||||||
c.Set(OpsUpstreamRequestBodyKey, string(body))
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err = s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
|
||||||
if err != nil {
|
|
||||||
lastErr = err
|
|
||||||
if attempt < antigravityMaxRetries {
|
|
||||||
log.Printf("%s status=request_failed retry=%d/%d error=%v", prefix, attempt, antigravityMaxRetries, err)
|
|
||||||
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
|
|
||||||
return nil, ctx.Err()
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
|
|
||||||
}
|
|
||||||
|
|
||||||
// 429/503 重试
|
|
||||||
if resp.StatusCode == http.StatusTooManyRequests || resp.StatusCode == http.StatusServiceUnavailable {
|
|
||||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
|
||||||
_ = resp.Body.Close()
|
|
||||||
|
|
||||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession)
|
|
||||||
|
|
||||||
if attempt < antigravityMaxRetries {
|
|
||||||
log.Printf("%s status=%d retry=%d/%d body=%s", prefix, resp.StatusCode, attempt, antigravityMaxRetries, truncateForLog(respBody, 200))
|
|
||||||
if !sleepAntigravityBackoffWithContext(ctx, attempt) {
|
|
||||||
return nil, ctx.Err()
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, &UpstreamFailoverError{
|
|
||||||
StatusCode: resp.StatusCode,
|
|
||||||
ForceCacheBilling: isStickySession,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
if resp == nil {
|
|
||||||
return nil, s.writeGoogleError(c, http.StatusBadGateway, fmt.Sprintf("upstream request failed: %v", lastErr))
|
// 单次发送,不重试
|
||||||
|
resp, err := s.httpUpstream.Do(req, proxyURL, account.ID, account.Concurrency)
|
||||||
|
if err != nil {
|
||||||
|
return nil, s.writeGoogleError(c, http.StatusBadGateway, fmt.Sprintf("Upstream request failed: %v", err))
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() { _ = resp.Body.Close() }()
|
||||||
if resp != nil && resp.Body != nil {
|
|
||||||
_ = resp.Body.Close()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// 错误响应处理
|
// 错误响应处理
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||||
contentType := resp.Header.Get("Content-Type")
|
contentType := resp.Header.Get("Content-Type")
|
||||||
_ = resp.Body.Close()
|
|
||||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
|
||||||
|
|
||||||
// 模型兜底
|
|
||||||
if s.settingService != nil && s.settingService.IsModelFallbackEnabled(ctx) &&
|
|
||||||
isModelNotFoundError(resp.StatusCode, respBody) {
|
|
||||||
fallbackModel := s.settingService.GetFallbackModel(ctx, PlatformAntigravity)
|
|
||||||
if fallbackModel != "" && fallbackModel != mappedModel {
|
|
||||||
log.Printf("[Antigravity-Upstream] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name)
|
|
||||||
fallbackURL := fmt.Sprintf("%s/antigravity/v1beta/models/%s:%s", baseURL, fallbackModel, upstreamAction)
|
|
||||||
if stream || upstreamAction == "streamGenerateContent" {
|
|
||||||
fallbackURL += "?alt=sse"
|
|
||||||
}
|
|
||||||
fallbackReq, err := http.NewRequestWithContext(ctx, http.MethodPost, fallbackURL, bytes.NewReader(body))
|
|
||||||
if err == nil {
|
|
||||||
fallbackReq.Header.Set("Content-Type", "application/json")
|
|
||||||
fallbackReq.Header.Set("Authorization", "Bearer "+apiKey)
|
|
||||||
fallbackResp, err := s.httpUpstream.Do(fallbackReq, proxyURL, account.ID, account.Concurrency)
|
|
||||||
if err == nil && fallbackResp.StatusCode < 400 {
|
|
||||||
_ = resp.Body.Close()
|
|
||||||
resp = fallbackResp
|
|
||||||
} else if fallbackResp != nil {
|
|
||||||
_ = fallbackResp.Body.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// fallback 成功
|
|
||||||
if resp.StatusCode < 400 {
|
|
||||||
goto upstreamGeminiSuccess
|
|
||||||
}
|
|
||||||
|
|
||||||
requestID := resp.Header.Get("x-request-id")
|
requestID := resp.Header.Get("x-request-id")
|
||||||
if requestID != "" {
|
if requestID != "" {
|
||||||
c.Header("x-request-id", requestID)
|
c.Header("x-request-id", requestID)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, quotaScope, 0, "", isStickySession)
|
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody, "", 0, "", isStickySession)
|
||||||
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
upstreamMsg := strings.TrimSpace(extractAntigravityErrorMessage(respBody))
|
||||||
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
upstreamMsg = sanitizeUpstreamErrorMessage(upstreamMsg)
|
||||||
upstreamDetail := s.getUpstreamErrorDetail(respBody)
|
upstreamDetail := s.getUpstreamErrorDetail(respBody)
|
||||||
@@ -3886,7 +3708,7 @@ func (s *AntigravityGatewayService) ForwardUpstreamGemini(ctx context.Context, c
|
|||||||
return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode)
|
return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
upstreamGeminiSuccess:
|
// 成功响应
|
||||||
requestID := resp.Header.Get("x-request-id")
|
requestID := resp.Header.Get("x-request-id")
|
||||||
if requestID != "" {
|
if requestID != "" {
|
||||||
c.Header("x-request-id", requestID)
|
c.Header("x-request-id", requestID)
|
||||||
|
|||||||
285
backend/internal/service/upstream_header_passthrough_test.go
Normal file
285
backend/internal/service/upstream_header_passthrough_test.go
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
//go:build unit
|
||||||
|
|
||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
// httpUpstreamCapture captures the outgoing *http.Request for assertion.
|
||||||
|
type httpUpstreamCapture struct {
|
||||||
|
capturedReq *http.Request
|
||||||
|
resp *http.Response
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *httpUpstreamCapture) Do(req *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
|
||||||
|
s.capturedReq = req
|
||||||
|
return s.resp, s.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *httpUpstreamCapture) DoWithTLS(req *http.Request, _ string, _ int64, _ int, _ bool) (*http.Response, error) {
|
||||||
|
s.capturedReq = req
|
||||||
|
return s.resp, s.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUpstreamAccount() *Account {
|
||||||
|
return &Account{
|
||||||
|
ID: 100,
|
||||||
|
Name: "upstream-test",
|
||||||
|
Platform: PlatformAntigravity,
|
||||||
|
Type: AccountTypeUpstream,
|
||||||
|
Status: StatusActive,
|
||||||
|
Concurrency: 1,
|
||||||
|
Credentials: map[string]any{
|
||||||
|
"base_url": "https://upstream.example.com",
|
||||||
|
"api_key": "sk-upstream-secret",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeSSEOKResponse builds a minimal SSE response that
|
||||||
|
// handleClaudeStreamingResponse / handleGeminiStreamingResponse
|
||||||
|
// can consume without error.
|
||||||
|
// We return 502 to bypass streaming and hit the error branch instead,
|
||||||
|
// which is sufficient for testing header passthrough.
|
||||||
|
func makeUpstreamErrorResponse() *http.Response {
|
||||||
|
body := []byte(`{"error":{"message":"test error"}}`)
|
||||||
|
return &http.Response{
|
||||||
|
StatusCode: http.StatusBadGateway,
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
Body: io.NopCloser(bytes.NewReader(body)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- ForwardUpstream tests ---
|
||||||
|
|
||||||
|
func TestForwardUpstream_PassthroughHeaders(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(map[string]any{
|
||||||
|
"model": "claude-sonnet-4-5",
|
||||||
|
"messages": []map[string]any{{"role": "user", "content": "hi"}},
|
||||||
|
"max_tokens": 1,
|
||||||
|
"stream": false,
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("anthropic-version", "2024-10-22")
|
||||||
|
req.Header.Set("anthropic-beta", "output-128k-2025-02-19")
|
||||||
|
req.Header.Set("X-Custom-Header", "custom-value")
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
|
||||||
|
svc := &AntigravityGatewayService{
|
||||||
|
tokenProvider: &AntigravityTokenProvider{},
|
||||||
|
httpUpstream: stub,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false)
|
||||||
|
|
||||||
|
captured := stub.capturedReq
|
||||||
|
require.NotNil(t, captured, "upstream request should have been made")
|
||||||
|
|
||||||
|
// 客户端 header 应被透传
|
||||||
|
require.Equal(t, "application/json", captured.Header.Get("Content-Type"))
|
||||||
|
require.Equal(t, "2024-10-22", captured.Header.Get("anthropic-version"))
|
||||||
|
require.Equal(t, "output-128k-2025-02-19", captured.Header.Get("anthropic-beta"))
|
||||||
|
require.Equal(t, "custom-value", captured.Header.Get("X-Custom-Header"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForwardUpstream_OverridesAuthHeaders(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(map[string]any{
|
||||||
|
"model": "claude-sonnet-4-5",
|
||||||
|
"messages": []map[string]any{{"role": "user", "content": "hi"}},
|
||||||
|
"max_tokens": 1,
|
||||||
|
"stream": false,
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
// 客户端发来的认证头应被覆盖
|
||||||
|
req.Header.Set("Authorization", "Bearer client-token")
|
||||||
|
req.Header.Set("x-api-key", "client-api-key")
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
|
||||||
|
svc := &AntigravityGatewayService{
|
||||||
|
tokenProvider: &AntigravityTokenProvider{},
|
||||||
|
httpUpstream: stub,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false)
|
||||||
|
|
||||||
|
captured := stub.capturedReq
|
||||||
|
require.NotNil(t, captured)
|
||||||
|
|
||||||
|
// 认证头应使用上游账号的 api_key,而非客户端的
|
||||||
|
require.Equal(t, "Bearer sk-upstream-secret", captured.Header.Get("Authorization"))
|
||||||
|
require.Equal(t, "sk-upstream-secret", captured.Header.Get("x-api-key"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForwardUpstream_ExcludesHopByHopHeaders(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(map[string]any{
|
||||||
|
"model": "claude-sonnet-4-5",
|
||||||
|
"messages": []map[string]any{{"role": "user", "content": "hi"}},
|
||||||
|
"max_tokens": 1,
|
||||||
|
"stream": false,
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Connection", "keep-alive")
|
||||||
|
req.Header.Set("Keep-Alive", "timeout=5")
|
||||||
|
req.Header.Set("Transfer-Encoding", "chunked")
|
||||||
|
req.Header.Set("Upgrade", "websocket")
|
||||||
|
req.Header.Set("Te", "trailers")
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
|
||||||
|
svc := &AntigravityGatewayService{
|
||||||
|
tokenProvider: &AntigravityTokenProvider{},
|
||||||
|
httpUpstream: stub,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = svc.ForwardUpstream(context.Background(), c, newUpstreamAccount(), body, false)
|
||||||
|
|
||||||
|
captured := stub.capturedReq
|
||||||
|
require.NotNil(t, captured)
|
||||||
|
|
||||||
|
// hop-by-hop header 不应出现
|
||||||
|
require.Empty(t, captured.Header.Get("Connection"))
|
||||||
|
require.Empty(t, captured.Header.Get("Keep-Alive"))
|
||||||
|
require.Empty(t, captured.Header.Get("Transfer-Encoding"))
|
||||||
|
require.Empty(t, captured.Header.Get("Upgrade"))
|
||||||
|
require.Empty(t, captured.Header.Get("Te"))
|
||||||
|
|
||||||
|
// 但普通 header 应保留
|
||||||
|
require.Equal(t, "application/json", captured.Header.Get("Content-Type"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- ForwardUpstreamGemini tests ---
|
||||||
|
|
||||||
|
func TestForwardUpstreamGemini_PassthroughHeaders(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(map[string]any{
|
||||||
|
"contents": []map[string]any{
|
||||||
|
{"role": "user", "parts": []map[string]any{{"text": "hi"}}},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("X-Custom-Gemini", "gemini-value")
|
||||||
|
req.Header.Set("X-Request-Id", "req-abc-123")
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
|
||||||
|
svc := &AntigravityGatewayService{
|
||||||
|
tokenProvider: &AntigravityTokenProvider{},
|
||||||
|
httpUpstream: stub,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false)
|
||||||
|
|
||||||
|
captured := stub.capturedReq
|
||||||
|
require.NotNil(t, captured, "upstream request should have been made")
|
||||||
|
|
||||||
|
// 客户端 header 应被透传
|
||||||
|
require.Equal(t, "application/json", captured.Header.Get("Content-Type"))
|
||||||
|
require.Equal(t, "gemini-value", captured.Header.Get("X-Custom-Gemini"))
|
||||||
|
require.Equal(t, "req-abc-123", captured.Header.Get("X-Request-Id"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForwardUpstreamGemini_OverridesAuthHeaders(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(map[string]any{
|
||||||
|
"contents": []map[string]any{
|
||||||
|
{"role": "user", "parts": []map[string]any{{"text": "hi"}}},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", "Bearer client-gemini-token")
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
|
||||||
|
svc := &AntigravityGatewayService{
|
||||||
|
tokenProvider: &AntigravityTokenProvider{},
|
||||||
|
httpUpstream: stub,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false)
|
||||||
|
|
||||||
|
captured := stub.capturedReq
|
||||||
|
require.NotNil(t, captured)
|
||||||
|
|
||||||
|
// 认证头应使用上游账号的 api_key
|
||||||
|
require.Equal(t, "Bearer sk-upstream-secret", captured.Header.Get("Authorization"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestForwardUpstreamGemini_ExcludesHopByHopHeaders(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
|
body, _ := json.Marshal(map[string]any{
|
||||||
|
"contents": []map[string]any{
|
||||||
|
{"role": "user", "parts": []map[string]any{{"text": "hi"}}},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/v1beta/models/gemini-2.5-flash:generateContent", bytes.NewReader(body))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Connection", "keep-alive")
|
||||||
|
req.Header.Set("Proxy-Authorization", "Basic dXNlcjpwYXNz")
|
||||||
|
req.Header.Set("Host", "evil.example.com")
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
stub := &httpUpstreamCapture{resp: makeUpstreamErrorResponse()}
|
||||||
|
svc := &AntigravityGatewayService{
|
||||||
|
tokenProvider: &AntigravityTokenProvider{},
|
||||||
|
httpUpstream: stub,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = svc.ForwardUpstreamGemini(context.Background(), c, newUpstreamAccount(), "gemini-2.5-flash", "generateContent", false, body, false)
|
||||||
|
|
||||||
|
captured := stub.capturedReq
|
||||||
|
require.NotNil(t, captured)
|
||||||
|
|
||||||
|
// hop-by-hop header 不应出现
|
||||||
|
require.Empty(t, captured.Header.Get("Connection"))
|
||||||
|
require.Empty(t, captured.Header.Get("Proxy-Authorization"))
|
||||||
|
// Host header 在 Go http.Request 中特殊处理,但我们的黑名单应阻止透传
|
||||||
|
require.Empty(t, captured.Header.Values("Host"))
|
||||||
|
|
||||||
|
// 普通 header 应保留
|
||||||
|
require.Equal(t, "application/json", captured.Header.Get("Content-Type"))
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user