From 06d483fa8d09ddc092fa1e69aa3a0170a6c5f445 Mon Sep 17 00:00:00 2001 From: IanShaw027 <131567472+IanShaw027@users.noreply.github.com> Date: Thu, 1 Jan 2026 04:22:39 +0800 Subject: [PATCH] feat(backend): implement gemini quota simulation and rate limiting - feat: add local quota tracking for gemini tiers (Legacy/Pro/Ultra) - feat: implement PreCheckUsage in RateLimitService - feat: align gemini daily reset window with PST - fix: sticky session fallback logic --- backend/cmd/server/wire_gen.go | 5 +- backend/internal/config/config.go | 13 + backend/internal/service/account.go | 31 ++ .../internal/service/account_usage_service.go | 71 ++++- backend/internal/service/domain_constants.go | 3 + .../service/gemini_messages_compat_service.go | 59 ++-- backend/internal/service/gemini_quota.go | 268 ++++++++++++++++++ backend/internal/service/ratelimit_service.go | 77 ++++- backend/internal/service/wire.go | 1 + deploy/.env.example | 8 + deploy/README.md | 1 + deploy/config.example.yaml | 16 ++ deploy/docker-compose-test.yml | 1 + deploy/docker-compose.override.yml | 21 ++ deploy/docker-compose.yml | 1 + 15 files changed, 537 insertions(+), 39 deletions(-) create mode 100644 backend/internal/service/gemini_quota.go create mode 100644 deploy/docker-compose.override.yml diff --git a/backend/cmd/server/wire_gen.go b/backend/cmd/server/wire_gen.go index 83cba823..eed203f1 100644 --- a/backend/cmd/server/wire_gen.go +++ b/backend/cmd/server/wire_gen.go @@ -87,9 +87,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) { geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig) geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient() geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, configConfig) - rateLimitService := service.NewRateLimitService(accountRepository, configConfig) + geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository) + rateLimitService := service.NewRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService) claudeUsageFetcher := repository.NewClaudeUsageFetcher() - accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher) + accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService) geminiTokenCache := repository.NewGeminiTokenCache(redisClient) geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService) gatewayCache := repository.NewGatewayCache(redisClient) diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index d3674932..6601f753 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -43,6 +43,7 @@ type Config struct { type GeminiConfig struct { OAuth GeminiOAuthConfig `mapstructure:"oauth"` + Quota GeminiQuotaConfig `mapstructure:"quota"` } type GeminiOAuthConfig struct { @@ -51,6 +52,17 @@ type GeminiOAuthConfig struct { Scopes string `mapstructure:"scopes"` } +type GeminiQuotaConfig struct { + Tiers map[string]GeminiTierQuotaConfig `mapstructure:"tiers"` + Policy string `mapstructure:"policy"` +} + +type GeminiTierQuotaConfig struct { + ProRPD *int64 `mapstructure:"pro_rpd" json:"pro_rpd"` + FlashRPD *int64 `mapstructure:"flash_rpd" json:"flash_rpd"` + CooldownMinutes *int `mapstructure:"cooldown_minutes" json:"cooldown_minutes"` +} + // TokenRefreshConfig OAuth token自动刷新配置 type TokenRefreshConfig struct { // 是否启用自动刷新 @@ -352,6 +364,7 @@ func setDefaults() { viper.SetDefault("gemini.oauth.client_id", "") viper.SetDefault("gemini.oauth.client_secret", "") viper.SetDefault("gemini.oauth.scopes", "") + viper.SetDefault("gemini.quota.policy", "") } func (c *Config) Validate() error { diff --git a/backend/internal/service/account.go b/backend/internal/service/account.go index 5d461b9c..dcc6c3c5 100644 --- a/backend/internal/service/account.go +++ b/backend/internal/service/account.go @@ -3,6 +3,7 @@ package service import ( "encoding/json" "strconv" + "strings" "time" ) @@ -78,6 +79,36 @@ func (a *Account) IsGemini() bool { return a.Platform == PlatformGemini } +func (a *Account) GeminiOAuthType() string { + if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth { + return "" + } + oauthType := strings.TrimSpace(a.GetCredential("oauth_type")) + if oauthType == "" && strings.TrimSpace(a.GetCredential("project_id")) != "" { + return "code_assist" + } + return oauthType +} + +func (a *Account) GeminiTierID() string { + tierID := strings.TrimSpace(a.GetCredential("tier_id")) + if tierID == "" { + return "" + } + return strings.ToUpper(tierID) +} + +func (a *Account) IsGeminiCodeAssist() bool { + if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth { + return false + } + oauthType := a.GeminiOAuthType() + if oauthType == "" { + return strings.TrimSpace(a.GetCredential("project_id")) != "" + } + return oauthType == "code_assist" +} + func (a *Account) CanGetUsage() bool { return a.Type == AccountTypeOAuth } diff --git a/backend/internal/service/account_usage_service.go b/backend/internal/service/account_usage_service.go index dba670b0..0040a643 100644 --- a/backend/internal/service/account_usage_service.go +++ b/backend/internal/service/account_usage_service.go @@ -97,6 +97,8 @@ type UsageInfo struct { FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口 SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口 SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口 + GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额 + GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额 } // ClaudeUsageResponse Anthropic API返回的usage结构 @@ -122,17 +124,19 @@ type ClaudeUsageFetcher interface { // AccountUsageService 账号使用量查询服务 type AccountUsageService struct { - accountRepo AccountRepository - usageLogRepo UsageLogRepository - usageFetcher ClaudeUsageFetcher + accountRepo AccountRepository + usageLogRepo UsageLogRepository + usageFetcher ClaudeUsageFetcher + geminiQuotaService *GeminiQuotaService } // NewAccountUsageService 创建AccountUsageService实例 -func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLogRepository, usageFetcher ClaudeUsageFetcher) *AccountUsageService { +func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLogRepository, usageFetcher ClaudeUsageFetcher, geminiQuotaService *GeminiQuotaService) *AccountUsageService { return &AccountUsageService{ - accountRepo: accountRepo, - usageLogRepo: usageLogRepo, - usageFetcher: usageFetcher, + accountRepo: accountRepo, + usageLogRepo: usageLogRepo, + usageFetcher: usageFetcher, + geminiQuotaService: geminiQuotaService, } } @@ -146,6 +150,10 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U return nil, fmt.Errorf("get account failed: %w", err) } + if account.Platform == PlatformGemini { + return s.getGeminiUsage(ctx, account) + } + // 只有oauth类型账号可以通过API获取usage(有profile scope) if account.CanGetUsage() { var apiResp *ClaudeUsageResponse @@ -192,6 +200,33 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U return nil, fmt.Errorf("account type %s does not support usage query", account.Type) } +func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Account) (*UsageInfo, error) { + now := time.Now() + start := geminiDailyWindowStart(now) + + stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID) + if err != nil { + return nil, fmt.Errorf("get gemini usage stats failed: %w", err) + } + + usage := &UsageInfo{ + UpdatedAt: &now, + } + + quota, ok := s.geminiQuotaService.QuotaForAccount(ctx, account) + if !ok { + return usage, nil + } + + totals := geminiAggregateUsage(stats) + resetAt := geminiDailyResetTime(now) + + usage.GeminiProDaily = buildGeminiUsageProgress(totals.ProRequests, quota.ProRPD, resetAt, totals.ProTokens, totals.ProCost, now) + usage.GeminiFlashDaily = buildGeminiUsageProgress(totals.FlashRequests, quota.FlashRPD, resetAt, totals.FlashTokens, totals.FlashCost, now) + + return usage, nil +} + // addWindowStats 为 usage 数据添加窗口期统计 // 使用独立缓存(1 分钟),与 API 缓存分离 func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) { @@ -388,3 +423,25 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn // Setup Token无法获取7d数据 return info } + +func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64, cost float64, now time.Time) *UsageProgress { + if limit <= 0 { + return nil + } + utilization := (float64(used) / float64(limit)) * 100 + remainingSeconds := int(resetAt.Sub(now).Seconds()) + if remainingSeconds < 0 { + remainingSeconds = 0 + } + resetCopy := resetAt + return &UsageProgress{ + Utilization: utilization, + ResetsAt: &resetCopy, + RemainingSeconds: remainingSeconds, + WindowStats: &WindowStats{ + Requests: used, + Tokens: tokens, + Cost: cost, + }, + } +} diff --git a/backend/internal/service/domain_constants.go b/backend/internal/service/domain_constants.go index 2e879263..ca2c2c99 100644 --- a/backend/internal/service/domain_constants.go +++ b/backend/internal/service/domain_constants.go @@ -91,6 +91,9 @@ const ( // 管理员 API Key SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成) + + // Gemini 配额策略(JSON) + SettingKeyGeminiQuotaPolicy = "gemini_quota_policy" ) // Admin API Key prefix (distinct from user "sk-" keys) diff --git a/backend/internal/service/gemini_messages_compat_service.go b/backend/internal/service/gemini_messages_compat_service.go index 111ff462..163dfe1d 100644 --- a/backend/internal/service/gemini_messages_compat_service.go +++ b/backend/internal/service/gemini_messages_compat_service.go @@ -116,8 +116,20 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co valid = true } if valid { - _ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL) - return account, nil + usable := true + if s.rateLimitService != nil && requestedModel != "" { + ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel) + if err != nil { + log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err) + } + if !ok { + usable = false + } + } + if usable { + _ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL) + return account, nil + } } } } @@ -157,6 +169,15 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) { continue } + if s.rateLimitService != nil && requestedModel != "" { + ok, err := s.rateLimitService.PreCheckUsage(ctx, acc, requestedModel) + if err != nil { + log.Printf("[Gemini PreCheck] Account %d precheck error: %v", acc.ID, err) + } + if !ok { + continue + } + } if selected == nil { selected = acc continue @@ -1887,26 +1908,23 @@ func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Cont return } - // 获取账号的 oauth_type、tier_id 和 project_id - oauthType := strings.TrimSpace(account.GetCredential("oauth_type")) - tierID := strings.TrimSpace(account.GetCredential("tier_id")) + oauthType := account.GeminiOAuthType() + tierID := account.GeminiTierID() projectID := strings.TrimSpace(account.GetCredential("project_id")) - - // 判断是否为 Code Assist:以 project_id 是否存在为准(更可靠) - isCodeAssist := projectID != "" - // Legacy 兼容:oauth_type 为空但 project_id 存在时视为 code_assist - if oauthType == "" && isCodeAssist { - oauthType = "code_assist" - } + isCodeAssist := account.IsGeminiCodeAssist() resetAt := ParseGeminiRateLimitResetTime(body) if resetAt == nil { // 根据账号类型使用不同的默认重置时间 var ra time.Time if isCodeAssist { - // Code Assist: 5 分钟滚动窗口 - ra = time.Now().Add(5 * time.Minute) - log.Printf("[Gemini 429] Account %d (Code Assist, tier=%s, project=%s) rate limited, reset in 5min", account.ID, tierID, projectID) + // Code Assist: fallback cooldown by tier + cooldown := geminiCooldownForTier(tierID) + if s.rateLimitService != nil { + cooldown = s.rateLimitService.GeminiCooldown(ctx, account) + } + ra = time.Now().Add(cooldown) + log.Printf("[Gemini 429] Account %d (Code Assist, tier=%s, project=%s) rate limited, cooldown=%v", account.ID, tierID, projectID, ra.Sub(time.Now()).Truncate(time.Second)) } else { // API Key / AI Studio OAuth: PST 午夜 if ts := nextGeminiDailyResetUnix(); ts != nil { @@ -1982,16 +2000,7 @@ func looksLikeGeminiDailyQuota(message string) bool { } func nextGeminiDailyResetUnix() *int64 { - loc, err := time.LoadLocation("America/Los_Angeles") - if err != nil { - // Fallback: PST without DST. - loc = time.FixedZone("PST", -8*3600) - } - now := time.Now().In(loc) - reset := time.Date(now.Year(), now.Month(), now.Day(), 0, 5, 0, 0, loc) - if !reset.After(now) { - reset = reset.Add(24 * time.Hour) - } + reset := geminiDailyResetTime(time.Now()) ts := reset.Unix() return &ts } diff --git a/backend/internal/service/gemini_quota.go b/backend/internal/service/gemini_quota.go new file mode 100644 index 00000000..47ffbfe8 --- /dev/null +++ b/backend/internal/service/gemini_quota.go @@ -0,0 +1,268 @@ +package service + +import ( + "context" + "encoding/json" + "errors" + "log" + "strings" + "sync" + "time" + + "github.com/Wei-Shaw/sub2api/internal/config" + "github.com/Wei-Shaw/sub2api/internal/pkg/usagestats" +) + +type geminiModelClass string + +const ( + geminiModelPro geminiModelClass = "pro" + geminiModelFlash geminiModelClass = "flash" +) + +type GeminiDailyQuota struct { + ProRPD int64 + FlashRPD int64 +} + +type GeminiTierPolicy struct { + Quota GeminiDailyQuota + Cooldown time.Duration +} + +type GeminiQuotaPolicy struct { + tiers map[string]GeminiTierPolicy +} + +type GeminiUsageTotals struct { + ProRequests int64 + FlashRequests int64 + ProTokens int64 + FlashTokens int64 + ProCost float64 + FlashCost float64 +} + +const geminiQuotaCacheTTL = time.Minute + +type geminiQuotaOverrides struct { + Tiers map[string]config.GeminiTierQuotaConfig `json:"tiers"` +} + +type GeminiQuotaService struct { + cfg *config.Config + settingRepo SettingRepository + mu sync.Mutex + cachedAt time.Time + policy *GeminiQuotaPolicy +} + +func NewGeminiQuotaService(cfg *config.Config, settingRepo SettingRepository) *GeminiQuotaService { + return &GeminiQuotaService{ + cfg: cfg, + settingRepo: settingRepo, + } +} + +func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy { + if s == nil { + return newGeminiQuotaPolicy() + } + + now := time.Now() + s.mu.Lock() + if s.policy != nil && now.Sub(s.cachedAt) < geminiQuotaCacheTTL { + policy := s.policy + s.mu.Unlock() + return policy + } + s.mu.Unlock() + + policy := newGeminiQuotaPolicy() + if s.cfg != nil { + policy.ApplyOverrides(s.cfg.Gemini.Quota.Tiers) + if strings.TrimSpace(s.cfg.Gemini.Quota.Policy) != "" { + var overrides geminiQuotaOverrides + if err := json.Unmarshal([]byte(s.cfg.Gemini.Quota.Policy), &overrides); err != nil { + log.Printf("gemini quota: parse config policy failed: %v", err) + } else { + policy.ApplyOverrides(overrides.Tiers) + } + } + } + + if s.settingRepo != nil { + value, err := s.settingRepo.GetValue(ctx, SettingKeyGeminiQuotaPolicy) + if err != nil && !errors.Is(err, ErrSettingNotFound) { + log.Printf("gemini quota: load setting failed: %v", err) + } else if strings.TrimSpace(value) != "" { + var overrides geminiQuotaOverrides + if err := json.Unmarshal([]byte(value), &overrides); err != nil { + log.Printf("gemini quota: parse setting failed: %v", err) + } else { + policy.ApplyOverrides(overrides.Tiers) + } + } + } + + s.mu.Lock() + s.policy = policy + s.cachedAt = now + s.mu.Unlock() + + return policy +} + +func (s *GeminiQuotaService) QuotaForAccount(ctx context.Context, account *Account) (GeminiDailyQuota, bool) { + if account == nil || !account.IsGeminiCodeAssist() { + return GeminiDailyQuota{}, false + } + policy := s.Policy(ctx) + return policy.QuotaForTier(account.GeminiTierID()) +} + +func (s *GeminiQuotaService) CooldownForTier(ctx context.Context, tierID string) time.Duration { + policy := s.Policy(ctx) + return policy.CooldownForTier(tierID) +} + +func newGeminiQuotaPolicy() *GeminiQuotaPolicy { + return &GeminiQuotaPolicy{ + tiers: map[string]GeminiTierPolicy{ + "LEGACY": {Quota: GeminiDailyQuota{ProRPD: 50, FlashRPD: 1500}, Cooldown: 30 * time.Minute}, + "PRO": {Quota: GeminiDailyQuota{ProRPD: 1500, FlashRPD: 4000}, Cooldown: 5 * time.Minute}, + "ULTRA": {Quota: GeminiDailyQuota{ProRPD: 2000, FlashRPD: 0}, Cooldown: 5 * time.Minute}, + }, + } +} + +func (p *GeminiQuotaPolicy) ApplyOverrides(tiers map[string]config.GeminiTierQuotaConfig) { + if p == nil || len(tiers) == 0 { + return + } + for rawID, override := range tiers { + tierID := normalizeGeminiTierID(rawID) + if tierID == "" { + continue + } + policy, ok := p.tiers[tierID] + if !ok { + policy = GeminiTierPolicy{Cooldown: 5 * time.Minute} + } + if override.ProRPD != nil { + policy.Quota.ProRPD = clampGeminiQuotaInt64(*override.ProRPD) + } + if override.FlashRPD != nil { + policy.Quota.FlashRPD = clampGeminiQuotaInt64(*override.FlashRPD) + } + if override.CooldownMinutes != nil { + minutes := clampGeminiQuotaInt(*override.CooldownMinutes) + policy.Cooldown = time.Duration(minutes) * time.Minute + } + p.tiers[tierID] = policy + } +} + +func (p *GeminiQuotaPolicy) QuotaForTier(tierID string) (GeminiDailyQuota, bool) { + policy, ok := p.policyForTier(tierID) + if !ok { + return GeminiDailyQuota{}, false + } + return policy.Quota, true +} + +func (p *GeminiQuotaPolicy) CooldownForTier(tierID string) time.Duration { + policy, ok := p.policyForTier(tierID) + if ok && policy.Cooldown > 0 { + return policy.Cooldown + } + return 5 * time.Minute +} + +func (p *GeminiQuotaPolicy) policyForTier(tierID string) (GeminiTierPolicy, bool) { + if p == nil { + return GeminiTierPolicy{}, false + } + normalized := normalizeGeminiTierID(tierID) + if normalized == "" { + normalized = "LEGACY" + } + if policy, ok := p.tiers[normalized]; ok { + return policy, true + } + policy, ok := p.tiers["LEGACY"] + return policy, ok +} + +func normalizeGeminiTierID(tierID string) string { + return strings.ToUpper(strings.TrimSpace(tierID)) +} + +func clampGeminiQuotaInt64(value int64) int64 { + if value < 0 { + return 0 + } + return value +} + +func clampGeminiQuotaInt(value int) int { + if value < 0 { + return 0 + } + return value +} + +func geminiCooldownForTier(tierID string) time.Duration { + policy := newGeminiQuotaPolicy() + return policy.CooldownForTier(tierID) +} + +func geminiModelClassFromName(model string) geminiModelClass { + name := strings.ToLower(strings.TrimSpace(model)) + if strings.Contains(name, "flash") || strings.Contains(name, "lite") { + return geminiModelFlash + } + return geminiModelPro +} + +func geminiAggregateUsage(stats []usagestats.ModelStat) GeminiUsageTotals { + var totals GeminiUsageTotals + for _, stat := range stats { + switch geminiModelClassFromName(stat.Model) { + case geminiModelFlash: + totals.FlashRequests += stat.Requests + totals.FlashTokens += stat.TotalTokens + totals.FlashCost += stat.ActualCost + default: + totals.ProRequests += stat.Requests + totals.ProTokens += stat.TotalTokens + totals.ProCost += stat.ActualCost + } + } + return totals +} + +func geminiQuotaLocation() *time.Location { + loc, err := time.LoadLocation("America/Los_Angeles") + if err != nil { + return time.FixedZone("PST", -8*3600) + } + return loc +} + +func geminiDailyWindowStart(now time.Time) time.Time { + loc := geminiQuotaLocation() + localNow := now.In(loc) + return time.Date(localNow.Year(), localNow.Month(), localNow.Day(), 0, 0, 0, 0, loc) +} + +func geminiDailyResetTime(now time.Time) time.Time { + loc := geminiQuotaLocation() + localNow := now.In(loc) + start := time.Date(localNow.Year(), localNow.Month(), localNow.Day(), 0, 0, 0, 0, loc) + reset := start.Add(24 * time.Hour) + if !reset.After(localNow) { + reset = reset.Add(24 * time.Hour) + } + return reset +} diff --git a/backend/internal/service/ratelimit_service.go b/backend/internal/service/ratelimit_service.go index 1474ae46..ea83723d 100644 --- a/backend/internal/service/ratelimit_service.go +++ b/backend/internal/service/ratelimit_service.go @@ -5,6 +5,7 @@ import ( "log" "net/http" "strconv" + "strings" "time" "github.com/Wei-Shaw/sub2api/internal/config" @@ -12,15 +13,19 @@ import ( // RateLimitService 处理限流和过载状态管理 type RateLimitService struct { - accountRepo AccountRepository - cfg *config.Config + accountRepo AccountRepository + usageRepo UsageLogRepository + cfg *config.Config + geminiQuotaService *GeminiQuotaService } // NewRateLimitService 创建RateLimitService实例 -func NewRateLimitService(accountRepo AccountRepository, cfg *config.Config) *RateLimitService { +func NewRateLimitService(accountRepo AccountRepository, usageRepo UsageLogRepository, cfg *config.Config, geminiQuotaService *GeminiQuotaService) *RateLimitService { return &RateLimitService{ - accountRepo: accountRepo, - cfg: cfg, + accountRepo: accountRepo, + usageRepo: usageRepo, + cfg: cfg, + geminiQuotaService: geminiQuotaService, } } @@ -62,6 +67,68 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc } } +// PreCheckUsage proactively checks local quota before dispatching a request. +// Returns false when the account should be skipped. +func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, requestedModel string) (bool, error) { + if account == nil || !account.IsGeminiCodeAssist() || strings.TrimSpace(requestedModel) == "" { + return true, nil + } + if s.usageRepo == nil { + return true, nil + } + + quota, ok := s.geminiQuotaService.QuotaForAccount(ctx, account) + if !ok { + return true, nil + } + + var limit int64 + switch geminiModelClassFromName(requestedModel) { + case geminiModelFlash: + limit = quota.FlashRPD + default: + limit = quota.ProRPD + } + if limit <= 0 { + return true, nil + } + + now := time.Now() + start := geminiDailyWindowStart(now) + stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID) + if err != nil { + return true, err + } + totals := geminiAggregateUsage(stats) + + var used int64 + switch geminiModelClassFromName(requestedModel) { + case geminiModelFlash: + used = totals.FlashRequests + default: + used = totals.ProRequests + } + + if used >= limit { + resetAt := geminiDailyResetTime(now) + if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil { + log.Printf("SetRateLimited failed for account %d: %v", account.ID, err) + } + log.Printf("[Gemini PreCheck] Account %d reached daily quota (%d/%d), rate limited until %v", account.ID, used, limit, resetAt) + return false, nil + } + + return true, nil +} + +// GeminiCooldown returns the fallback cooldown duration for Gemini 429s based on tier. +func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account) time.Duration { + if account == nil { + return 5 * time.Minute + } + return s.geminiQuotaService.CooldownForTier(ctx, account.GeminiTierID()) +} + // handleAuthError 处理认证类错误(401/403),停止账号调度 func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) { if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil { diff --git a/backend/internal/service/wire.go b/backend/internal/service/wire.go index 81e01d47..9843bb91 100644 --- a/backend/internal/service/wire.go +++ b/backend/internal/service/wire.go @@ -94,6 +94,7 @@ var ProviderSet = wire.NewSet( NewOAuthService, NewOpenAIOAuthService, NewGeminiOAuthService, + NewGeminiQuotaService, NewAntigravityOAuthService, NewGeminiTokenProvider, NewGeminiMessagesCompatService, diff --git a/deploy/.env.example b/deploy/.env.example index 19fcc853..ffea8be4 100644 --- a/deploy/.env.example +++ b/deploy/.env.example @@ -86,3 +86,11 @@ GEMINI_OAUTH_CLIENT_ID= GEMINI_OAUTH_CLIENT_SECRET= # Optional; leave empty to auto-select scopes based on oauth_type GEMINI_OAUTH_SCOPES= + +# ----------------------------------------------------------------------------- +# Gemini Quota Policy (OPTIONAL, local simulation) +# ----------------------------------------------------------------------------- +# JSON overrides for local quota simulation (Code Assist only). +# Example: +# GEMINI_QUOTA_POLICY={"tiers":{"LEGACY":{"pro_rpd":50,"flash_rpd":1500,"cooldown_minutes":30},"PRO":{"pro_rpd":1500,"flash_rpd":4000,"cooldown_minutes":5},"ULTRA":{"pro_rpd":2000,"flash_rpd":0,"cooldown_minutes":5}}} +GEMINI_QUOTA_POLICY= diff --git a/deploy/README.md b/deploy/README.md index 24b6d067..f697247d 100644 --- a/deploy/README.md +++ b/deploy/README.md @@ -123,6 +123,7 @@ docker-compose down -v | `GEMINI_OAUTH_CLIENT_ID` | No | *(builtin)* | Google OAuth client ID (Gemini OAuth). Leave empty to use the built-in Gemini CLI client. | | `GEMINI_OAUTH_CLIENT_SECRET` | No | *(builtin)* | Google OAuth client secret (Gemini OAuth). Leave empty to use the built-in Gemini CLI client. | | `GEMINI_OAUTH_SCOPES` | No | *(default)* | OAuth scopes (Gemini OAuth) | +| `GEMINI_QUOTA_POLICY` | No | *(empty)* | JSON overrides for Gemini local quota simulation (Code Assist only). | See `.env.example` for all available options. diff --git a/deploy/config.example.yaml b/deploy/config.example.yaml index 5478d151..f07e893c 100644 --- a/deploy/config.example.yaml +++ b/deploy/config.example.yaml @@ -156,3 +156,19 @@ gemini: client_secret: "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" # Optional scopes (space-separated). Leave empty to auto-select based on oauth_type. scopes: "" + quota: + # Optional: local quota simulation for Gemini Code Assist (local billing). + # These values are used for UI progress + precheck scheduling, not official Google quotas. + tiers: + LEGACY: + pro_rpd: 50 + flash_rpd: 1500 + cooldown_minutes: 30 + PRO: + pro_rpd: 1500 + flash_rpd: 4000 + cooldown_minutes: 5 + ULTRA: + pro_rpd: 2000 + flash_rpd: 0 + cooldown_minutes: 5 diff --git a/deploy/docker-compose-test.yml b/deploy/docker-compose-test.yml index 35aa553b..b73d4a26 100644 --- a/deploy/docker-compose-test.yml +++ b/deploy/docker-compose-test.yml @@ -90,6 +90,7 @@ services: - GEMINI_OAUTH_CLIENT_ID=${GEMINI_OAUTH_CLIENT_ID:-} - GEMINI_OAUTH_CLIENT_SECRET=${GEMINI_OAUTH_CLIENT_SECRET:-} - GEMINI_OAUTH_SCOPES=${GEMINI_OAUTH_SCOPES:-} + - GEMINI_QUOTA_POLICY=${GEMINI_QUOTA_POLICY:-} depends_on: postgres: condition: service_healthy diff --git a/deploy/docker-compose.override.yml b/deploy/docker-compose.override.yml new file mode 100644 index 00000000..d877ff50 --- /dev/null +++ b/deploy/docker-compose.override.yml @@ -0,0 +1,21 @@ +# ============================================================================= +# Docker Compose Override for Local Development +# ============================================================================= +# This file automatically extends docker-compose-test.yml +# Usage: docker-compose -f docker-compose-test.yml up -d +# ============================================================================= + +services: + # =========================================================================== + # PostgreSQL - 暴露端口用于本地开发 + # =========================================================================== + postgres: + ports: + - "127.0.0.1:5432:5432" + + # =========================================================================== + # Redis - 暴露端口用于本地开发 + # =========================================================================== + redis: + ports: + - "127.0.0.1:6379:6379" diff --git a/deploy/docker-compose.yml b/deploy/docker-compose.yml index 45b3796b..fbd79710 100644 --- a/deploy/docker-compose.yml +++ b/deploy/docker-compose.yml @@ -90,6 +90,7 @@ services: - GEMINI_OAUTH_CLIENT_ID=${GEMINI_OAUTH_CLIENT_ID:-} - GEMINI_OAUTH_CLIENT_SECRET=${GEMINI_OAUTH_CLIENT_SECRET:-} - GEMINI_OAUTH_SCOPES=${GEMINI_OAUTH_SCOPES:-} + - GEMINI_QUOTA_POLICY=${GEMINI_QUOTA_POLICY:-} depends_on: postgres: condition: service_healthy