diff --git a/backend/internal/handler/gateway_handler.go b/backend/internal/handler/gateway_handler.go index 7eb7007e..bbc9c181 100644 --- a/backend/internal/handler/gateway_handler.go +++ b/backend/internal/handler/gateway_handler.go @@ -586,8 +586,20 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e // Stream already started, send error as SSE event then close flusher, ok := c.Writer.(http.Flusher) if ok { - // Send error event in SSE format - errorEvent := fmt.Sprintf(`data: {"type": "error", "error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message) + // Send error event in SSE format with proper JSON marshaling + errorData := map[string]any{ + "type": "error", + "error": map[string]string{ + "type": errType, + "message": message, + }, + } + jsonBytes, err := json.Marshal(errorData) + if err != nil { + _ = c.Error(err) + return + } + errorEvent := fmt.Sprintf("data: %s\n\n", string(jsonBytes)) if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil { _ = c.Error(err) } @@ -737,8 +749,27 @@ func sendMockWarmupStream(c *gin.Context, model string) { c.Header("Connection", "keep-alive") c.Header("X-Accel-Buffering", "no") + // Build message_start event with proper JSON marshaling + messageStart := map[string]any{ + "type": "message_start", + "message": map[string]any{ + "id": "msg_mock_warmup", + "type": "message", + "role": "assistant", + "model": model, + "content": []any{}, + "stop_reason": nil, + "stop_sequence": nil, + "usage": map[string]int{ + "input_tokens": 10, + "output_tokens": 0, + }, + }, + } + messageStartJSON, _ := json.Marshal(messageStart) + events := []string{ - `event: message_start` + "\n" + `data: {"message":{"content":[],"id":"msg_mock_warmup","model":"` + model + `","role":"assistant","stop_reason":null,"stop_sequence":null,"type":"message","usage":{"input_tokens":10,"output_tokens":0}},"type":"message_start"}`, + `event: message_start` + "\n" + `data: ` + string(messageStartJSON), `event: content_block_start` + "\n" + `data: {"content_block":{"text":"","type":"text"},"index":0,"type":"content_block_start"}`, `event: content_block_delta` + "\n" + `data: {"delta":{"text":"New","type":"text_delta"},"index":0,"type":"content_block_delta"}`, `event: content_block_delta` + "\n" + `data: {"delta":{"text":" Conversation","type":"text_delta"},"index":0,"type":"content_block_delta"}`, diff --git a/backend/internal/handler/gateway_helper.go b/backend/internal/handler/gateway_helper.go index 4e049dbb..9d2e4a9d 100644 --- a/backend/internal/handler/gateway_helper.go +++ b/backend/internal/handler/gateway_helper.go @@ -144,6 +144,21 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType ctx, cancel := context.WithTimeout(c.Request.Context(), timeout) defer cancel() + // Try immediate acquire first (avoid unnecessary wait) + var result *service.AcquireResult + var err error + if slotType == "user" { + result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency) + } else { + result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency) + } + if err != nil { + return nil, err + } + if result.Acquired { + return result.ReleaseFunc, nil + } + // Determine if ping is needed (streaming + ping format defined) needPing := isStream && h.pingFormat != "" diff --git a/backend/internal/pkg/antigravity/request_transformer_test.go b/backend/internal/pkg/antigravity/request_transformer_test.go index ba07893f..56eebad0 100644 --- a/backend/internal/pkg/antigravity/request_transformer_test.go +++ b/backend/internal/pkg/antigravity/request_transformer_test.go @@ -96,7 +96,7 @@ func TestBuildTools_CustomTypeTools(t *testing.T) { { Type: "custom", Name: "mcp_tool", - Custom: &CustomToolSpec{ + Custom: &ClaudeCustomToolSpec{ Description: "MCP tool description", InputSchema: map[string]any{ "type": "object", @@ -121,7 +121,7 @@ func TestBuildTools_CustomTypeTools(t *testing.T) { { Type: "custom", Name: "custom_tool", - Custom: &CustomToolSpec{ + Custom: &ClaudeCustomToolSpec{ Description: "Custom tool", InputSchema: map[string]any{"type": "object"}, }, @@ -148,7 +148,7 @@ func TestBuildTools_CustomTypeTools(t *testing.T) { { Type: "custom", Name: "invalid_custom", - Custom: &CustomToolSpec{ + Custom: &ClaudeCustomToolSpec{ Description: "Invalid", // InputSchema 为 nil }, diff --git a/backend/internal/repository/concurrency_cache.go b/backend/internal/repository/concurrency_cache.go index 35296497..95370f51 100644 --- a/backend/internal/repository/concurrency_cache.go +++ b/backend/internal/repository/concurrency_cache.go @@ -151,11 +151,17 @@ var ( return 1 `) - // getAccountsLoadBatchScript - batch load query (read-only) - // ARGV[1] = slot TTL (seconds, retained for compatibility) + // getAccountsLoadBatchScript - batch load query with expired slot cleanup + // ARGV[1] = slot TTL (seconds) // ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ... getAccountsLoadBatchScript = redis.NewScript(` local result = {} + local slotTTL = tonumber(ARGV[1]) + + -- Get current server time + local timeResult = redis.call('TIME') + local nowSeconds = tonumber(timeResult[1]) + local cutoffTime = nowSeconds - slotTTL local i = 2 while i <= #ARGV do @@ -163,6 +169,9 @@ var ( local maxConcurrency = tonumber(ARGV[i + 1]) local slotKey = 'concurrency:account:' .. accountID + + -- Clean up expired slots before counting + redis.call('ZREMRANGEBYSCORE', slotKey, '-inf', cutoffTime) local currentConcurrency = redis.call('ZCARD', slotKey) local waitKey = 'wait:account:' .. accountID diff --git a/backend/internal/service/gateway_service.go b/backend/internal/service/gateway_service.go index 3932c35c..bd6f59f7 100644 --- a/backend/internal/service/gateway_service.go +++ b/backend/internal/service/gateway_service.go @@ -204,7 +204,7 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string { // BindStickySession sets session -> account binding with standard TTL. func (s *GatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error { - if sessionHash == "" || accountID <= 0 { + if sessionHash == "" || accountID <= 0 || s.cache == nil { return nil } return s.cache.SetSessionAccountID(ctx, sessionHash, accountID, stickySessionTTL) @@ -429,7 +429,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro } // ============ Layer 1: 粘性会话优先 ============ - if sessionHash != "" { + if sessionHash != "" && s.cache != nil { accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash) if err == nil && accountID > 0 && !isExcluded(accountID) { account, err := s.accountRepo.GetByID(ctx, accountID)