feat(backend): 提交后端审计修复与配套测试改动

This commit is contained in:
yangjianbo
2026-02-14 11:23:10 +08:00
parent 862199143e
commit d04b47b3ca
22 changed files with 653 additions and 55 deletions

View File

@@ -3332,7 +3332,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 不需要重试(成功或不可重试的错误),跳出循环
// DEBUG: 输出响应 headers用于检测 rate limit 信息)
if account.Platform == PlatformGemini && resp.StatusCode < 400 {
if account.Platform == PlatformGemini && resp.StatusCode < 400 && s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders {
logger.LegacyPrintf("service.gateway", "[DEBUG] Gemini API Response Headers for account %d:", account.ID)
for k, v := range resp.Header {
logger.LegacyPrintf("service.gateway", "[DEBUG] %s: %v", k, v)
@@ -4467,8 +4467,19 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
// 更新5h窗口状态
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
body, err := io.ReadAll(resp.Body)
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
if err != nil {
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
c.JSON(http.StatusBadGateway, gin.H{
"type": "error",
"error": gin.H{
"type": "upstream_error",
"message": "Upstream response too large",
},
})
}
return nil, err
}
@@ -4990,9 +5001,15 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// 读取响应体
respBody, err := io.ReadAll(resp.Body)
maxReadBytes := resolveUpstreamResponseReadLimit(s.cfg)
respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxReadBytes)
_ = resp.Body.Close()
if err != nil {
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large")
return err
}
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
return err
}
@@ -5007,9 +5024,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
if retryErr == nil {
resp = retryResp
respBody, err = io.ReadAll(resp.Body)
respBody, err = readUpstreamResponseBodyLimited(resp.Body, maxReadBytes)
_ = resp.Body.Close()
if err != nil {
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large")
return err
}
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
return err
}

View File

@@ -2358,29 +2358,36 @@ type UpstreamHTTPResult struct {
}
func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Context, resp *http.Response, isOAuth bool) (*ClaudeUsage, error) {
// Log response headers for debugging
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Response Headers ==========")
for key, values := range resp.Header {
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values)
if s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders {
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Response Headers ==========")
for key, values := range resp.Header {
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values)
}
}
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========================================")
}
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========================================")
respBody, err := io.ReadAll(resp.Body)
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
if err != nil {
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": "Upstream response too large",
},
})
}
return nil, err
}
var parsed map[string]any
if isOAuth {
unwrappedBody, uwErr := unwrapGeminiResponse(respBody)
if uwErr == nil {
respBody = unwrappedBody
}
_ = json.Unmarshal(respBody, &parsed)
} else {
_ = json.Unmarshal(respBody, &parsed)
}
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
@@ -2398,14 +2405,15 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
}
func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, isOAuth bool) (*geminiNativeStreamResult, error) {
// Log response headers for debugging
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Streaming Response Headers ==========")
for key, values := range resp.Header {
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values)
if s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders {
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Streaming Response Headers ==========")
for key, values := range resp.Header {
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values)
}
}
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ====================================================")
}
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ====================================================")
if s.cfg != nil {
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)

View File

@@ -3,10 +3,15 @@ package service
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
@@ -133,6 +138,38 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
}
}
func TestGeminiHandleNativeNonStreamingResponse_DebugDisabledDoesNotEmitHeaderLogs(t *testing.T) {
gin.SetMode(gin.TestMode)
logSink, restore := captureStructuredLog(t)
defer restore()
svc := &GeminiMessagesCompatService{
cfg: &config.Config{
Gateway: config.GatewayConfig{
GeminiDebugResponseHeaders: false,
},
},
}
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
resp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{
"Content-Type": []string{"application/json"},
"X-RateLimit-Limit": []string{"60"},
},
Body: io.NopCloser(strings.NewReader(`{"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":2}}`)),
}
usage, err := svc.handleNativeNonStreamingResponse(c, resp, false)
require.NoError(t, err)
require.NotNil(t, usage)
require.False(t, logSink.ContainsMessage("[GeminiAPI]"), "debug 关闭时不应输出 Gemini 响应头日志")
}
func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) {
claudeReq := map[string]any{
"model": "claude-haiku-4-5-20251001",

View File

@@ -1741,8 +1741,18 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
resp *http.Response,
c *gin.Context,
) (*OpenAIUsage, error) {
body, err := io.ReadAll(resp.Body)
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
if err != nil {
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": "Upstream response too large",
},
})
}
return nil, err
}
@@ -2371,8 +2381,18 @@ func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
}
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
body, err := io.ReadAll(resp.Body)
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
if err != nil {
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
c.JSON(http.StatusBadGateway, gin.H{
"error": gin.H{
"type": "upstream_error",
"message": "Upstream response too large",
},
})
}
return nil, err
}
@@ -2930,6 +2950,25 @@ func normalizeOpenAIPassthroughOAuthBody(body []byte) ([]byte, bool, error) {
return normalized, changed, nil
}
func detectOpenAIPassthroughInstructionsRejectReason(reqModel string, body []byte) string {
model := strings.ToLower(strings.TrimSpace(reqModel))
if !strings.Contains(model, "codex") {
return ""
}
instructions := gjson.GetBytes(body, "instructions")
if !instructions.Exists() {
return "instructions_missing"
}
if instructions.Type != gjson.String {
return "instructions_not_string"
}
if strings.TrimSpace(instructions.String()) == "" {
return "instructions_empty"
}
return ""
}
func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *string {
reasoningEffort := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String())
if reasoningEffort == "" {
@@ -3002,22 +3041,3 @@ func normalizeOpenAIReasoningEffort(raw string) string {
return ""
}
}
func detectOpenAIPassthroughInstructionsRejectReason(reqModel string, body []byte) string {
model := strings.ToLower(strings.TrimSpace(reqModel))
if !strings.Contains(model, "codex") {
return ""
}
instructions := gjson.GetBytes(body, "instructions")
if !instructions.Exists() {
return "instructions_missing"
}
if instructions.Type != gjson.String {
return "instructions_not_string"
}
if strings.TrimSpace(instructions.String()) == "" {
return "instructions_empty"
}
return ""
}

View File

@@ -0,0 +1,38 @@
package service
import (
"errors"
"fmt"
"io"
"github.com/Wei-Shaw/sub2api/internal/config"
)
var ErrUpstreamResponseBodyTooLarge = errors.New("upstream response body too large")
const defaultUpstreamResponseReadMaxBytes int64 = 8 * 1024 * 1024
func resolveUpstreamResponseReadLimit(cfg *config.Config) int64 {
if cfg != nil && cfg.Gateway.UpstreamResponseReadMaxBytes > 0 {
return cfg.Gateway.UpstreamResponseReadMaxBytes
}
return defaultUpstreamResponseReadMaxBytes
}
func readUpstreamResponseBodyLimited(reader io.Reader, maxBytes int64) ([]byte, error) {
if reader == nil {
return nil, errors.New("response body is nil")
}
if maxBytes <= 0 {
maxBytes = defaultUpstreamResponseReadMaxBytes
}
body, err := io.ReadAll(io.LimitReader(reader, maxBytes+1))
if err != nil {
return nil, err
}
if int64(len(body)) > maxBytes {
return nil, fmt.Errorf("%w: limit=%d", ErrUpstreamResponseBodyTooLarge, maxBytes)
}
return body, nil
}

View File

@@ -0,0 +1,37 @@
package service
import (
"bytes"
"errors"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestResolveUpstreamResponseReadLimit(t *testing.T) {
t.Run("use default when config missing", func(t *testing.T) {
require.Equal(t, defaultUpstreamResponseReadMaxBytes, resolveUpstreamResponseReadLimit(nil))
})
t.Run("use configured value", func(t *testing.T) {
cfg := &config.Config{}
cfg.Gateway.UpstreamResponseReadMaxBytes = 1234
require.Equal(t, int64(1234), resolveUpstreamResponseReadLimit(cfg))
})
}
func TestReadUpstreamResponseBodyLimited(t *testing.T) {
t.Run("within limit", func(t *testing.T) {
body, err := readUpstreamResponseBodyLimited(bytes.NewReader([]byte("ok")), 2)
require.NoError(t, err)
require.Equal(t, []byte("ok"), body)
})
t.Run("exceeds limit", func(t *testing.T) {
body, err := readUpstreamResponseBodyLimited(bytes.NewReader([]byte("toolong")), 3)
require.Nil(t, body)
require.Error(t, err)
require.True(t, errors.Is(err, ErrUpstreamResponseBodyTooLarge))
})
}