181 lines
5.4 KiB
Go
181 lines
5.4 KiB
Go
package handler
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/model"
|
|
"github.com/Wei-Shaw/sub2api/internal/service"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
)
|
|
|
|
const (
|
|
// maxConcurrencyWait is the maximum time to wait for a concurrency slot
|
|
maxConcurrencyWait = 30 * time.Second
|
|
// pingInterval is the interval for sending ping events during slot wait
|
|
pingInterval = 15 * time.Second
|
|
)
|
|
|
|
// SSEPingFormat defines the format of SSE ping events for different platforms
|
|
type SSEPingFormat string
|
|
|
|
const (
|
|
// SSEPingFormatClaude is the Claude/Anthropic SSE ping format
|
|
SSEPingFormatClaude SSEPingFormat = "data: {\"type\": \"ping\"}\n\n"
|
|
// SSEPingFormatNone indicates no ping should be sent (e.g., OpenAI has no ping spec)
|
|
SSEPingFormatNone SSEPingFormat = ""
|
|
)
|
|
|
|
// ConcurrencyError represents a concurrency limit error with context
|
|
type ConcurrencyError struct {
|
|
SlotType string
|
|
IsTimeout bool
|
|
}
|
|
|
|
func (e *ConcurrencyError) Error() string {
|
|
if e.IsTimeout {
|
|
return fmt.Sprintf("timeout waiting for %s concurrency slot", e.SlotType)
|
|
}
|
|
return fmt.Sprintf("%s concurrency limit reached", e.SlotType)
|
|
}
|
|
|
|
// ConcurrencyHelper provides common concurrency slot management for gateway handlers
|
|
type ConcurrencyHelper struct {
|
|
concurrencyService *service.ConcurrencyService
|
|
pingFormat SSEPingFormat
|
|
}
|
|
|
|
// NewConcurrencyHelper creates a new ConcurrencyHelper
|
|
func NewConcurrencyHelper(concurrencyService *service.ConcurrencyService, pingFormat SSEPingFormat) *ConcurrencyHelper {
|
|
return &ConcurrencyHelper{
|
|
concurrencyService: concurrencyService,
|
|
pingFormat: pingFormat,
|
|
}
|
|
}
|
|
|
|
// IncrementWaitCount increments the wait count for a user
|
|
func (h *ConcurrencyHelper) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
|
return h.concurrencyService.IncrementWaitCount(ctx, userID, maxWait)
|
|
}
|
|
|
|
// DecrementWaitCount decrements the wait count for a user
|
|
func (h *ConcurrencyHelper) DecrementWaitCount(ctx context.Context, userID int64) {
|
|
h.concurrencyService.DecrementWaitCount(ctx, userID)
|
|
}
|
|
|
|
// AcquireUserSlotWithWait acquires a user concurrency slot, waiting if necessary.
|
|
// For streaming requests, sends ping events during the wait.
|
|
// streamStarted is updated if streaming response has begun.
|
|
func (h *ConcurrencyHelper) AcquireUserSlotWithWait(c *gin.Context, user *model.User, isStream bool, streamStarted *bool) (func(), error) {
|
|
ctx := c.Request.Context()
|
|
|
|
// Try to acquire immediately
|
|
result, err := h.concurrencyService.AcquireUserSlot(ctx, user.ID, user.Concurrency)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if result.Acquired {
|
|
return result.ReleaseFunc, nil
|
|
}
|
|
|
|
// Need to wait - handle streaming ping if needed
|
|
return h.waitForSlotWithPing(c, "user", user.ID, user.Concurrency, isStream, streamStarted)
|
|
}
|
|
|
|
// AcquireAccountSlotWithWait acquires an account concurrency slot, waiting if necessary.
|
|
// For streaming requests, sends ping events during the wait.
|
|
// streamStarted is updated if streaming response has begun.
|
|
func (h *ConcurrencyHelper) AcquireAccountSlotWithWait(c *gin.Context, account *model.Account, isStream bool, streamStarted *bool) (func(), error) {
|
|
ctx := c.Request.Context()
|
|
|
|
// Try to acquire immediately
|
|
result, err := h.concurrencyService.AcquireAccountSlot(ctx, account.ID, account.Concurrency)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if result.Acquired {
|
|
return result.ReleaseFunc, nil
|
|
}
|
|
|
|
// Need to wait - handle streaming ping if needed
|
|
return h.waitForSlotWithPing(c, "account", account.ID, account.Concurrency, isStream, streamStarted)
|
|
}
|
|
|
|
// waitForSlotWithPing waits for a concurrency slot, sending ping events for streaming requests.
|
|
// streamStarted pointer is updated when streaming begins (for proper error handling by caller).
|
|
func (h *ConcurrencyHelper) waitForSlotWithPing(c *gin.Context, slotType string, id int64, maxConcurrency int, isStream bool, streamStarted *bool) (func(), error) {
|
|
ctx, cancel := context.WithTimeout(c.Request.Context(), maxConcurrencyWait)
|
|
defer cancel()
|
|
|
|
// Determine if ping is needed (streaming + ping format defined)
|
|
needPing := isStream && h.pingFormat != ""
|
|
|
|
var flusher http.Flusher
|
|
if needPing {
|
|
var ok bool
|
|
flusher, ok = c.Writer.(http.Flusher)
|
|
if !ok {
|
|
return nil, fmt.Errorf("streaming not supported")
|
|
}
|
|
}
|
|
|
|
// Only create ping ticker if ping is needed
|
|
var pingCh <-chan time.Time
|
|
if needPing {
|
|
pingTicker := time.NewTicker(pingInterval)
|
|
defer pingTicker.Stop()
|
|
pingCh = pingTicker.C
|
|
}
|
|
|
|
pollTicker := time.NewTicker(100 * time.Millisecond)
|
|
defer pollTicker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, &ConcurrencyError{
|
|
SlotType: slotType,
|
|
IsTimeout: true,
|
|
}
|
|
|
|
case <-pingCh:
|
|
// Send ping to keep connection alive
|
|
if !*streamStarted {
|
|
c.Header("Content-Type", "text/event-stream")
|
|
c.Header("Cache-Control", "no-cache")
|
|
c.Header("Connection", "keep-alive")
|
|
c.Header("X-Accel-Buffering", "no")
|
|
*streamStarted = true
|
|
}
|
|
if _, err := fmt.Fprint(c.Writer, string(h.pingFormat)); err != nil {
|
|
return nil, err
|
|
}
|
|
flusher.Flush()
|
|
|
|
case <-pollTicker.C:
|
|
// Try to acquire slot
|
|
var result *service.AcquireResult
|
|
var err error
|
|
|
|
if slotType == "user" {
|
|
result, err = h.concurrencyService.AcquireUserSlot(ctx, id, maxConcurrency)
|
|
} else {
|
|
result, err = h.concurrencyService.AcquireAccountSlot(ctx, id, maxConcurrency)
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if result.Acquired {
|
|
return result.ReleaseFunc, nil
|
|
}
|
|
}
|
|
}
|
|
}
|