Merge branch 'feature/atomic-scheduling-v2'
This commit is contained in:
@@ -586,8 +586,20 @@ func (h *GatewayHandler) handleStreamingAwareError(c *gin.Context, status int, e
|
|||||||
// Stream already started, send error as SSE event then close
|
// Stream already started, send error as SSE event then close
|
||||||
flusher, ok := c.Writer.(http.Flusher)
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
if ok {
|
if ok {
|
||||||
// Send error event in SSE format
|
// Send error event in SSE format with proper JSON marshaling
|
||||||
errorEvent := fmt.Sprintf(`data: {"type": "error", "error": {"type": "%s", "message": "%s"}}`+"\n\n", errType, message)
|
errorData := map[string]any{
|
||||||
|
"type": "error",
|
||||||
|
"error": map[string]string{
|
||||||
|
"type": errType,
|
||||||
|
"message": message,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
jsonBytes, err := json.Marshal(errorData)
|
||||||
|
if err != nil {
|
||||||
|
_ = c.Error(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
errorEvent := fmt.Sprintf("data: %s\n\n", string(jsonBytes))
|
||||||
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
if _, err := fmt.Fprint(c.Writer, errorEvent); err != nil {
|
||||||
_ = c.Error(err)
|
_ = c.Error(err)
|
||||||
}
|
}
|
||||||
@@ -737,8 +749,27 @@ func sendMockWarmupStream(c *gin.Context, model string) {
|
|||||||
c.Header("Connection", "keep-alive")
|
c.Header("Connection", "keep-alive")
|
||||||
c.Header("X-Accel-Buffering", "no")
|
c.Header("X-Accel-Buffering", "no")
|
||||||
|
|
||||||
|
// Build message_start event with proper JSON marshaling
|
||||||
|
messageStart := map[string]any{
|
||||||
|
"type": "message_start",
|
||||||
|
"message": map[string]any{
|
||||||
|
"id": "msg_mock_warmup",
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"model": model,
|
||||||
|
"content": []any{},
|
||||||
|
"stop_reason": nil,
|
||||||
|
"stop_sequence": nil,
|
||||||
|
"usage": map[string]int{
|
||||||
|
"input_tokens": 10,
|
||||||
|
"output_tokens": 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
messageStartJSON, _ := json.Marshal(messageStart)
|
||||||
|
|
||||||
events := []string{
|
events := []string{
|
||||||
`event: message_start` + "\n" + `data: {"message":{"content":[],"id":"msg_mock_warmup","model":"` + model + `","role":"assistant","stop_reason":null,"stop_sequence":null,"type":"message","usage":{"input_tokens":10,"output_tokens":0}},"type":"message_start"}`,
|
`event: message_start` + "\n" + `data: ` + string(messageStartJSON),
|
||||||
`event: content_block_start` + "\n" + `data: {"content_block":{"text":"","type":"text"},"index":0,"type":"content_block_start"}`,
|
`event: content_block_start` + "\n" + `data: {"content_block":{"text":"","type":"text"},"index":0,"type":"content_block_start"}`,
|
||||||
`event: content_block_delta` + "\n" + `data: {"delta":{"text":"New","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
|
`event: content_block_delta` + "\n" + `data: {"delta":{"text":"New","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
|
||||||
`event: content_block_delta` + "\n" + `data: {"delta":{"text":" Conversation","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
|
`event: content_block_delta` + "\n" + `data: {"delta":{"text":" Conversation","type":"text_delta"},"index":0,"type":"content_block_delta"}`,
|
||||||
|
|||||||
@@ -144,6 +144,21 @@ func (h *ConcurrencyHelper) waitForSlotWithPingTimeout(c *gin.Context, slotType
|
|||||||
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
ctx, cancel := context.WithTimeout(c.Request.Context(), timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
// Try immediate acquire first (avoid unnecessary wait)
|
||||||
|
var result *service.AcquireResult
|
||||||
|
var err error
|
||||||
|
if slotType == "user" {
|
||||||
|
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
|
||||||
|
} else {
|
||||||
|
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if result.Acquired {
|
||||||
|
return result.ReleaseFunc, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Determine if ping is needed (streaming + ping format defined)
|
// Determine if ping is needed (streaming + ping format defined)
|
||||||
needPing := isStream && h.pingFormat != ""
|
needPing := isStream && h.pingFormat != ""
|
||||||
|
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ func TestBuildTools_CustomTypeTools(t *testing.T) {
|
|||||||
{
|
{
|
||||||
Type: "custom",
|
Type: "custom",
|
||||||
Name: "mcp_tool",
|
Name: "mcp_tool",
|
||||||
Custom: &CustomToolSpec{
|
Custom: &ClaudeCustomToolSpec{
|
||||||
Description: "MCP tool description",
|
Description: "MCP tool description",
|
||||||
InputSchema: map[string]any{
|
InputSchema: map[string]any{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
@@ -121,7 +121,7 @@ func TestBuildTools_CustomTypeTools(t *testing.T) {
|
|||||||
{
|
{
|
||||||
Type: "custom",
|
Type: "custom",
|
||||||
Name: "custom_tool",
|
Name: "custom_tool",
|
||||||
Custom: &CustomToolSpec{
|
Custom: &ClaudeCustomToolSpec{
|
||||||
Description: "Custom tool",
|
Description: "Custom tool",
|
||||||
InputSchema: map[string]any{"type": "object"},
|
InputSchema: map[string]any{"type": "object"},
|
||||||
},
|
},
|
||||||
@@ -148,7 +148,7 @@ func TestBuildTools_CustomTypeTools(t *testing.T) {
|
|||||||
{
|
{
|
||||||
Type: "custom",
|
Type: "custom",
|
||||||
Name: "invalid_custom",
|
Name: "invalid_custom",
|
||||||
Custom: &CustomToolSpec{
|
Custom: &ClaudeCustomToolSpec{
|
||||||
Description: "Invalid",
|
Description: "Invalid",
|
||||||
// InputSchema 为 nil
|
// InputSchema 为 nil
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -151,11 +151,17 @@ var (
|
|||||||
return 1
|
return 1
|
||||||
`)
|
`)
|
||||||
|
|
||||||
// getAccountsLoadBatchScript - batch load query (read-only)
|
// getAccountsLoadBatchScript - batch load query with expired slot cleanup
|
||||||
// ARGV[1] = slot TTL (seconds, retained for compatibility)
|
// ARGV[1] = slot TTL (seconds)
|
||||||
// ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ...
|
// ARGV[2..n] = accountID1, maxConcurrency1, accountID2, maxConcurrency2, ...
|
||||||
getAccountsLoadBatchScript = redis.NewScript(`
|
getAccountsLoadBatchScript = redis.NewScript(`
|
||||||
local result = {}
|
local result = {}
|
||||||
|
local slotTTL = tonumber(ARGV[1])
|
||||||
|
|
||||||
|
-- Get current server time
|
||||||
|
local timeResult = redis.call('TIME')
|
||||||
|
local nowSeconds = tonumber(timeResult[1])
|
||||||
|
local cutoffTime = nowSeconds - slotTTL
|
||||||
|
|
||||||
local i = 2
|
local i = 2
|
||||||
while i <= #ARGV do
|
while i <= #ARGV do
|
||||||
@@ -163,6 +169,9 @@ var (
|
|||||||
local maxConcurrency = tonumber(ARGV[i + 1])
|
local maxConcurrency = tonumber(ARGV[i + 1])
|
||||||
|
|
||||||
local slotKey = 'concurrency:account:' .. accountID
|
local slotKey = 'concurrency:account:' .. accountID
|
||||||
|
|
||||||
|
-- Clean up expired slots before counting
|
||||||
|
redis.call('ZREMRANGEBYSCORE', slotKey, '-inf', cutoffTime)
|
||||||
local currentConcurrency = redis.call('ZCARD', slotKey)
|
local currentConcurrency = redis.call('ZCARD', slotKey)
|
||||||
|
|
||||||
local waitKey = 'wait:account:' .. accountID
|
local waitKey = 'wait:account:' .. accountID
|
||||||
|
|||||||
@@ -204,7 +204,7 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
|
|||||||
|
|
||||||
// BindStickySession sets session -> account binding with standard TTL.
|
// BindStickySession sets session -> account binding with standard TTL.
|
||||||
func (s *GatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error {
|
func (s *GatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error {
|
||||||
if sessionHash == "" || accountID <= 0 {
|
if sessionHash == "" || accountID <= 0 || s.cache == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return s.cache.SetSessionAccountID(ctx, sessionHash, accountID, stickySessionTTL)
|
return s.cache.SetSessionAccountID(ctx, sessionHash, accountID, stickySessionTTL)
|
||||||
@@ -429,7 +429,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ============ Layer 1: 粘性会话优先 ============
|
// ============ Layer 1: 粘性会话优先 ============
|
||||||
if sessionHash != "" {
|
if sessionHash != "" && s.cache != nil {
|
||||||
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
||||||
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
if err == nil && accountID > 0 && !isExcluded(accountID) {
|
||||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||||
|
|||||||
Reference in New Issue
Block a user