feat(backend): 提交后端审计修复与配套测试改动
This commit is contained in:
@@ -3332,7 +3332,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
|
||||
// 不需要重试(成功或不可重试的错误),跳出循环
|
||||
// DEBUG: 输出响应 headers(用于检测 rate limit 信息)
|
||||
if account.Platform == PlatformGemini && resp.StatusCode < 400 {
|
||||
if account.Platform == PlatformGemini && resp.StatusCode < 400 && s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders {
|
||||
logger.LegacyPrintf("service.gateway", "[DEBUG] Gemini API Response Headers for account %d:", account.ID)
|
||||
for k, v := range resp.Header {
|
||||
logger.LegacyPrintf("service.gateway", "[DEBUG] %s: %v", k, v)
|
||||
@@ -4467,8 +4467,19 @@ func (s *GatewayService) handleNonStreamingResponse(ctx context.Context, resp *h
|
||||
// 更新5h窗口状态
|
||||
s.rateLimitService.UpdateSessionWindow(ctx, account, resp.Header)
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"type": "error",
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream response too large",
|
||||
},
|
||||
})
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -4990,9 +5001,15 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
// 读取响应体
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
maxReadBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||
respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxReadBytes)
|
||||
_ = resp.Body.Close()
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large")
|
||||
return err
|
||||
}
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
||||
return err
|
||||
}
|
||||
@@ -5007,9 +5024,14 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
retryResp, retryErr := s.httpUpstream.DoWithTLS(retryReq, proxyURL, account.ID, account.Concurrency, account.IsTLSFingerprintEnabled())
|
||||
if retryErr == nil {
|
||||
resp = retryResp
|
||||
respBody, err = io.ReadAll(resp.Body)
|
||||
respBody, err = readUpstreamResponseBodyLimited(resp.Body, maxReadBytes)
|
||||
_ = resp.Body.Close()
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Upstream response too large")
|
||||
return err
|
||||
}
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -2358,29 +2358,36 @@ type UpstreamHTTPResult struct {
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Context, resp *http.Response, isOAuth bool) (*ClaudeUsage, error) {
|
||||
// Log response headers for debugging
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Response Headers ==========")
|
||||
for key, values := range resp.Header {
|
||||
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values)
|
||||
if s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders {
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Response Headers ==========")
|
||||
for key, values := range resp.Header {
|
||||
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values)
|
||||
}
|
||||
}
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========================================")
|
||||
}
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========================================")
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||
respBody, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream response too large",
|
||||
},
|
||||
})
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var parsed map[string]any
|
||||
if isOAuth {
|
||||
unwrappedBody, uwErr := unwrapGeminiResponse(respBody)
|
||||
if uwErr == nil {
|
||||
respBody = unwrappedBody
|
||||
}
|
||||
_ = json.Unmarshal(respBody, &parsed)
|
||||
} else {
|
||||
_ = json.Unmarshal(respBody, &parsed)
|
||||
}
|
||||
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
@@ -2398,14 +2405,15 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, isOAuth bool) (*geminiNativeStreamResult, error) {
|
||||
// Log response headers for debugging
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Streaming Response Headers ==========")
|
||||
for key, values := range resp.Header {
|
||||
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values)
|
||||
if s.cfg != nil && s.cfg.Gateway.GeminiDebugResponseHeaders {
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ========== Streaming Response Headers ==========")
|
||||
for key, values := range resp.Header {
|
||||
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] %s: %v", key, values)
|
||||
}
|
||||
}
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ====================================================")
|
||||
}
|
||||
logger.LegacyPrintf("service.gemini_messages_compat", "[GeminiAPI] ====================================================")
|
||||
|
||||
if s.cfg != nil {
|
||||
responseheaders.WriteFilteredHeaders(c.Writer.Header(), resp.Header, s.cfg.Security.ResponseHeaders)
|
||||
|
||||
@@ -3,10 +3,15 @@ package service
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -133,6 +138,38 @@ func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGeminiHandleNativeNonStreamingResponse_DebugDisabledDoesNotEmitHeaderLogs(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
logSink, restore := captureStructuredLog(t)
|
||||
defer restore()
|
||||
|
||||
svc := &GeminiMessagesCompatService{
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
GeminiDebugResponseHeaders: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{
|
||||
"Content-Type": []string{"application/json"},
|
||||
"X-RateLimit-Limit": []string{"60"},
|
||||
},
|
||||
Body: io.NopCloser(strings.NewReader(`{"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":2}}`)),
|
||||
}
|
||||
|
||||
usage, err := svc.handleNativeNonStreamingResponse(c, resp, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, usage)
|
||||
require.False(t, logSink.ContainsMessage("[GeminiAPI]"), "debug 关闭时不应输出 Gemini 响应头日志")
|
||||
}
|
||||
|
||||
func TestConvertClaudeMessagesToGeminiGenerateContent_AddsThoughtSignatureForToolUse(t *testing.T) {
|
||||
claudeReq := map[string]any{
|
||||
"model": "claude-haiku-4-5-20251001",
|
||||
|
||||
@@ -1741,8 +1741,18 @@ func (s *OpenAIGatewayService) handleNonStreamingResponsePassthrough(
|
||||
resp *http.Response,
|
||||
c *gin.Context,
|
||||
) (*OpenAIUsage, error) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream response too large",
|
||||
},
|
||||
})
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -2371,8 +2381,18 @@ func (s *OpenAIGatewayService) parseSSEUsage(data string, usage *OpenAIUsage) {
|
||||
}
|
||||
|
||||
func (s *OpenAIGatewayService) handleNonStreamingResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account, originalModel, mappedModel string) (*OpenAIUsage, error) {
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
maxBytes := resolveUpstreamResponseReadLimit(s.cfg)
|
||||
body, err := readUpstreamResponseBodyLimited(resp.Body, maxBytes)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUpstreamResponseBodyTooLarge) {
|
||||
setOpsUpstreamError(c, http.StatusBadGateway, "upstream response too large", "")
|
||||
c.JSON(http.StatusBadGateway, gin.H{
|
||||
"error": gin.H{
|
||||
"type": "upstream_error",
|
||||
"message": "Upstream response too large",
|
||||
},
|
||||
})
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -2930,6 +2950,25 @@ func normalizeOpenAIPassthroughOAuthBody(body []byte) ([]byte, bool, error) {
|
||||
return normalized, changed, nil
|
||||
}
|
||||
|
||||
func detectOpenAIPassthroughInstructionsRejectReason(reqModel string, body []byte) string {
|
||||
model := strings.ToLower(strings.TrimSpace(reqModel))
|
||||
if !strings.Contains(model, "codex") {
|
||||
return ""
|
||||
}
|
||||
|
||||
instructions := gjson.GetBytes(body, "instructions")
|
||||
if !instructions.Exists() {
|
||||
return "instructions_missing"
|
||||
}
|
||||
if instructions.Type != gjson.String {
|
||||
return "instructions_not_string"
|
||||
}
|
||||
if strings.TrimSpace(instructions.String()) == "" {
|
||||
return "instructions_empty"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func extractOpenAIReasoningEffortFromBody(body []byte, requestedModel string) *string {
|
||||
reasoningEffort := strings.TrimSpace(gjson.GetBytes(body, "reasoning.effort").String())
|
||||
if reasoningEffort == "" {
|
||||
@@ -3002,22 +3041,3 @@ func normalizeOpenAIReasoningEffort(raw string) string {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
func detectOpenAIPassthroughInstructionsRejectReason(reqModel string, body []byte) string {
|
||||
model := strings.ToLower(strings.TrimSpace(reqModel))
|
||||
if !strings.Contains(model, "codex") {
|
||||
return ""
|
||||
}
|
||||
|
||||
instructions := gjson.GetBytes(body, "instructions")
|
||||
if !instructions.Exists() {
|
||||
return "instructions_missing"
|
||||
}
|
||||
if instructions.Type != gjson.String {
|
||||
return "instructions_not_string"
|
||||
}
|
||||
if strings.TrimSpace(instructions.String()) == "" {
|
||||
return "instructions_empty"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
|
||||
38
backend/internal/service/upstream_response_limit.go
Normal file
38
backend/internal/service/upstream_response_limit.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
var ErrUpstreamResponseBodyTooLarge = errors.New("upstream response body too large")
|
||||
|
||||
const defaultUpstreamResponseReadMaxBytes int64 = 8 * 1024 * 1024
|
||||
|
||||
func resolveUpstreamResponseReadLimit(cfg *config.Config) int64 {
|
||||
if cfg != nil && cfg.Gateway.UpstreamResponseReadMaxBytes > 0 {
|
||||
return cfg.Gateway.UpstreamResponseReadMaxBytes
|
||||
}
|
||||
return defaultUpstreamResponseReadMaxBytes
|
||||
}
|
||||
|
||||
func readUpstreamResponseBodyLimited(reader io.Reader, maxBytes int64) ([]byte, error) {
|
||||
if reader == nil {
|
||||
return nil, errors.New("response body is nil")
|
||||
}
|
||||
if maxBytes <= 0 {
|
||||
maxBytes = defaultUpstreamResponseReadMaxBytes
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(io.LimitReader(reader, maxBytes+1))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if int64(len(body)) > maxBytes {
|
||||
return nil, fmt.Errorf("%w: limit=%d", ErrUpstreamResponseBodyTooLarge, maxBytes)
|
||||
}
|
||||
return body, nil
|
||||
}
|
||||
37
backend/internal/service/upstream_response_limit_test.go
Normal file
37
backend/internal/service/upstream_response_limit_test.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestResolveUpstreamResponseReadLimit(t *testing.T) {
|
||||
t.Run("use default when config missing", func(t *testing.T) {
|
||||
require.Equal(t, defaultUpstreamResponseReadMaxBytes, resolveUpstreamResponseReadLimit(nil))
|
||||
})
|
||||
|
||||
t.Run("use configured value", func(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
cfg.Gateway.UpstreamResponseReadMaxBytes = 1234
|
||||
require.Equal(t, int64(1234), resolveUpstreamResponseReadLimit(cfg))
|
||||
})
|
||||
}
|
||||
|
||||
func TestReadUpstreamResponseBodyLimited(t *testing.T) {
|
||||
t.Run("within limit", func(t *testing.T) {
|
||||
body, err := readUpstreamResponseBodyLimited(bytes.NewReader([]byte("ok")), 2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []byte("ok"), body)
|
||||
})
|
||||
|
||||
t.Run("exceeds limit", func(t *testing.T) {
|
||||
body, err := readUpstreamResponseBodyLimited(bytes.NewReader([]byte("toolong")), 3)
|
||||
require.Nil(t, body)
|
||||
require.Error(t, err)
|
||||
require.True(t, errors.Is(err, ErrUpstreamResponseBodyTooLarge))
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user