package handler import ( "context" "fmt" "net/http" "time" "github.com/Wei-Shaw/sub2api/internal/model" "github.com/Wei-Shaw/sub2api/internal/service" "github.com/gin-gonic/gin" ) const ( // maxConcurrencyWait is the maximum time to wait for a concurrency slot maxConcurrencyWait = 30 * time.Second // pingInterval is the interval for sending ping events during slot wait pingInterval = 15 * time.Second ) // SSEPingFormat defines the format of SSE ping events for different platforms type SSEPingFormat string const ( // SSEPingFormatClaude is the Claude/Anthropic SSE ping format SSEPingFormatClaude SSEPingFormat = "data: {\"type\": \"ping\"}\n\n" // SSEPingFormatNone indicates no ping should be sent (e.g., OpenAI has no ping spec) SSEPingFormatNone SSEPingFormat = "" ) // ConcurrencyError represents a concurrency limit error with context type ConcurrencyError struct { SlotType string IsTimeout bool } func (e *ConcurrencyError) Error() string { if e.IsTimeout { return fmt.Sprintf("timeout waiting for %s concurrency slot", e.SlotType) } return fmt.Sprintf("%s concurrency limit reached", e.SlotType) } // ConcurrencyHelper provides common concurrency slot management for gateway handlers type ConcurrencyHelper struct { concurrencyService *service.ConcurrencyService pingFormat SSEPingFormat } // NewConcurrencyHelper creates a new ConcurrencyHelper func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFormat SSEPingFormat) *ConcurrencyHelper { return &ConcurrencyHelper{ concurrencyService: concurrencyService, pingFormat: pingFormat, } } // IncrementWaitCount increments the wait count for a user func (h *ConcurrencyHelper) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) { return h.concurrencyService.IncrementWaitCount(ctx, userID, maxWait) } // DecrementWaitCount decrements the wait count for a user func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64) { h.concurrencyService.DecrementWaitCount(ctx, userID) } // AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary. // For streaming requests, sends ping events during the wait. // streamStarted is updated if streaming response has begun. func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, user *model.User, isStream bool, streamStarted *bool) (func(), error) { ctx := c.Request.Context() // Try to acquire immediately result, err := h.concurrencyService.AcquireUserSlot(ctx, user.ID, user.Concurrency) if err != nil { return nil, err } if result.Acquired { return result.ReleaseFunc, nil } // Need to wait - handle streaming ping if needed return h.waitForSlotWithPing(c, "user", user.ID, user.Concurrency, isStream, streamStarted) } // AcquireAccountSlotWithWait acquires an account concurrency slot, waiting if necessary. // For streaming requests, sends ping events during the wait. // streamStarted is updated if streaming response has begun. func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, account *model.Account, isStream bool, streamStarted *bool) (func(), error) { ctx := c.Request.Context() // Try to acquire immediately result, err := h.concurrencyService.AcquireAccountSlot(ctx, account.ID, account.Concurrency) if err != nil { return nil, err } if result.Acquired { return result.ReleaseFunc, nil } // Need to wait - handle streaming ping if needed return h.waitForSlotWithPing(c, "account", account.ID, account.Concurrency, isStream, streamStarted) } // waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests. // streamStarted pointer is updated when streaming begins (for proper error handling by caller). func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) { ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait) defer cancel() // Determine if ping is needed (streaming + ping format defined) needPing := isStream && h.pingFormat != "" var flusher http.Flusher if needPing { var ok bool flusher, ok = c.Writer.(http.Flusher) if !ok { return nil, fmt.Errorf("streaming not supported") } } // Only create ping ticker if ping is needed var pingCh <-chan time.Time if needPing { pingTicker := time.NewTicker(pingInterval) defer pingTicker.Stop() pingCh = pingTicker.C } pollTicker := time.NewTicker(100 * time.Millisecond) defer pollTicker.Stop() for { select { case <-ctx.Done(): return nil, &ConcurrencyError{ SlotType: slotType, IsTimeout: true, } case <-pingCh: // Send ping to keep connection alive if !*streamStarted { c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") c.Header("X-Accel-Buffering", "no") *streamStarted = true } if _, err := fmt.Fprint(c.Writer, string(h.pingFormat)); err != nil { return nil, err } flusher.Flush() case <-pollTicker.C: // Try to acquire slot var result *service.AcquireResult var err error if slotType == "user" { result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency) } else { result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency) } if err != nil { return nil, err } if result.Acquired { return result.ReleaseFunc, nil } } } }