Merge branch 'feature/gemini-quota' (PR #113)

feat: Gemini 配额模拟和限流功能

主要变更:
- 新增 GeminiQuotaService 实现基于 Tier 的配额管理
- RateLimitService 增加 PreCheckUsage 预检查功能
- gemini_oauth_service 改进 tier_id 处理逻辑(向后兼容)
- 前端新增配额可视化组件 (AccountQuotaInfo.vue)
- 数据库迁移: 为现有 Code Assist 账号添加默认 tier_id

技术细节:
- 支持 LEGACY/PRO/ULTRA 三种配额等级
- 配额策略可通过配置文件或数据库设置覆盖
- fetchProjectID 返回值保留 tierID(即使 projectID 获取失败)
- 删除冗余类型别名 ClaudeCustomToolSpec
This commit is contained in:
shaw
2026-01-01 10:58:11 +08:00
30 changed files with 1802 additions and 105 deletions

View File

@@ -87,9 +87,10 @@ func initializeApplication(buildInfo handler.BuildInfo) (*Application, error) {
geminiOAuthClient := repository.NewGeminiOAuthClient(configConfig)
geminiCliCodeAssistClient := repository.NewGeminiCliCodeAssistClient()
geminiOAuthService := service.NewGeminiOAuthService(proxyRepository, geminiOAuthClient, geminiCliCodeAssistClient, configConfig)
rateLimitService := service.NewRateLimitService(accountRepository, configConfig)
geminiQuotaService := service.NewGeminiQuotaService(configConfig, settingRepository)
rateLimitService := service.NewRateLimitService(accountRepository, usageLogRepository, configConfig, geminiQuotaService)
claudeUsageFetcher := repository.NewClaudeUsageFetcher()
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher)
accountUsageService := service.NewAccountUsageService(accountRepository, usageLogRepository, claudeUsageFetcher, geminiQuotaService)
geminiTokenCache := repository.NewGeminiTokenCache(redisClient)
geminiTokenProvider := service.NewGeminiTokenProvider(accountRepository, geminiTokenCache, geminiOAuthService)
gatewayCache := repository.NewGatewayCache(redisClient)

View File

@@ -43,6 +43,7 @@ type Config struct {
type GeminiConfig struct {
OAuth GeminiOAuthConfig `mapstructure:"oauth"`
Quota GeminiQuotaConfig `mapstructure:"quota"`
}
type GeminiOAuthConfig struct {
@@ -51,6 +52,17 @@ type GeminiOAuthConfig struct {
Scopes string `mapstructure:"scopes"`
}
type GeminiQuotaConfig struct {
Tiers map[string]GeminiTierQuotaConfig `mapstructure:"tiers"`
Policy string `mapstructure:"policy"`
}
type GeminiTierQuotaConfig struct {
ProRPD *int64 `mapstructure:"pro_rpd" json:"pro_rpd"`
FlashRPD *int64 `mapstructure:"flash_rpd" json:"flash_rpd"`
CooldownMinutes *int `mapstructure:"cooldown_minutes" json:"cooldown_minutes"`
}
// TokenRefreshConfig OAuth token自动刷新配置
type TokenRefreshConfig struct {
// 是否启用自动刷新
@@ -352,6 +364,7 @@ func setDefaults() {
viper.SetDefault("gemini.oauth.client_id", "")
viper.SetDefault("gemini.oauth.client_secret", "")
viper.SetDefault("gemini.oauth.scopes", "")
viper.SetDefault("gemini.quota.policy", "")
}
func (c *Config) Validate() error {

View File

@@ -96,7 +96,7 @@ func TestBuildTools_CustomTypeTools(t *testing.T) {
{
Type: "custom",
Name: "mcp_tool",
Custom: &ClaudeCustomToolSpec{
Custom: &CustomToolSpec{
Description: "MCP tool description",
InputSchema: map[string]any{
"type": "object",
@@ -121,7 +121,7 @@ func TestBuildTools_CustomTypeTools(t *testing.T) {
{
Type: "custom",
Name: "custom_tool",
Custom: &ClaudeCustomToolSpec{
Custom: &CustomToolSpec{
Description: "Custom tool",
InputSchema: map[string]any{"type": "object"},
},
@@ -148,7 +148,7 @@ func TestBuildTools_CustomTypeTools(t *testing.T) {
{
Type: "custom",
Name: "invalid_custom",
Custom: &ClaudeCustomToolSpec{
Custom: &CustomToolSpec{
Description: "Invalid",
// InputSchema 为 nil
},

View File

@@ -127,7 +127,15 @@ func applyMigrationsFS(ctx context.Context, db *sql.DB, fsys fs.FS) error {
if existing != checksum {
// 校验和不匹配意味着迁移文件在应用后被修改,这是危险的。
// 正确的做法是创建新的迁移文件来进行变更。
return fmt.Errorf("migration %s checksum mismatch (db=%s file=%s)", name, existing, checksum)
return fmt.Errorf(
"migration %s checksum mismatch (db=%s file=%s)\n"+
"This means the migration file was modified after being applied to the database.\n"+
"Solutions:\n"+
" 1. Revert to original: git log --oneline -- migrations/%s && git checkout <commit> -- migrations/%s\n"+
" 2. For new changes, create a new migration file instead of modifying existing ones\n"+
"Note: Modifying applied migrations breaks the immutability principle and can cause inconsistencies across environments",
name, existing, checksum, name, name,
)
}
continue // 迁移已应用且校验和匹配,跳过
}

View File

@@ -3,6 +3,7 @@ package service
import (
"encoding/json"
"strconv"
"strings"
"time"
)
@@ -78,6 +79,36 @@ func (a *Account) IsGemini() bool {
return a.Platform == PlatformGemini
}
func (a *Account) GeminiOAuthType() string {
if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth {
return ""
}
oauthType := strings.TrimSpace(a.GetCredential("oauth_type"))
if oauthType == "" && strings.TrimSpace(a.GetCredential("project_id")) != "" {
return "code_assist"
}
return oauthType
}
func (a *Account) GeminiTierID() string {
tierID := strings.TrimSpace(a.GetCredential("tier_id"))
if tierID == "" {
return ""
}
return strings.ToUpper(tierID)
}
func (a *Account) IsGeminiCodeAssist() bool {
if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth {
return false
}
oauthType := a.GeminiOAuthType()
if oauthType == "" {
return strings.TrimSpace(a.GetCredential("project_id")) != ""
}
return oauthType == "code_assist"
}
func (a *Account) CanGetUsage() bool {
return a.Type == AccountTypeOAuth
}

View File

@@ -93,10 +93,12 @@ type UsageProgress struct {
// UsageInfo 账号使用量信息
type UsageInfo struct {
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额
GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额
}
// ClaudeUsageResponse Anthropic API返回的usage结构
@@ -122,17 +124,19 @@ type ClaudeUsageFetcher interface {
// AccountUsageService 账号使用量查询服务
type AccountUsageService struct {
accountRepo AccountRepository
usageLogRepo UsageLogRepository
usageFetcher ClaudeUsageFetcher
accountRepo AccountRepository
usageLogRepo UsageLogRepository
usageFetcher ClaudeUsageFetcher
geminiQuotaService *GeminiQuotaService
}
// NewAccountUsageService 创建AccountUsageService实例
func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLogRepository, usageFetcher ClaudeUsageFetcher) *AccountUsageService {
func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLogRepository, usageFetcher ClaudeUsageFetcher, geminiQuotaService *GeminiQuotaService) *AccountUsageService {
return &AccountUsageService{
accountRepo: accountRepo,
usageLogRepo: usageLogRepo,
usageFetcher: usageFetcher,
accountRepo: accountRepo,
usageLogRepo: usageLogRepo,
usageFetcher: usageFetcher,
geminiQuotaService: geminiQuotaService,
}
}
@@ -146,6 +150,10 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return nil, fmt.Errorf("get account failed: %w", err)
}
if account.Platform == PlatformGemini {
return s.getGeminiUsage(ctx, account)
}
// 只有oauth类型账号可以通过API获取usage有profile scope
if account.CanGetUsage() {
var apiResp *ClaudeUsageResponse
@@ -192,6 +200,36 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return nil, fmt.Errorf("account type %s does not support usage query", account.Type)
}
func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
now := time.Now()
usage := &UsageInfo{
UpdatedAt: &now,
}
if s.geminiQuotaService == nil || s.usageLogRepo == nil {
return usage, nil
}
quota, ok := s.geminiQuotaService.QuotaForAccount(ctx, account)
if !ok {
return usage, nil
}
start := geminiDailyWindowStart(now)
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
if err != nil {
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
}
totals := geminiAggregateUsage(stats)
resetAt := geminiDailyResetTime(now)
usage.GeminiProDaily = buildGeminiUsageProgress(totals.ProRequests, quota.ProRPD, resetAt, totals.ProTokens, totals.ProCost, now)
usage.GeminiFlashDaily = buildGeminiUsageProgress(totals.FlashRequests, quota.FlashRPD, resetAt, totals.FlashTokens, totals.FlashCost, now)
return usage, nil
}
// addWindowStats 为 usage 数据添加窗口期统计
// 使用独立缓存1 分钟),与 API 缓存分离
func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) {
@@ -388,3 +426,25 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn
// Setup Token无法获取7d数据
return info
}
func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64, cost float64, now time.Time) *UsageProgress {
if limit <= 0 {
return nil
}
utilization := (float64(used) / float64(limit)) * 100
remainingSeconds := int(resetAt.Sub(now).Seconds())
if remainingSeconds < 0 {
remainingSeconds = 0
}
resetCopy := resetAt
return &UsageProgress{
Utilization: utilization,
ResetsAt: &resetCopy,
RemainingSeconds: remainingSeconds,
WindowStats: &WindowStats{
Requests: used,
Tokens: tokens,
Cost: cost,
},
}
}

View File

@@ -91,6 +91,9 @@ const (
// 管理员 API Key
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key用于外部系统集成
// Gemini 配额策略JSON
SettingKeyGeminiQuotaPolicy = "gemini_quota_policy"
)
// Admin API Key prefix (distinct from user "sk-" keys)

View File

@@ -116,8 +116,20 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
valid = true
}
if valid {
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
return account, nil
usable := true
if s.rateLimitService != nil && requestedModel != "" {
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
if err != nil {
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
}
if !ok {
usable = false
}
}
if usable {
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
return account, nil
}
}
}
}
@@ -157,6 +169,15 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
if s.rateLimitService != nil && requestedModel != "" {
ok, err := s.rateLimitService.PreCheckUsage(ctx, acc, requestedModel)
if err != nil {
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", acc.ID, err)
}
if !ok {
continue
}
}
if selected == nil {
selected = acc
continue
@@ -1886,13 +1907,44 @@ func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Cont
if statusCode != 429 {
return
}
oauthType := account.GeminiOAuthType()
tierID := account.GeminiTierID()
projectID := strings.TrimSpace(account.GetCredential("project_id"))
isCodeAssist := account.IsGeminiCodeAssist()
resetAt := ParseGeminiRateLimitResetTime(body)
if resetAt == nil {
ra := time.Now().Add(5 * time.Minute)
// 根据账号类型使用不同的默认重置时间
var ra time.Time
if isCodeAssist {
// Code Assist: fallback cooldown by tier
cooldown := geminiCooldownForTier(tierID)
if s.rateLimitService != nil {
cooldown = s.rateLimitService.GeminiCooldown(ctx, account)
}
ra = time.Now().Add(cooldown)
log.Printf("[Gemini 429] Account %d (Code Assist, tier=%s, project=%s) rate limited, cooldown=%v", account.ID, tierID, projectID, time.Until(ra).Truncate(time.Second))
} else {
// API Key / AI Studio OAuth: PST 午夜
if ts := nextGeminiDailyResetUnix(); ts != nil {
ra = time.Unix(*ts, 0)
log.Printf("[Gemini 429] Account %d (API Key/AI Studio, type=%s) rate limited, reset at PST midnight (%v)", account.ID, account.Type, ra)
} else {
// 兜底5 分钟
ra = time.Now().Add(5 * time.Minute)
log.Printf("[Gemini 429] Account %d rate limited, fallback to 5min", account.ID)
}
}
_ = s.accountRepo.SetRateLimited(ctx, account.ID, ra)
return
}
_ = s.accountRepo.SetRateLimited(ctx, account.ID, time.Unix(*resetAt, 0))
// 使用解析到的重置时间
resetTime := time.Unix(*resetAt, 0)
_ = s.accountRepo.SetRateLimited(ctx, account.ID, resetTime)
log.Printf("[Gemini 429] Account %d rate limited until %v (oauth_type=%s, tier=%s)",
account.ID, resetTime, oauthType, tierID)
}
// ParseGeminiRateLimitResetTime 解析 Gemini 格式的 429 响应,返回重置时间的 Unix 时间戳
@@ -1948,16 +2000,7 @@ func looksLikeGeminiDailyQuota(message string) bool {
}
func nextGeminiDailyResetUnix() *int64 {
loc, err := time.LoadLocation("America/Los_Angeles")
if err != nil {
// Fallback: PST without DST.
loc = time.FixedZone("PST", -8*3600)
}
now := time.Now().In(loc)
reset := time.Date(now.Year(), now.Month(), now.Day(), 0, 5, 0, 0, loc)
if !reset.After(now) {
reset = reset.Add(24 * time.Hour)
}
reset := geminiDailyResetTime(time.Now())
ts := reset.Unix()
return &ts
}

View File

@@ -259,8 +259,15 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
sessionProjectID := strings.TrimSpace(session.ProjectID)
s.sessionStore.Delete(input.SessionID)
// 计算过期时间减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
// 计算过期时间减去 5 分钟安全时间窗口考虑网络延迟和时钟偏差
// 同时设置下界保护,防止 expires_in 过小导致过去时间(引发刷新风暴)
const safetyWindow = 300 // 5 minutes
const minTTL = 30 // minimum 30 seconds
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - safetyWindow
minExpiresAt := time.Now().Unix() + minTTL
if expiresAt < minExpiresAt {
expiresAt = minExpiresAt
}
projectID := sessionProjectID
var tierID string
@@ -275,10 +282,22 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
// 记录警告但不阻断流程,允许后续补充 project_id
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err)
}
} else {
// 用户手动填了 project_id仍需调用 LoadCodeAssist 获取 tierID
_, fetchedTierID, err := s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
if err != nil {
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch tierID: %v\n", err)
} else {
tierID = fetchedTierID
}
}
if strings.TrimSpace(projectID) == "" {
return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project")
}
// tierID 缺失时使用默认值
if tierID == "" {
tierID = "LEGACY"
}
}
return &GeminiTokenInfo{
@@ -308,8 +327,15 @@ func (s *GeminiOAuthService) RefreshToken(ctx context.Context, oauthType, refres
tokenResp, err := s.oauthClient.RefreshToken(ctx, oauthType, refreshToken, proxyURL)
if err == nil {
// 计算过期时间减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
// 计算过期时间减去 5 分钟安全时间窗口考虑网络延迟和时钟偏差
// 同时设置下界保护,防止 expires_in 过小导致过去时间(引发刷新风暴)
const safetyWindow = 300 // 5 minutes
const minTTL = 30 // minimum 30 seconds
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - safetyWindow
minExpiresAt := time.Now().Unix() + minTTL
if expiresAt < minExpiresAt {
expiresAt = minExpiresAt
}
return &GeminiTokenInfo{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
@@ -396,19 +422,39 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
tokenInfo.ProjectID = existingProjectID
}
// 尝试从账号凭证获取 tierID向后兼容
existingTierID := strings.TrimSpace(account.GetCredential("tier_id"))
// For Code Assist, project_id is required. Auto-detect if missing.
// For AI Studio OAuth, project_id is optional and should not block refresh.
if oauthType == "code_assist" && strings.TrimSpace(tokenInfo.ProjectID) == "" {
projectID, tierID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL)
if err != nil {
return nil, fmt.Errorf("failed to auto-detect project_id: %w", err)
if oauthType == "code_assist" {
// 先设置默认值或保留旧值,确保 tier_id 始终有值
if existingTierID != "" {
tokenInfo.TierID = existingTierID
} else {
tokenInfo.TierID = "LEGACY" // 默认值
}
projectID = strings.TrimSpace(projectID)
if projectID == "" {
// 尝试自动探测 project_id 和 tier_id
needDetect := strings.TrimSpace(tokenInfo.ProjectID) == "" || existingTierID == ""
if needDetect {
projectID, tierID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL)
if err != nil {
fmt.Printf("[GeminiOAuth] Warning: failed to auto-detect project/tier: %v\n", err)
} else {
if strings.TrimSpace(tokenInfo.ProjectID) == "" && projectID != "" {
tokenInfo.ProjectID = projectID
}
// 只有当原来没有 tier_id 且探测成功时才更新
if existingTierID == "" && tierID != "" {
tokenInfo.TierID = tierID
}
}
}
if strings.TrimSpace(tokenInfo.ProjectID) == "" {
return nil, fmt.Errorf("failed to auto-detect project_id: empty result")
}
tokenInfo.ProjectID = projectID
tokenInfo.TierID = tierID
}
return tokenInfo, nil
@@ -466,9 +512,6 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
return strings.TrimSpace(loadResp.CloudAICompanionProject), tierID, nil
}
// Pick tier from allowedTiers; if no default tier is marked, pick the first non-empty tier ID.
// (tierID already extracted above, reuse it)
req := &geminicli.OnboardUserRequest{
TierID: tierID,
Metadata: geminicli.LoadCodeAssistMetadata{
@@ -487,7 +530,7 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), tierID, nil
}
return "", "", err
return "", tierID, err
}
if resp.Done {
if resp.Response != nil && resp.Response.CloudAICompanionProject != nil {
@@ -505,7 +548,7 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), tierID, nil
}
return "", "", errors.New("onboardUser completed but no project_id returned")
return "", tierID, errors.New("onboardUser completed but no project_id returned")
}
time.Sleep(2 * time.Second)
}
@@ -515,9 +558,9 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
return strings.TrimSpace(fallback), tierID, nil
}
if loadErr != nil {
return "", "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts)
return "", tierID, fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts)
}
return "", "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts)
return "", tierID, fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts)
}
type googleCloudProject struct {

View File

@@ -0,0 +1,268 @@
package service
import (
"context"
"encoding/json"
"errors"
"log"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
)
type geminiModelClass string
const (
geminiModelPro geminiModelClass = "pro"
geminiModelFlash geminiModelClass = "flash"
)
type GeminiDailyQuota struct {
ProRPD int64
FlashRPD int64
}
type GeminiTierPolicy struct {
Quota GeminiDailyQuota
Cooldown time.Duration
}
type GeminiQuotaPolicy struct {
tiers map[string]GeminiTierPolicy
}
type GeminiUsageTotals struct {
ProRequests int64
FlashRequests int64
ProTokens int64
FlashTokens int64
ProCost float64
FlashCost float64
}
const geminiQuotaCacheTTL = time.Minute
type geminiQuotaOverrides struct {
Tiers map[string]config.GeminiTierQuotaConfig `json:"tiers"`
}
type GeminiQuotaService struct {
cfg *config.Config
settingRepo SettingRepository
mu sync.Mutex
cachedAt time.Time
policy *GeminiQuotaPolicy
}
func NewGeminiQuotaService(cfg *config.Config, settingRepo SettingRepository) *GeminiQuotaService {
return &GeminiQuotaService{
cfg: cfg,
settingRepo: settingRepo,
}
}
func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
if s == nil {
return newGeminiQuotaPolicy()
}
now := time.Now()
s.mu.Lock()
if s.policy != nil && now.Sub(s.cachedAt) < geminiQuotaCacheTTL {
policy := s.policy
s.mu.Unlock()
return policy
}
s.mu.Unlock()
policy := newGeminiQuotaPolicy()
if s.cfg != nil {
policy.ApplyOverrides(s.cfg.Gemini.Quota.Tiers)
if strings.TrimSpace(s.cfg.Gemini.Quota.Policy) != "" {
var overrides geminiQuotaOverrides
if err := json.Unmarshal([]byte(s.cfg.Gemini.Quota.Policy), &overrides); err != nil {
log.Printf("gemini quota: parse config policy failed: %v", err)
} else {
policy.ApplyOverrides(overrides.Tiers)
}
}
}
if s.settingRepo != nil {
value, err := s.settingRepo.GetValue(ctx, SettingKeyGeminiQuotaPolicy)
if err != nil && !errors.Is(err, ErrSettingNotFound) {
log.Printf("gemini quota: load setting failed: %v", err)
} else if strings.TrimSpace(value) != "" {
var overrides geminiQuotaOverrides
if err := json.Unmarshal([]byte(value), &overrides); err != nil {
log.Printf("gemini quota: parse setting failed: %v", err)
} else {
policy.ApplyOverrides(overrides.Tiers)
}
}
}
s.mu.Lock()
s.policy = policy
s.cachedAt = now
s.mu.Unlock()
return policy
}
func (s *GeminiQuotaService) QuotaForAccount(ctx context.Context, account *Account) (GeminiDailyQuota, bool) {
if account == nil || !account.IsGeminiCodeAssist() {
return GeminiDailyQuota{}, false
}
policy := s.Policy(ctx)
return policy.QuotaForTier(account.GeminiTierID())
}
func (s *GeminiQuotaService) CooldownForTier(ctx context.Context, tierID string) time.Duration {
policy := s.Policy(ctx)
return policy.CooldownForTier(tierID)
}
func newGeminiQuotaPolicy() *GeminiQuotaPolicy {
return &GeminiQuotaPolicy{
tiers: map[string]GeminiTierPolicy{
"LEGACY": {Quota: GeminiDailyQuota{ProRPD: 50, FlashRPD: 1500}, Cooldown: 30 * time.Minute},
"PRO": {Quota: GeminiDailyQuota{ProRPD: 1500, FlashRPD: 4000}, Cooldown: 5 * time.Minute},
"ULTRA": {Quota: GeminiDailyQuota{ProRPD: 2000, FlashRPD: 0}, Cooldown: 5 * time.Minute},
},
}
}
func (p *GeminiQuotaPolicy) ApplyOverrides(tiers map[string]config.GeminiTierQuotaConfig) {
if p == nil || len(tiers) == 0 {
return
}
for rawID, override := range tiers {
tierID := normalizeGeminiTierID(rawID)
if tierID == "" {
continue
}
policy, ok := p.tiers[tierID]
if !ok {
policy = GeminiTierPolicy{Cooldown: 5 * time.Minute}
}
if override.ProRPD != nil {
policy.Quota.ProRPD = clampGeminiQuotaInt64(*override.ProRPD)
}
if override.FlashRPD != nil {
policy.Quota.FlashRPD = clampGeminiQuotaInt64(*override.FlashRPD)
}
if override.CooldownMinutes != nil {
minutes := clampGeminiQuotaInt(*override.CooldownMinutes)
policy.Cooldown = time.Duration(minutes) * time.Minute
}
p.tiers[tierID] = policy
}
}
func (p *GeminiQuotaPolicy) QuotaForTier(tierID string) (GeminiDailyQuota, bool) {
policy, ok := p.policyForTier(tierID)
if !ok {
return GeminiDailyQuota{}, false
}
return policy.Quota, true
}
func (p *GeminiQuotaPolicy) CooldownForTier(tierID string) time.Duration {
policy, ok := p.policyForTier(tierID)
if ok && policy.Cooldown > 0 {
return policy.Cooldown
}
return 5 * time.Minute
}
func (p *GeminiQuotaPolicy) policyForTier(tierID string) (GeminiTierPolicy, bool) {
if p == nil {
return GeminiTierPolicy{}, false
}
normalized := normalizeGeminiTierID(tierID)
if normalized == "" {
normalized = "LEGACY"
}
if policy, ok := p.tiers[normalized]; ok {
return policy, true
}
policy, ok := p.tiers["LEGACY"]
return policy, ok
}
func normalizeGeminiTierID(tierID string) string {
return strings.ToUpper(strings.TrimSpace(tierID))
}
func clampGeminiQuotaInt64(value int64) int64 {
if value < 0 {
return 0
}
return value
}
func clampGeminiQuotaInt(value int) int {
if value < 0 {
return 0
}
return value
}
func geminiCooldownForTier(tierID string) time.Duration {
policy := newGeminiQuotaPolicy()
return policy.CooldownForTier(tierID)
}
func geminiModelClassFromName(model string) geminiModelClass {
name := strings.ToLower(strings.TrimSpace(model))
if strings.Contains(name, "flash") || strings.Contains(name, "lite") {
return geminiModelFlash
}
return geminiModelPro
}
func geminiAggregateUsage(stats []usagestats.ModelStat) GeminiUsageTotals {
var totals GeminiUsageTotals
for _, stat := range stats {
switch geminiModelClassFromName(stat.Model) {
case geminiModelFlash:
totals.FlashRequests += stat.Requests
totals.FlashTokens += stat.TotalTokens
totals.FlashCost += stat.ActualCost
default:
totals.ProRequests += stat.Requests
totals.ProTokens += stat.TotalTokens
totals.ProCost += stat.ActualCost
}
}
return totals
}
func geminiQuotaLocation() *time.Location {
loc, err := time.LoadLocation("America/Los_Angeles")
if err != nil {
return time.FixedZone("PST", -8*3600)
}
return loc
}
func geminiDailyWindowStart(now time.Time) time.Time {
loc := geminiQuotaLocation()
localNow := now.In(loc)
return time.Date(localNow.Year(), localNow.Month(), localNow.Day(), 0, 0, 0, 0, loc)
}
func geminiDailyResetTime(now time.Time) time.Time {
loc := geminiQuotaLocation()
localNow := now.In(loc)
start := time.Date(localNow.Year(), localNow.Month(), localNow.Day(), 0, 0, 0, 0, loc)
reset := start.Add(24 * time.Hour)
if !reset.After(localNow) {
reset = reset.Add(24 * time.Hour)
}
return reset
}

View File

@@ -118,6 +118,7 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
return accessToken, nil
}
detected = strings.TrimSpace(detected)
tierID = strings.TrimSpace(tierID)
if detected != "" {
if account.Credentials == nil {
account.Credentials = make(map[string]any)

View File

@@ -5,6 +5,8 @@ import (
"log"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
@@ -12,15 +14,30 @@ import (
// RateLimitService 处理限流和过载状态管理
type RateLimitService struct {
accountRepo AccountRepository
cfg *config.Config
accountRepo AccountRepository
usageRepo UsageLogRepository
cfg *config.Config
geminiQuotaService *GeminiQuotaService
usageCacheMu sync.Mutex
usageCache map[int64]*geminiUsageCacheEntry
}
type geminiUsageCacheEntry struct {
windowStart time.Time
cachedAt time.Time
totals GeminiUsageTotals
}
const geminiPrecheckCacheTTL = time.Minute
// NewRateLimitService 创建RateLimitService实例
func NewRateLimitService(accountRepo AccountRepository, cfg *config.Config) *RateLimitService {
func NewRateLimitService(accountRepo AccountRepository, usageRepo UsageLogRepository, cfg *config.Config, geminiQuotaService *GeminiQuotaService) *RateLimitService {
return &RateLimitService{
accountRepo: accountRepo,
cfg: cfg,
accountRepo: accountRepo,
usageRepo: usageRepo,
cfg: cfg,
geminiQuotaService: geminiQuotaService,
usageCache: make(map[int64]*geminiUsageCacheEntry),
}
}
@@ -62,6 +79,106 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
}
}
// PreCheckUsage proactively checks local quota before dispatching a request.
// Returns false when the account should be skipped.
func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, requestedModel string) (bool, error) {
if account == nil || !account.IsGeminiCodeAssist() || strings.TrimSpace(requestedModel) == "" {
return true, nil
}
if s.usageRepo == nil || s.geminiQuotaService == nil {
return true, nil
}
quota, ok := s.geminiQuotaService.QuotaForAccount(ctx, account)
if !ok {
return true, nil
}
var limit int64
switch geminiModelClassFromName(requestedModel) {
case geminiModelFlash:
limit = quota.FlashRPD
default:
limit = quota.ProRPD
}
if limit <= 0 {
return true, nil
}
now := time.Now()
start := geminiDailyWindowStart(now)
totals, ok := s.getGeminiUsageTotals(account.ID, start, now)
if !ok {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
if err != nil {
return true, err
}
totals = geminiAggregateUsage(stats)
s.setGeminiUsageTotals(account.ID, start, now, totals)
}
var used int64
switch geminiModelClassFromName(requestedModel) {
case geminiModelFlash:
used = totals.FlashRequests
default:
used = totals.ProRequests
}
if used >= limit {
resetAt := geminiDailyResetTime(now)
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
}
log.Printf("[Gemini PreCheck] Account %d reached daily quota (%d/%d), rate limited until %v", account.ID, used, limit, resetAt)
return false, nil
}
return true, nil
}
func (s *RateLimitService) getGeminiUsageTotals(accountID int64, windowStart, now time.Time) (GeminiUsageTotals, bool) {
s.usageCacheMu.Lock()
defer s.usageCacheMu.Unlock()
if s.usageCache == nil {
return GeminiUsageTotals{}, false
}
entry, ok := s.usageCache[accountID]
if !ok || entry == nil {
return GeminiUsageTotals{}, false
}
if !entry.windowStart.Equal(windowStart) {
return GeminiUsageTotals{}, false
}
if now.Sub(entry.cachedAt) >= geminiPrecheckCacheTTL {
return GeminiUsageTotals{}, false
}
return entry.totals, true
}
func (s *RateLimitService) setGeminiUsageTotals(accountID int64, windowStart, now time.Time, totals GeminiUsageTotals) {
s.usageCacheMu.Lock()
defer s.usageCacheMu.Unlock()
if s.usageCache == nil {
s.usageCache = make(map[int64]*geminiUsageCacheEntry)
}
s.usageCache[accountID] = &geminiUsageCacheEntry{
windowStart: windowStart,
cachedAt: now,
totals: totals,
}
}
// GeminiCooldown returns the fallback cooldown duration for Gemini 429s based on tier.
func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account) time.Duration {
if account == nil {
return 5 * time.Minute
}
return s.geminiQuotaService.CooldownForTier(ctx, account.GeminiTierID())
}
// handleAuthError 处理认证类错误(401/403),停止账号调度
func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) {
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {

View File

@@ -94,6 +94,7 @@ var ProviderSet = wire.NewSet(
NewOAuthService,
NewOpenAIOAuthService,
NewGeminiOAuthService,
NewGeminiQuotaService,
NewAntigravityOAuthService,
NewGeminiTokenProvider,
NewGeminiMessagesCompatService,

View File

@@ -0,0 +1,30 @@
-- +goose Up
-- +goose StatementBegin
-- 为 Gemini Code Assist OAuth 账号添加默认 tier_id
-- 包括显式标记为 code_assist 的账号,以及 legacy 账号oauth_type 为空但 project_id 存在)
UPDATE accounts
SET credentials = jsonb_set(
credentials,
'{tier_id}',
'"LEGACY"',
true
)
WHERE platform = 'gemini'
AND type = 'oauth'
AND jsonb_typeof(credentials) = 'object'
AND credentials->>'tier_id' IS NULL
AND (
credentials->>'oauth_type' = 'code_assist'
OR (credentials->>'oauth_type' IS NULL AND credentials->>'project_id' IS NOT NULL)
);
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
-- 回滚:删除 tier_id 字段
UPDATE accounts
SET credentials = credentials - 'tier_id'
WHERE platform = 'gemini'
AND type = 'oauth'
AND credentials->>'oauth_type' = 'code_assist';
-- +goose StatementEnd

View File

@@ -0,0 +1,178 @@
# Database Migrations
## Overview
This directory contains SQL migration files for database schema changes. The migration system uses SHA256 checksums to ensure migration immutability and consistency across environments.
## Migration File Naming
Format: `NNN_description.sql`
- `NNN`: Sequential number (e.g., 001, 002, 003)
- `description`: Brief description in snake_case
Example: `017_add_gemini_tier_id.sql`
## Migration File Structure
```sql
-- +goose Up
-- +goose StatementBegin
-- Your forward migration SQL here
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
-- Your rollback migration SQL here
-- +goose StatementEnd
```
## Important Rules
### ⚠️ Immutability Principle
**Once a migration is applied to ANY environment (dev, staging, production), it MUST NOT be modified.**
Why?
- Each migration has a SHA256 checksum stored in the `schema_migrations` table
- Modifying an applied migration causes checksum mismatch errors
- Different environments would have inconsistent database states
- Breaks audit trail and reproducibility
### ✅ Correct Workflow
1. **Create new migration**
```bash
# Create new file with next sequential number
touch migrations/018_your_change.sql
```
2. **Write Up and Down migrations**
- Up: Apply the change
- Down: Revert the change (should be symmetric with Up)
3. **Test locally**
```bash
# Apply migration
make migrate-up
# Test rollback
make migrate-down
```
4. **Commit and deploy**
```bash
git add migrations/018_your_change.sql
git commit -m "feat(db): add your change"
```
### ❌ What NOT to Do
- ❌ Modify an already-applied migration file
- ❌ Delete migration files
- ❌ Change migration file names
- ❌ Reorder migration numbers
### 🔧 If You Accidentally Modified an Applied Migration
**Error message:**
```
migration 017_add_gemini_tier_id.sql checksum mismatch (db=abc123... file=def456...)
```
**Solution:**
```bash
# 1. Find the original version
git log --oneline -- migrations/017_add_gemini_tier_id.sql
# 2. Revert to the commit when it was first applied
git checkout <commit-hash> -- migrations/017_add_gemini_tier_id.sql
# 3. Create a NEW migration for your changes
touch migrations/018_your_new_change.sql
```
## Migration System Details
- **Checksum Algorithm**: SHA256 of trimmed file content
- **Tracking Table**: `schema_migrations` (filename, checksum, applied_at)
- **Runner**: `internal/repository/migrations_runner.go`
- **Auto-run**: Migrations run automatically on service startup
## Best Practices
1. **Keep migrations small and focused**
- One logical change per migration
- Easier to review and rollback
2. **Write reversible migrations**
- Always provide a working Down migration
- Test rollback before committing
3. **Use transactions**
- Wrap DDL statements in transactions when possible
- Ensures atomicity
4. **Add comments**
- Explain WHY the change is needed
- Document any special considerations
5. **Test in development first**
- Apply migration locally
- Verify data integrity
- Test rollback
## Example Migration
```sql
-- +goose Up
-- +goose StatementBegin
-- Add tier_id field to Gemini OAuth accounts for quota tracking
UPDATE accounts
SET credentials = jsonb_set(
credentials,
'{tier_id}',
'"LEGACY"',
true
)
WHERE platform = 'gemini'
AND type = 'oauth'
AND credentials->>'tier_id' IS NULL;
-- +goose StatementEnd
-- +goose Down
-- +goose StatementBegin
-- Remove tier_id field
UPDATE accounts
SET credentials = credentials - 'tier_id'
WHERE platform = 'gemini'
AND type = 'oauth'
AND credentials->>'tier_id' = 'LEGACY';
-- +goose StatementEnd
```
## Troubleshooting
### Checksum Mismatch
See "If You Accidentally Modified an Applied Migration" above.
### Migration Failed
```bash
# Check migration status
psql -d sub2api -c "SELECT * FROM schema_migrations ORDER BY applied_at DESC;"
# Manually rollback if needed (use with caution)
# Better to fix the migration and create a new one
```
### Need to Skip a Migration (Emergency Only)
```sql
-- DANGEROUS: Only use in development or with extreme caution
INSERT INTO schema_migrations (filename, checksum, applied_at)
VALUES ('NNN_migration.sql', 'calculated_checksum', NOW());
```
## References
- Migration runner: `internal/repository/migrations_runner.go`
- Goose syntax: https://github.com/pressly/goose
- PostgreSQL docs: https://www.postgresql.org/docs/