This commit is contained in:
yangjianbo
2026-01-12 14:50:53 +08:00
20 changed files with 756 additions and 12 deletions

View File

@@ -97,7 +97,8 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
antigravityOAuthService := service.NewAntigravityOAuthService(proxyRepository)
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
tempUnschedCache := repository.NewTempUnschedCache(redisClient)
rateLimitService := service.NewRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache)
timeoutCounterCache := repository.NewTimeoutCounterCache(redisClient)
rateLimitService := service.ProvideRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService, tempUnschedCache, timeoutCounterCache, settingService)
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
antigravityQuotaFetcher := service.NewAntigravityQuotaFetcher(proxyRepository)
usageCache := service.NewUsageCache()

View File

@@ -654,3 +654,68 @@ func (h *SettingHandler) DeleteAdminAPIKey(c *gin.Context) {
response.Success(c, gin.H{"message": "Admin API key deleted"})
}
// GetStreamTimeoutSettings 获取流超时处理配置
// GET /api/v1/admin/settings/stream-timeout
func (h *SettingHandler) GetStreamTimeoutSettings(c *gin.Context) {
settings, err := h.settingService.GetStreamTimeoutSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.StreamTimeoutSettings{
Enabled: settings.Enabled,
Action: settings.Action,
TempUnschedMinutes: settings.TempUnschedMinutes,
ThresholdCount: settings.ThresholdCount,
ThresholdWindowMinutes: settings.ThresholdWindowMinutes,
})
}
// UpdateStreamTimeoutSettingsRequest 更新流超时配置请求
type UpdateStreamTimeoutSettingsRequest struct {
Enabled bool `json:"enabled"`
Action string `json:"action"`
TempUnschedMinutes int `json:"temp_unsched_minutes"`
ThresholdCount int `json:"threshold_count"`
ThresholdWindowMinutes int `json:"threshold_window_minutes"`
}
// UpdateStreamTimeoutSettings 更新流超时处理配置
// PUT /api/v1/admin/settings/stream-timeout
func (h *SettingHandler) UpdateStreamTimeoutSettings(c *gin.Context) {
var req UpdateStreamTimeoutSettingsRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "Invalid request: "+err.Error())
return
}
settings := &service.StreamTimeoutSettings{
Enabled: req.Enabled,
Action: req.Action,
TempUnschedMinutes: req.TempUnschedMinutes,
ThresholdCount: req.ThresholdCount,
ThresholdWindowMinutes: req.ThresholdWindowMinutes,
}
if err := h.settingService.SetStreamTimeoutSettings(c.Request.Context(), settings); err != nil {
response.BadRequest(c, err.Error())
return
}
// 重新获取设置返回
updatedSettings, err := h.settingService.GetStreamTimeoutSettings(c.Request.Context())
if err != nil {
response.ErrorFrom(c, err)
return
}
response.Success(c, dto.StreamTimeoutSettings{
Enabled: updatedSettings.Enabled,
Action: updatedSettings.Action,
TempUnschedMinutes: updatedSettings.TempUnschedMinutes,
ThresholdCount: updatedSettings.ThresholdCount,
ThresholdWindowMinutes: updatedSettings.ThresholdWindowMinutes,
})
}

View File

@@ -66,3 +66,12 @@ type PublicSettings struct {
LinuxDoOAuthEnabled bool `json:"linuxdo_oauth_enabled"`
Version string `json:"version"`
}
// StreamTimeoutSettings 流超时处理配置 DTO
type StreamTimeoutSettings struct {
Enabled bool `json:"enabled"`
Action string `json:"action"`
TempUnschedMinutes int `json:"temp_unsched_minutes"`
ThresholdCount int `json:"threshold_count"`
ThresholdWindowMinutes int `json:"threshold_window_minutes"`
}

View File

@@ -0,0 +1,80 @@
package repository
import (
"context"
"fmt"
"time"
"github.com/Wei-Shaw/sub2api/internal/service"
"github.com/redis/go-redis/v9"
)
const timeoutCounterPrefix = "timeout_count:account:"
// timeoutCounterIncrScript 使用 Lua 脚本原子性地增加计数并返回当前值
// 如果 key 不存在,则创建并设置过期时间
var timeoutCounterIncrScript = redis.NewScript(`
local key = KEYS[1]
local ttl = tonumber(ARGV[1])
local count = redis.call('INCR', key)
if count == 1 then
redis.call('EXPIRE', key, ttl)
end
return count
`)
type timeoutCounterCache struct {
rdb *redis.Client
}
// NewTimeoutCounterCache 创建超时计数器缓存实例
func NewTimeoutCounterCache(rdb *redis.Client) service.TimeoutCounterCache {
return &timeoutCounterCache{rdb: rdb}
}
// IncrementTimeoutCount 增加账户的超时计数,返回当前计数值
// windowMinutes 是计数窗口时间(分钟),超过此时间计数器会自动重置
func (c *timeoutCounterCache) IncrementTimeoutCount(ctx context.Context, accountID int64, windowMinutes int) (int64, error) {
key := fmt.Sprintf("%s%d", timeoutCounterPrefix, accountID)
ttlSeconds := windowMinutes * 60
if ttlSeconds < 60 {
ttlSeconds = 60 // 最小1分钟
}
result, err := timeoutCounterIncrScript.Run(ctx, c.rdb, []string{key}, ttlSeconds).Int64()
if err != nil {
return 0, fmt.Errorf("increment timeout count: %w", err)
}
return result, nil
}
// GetTimeoutCount 获取账户当前的超时计数
func (c *timeoutCounterCache) GetTimeoutCount(ctx context.Context, accountID int64) (int64, error) {
key := fmt.Sprintf("%s%d", timeoutCounterPrefix, accountID)
val, err := c.rdb.Get(ctx, key).Int64()
if err == redis.Nil {
return 0, nil
}
if err != nil {
return 0, fmt.Errorf("get timeout count: %w", err)
}
return val, nil
}
// ResetTimeoutCount 重置账户的超时计数
func (c *timeoutCounterCache) ResetTimeoutCount(ctx context.Context, accountID int64) error {
key := fmt.Sprintf("%s%d", timeoutCounterPrefix, accountID)
return c.rdb.Del(ctx, key).Err()
}
// GetTimeoutCountTTL 获取计数器剩余过期时间
func (c *timeoutCounterCache) GetTimeoutCountTTL(ctx context.Context, accountID int64) (time.Duration, error) {
key := fmt.Sprintf("%s%d", timeoutCounterPrefix, accountID)
return c.rdb.TTL(ctx, key).Result()
}

View File

@@ -59,6 +59,7 @@ var ProviderSet = wire.NewSet(
NewBillingCache,
NewAPIKeyCache,
NewTempUnschedCache,
NewTimeoutCounterCache,
ProvideConcurrencyCache,
NewDashboardCache,
NewEmailCache,

View File

@@ -283,6 +283,9 @@ func registerSettingsRoutes(admin *gin.RouterGroup, h *handler.Handlers) {
adminSettings.GET("/admin-api-key", h.Admin.Setting.GetAdminAPIKey)
adminSettings.POST("/admin-api-key/regenerate", h.Admin.Setting.RegenerateAdminAPIKey)
adminSettings.DELETE("/admin-api-key", h.Admin.Setting.DeleteAdminAPIKey)
// 流超时处理配置
adminSettings.GET("/stream-timeout", h.Admin.Setting.GetStreamTimeoutSettings)
adminSettings.PUT("/stream-timeout", h.Admin.Setting.UpdateStreamTimeoutSettings)
}
}

View File

@@ -1717,6 +1717,7 @@ func (s *AntigravityGatewayService) handleGeminiStreamingResponse(c *gin.Context
continue
}
log.Printf("Stream data interval timeout (antigravity)")
// 注意:此函数没有 account 上下文,无法调用 HandleStreamTimeout
sendErrorEvent("stream_timeout")
return &antigravityStreamResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
}
@@ -2271,6 +2272,7 @@ func (s *AntigravityGatewayService) handleClaudeStreamingResponse(c *gin.Context
continue
}
log.Printf("Stream data interval timeout (antigravity)")
// 注意:此函数没有 account 上下文,无法调用 HandleStreamTimeout
sendErrorEvent("stream_timeout")
return &antigravityStreamResult{usage: convertUsage(nil), firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
}

View File

@@ -146,6 +146,13 @@ const (
// SettingKeyOpsAdvancedSettings stores JSON config for ops advanced settings (data retention, aggregation).
SettingKeyOpsAdvancedSettings = "ops_advanced_settings"
// =========================
// Stream Timeout Handling
// =========================
// SettingKeyStreamTimeoutSettings stores JSON config for stream timeout handling.
SettingKeyStreamTimeoutSettings = "stream_timeout_settings"
)
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).

View File

@@ -2341,6 +2341,10 @@ func (s *GatewayService) handleStreamingResponse(ctx context.Context, resp *http
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
// 处理流超时,可能标记账户为临时不可调度或错误状态
if s.rateLimitService != nil {
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
}
sendErrorEvent("stream_timeout")
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
}

View File

@@ -1047,6 +1047,10 @@ func (s *OpenAIGatewayService) handleStreamingResponse(ctx context.Context, resp
continue
}
log.Printf("Stream data interval timeout: account=%d model=%s interval=%s", account.ID, originalModel, streamInterval)
// 处理流超时,可能标记账户为临时不可调度或错误状态
if s.rateLimitService != nil {
s.rateLimitService.HandleStreamTimeout(ctx, account, originalModel)
}
sendErrorEvent("stream_timeout")
return &openaiStreamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")

View File

@@ -15,13 +15,15 @@ import (
// RateLimitService 处理限流和过载状态管理
type RateLimitService struct {
accountRepo AccountRepository
usageRepo UsageLogRepository
cfg *config.Config
geminiQuotaService *GeminiQuotaService
tempUnschedCache TempUnschedCache
usageCacheMu sync.RWMutex
usageCache map[int64]*geminiUsageCacheEntry
accountRepo AccountRepository
usageRepo UsageLogRepository
cfg *config.Config
geminiQuotaService *GeminiQuotaService
tempUnschedCache TempUnschedCache
timeoutCounterCache TimeoutCounterCache
settingService *SettingService
usageCacheMu sync.RWMutex
usageCache map[int64]*geminiUsageCacheEntry
}
type geminiUsageCacheEntry struct {
@@ -44,6 +46,16 @@ func NewRateLimitService(accountRepo AccountRepository, usageRepo UsageLogReposi
}
}
// SetTimeoutCounterCache 设置超时计数器缓存(可选依赖)
func (s *RateLimitService) SetTimeoutCounterCache(cache TimeoutCounterCache) {
s.timeoutCounterCache = cache
}
// SetSettingService 设置系统设置服务(可选依赖)
func (s *RateLimitService) SetSettingService(settingService *SettingService) {
s.settingService = settingService
}
// HandleUpstreamError 处理上游错误响应,标记账号状态
// 返回是否应该停止该账号的调度
func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Account, statusCode int, headers http.Header, responseBody []byte) (shouldDisable bool) {
@@ -555,3 +567,125 @@ func truncateTempUnschedMessage(body []byte, maxBytes int) string {
}
return strings.TrimSpace(string(body))
}
// HandleStreamTimeout 处理流数据超时
// 根据系统设置决定是否标记账户为临时不可调度或错误状态
// 返回是否应该停止该账号的调度
func (s *RateLimitService) HandleStreamTimeout(ctx context.Context, account *Account, model string) bool {
if account == nil {
return false
}
// 获取系统设置
if s.settingService == nil {
log.Printf("[StreamTimeout] settingService not configured, skipping timeout handling for account %d", account.ID)
return false
}
settings, err := s.settingService.GetStreamTimeoutSettings(ctx)
if err != nil {
log.Printf("[StreamTimeout] Failed to get settings: %v", err)
return false
}
if !settings.Enabled {
return false
}
if settings.Action == StreamTimeoutActionNone {
return false
}
// 增加超时计数
var count int64 = 1
if s.timeoutCounterCache != nil {
count, err = s.timeoutCounterCache.IncrementTimeoutCount(ctx, account.ID, settings.ThresholdWindowMinutes)
if err != nil {
log.Printf("[StreamTimeout] Failed to increment timeout count for account %d: %v", account.ID, err)
// 继续处理,使用 count=1
count = 1
}
}
log.Printf("[StreamTimeout] Account %d timeout count: %d/%d (window: %d min, model: %s)",
account.ID, count, settings.ThresholdCount, settings.ThresholdWindowMinutes, model)
// 检查是否达到阈值
if count < int64(settings.ThresholdCount) {
return false
}
// 达到阈值,执行相应操作
switch settings.Action {
case StreamTimeoutActionTempUnsched:
return s.triggerStreamTimeoutTempUnsched(ctx, account, settings, model)
case StreamTimeoutActionError:
return s.triggerStreamTimeoutError(ctx, account, model)
default:
return false
}
}
// triggerStreamTimeoutTempUnsched 触发流超时临时不可调度
func (s *RateLimitService) triggerStreamTimeoutTempUnsched(ctx context.Context, account *Account, settings *StreamTimeoutSettings, model string) bool {
now := time.Now()
until := now.Add(time.Duration(settings.TempUnschedMinutes) * time.Minute)
state := &TempUnschedState{
UntilUnix: until.Unix(),
TriggeredAtUnix: now.Unix(),
StatusCode: 0, // 超时没有状态码
MatchedKeyword: "stream_timeout",
RuleIndex: -1, // 表示系统级规则
ErrorMessage: "Stream data interval timeout for model: " + model,
}
reason := ""
if raw, err := json.Marshal(state); err == nil {
reason = string(raw)
}
if reason == "" {
reason = state.ErrorMessage
}
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
log.Printf("[StreamTimeout] SetTempUnschedulable failed for account %d: %v", account.ID, err)
return false
}
if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil {
log.Printf("[StreamTimeout] SetTempUnsched cache failed for account %d: %v", account.ID, err)
}
}
// 重置超时计数
if s.timeoutCounterCache != nil {
if err := s.timeoutCounterCache.ResetTimeoutCount(ctx, account.ID); err != nil {
log.Printf("[StreamTimeout] ResetTimeoutCount failed for account %d: %v", account.ID, err)
}
}
log.Printf("[StreamTimeout] Account %d marked as temp unschedulable until %v (model: %s)", account.ID, until, model)
return true
}
// triggerStreamTimeoutError 触发流超时错误状态
func (s *RateLimitService) triggerStreamTimeoutError(ctx context.Context, account *Account, model string) bool {
errorMsg := "Stream data interval timeout (repeated failures) for model: " + model
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {
log.Printf("[StreamTimeout] SetError failed for account %d: %v", account.ID, err)
return false
}
// 重置超时计数
if s.timeoutCounterCache != nil {
if err := s.timeoutCounterCache.ResetTimeoutCount(ctx, account.ID); err != nil {
log.Printf("[StreamTimeout] ResetTimeoutCount failed for account %d: %v", account.ID, err)
}
}
log.Printf("[StreamTimeout] Account %d marked as error (model: %s)", account.ID, model)
return true
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"crypto/rand"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"strconv"
@@ -675,3 +676,84 @@ func (s *SettingService) GetLinuxDoConnectOAuthConfig(ctx context.Context) (conf
return effective, nil
}
// GetStreamTimeoutSettings 获取流超时处理配置
func (s *SettingService) GetStreamTimeoutSettings(ctx context.Context) (*StreamTimeoutSettings, error) {
value, err := s.settingRepo.GetValue(ctx, SettingKeyStreamTimeoutSettings)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
return DefaultStreamTimeoutSettings(), nil
}
return nil, fmt.Errorf("get stream timeout settings: %w", err)
}
if value == "" {
return DefaultStreamTimeoutSettings(), nil
}
var settings StreamTimeoutSettings
if err := json.Unmarshal([]byte(value), &settings); err != nil {
return DefaultStreamTimeoutSettings(), nil
}
// 验证并修正配置值
if settings.TempUnschedMinutes < 1 {
settings.TempUnschedMinutes = 1
}
if settings.TempUnschedMinutes > 60 {
settings.TempUnschedMinutes = 60
}
if settings.ThresholdCount < 1 {
settings.ThresholdCount = 1
}
if settings.ThresholdCount > 10 {
settings.ThresholdCount = 10
}
if settings.ThresholdWindowMinutes < 1 {
settings.ThresholdWindowMinutes = 1
}
if settings.ThresholdWindowMinutes > 60 {
settings.ThresholdWindowMinutes = 60
}
// 验证 action
switch settings.Action {
case StreamTimeoutActionTempUnsched, StreamTimeoutActionError, StreamTimeoutActionNone:
// valid
default:
settings.Action = StreamTimeoutActionTempUnsched
}
return &settings, nil
}
// SetStreamTimeoutSettings 设置流超时处理配置
func (s *SettingService) SetStreamTimeoutSettings(ctx context.Context, settings *StreamTimeoutSettings) error {
if settings == nil {
return fmt.Errorf("settings cannot be nil")
}
// 验证配置值
if settings.TempUnschedMinutes < 1 || settings.TempUnschedMinutes > 60 {
return fmt.Errorf("temp_unsched_minutes must be between 1-60")
}
if settings.ThresholdCount < 1 || settings.ThresholdCount > 10 {
return fmt.Errorf("threshold_count must be between 1-10")
}
if settings.ThresholdWindowMinutes < 1 || settings.ThresholdWindowMinutes > 60 {
return fmt.Errorf("threshold_window_minutes must be between 1-60")
}
switch settings.Action {
case StreamTimeoutActionTempUnsched, StreamTimeoutActionError, StreamTimeoutActionNone:
// valid
default:
return fmt.Errorf("invalid action: %s", settings.Action)
}
data, err := json.Marshal(settings)
if err != nil {
return fmt.Errorf("marshal stream timeout settings: %w", err)
}
return s.settingRepo.Set(ctx, SettingKeyStreamTimeoutSettings, string(data))
}

View File

@@ -69,3 +69,35 @@ type PublicSettings struct {
LinuxDoOAuthEnabled bool
Version string
}
// StreamTimeoutSettings 流超时处理配置(仅控制超时后的处理方式,超时判定由网关配置控制)
type StreamTimeoutSettings struct {
// Enabled 是否启用流超时处理
Enabled bool `json:"enabled"`
// Action 超时后的处理方式: "temp_unsched" | "error" | "none"
Action string `json:"action"`
// TempUnschedMinutes 临时不可调度持续时间(分钟)
TempUnschedMinutes int `json:"temp_unsched_minutes"`
// ThresholdCount 触发阈值次数(累计多少次超时才触发)
ThresholdCount int `json:"threshold_count"`
// ThresholdWindowMinutes 阈值窗口时间(分钟)
ThresholdWindowMinutes int `json:"threshold_window_minutes"`
}
// StreamTimeoutAction 流超时处理方式常量
const (
StreamTimeoutActionTempUnsched = "temp_unsched" // 临时不可调度
StreamTimeoutActionError = "error" // 标记为错误状态
StreamTimeoutActionNone = "none" // 不处理
)
// DefaultStreamTimeoutSettings 返回默认的流超时配置
func DefaultStreamTimeoutSettings() *StreamTimeoutSettings {
return &StreamTimeoutSettings{
Enabled: false,
Action: StreamTimeoutActionTempUnsched,
TempUnschedMinutes: 5,
ThresholdCount: 3,
ThresholdWindowMinutes: 10,
}
}

View File

@@ -2,6 +2,7 @@ package service
import (
"context"
"time"
)
// TempUnschedState 临时不可调度状态
@@ -20,3 +21,16 @@ type TempUnschedCache interface {
GetTempUnsched(ctx context.Context, accountID int64) (*TempUnschedState, error)
DeleteTempUnsched(ctx context.Context, accountID int64) error
}
// TimeoutCounterCache 超时计数器缓存接口
type TimeoutCounterCache interface {
// IncrementTimeoutCount 增加账户的超时计数,返回当前计数值
// windowMinutes 是计数窗口时间(分钟),超过此时间计数器会自动重置
IncrementTimeoutCount(ctx context.Context, accountID int64, windowMinutes int) (int64, error)
// GetTimeoutCount 获取账户当前的超时计数
GetTimeoutCount(ctx context.Context, accountID int64) (int64, error)
// ResetTimeoutCount 重置账户的超时计数
ResetTimeoutCount(ctx context.Context, accountID int64) error
// GetTimeoutCountTTL 获取计数器剩余过期时间
GetTimeoutCountTTL(ctx context.Context, accountID int64) (time.Duration, error)
}

View File

@@ -99,6 +99,22 @@ func ProvideSchedulerSnapshotService(
return svc
}
// ProvideRateLimitService creates RateLimitService with optional dependencies.
func ProvideRateLimitService(
accountRepo AccountRepository,
usageRepo UsageLogRepository,
cfg *config.Config,
geminiQuotaService *GeminiQuotaService,
tempUnschedCache TempUnschedCache,
timeoutCounterCache TimeoutCounterCache,
settingService *SettingService,
) *RateLimitService {
svc := NewRateLimitService(accountRepo, usageRepo, cfg, geminiQuotaService, tempUnschedCache)
svc.SetTimeoutCounterCache(timeoutCounterCache)
svc.SetSettingService(settingService)
return svc
}
// ProvideOpsMetricsCollector creates and starts OpsMetricsCollector.
func ProvideOpsMetricsCollector(
opsRepo OpsRepository,
@@ -199,7 +215,7 @@ var ProviderSet = wire.NewSet(
NewGeminiMessagesCompatService,
NewAntigravityTokenProvider,
NewAntigravityGatewayService,
NewRateLimitService,
ProvideRateLimitService,
NewAccountUsageService,
NewAccountTestService,
NewSettingService,