feat(sync): full code sync from release

This commit is contained in:
yangjianbo
2026-02-28 15:01:20 +08:00
parent bfc7b339f7
commit bb664d9bbf
338 changed files with 54513 additions and 2011 deletions

View File

@@ -6,9 +6,10 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
@@ -17,6 +18,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/claude"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
pkgerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/openai"
@@ -27,6 +29,10 @@ import (
"go.uber.org/zap"
)
const gatewayCompatibilityMetricsLogInterval = 1024
var gatewayCompatibilityMetricsLogCounter atomic.Uint64
// GatewayHandler handles API gateway requests
type GatewayHandler struct {
gatewayService *service.GatewayService
@@ -109,9 +115,10 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
defer h.maybeLogCompatibilityFallbackMetrics(reqLog)
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
@@ -140,16 +147,16 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 设置 max_tokens=1 + haiku 探测请求标识到 context 中
// 必须在 SetClaudeCodeClientContext 之前设置,因为 ClaudeCodeValidator 需要读取此标识进行绕过判断
if isMaxTokensOneHaikuRequest(reqModel, parsedReq.MaxTokens, reqStream) {
ctx := context.WithValue(c.Request.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true)
ctx := service.WithIsMaxTokensOneHaikuRequest(c.Request.Context(), true, h.metadataBridgeEnabled())
c.Request = c.Request.WithContext(ctx)
}
// 检查是否为 Claude Code 客户端,设置到 context 中
SetClaudeCodeClientContext(c, body)
// 检查是否为 Claude Code 客户端,设置到 context 中(复用已解析请求,避免二次反序列化)。
SetClaudeCodeClientContext(c, body, parsedReq)
isClaudeCodeClient := service.IsClaudeCodeClient(c.Request.Context())
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled))
c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled()))
setOpsRequestContext(c, reqModel, reqStream, body)
@@ -247,8 +254,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
if apiKey.GroupID != nil {
prefetchedGroupID = *apiKey.GroupID
}
ctx := context.WithValue(c.Request.Context(), ctxkey.PrefetchedStickyAccountID, sessionBoundAccountID)
ctx = context.WithValue(ctx, ctxkey.PrefetchedStickyGroupID, prefetchedGroupID)
ctx := service.WithPrefetchedStickySession(c.Request.Context(), sessionBoundAccountID, prefetchedGroupID, h.metadataBridgeEnabled())
c.Request = c.Request.WithContext(ctx)
}
}
@@ -261,7 +267,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), apiKey.GroupID) {
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
c.Request = c.Request.WithContext(ctx)
}
@@ -275,7 +281,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
action := fs.HandleSelectionExhausted(c.Request.Context())
switch action {
case FailoverContinue:
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
c.Request = c.Request.WithContext(ctx)
continue
case FailoverCanceled:
@@ -364,7 +370,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
var result *service.ForwardResult
requestCtx := c.Request.Context()
if fs.SwitchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
}
if account.Platform == service.PlatformAntigravity {
result, err = h.antigravityGatewayService.ForwardGemini(requestCtx, c, account, reqModel, "generateContent", reqStream, body, hasBoundSession)
@@ -439,7 +445,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
// 单账号分组提前设置 SingleAccountRetry 标记,让 Service 层首次 503 就不设模型限流标记。
// 避免单账号分组收到 503 (MODEL_CAPACITY_EXHAUSTED) 时设 29s 限流,导致后续请求连续快速失败。
if h.gatewayService.IsSingleAntigravityAccountGroup(c.Request.Context(), currentAPIKey.GroupID) {
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
c.Request = c.Request.WithContext(ctx)
}
@@ -458,7 +464,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
action := fs.HandleSelectionExhausted(c.Request.Context())
switch action {
case FailoverContinue:
ctx := context.WithValue(c.Request.Context(), ctxkey.SingleAccountRetry, true)
ctx := service.WithSingleAccountRetry(c.Request.Context(), true, h.metadataBridgeEnabled())
c.Request = c.Request.WithContext(ctx)
continue
case FailoverCanceled:
@@ -547,7 +553,7 @@ func (h *GatewayHandler) Messages(c *gin.Context) {
var result *service.ForwardResult
requestCtx := c.Request.Context()
if fs.SwitchCount > 0 {
requestCtx = context.WithValue(requestCtx, ctxkey.AccountSwitchCount, fs.SwitchCount)
requestCtx = service.WithAccountSwitchCount(requestCtx, fs.SwitchCount, h.metadataBridgeEnabled())
}
if account.Platform == service.PlatformAntigravity && account.Type != service.AccountTypeAPIKey {
result, err = h.antigravityGatewayService.Forward(requestCtx, c, account, body, hasBoundSession)
@@ -956,20 +962,8 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
// Stream already started, send error as SSE event then close
flusher, ok := c.Writer.(http.Flusher)
if ok {
// Send error event in SSE format with proper JSON marshaling
errorData := map[string]any{
"type": "error",
"error": map[string]string{
"type": errType,
"message": message,
},
}
jsonBytes, err := json.Marshal(errorData)
if err != nil {
_ = c.Error(err)
return
}
errorEvent := fmt.Sprintf("data: %s\n\n", string(jsonBytes))
// SSE 错误事件固定 schema使用 Quote 直拼可避免额外 Marshal 分配。
errorEvent := `data: {"type":"error","error":{"type":` + strconv.Quote(errType) + `,"message":` + strconv.Quote(message) + `}}` + "\n\n"
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
_ = c.Error(err)
}
@@ -1024,9 +1018,10 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
zap.Int64("api_key_id", apiKey.ID),
zap.Any("group_id", apiKey.GroupID),
)
defer h.maybeLogCompatibilityFallbackMetrics(reqLog)
// 读取请求体
body, err := io.ReadAll(c.Request.Body)
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
if err != nil {
if maxErr, ok := extractMaxBytesError(err); ok {
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
@@ -1041,9 +1036,6 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
return
}
// 检查是否为 Claude Code 客户端,设置到 context 中
SetClaudeCodeClientContext(c, body)
setOpsRequestContext(c, "", false, body)
parsedReq, err := service.ParseGatewayRequest(body, domain.PlatformAnthropic)
@@ -1051,9 +1043,11 @@ func (h *GatewayHandler) CountTokens(c *gin.Context) {
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
return
}
// count_tokens 走 messages 严格校验时,复用已解析请求,避免二次反序列化。
SetClaudeCodeClientContext(c, body, parsedReq)
reqLog = reqLog.With(zap.String("model", parsedReq.Model), zap.Bool("stream", parsedReq.Stream))
// 在请求上下文中记录 thinking 状态,供 Antigravity 最终模型 key 推导/模型维度限流使用
c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), ctxkey.ThinkingEnabled, parsedReq.ThinkingEnabled))
c.Request = c.Request.WithContext(service.WithThinkingEnabled(c.Request.Context(), parsedReq.ThinkingEnabled, h.metadataBridgeEnabled()))
// 验证 model 必填
if parsedReq.Model == "" {
@@ -1217,24 +1211,8 @@ func sendMockInterceptStream(c *gin.Context, model string, interceptType Interce
textDeltas = []string{"New", " Conversation"}
}
// Build message_start event with proper JSON marshaling
messageStart := map[string]any{
"type": "message_start",
"message": map[string]any{
"id": msgID,
"type": "message",
"role": "assistant",
"model": model,
"content": []any{},
"stop_reason": nil,
"stop_sequence": nil,
"usage": map[string]int{
"input_tokens": 10,
"output_tokens": 0,
},
},
}
messageStartJSON, _ := json.Marshal(messageStart)
// Build message_start event with fixed schema.
messageStartJSON := `{"type":"message_start","message":{"id":` + strconv.Quote(msgID) + `,"type":"message","role":"assistant","model":` + strconv.Quote(model) + `,"content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":0}}}`
// Build events
events := []string{
@@ -1244,31 +1222,12 @@ func sendMockInterceptStream(c *gin.Context, model string, interceptType Interce
// Add text deltas
for _, text := range textDeltas {
delta := map[string]any{
"type": "content_block_delta",
"index": 0,
"delta": map[string]string{
"type": "text_delta",
"text": text,
},
}
deltaJSON, _ := json.Marshal(delta)
deltaJSON := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":` + strconv.Quote(text) + `}}`
events = append(events, `event: content_block_delta`+"\n"+`data: `+string(deltaJSON))
}
// Add final events
messageDelta := map[string]any{
"type": "message_delta",
"delta": map[string]any{
"stop_reason": "end_turn",
"stop_sequence": nil,
},
"usage": map[string]int{
"input_tokens": 10,
"output_tokens": outputTokens,
},
}
messageDeltaJSON, _ := json.Marshal(messageDelta)
messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":10,"output_tokens":` + strconv.Itoa(outputTokens) + `}}`
events = append(events,
`event: content_block_stop`+"\n"+`data: {"index":0,"type":"content_block_stop"}`,
@@ -1366,6 +1325,30 @@ func billingErrorDetails(err error) (status int, code, message string) {
return http.StatusForbidden, "billing_error", msg
}
func (h *GatewayHandler) metadataBridgeEnabled() bool {
if h == nil || h.cfg == nil {
return true
}
return h.cfg.Gateway.OpenAIWS.MetadataBridgeEnabled
}
func (h *GatewayHandler) maybeLogCompatibilityFallbackMetrics(reqLog *zap.Logger) {
if reqLog == nil {
return
}
if gatewayCompatibilityMetricsLogCounter.Add(1)%gatewayCompatibilityMetricsLogInterval != 0 {
return
}
metrics := service.SnapshotOpenAICompatibilityFallbackMetrics()
reqLog.Info("gateway.compatibility_fallback_metrics",
zap.Int64("session_hash_legacy_read_fallback_total", metrics.SessionHashLegacyReadFallbackTotal),
zap.Int64("session_hash_legacy_read_fallback_hit", metrics.SessionHashLegacyReadFallbackHit),
zap.Int64("session_hash_legacy_dual_write_total", metrics.SessionHashLegacyDualWriteTotal),
zap.Float64("session_hash_legacy_read_hit_rate", metrics.SessionHashLegacyReadHitRate),
zap.Int64("metadata_legacy_fallback_total", metrics.MetadataLegacyFallbackTotal),
)
}
func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
if task == nil {
return
@@ -1377,5 +1360,13 @@ func (h *GatewayHandler) submitUsageRecordTask(task service.UsageRecordTask) {
// 回退路径worker 池未注入时同步执行,避免退回到无界 goroutine 模式。
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
defer func() {
if recovered := recover(); recovered != nil {
logger.L().With(
zap.String("component", "handler.gateway.messages"),
zap.Any("panic", recovered),
).Error("gateway.usage_record_task_panic_recovered")
}
}()
task(ctx)
}