perf(后端): 完成性能优化与连接池配置

新增 DB/Redis 连接池配置与校验,并补充单测

网关请求体大小限制与 413 处理

HTTP/req 客户端池化并调整上游连接池默认值

并发槽位改为 ZSET+Lua 与指数退避

用量统计改 SQL 聚合并新增索引迁移

计费缓存写入改工作池并补测试/基准

测试: 在 backend/ 下运行 go test ./...
This commit is contained in:
yangjianbo
2025-12-31 08:50:12 +08:00
parent 5376786694
commit 7efa8b54c4
53 changed files with 1805 additions and 449 deletions

View File

@@ -19,7 +19,6 @@ import (
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/gin-gonic/gin"
@@ -33,7 +32,10 @@ const (
// sseDataRe matches SSE data lines with optional whitespace after colon.
// Some upstream APIs return non-standard "data:" without space (should be "data: ").
var sseDataRe = regexp.MustCompile(`^data:\s*`)
var (
sseDataRe = regexp.MustCompile(`^data:\s*`)
sessionIDRegex = regexp.MustCompile(`session_([a-f0-9-]{36})`)
)
// allowedHeaders 白名单headers参考CRS项目
var allowedHeaders = map[string]bool{
@@ -141,40 +143,36 @@ func NewGatewayService(
}
}
// GenerateSessionHash 从请求计算粘性会话hash
func (s *GatewayService) GenerateSessionHash(body []byte) string {
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
// GenerateSessionHash 从预解析请求计算粘性会话 hash
func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
if parsed == nil {
return ""
}
// 1. 最高优先级从metadata.user_id提取session_xxx
if metadata, ok := req["metadata"].(map[string]any); ok {
if userID, ok := metadata["user_id"].(string); ok {
re := regexp.MustCompile(`session_([a-f0-9-]{36})`)
if match := re.FindStringSubmatch(userID); len(match) > 1 {
return match[1]
}
// 1. 最高优先级:从 metadata.user_id 提取 session_xxx
if parsed.MetadataUserID != "" {
if match := sessionIDRegex.FindStringSubmatch(parsed.MetadataUserID); len(match) > 1 {
return match[1]
}
}
// 2. 提取带cache_control: {type: "ephemeral"}的内容
cacheableContent := s.extractCacheableContent(req)
// 2. 提取带 cache_control: {type: "ephemeral"} 的内容
cacheableContent := s.extractCacheableContent(parsed)
if cacheableContent != "" {
return s.hashContent(cacheableContent)
}
// 3. Fallback: 使用system内容
if system := req["system"]; system != nil {
systemText := s.extractTextFromSystem(system)
// 3. Fallback: 使用 system 内容
if parsed.System != nil {
systemText := s.extractTextFromSystem(parsed.System)
if systemText != "" {
return s.hashContent(systemText)
}
}
// 4. 最后fallback: 使用第一条消息
if messages, ok := req["messages"].([]any); ok && len(messages) > 0 {
if firstMsg, ok := messages[0].(map[string]any); ok {
// 4. 最后 fallback: 使用第一条消息
if len(parsed.Messages) > 0 {
if firstMsg, ok := parsed.Messages[0].(map[string]any); ok {
msgText := s.extractTextFromContent(firstMsg["content"])
if msgText != "" {
return s.hashContent(msgText)
@@ -185,36 +183,38 @@ func (s *GatewayService) GenerateSessionHash(body []byte) string {
return ""
}
func (s *GatewayService) extractCacheableContent(req map[string]any) string {
var content string
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
if parsed == nil {
return ""
}
// 检查system中的cacheable内容
if system, ok := req["system"].([]any); ok {
var builder strings.Builder
// 检查 system 中的 cacheable 内容
if system, ok := parsed.System.([]any); ok {
for _, part := range system {
if partMap, ok := part.(map[string]any); ok {
if cc, ok := partMap["cache_control"].(map[string]any); ok {
if cc["type"] == "ephemeral" {
if text, ok := partMap["text"].(string); ok {
content += text
builder.WriteString(text)
}
}
}
}
}
}
systemText := builder.String()
// 检查messages中的cacheable内容
if messages, ok := req["messages"].([]any); ok {
for _, msg := range messages {
if msgMap, ok := msg.(map[string]any); ok {
if msgContent, ok := msgMap["content"].([]any); ok {
for _, part := range msgContent {
if partMap, ok := part.(map[string]any); ok {
if cc, ok := partMap["cache_control"].(map[string]any); ok {
if cc["type"] == "ephemeral" {
// 找到cacheable内容提取第一条消息的文本
return s.extractTextFromContent(msgMap["content"])
}
// 检查 messages 中的 cacheable 内容
for _, msg := range parsed.Messages {
if msgMap, ok := msg.(map[string]any); ok {
if msgContent, ok := msgMap["content"].([]any); ok {
for _, part := range msgContent {
if partMap, ok := part.(map[string]any); ok {
if cc, ok := partMap["cache_control"].(map[string]any); ok {
if cc["type"] == "ephemeral" {
return s.extractTextFromContent(msgMap["content"])
}
}
}
@@ -223,7 +223,7 @@ func (s *GatewayService) extractCacheableContent(req map[string]any) string {
}
}
return content
return systemText
}
func (s *GatewayService) extractTextFromSystem(system any) string {
@@ -588,19 +588,17 @@ func (s *GatewayService) shouldFailoverUpstreamError(statusCode int) bool {
}
// Forward 转发请求到Claude API
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) (*ForwardResult, error) {
startTime := time.Now()
// 解析请求获取model和stream
var req struct {
Model string `json:"model"`
Stream bool `json:"stream"`
}
if err := json.Unmarshal(body, &req); err != nil {
return nil, fmt.Errorf("parse request: %w", err)
if parsed == nil {
return nil, fmt.Errorf("parse request: empty request")
}
if !gjson.GetBytes(body, "system").Exists() {
body := parsed.Body
reqModel := parsed.Model
reqStream := parsed.Stream
if !parsed.HasSystem {
body, _ = sjson.SetBytes(body, "system", []any{
map[string]any{
"type": "text",
@@ -613,13 +611,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
// 应用模型映射仅对apikey类型账号
originalModel := req.Model
originalModel := reqModel
if account.Type == AccountTypeApiKey {
mappedModel := account.GetMappedModel(req.Model)
if mappedModel != req.Model {
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
// 替换请求体中的模型名
body = s.replaceModelInBody(body, mappedModel)
req.Model = mappedModel
reqModel = mappedModel
log.Printf("Model mapping applied: %s -> %s (account: %s)", originalModel, mappedModel, account.Name)
}
}
@@ -640,7 +638,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
var resp *http.Response
for attempt := 1; attempt <= maxRetries; attempt++ {
// 构建上游请求(每次重试需要重新构建,因为请求体需要重新读取)
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType)
upstreamReq, err := s.buildUpstreamRequest(ctx, c, account, body, token, tokenType, reqModel)
if err != nil {
return nil, err
}
@@ -692,8 +690,8 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 处理正常响应
var usage *ClaudeUsage
var firstTokenMs *int
if req.Stream {
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, req.Model)
if reqStream {
streamResult, err := s.handleStreamingResponse(ctx, resp, c, account, startTime, originalModel, reqModel)
if err != nil {
if err.Error() == "have error in stream" {
return nil, &UpstreamFailoverError{
@@ -705,7 +703,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
usage = streamResult.usage
firstTokenMs = streamResult.firstTokenMs
} else {
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, req.Model)
usage, err = s.handleNonStreamingResponse(ctx, resp, c, account, originalModel, reqModel)
if err != nil {
return nil, err
}
@@ -715,13 +713,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
RequestID: resp.Header.Get("x-request-id"),
Usage: *usage,
Model: originalModel, // 使用原始模型用于计费和日志
Stream: req.Stream,
Stream: reqStream,
Duration: time.Since(startTime),
FirstTokenMs: firstTokenMs,
}, nil
}
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) {
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标URL
targetURL := claudeAPIURL
if account.Type == AccountTypeApiKey {
@@ -787,7 +785,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// 处理anthropic-beta headerOAuth账号需要特殊处理
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta")))
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
}
return req, nil
@@ -795,7 +793,7 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// getBetaHeader 处理anthropic-beta header
// 对于OAuth账号需要确保包含oauth-2025-04-20
func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) string {
func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string) string {
// 如果客户端传了anthropic-beta
if clientBetaHeader != "" {
// 已包含oauth beta则直接返回
@@ -832,15 +830,7 @@ func (s *GatewayService) getBetaHeader(body []byte, clientBetaHeader string) str
}
// 客户端没传,根据模型生成
var modelID string
var reqMap map[string]any
if json.Unmarshal(body, &reqMap) == nil {
if m, ok := reqMap["model"].(string); ok {
modelID = m
}
}
// haiku模型不需要claude-code beta
// haiku 模型不需要 claude-code beta
if strings.Contains(strings.ToLower(modelID), "haiku") {
return claude.HaikuBetaHeader
}
@@ -1248,13 +1238,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
log.Printf("Increment subscription usage failed: %v", err)
}
// 异步更新订阅缓存
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.billingCacheService.UpdateSubscriptionUsage(cacheCtx, user.ID, *apiKey.GroupID, cost.TotalCost); err != nil {
log.Printf("Update subscription cache failed: %v", err)
}
}()
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
}
} else {
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
@@ -1263,13 +1247,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
log.Printf("Deduct balance failed: %v", err)
}
// 异步更新余额缓存
go func() {
cacheCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.billingCacheService.DeductBalanceCache(cacheCtx, user.ID, cost.ActualCost); err != nil {
log.Printf("Update balance cache failed: %v", err)
}
}()
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
}
}
@@ -1281,7 +1259,15 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
// ForwardCountTokens 转发 count_tokens 请求到上游 API
// 特点:不记录使用量、仅支持非流式响应
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, body []byte) error {
func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context, account *Account, parsed *ParsedRequest) error {
if parsed == nil {
s.countTokensError(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
return fmt.Errorf("parse request: empty request")
}
body := parsed.Body
reqModel := parsed.Model
// Antigravity 账户不支持 count_tokens 转发,返回估算值
// 参考 Antigravity-Manager 和 proxycast 实现
if account.Platform == PlatformAntigravity {
@@ -1291,14 +1277,12 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
// 应用模型映射(仅对 apikey 类型账号)
if account.Type == AccountTypeApiKey {
var req struct {
Model string `json:"model"`
}
if err := json.Unmarshal(body, &req); err == nil && req.Model != "" {
mappedModel := account.GetMappedModel(req.Model)
if mappedModel != req.Model {
if reqModel != "" {
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
body = s.replaceModelInBody(body, mappedModel)
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", req.Model, mappedModel, account.Name)
reqModel = mappedModel
log.Printf("CountTokens model mapping applied: %s -> %s (account: %s)", parsed.Model, mappedModel, account.Name)
}
}
}
@@ -1311,7 +1295,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// 构建上游请求
upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType)
upstreamReq, err := s.buildCountTokensRequest(ctx, c, account, body, token, tokenType, reqModel)
if err != nil {
s.countTokensError(c, http.StatusInternalServerError, "api_error", "Failed to build request")
return err
@@ -1363,7 +1347,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// buildCountTokensRequest 构建 count_tokens 上游请求
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType string) (*http.Request, error) {
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标 URL
targetURL := claudeAPICountTokensURL
if account.Type == AccountTypeApiKey {
@@ -1424,7 +1408,7 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(body, c.GetHeader("anthropic-beta")))
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
}
return req, nil