主要改动: - 后端:重构 Gemini 配额服务,支持多层级配额策略(GCP Standard/Free, Google One, AI Studio, Code Assist) - 后端:优化 OAuth 服务,增强 tier_id 识别和存储逻辑 - 后端:改进用量统计服务,支持不同平台的配额查询 - 后端:优化限流服务,增加临时解除调度状态管理 - 前端:统一四种授权方式的用量显示格式和徽标样式 - 前端:增强账户配额信息展示,支持多种配额类型 - 前端:改进创建和重新授权模态框的用户体验 - 国际化:完善中英文配额相关文案 - 移除 CHANGELOG.md 文件 测试:所有单元测试通过
449 lines
13 KiB
Go
449 lines
13 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"log"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/Wei-Shaw/sub2api/internal/config"
|
|
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
|
)
|
|
|
|
type geminiModelClass string
|
|
|
|
const (
|
|
geminiModelPro geminiModelClass = "pro"
|
|
geminiModelFlash geminiModelClass = "flash"
|
|
)
|
|
|
|
type GeminiQuota struct {
|
|
// SharedRPD is a shared requests-per-day pool across models.
|
|
// When SharedRPD > 0, callers should treat ProRPD/FlashRPD as not applicable for daily quota checks.
|
|
SharedRPD int64 `json:"shared_rpd,omitempty"`
|
|
// SharedRPM is a shared requests-per-minute pool across models.
|
|
// When SharedRPM > 0, callers should treat ProRPM/FlashRPM as not applicable for minute quota checks.
|
|
SharedRPM int64 `json:"shared_rpm,omitempty"`
|
|
|
|
// Per-model quotas (AI Studio / API key).
|
|
// A value of -1 means "unlimited" (pay-as-you-go).
|
|
ProRPD int64 `json:"pro_rpd,omitempty"`
|
|
ProRPM int64 `json:"pro_rpm,omitempty"`
|
|
FlashRPD int64 `json:"flash_rpd,omitempty"`
|
|
FlashRPM int64 `json:"flash_rpm,omitempty"`
|
|
}
|
|
|
|
type GeminiTierPolicy struct {
|
|
Quota GeminiQuota
|
|
Cooldown time.Duration
|
|
}
|
|
|
|
type GeminiQuotaPolicy struct {
|
|
tiers map[string]GeminiTierPolicy
|
|
}
|
|
|
|
type GeminiUsageTotals struct {
|
|
ProRequests int64
|
|
FlashRequests int64
|
|
ProTokens int64
|
|
FlashTokens int64
|
|
ProCost float64
|
|
FlashCost float64
|
|
}
|
|
|
|
const geminiQuotaCacheTTL = time.Minute
|
|
|
|
type geminiQuotaOverridesV1 struct {
|
|
Tiers map[string]config.GeminiTierQuotaConfig `json:"tiers"`
|
|
}
|
|
|
|
type geminiQuotaOverridesV2 struct {
|
|
QuotaRules map[string]geminiQuotaRuleOverride `json:"quota_rules"`
|
|
}
|
|
|
|
type geminiQuotaRuleOverride struct {
|
|
SharedRPD *int64 `json:"shared_rpd,omitempty"`
|
|
SharedRPM *int64 `json:"rpm,omitempty"`
|
|
GeminiPro *geminiModelQuotaOverride `json:"gemini_pro,omitempty"`
|
|
GeminiFlash *geminiModelQuotaOverride `json:"gemini_flash,omitempty"`
|
|
Desc *string `json:"desc,omitempty"`
|
|
}
|
|
|
|
type geminiModelQuotaOverride struct {
|
|
RPD *int64 `json:"rpd,omitempty"`
|
|
RPM *int64 `json:"rpm,omitempty"`
|
|
}
|
|
|
|
type GeminiQuotaService struct {
|
|
cfg *config.Config
|
|
settingRepo SettingRepository
|
|
mu sync.Mutex
|
|
cachedAt time.Time
|
|
policy *GeminiQuotaPolicy
|
|
}
|
|
|
|
func NewGeminiQuotaService(cfg *config.Config, settingRepo SettingRepository) *GeminiQuotaService {
|
|
return &GeminiQuotaService{
|
|
cfg: cfg,
|
|
settingRepo: settingRepo,
|
|
}
|
|
}
|
|
|
|
func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
|
|
if s == nil {
|
|
return newGeminiQuotaPolicy()
|
|
}
|
|
|
|
now := time.Now()
|
|
s.mu.Lock()
|
|
if s.policy != nil && now.Sub(s.cachedAt) < geminiQuotaCacheTTL {
|
|
policy := s.policy
|
|
s.mu.Unlock()
|
|
return policy
|
|
}
|
|
s.mu.Unlock()
|
|
|
|
policy := newGeminiQuotaPolicy()
|
|
if s.cfg != nil {
|
|
policy.ApplyOverrides(s.cfg.Gemini.Quota.Tiers)
|
|
if strings.TrimSpace(s.cfg.Gemini.Quota.Policy) != "" {
|
|
raw := []byte(s.cfg.Gemini.Quota.Policy)
|
|
var overridesV2 geminiQuotaOverridesV2
|
|
if err := json.Unmarshal(raw, &overridesV2); err == nil && len(overridesV2.QuotaRules) > 0 {
|
|
policy.ApplyQuotaRulesOverrides(overridesV2.QuotaRules)
|
|
} else {
|
|
var overridesV1 geminiQuotaOverridesV1
|
|
if err := json.Unmarshal(raw, &overridesV1); err != nil {
|
|
log.Printf("gemini quota: parse config policy failed: %v", err)
|
|
} else {
|
|
policy.ApplyOverrides(overridesV1.Tiers)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if s.settingRepo != nil {
|
|
value, err := s.settingRepo.GetValue(ctx, SettingKeyGeminiQuotaPolicy)
|
|
if err != nil && !errors.Is(err, ErrSettingNotFound) {
|
|
log.Printf("gemini quota: load setting failed: %v", err)
|
|
} else if strings.TrimSpace(value) != "" {
|
|
raw := []byte(value)
|
|
var overridesV2 geminiQuotaOverridesV2
|
|
if err := json.Unmarshal(raw, &overridesV2); err == nil && len(overridesV2.QuotaRules) > 0 {
|
|
policy.ApplyQuotaRulesOverrides(overridesV2.QuotaRules)
|
|
} else {
|
|
var overridesV1 geminiQuotaOverridesV1
|
|
if err := json.Unmarshal(raw, &overridesV1); err != nil {
|
|
log.Printf("gemini quota: parse setting failed: %v", err)
|
|
} else {
|
|
policy.ApplyOverrides(overridesV1.Tiers)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
s.mu.Lock()
|
|
s.policy = policy
|
|
s.cachedAt = now
|
|
s.mu.Unlock()
|
|
|
|
return policy
|
|
}
|
|
|
|
func (s *GeminiQuotaService) QuotaForAccount(ctx context.Context, account *Account) (GeminiQuota, bool) {
|
|
if account == nil || account.Platform != PlatformGemini {
|
|
return GeminiQuota{}, false
|
|
}
|
|
|
|
// Map (oauth_type + tier_id) to a canonical policy tier key.
|
|
// This keeps the policy table stable even if upstream tier_id strings vary.
|
|
tierKey := geminiQuotaTierKeyForAccount(account)
|
|
if tierKey == "" {
|
|
return GeminiQuota{}, false
|
|
}
|
|
|
|
policy := s.Policy(ctx)
|
|
return policy.QuotaForTier(tierKey)
|
|
}
|
|
|
|
func (s *GeminiQuotaService) CooldownForTier(ctx context.Context, tierID string) time.Duration {
|
|
policy := s.Policy(ctx)
|
|
return policy.CooldownForTier(tierID)
|
|
}
|
|
|
|
func (s *GeminiQuotaService) CooldownForAccount(ctx context.Context, account *Account) time.Duration {
|
|
if s == nil || account == nil || account.Platform != PlatformGemini {
|
|
return 5 * time.Minute
|
|
}
|
|
tierKey := geminiQuotaTierKeyForAccount(account)
|
|
if strings.TrimSpace(tierKey) == "" {
|
|
return 5 * time.Minute
|
|
}
|
|
return s.CooldownForTier(ctx, tierKey)
|
|
}
|
|
|
|
func newGeminiQuotaPolicy() *GeminiQuotaPolicy {
|
|
return &GeminiQuotaPolicy{
|
|
tiers: map[string]GeminiTierPolicy{
|
|
// --- AI Studio / API Key (per-model) ---
|
|
// aistudio_free:
|
|
// - gemini_pro: 50 RPD / 2 RPM
|
|
// - gemini_flash: 1500 RPD / 15 RPM
|
|
GeminiTierAIStudioFree: {Quota: GeminiQuota{ProRPD: 50, ProRPM: 2, FlashRPD: 1500, FlashRPM: 15}, Cooldown: 30 * time.Minute},
|
|
// aistudio_paid: -1 means "unlimited/pay-as-you-go" for RPD.
|
|
GeminiTierAIStudioPaid: {Quota: GeminiQuota{ProRPD: -1, ProRPM: 1000, FlashRPD: -1, FlashRPM: 2000}, Cooldown: 5 * time.Minute},
|
|
|
|
// --- Google One (shared pool) ---
|
|
GeminiTierGoogleOneFree: {Quota: GeminiQuota{SharedRPD: 1000, SharedRPM: 60}, Cooldown: 30 * time.Minute},
|
|
GeminiTierGoogleAIPro: {Quota: GeminiQuota{SharedRPD: 1500, SharedRPM: 120}, Cooldown: 5 * time.Minute},
|
|
GeminiTierGoogleAIUltra: {Quota: GeminiQuota{SharedRPD: 2000, SharedRPM: 120}, Cooldown: 5 * time.Minute},
|
|
|
|
// --- GCP Code Assist (shared pool) ---
|
|
GeminiTierGCPStandard: {Quota: GeminiQuota{SharedRPD: 1500, SharedRPM: 120}, Cooldown: 5 * time.Minute},
|
|
GeminiTierGCPEnterprise: {Quota: GeminiQuota{SharedRPD: 2000, SharedRPM: 120}, Cooldown: 5 * time.Minute},
|
|
},
|
|
}
|
|
}
|
|
|
|
func (p *GeminiQuotaPolicy) ApplyOverrides(tiers map[string]config.GeminiTierQuotaConfig) {
|
|
if p == nil || len(tiers) == 0 {
|
|
return
|
|
}
|
|
for rawID, override := range tiers {
|
|
tierID := normalizeGeminiTierID(rawID)
|
|
if tierID == "" {
|
|
continue
|
|
}
|
|
policy, ok := p.tiers[tierID]
|
|
if !ok {
|
|
policy = GeminiTierPolicy{Cooldown: 5 * time.Minute}
|
|
}
|
|
// Backward-compatible overrides:
|
|
// - If the tier uses shared quota, interpret pro_rpd as shared_rpd.
|
|
// - Otherwise apply per-model overrides.
|
|
if override.ProRPD != nil {
|
|
if policy.Quota.SharedRPD > 0 {
|
|
policy.Quota.SharedRPD = clampGeminiQuotaInt64WithUnlimited(*override.ProRPD)
|
|
} else {
|
|
policy.Quota.ProRPD = clampGeminiQuotaInt64WithUnlimited(*override.ProRPD)
|
|
}
|
|
}
|
|
if override.FlashRPD != nil {
|
|
if policy.Quota.SharedRPD > 0 {
|
|
// No separate flash RPD for shared tiers.
|
|
} else {
|
|
policy.Quota.FlashRPD = clampGeminiQuotaInt64WithUnlimited(*override.FlashRPD)
|
|
}
|
|
}
|
|
if override.CooldownMinutes != nil {
|
|
minutes := clampGeminiQuotaInt(*override.CooldownMinutes)
|
|
policy.Cooldown = time.Duration(minutes) * time.Minute
|
|
}
|
|
p.tiers[tierID] = policy
|
|
}
|
|
}
|
|
|
|
func (p *GeminiQuotaPolicy) ApplyQuotaRulesOverrides(rules map[string]geminiQuotaRuleOverride) {
|
|
if p == nil || len(rules) == 0 {
|
|
return
|
|
}
|
|
for rawID, override := range rules {
|
|
tierID := normalizeGeminiTierID(rawID)
|
|
if tierID == "" {
|
|
continue
|
|
}
|
|
policy, ok := p.tiers[tierID]
|
|
if !ok {
|
|
policy = GeminiTierPolicy{Cooldown: 5 * time.Minute}
|
|
}
|
|
|
|
if override.SharedRPD != nil {
|
|
policy.Quota.SharedRPD = clampGeminiQuotaInt64WithUnlimited(*override.SharedRPD)
|
|
}
|
|
if override.SharedRPM != nil {
|
|
policy.Quota.SharedRPM = clampGeminiQuotaRPM(*override.SharedRPM)
|
|
}
|
|
if override.GeminiPro != nil {
|
|
if override.GeminiPro.RPD != nil {
|
|
policy.Quota.ProRPD = clampGeminiQuotaInt64WithUnlimited(*override.GeminiPro.RPD)
|
|
}
|
|
if override.GeminiPro.RPM != nil {
|
|
policy.Quota.ProRPM = clampGeminiQuotaRPM(*override.GeminiPro.RPM)
|
|
}
|
|
}
|
|
if override.GeminiFlash != nil {
|
|
if override.GeminiFlash.RPD != nil {
|
|
policy.Quota.FlashRPD = clampGeminiQuotaInt64WithUnlimited(*override.GeminiFlash.RPD)
|
|
}
|
|
if override.GeminiFlash.RPM != nil {
|
|
policy.Quota.FlashRPM = clampGeminiQuotaRPM(*override.GeminiFlash.RPM)
|
|
}
|
|
}
|
|
|
|
p.tiers[tierID] = policy
|
|
}
|
|
}
|
|
|
|
func (p *GeminiQuotaPolicy) QuotaForTier(tierID string) (GeminiQuota, bool) {
|
|
policy, ok := p.policyForTier(tierID)
|
|
if !ok {
|
|
return GeminiQuota{}, false
|
|
}
|
|
return policy.Quota, true
|
|
}
|
|
|
|
func (p *GeminiQuotaPolicy) CooldownForTier(tierID string) time.Duration {
|
|
policy, ok := p.policyForTier(tierID)
|
|
if ok && policy.Cooldown > 0 {
|
|
return policy.Cooldown
|
|
}
|
|
return 5 * time.Minute
|
|
}
|
|
|
|
func (p *GeminiQuotaPolicy) policyForTier(tierID string) (GeminiTierPolicy, bool) {
|
|
if p == nil {
|
|
return GeminiTierPolicy{}, false
|
|
}
|
|
normalized := normalizeGeminiTierID(tierID)
|
|
if policy, ok := p.tiers[normalized]; ok {
|
|
return policy, true
|
|
}
|
|
return GeminiTierPolicy{}, false
|
|
}
|
|
|
|
func normalizeGeminiTierID(tierID string) string {
|
|
tierID = strings.TrimSpace(tierID)
|
|
if tierID == "" {
|
|
return ""
|
|
}
|
|
// Prefer canonical mapping (handles legacy tier strings).
|
|
if canonical := canonicalGeminiTierID(tierID); canonical != "" {
|
|
return canonical
|
|
}
|
|
// Accept older policy keys that used uppercase names.
|
|
switch strings.ToUpper(tierID) {
|
|
case "AISTUDIO_FREE":
|
|
return GeminiTierAIStudioFree
|
|
case "AISTUDIO_PAID":
|
|
return GeminiTierAIStudioPaid
|
|
case "GOOGLE_ONE_FREE":
|
|
return GeminiTierGoogleOneFree
|
|
case "GOOGLE_AI_PRO":
|
|
return GeminiTierGoogleAIPro
|
|
case "GOOGLE_AI_ULTRA":
|
|
return GeminiTierGoogleAIUltra
|
|
case "GCP_STANDARD":
|
|
return GeminiTierGCPStandard
|
|
case "GCP_ENTERPRISE":
|
|
return GeminiTierGCPEnterprise
|
|
}
|
|
return strings.ToLower(tierID)
|
|
}
|
|
|
|
func clampGeminiQuotaInt64WithUnlimited(value int64) int64 {
|
|
if value < -1 {
|
|
return 0
|
|
}
|
|
return value
|
|
}
|
|
|
|
func clampGeminiQuotaInt(value int) int {
|
|
if value < 0 {
|
|
return 0
|
|
}
|
|
return value
|
|
}
|
|
|
|
func clampGeminiQuotaRPM(value int64) int64 {
|
|
if value < 0 {
|
|
return 0
|
|
}
|
|
return value
|
|
}
|
|
|
|
func geminiCooldownForTier(tierID string) time.Duration {
|
|
policy := newGeminiQuotaPolicy()
|
|
return policy.CooldownForTier(tierID)
|
|
}
|
|
|
|
func geminiQuotaTierKeyForAccount(account *Account) string {
|
|
if account == nil || account.Platform != PlatformGemini {
|
|
return ""
|
|
}
|
|
|
|
// Note: GeminiOAuthType() already defaults legacy (project_id present) to code_assist.
|
|
oauthType := strings.ToLower(strings.TrimSpace(account.GeminiOAuthType()))
|
|
rawTier := strings.TrimSpace(account.GeminiTierID())
|
|
|
|
// Prefer the canonical tier stored in credentials.
|
|
if tierID := canonicalGeminiTierIDForOAuthType(oauthType, rawTier); tierID != "" && tierID != GeminiTierGoogleOneUnknown {
|
|
return tierID
|
|
}
|
|
|
|
// Fallback defaults when tier_id is missing or unknown.
|
|
switch oauthType {
|
|
case "google_one":
|
|
return GeminiTierGoogleOneFree
|
|
case "code_assist":
|
|
return GeminiTierGCPStandard
|
|
case "ai_studio":
|
|
return GeminiTierAIStudioFree
|
|
default:
|
|
// API Key accounts (type=apikey) have empty oauth_type and are treated as AI Studio.
|
|
return GeminiTierAIStudioFree
|
|
}
|
|
}
|
|
|
|
func geminiModelClassFromName(model string) geminiModelClass {
|
|
name := strings.ToLower(strings.TrimSpace(model))
|
|
if strings.Contains(name, "flash") || strings.Contains(name, "lite") {
|
|
return geminiModelFlash
|
|
}
|
|
return geminiModelPro
|
|
}
|
|
|
|
func geminiAggregateUsage(stats []usagestats.ModelStat) GeminiUsageTotals {
|
|
var totals GeminiUsageTotals
|
|
for _, stat := range stats {
|
|
switch geminiModelClassFromName(stat.Model) {
|
|
case geminiModelFlash:
|
|
totals.FlashRequests += stat.Requests
|
|
totals.FlashTokens += stat.TotalTokens
|
|
totals.FlashCost += stat.ActualCost
|
|
default:
|
|
totals.ProRequests += stat.Requests
|
|
totals.ProTokens += stat.TotalTokens
|
|
totals.ProCost += stat.ActualCost
|
|
}
|
|
}
|
|
return totals
|
|
}
|
|
|
|
func geminiQuotaLocation() *time.Location {
|
|
loc, err := time.LoadLocation("America/Los_Angeles")
|
|
if err != nil {
|
|
return time.FixedZone("PST", -8*3600)
|
|
}
|
|
return loc
|
|
}
|
|
|
|
func geminiDailyWindowStart(now time.Time) time.Time {
|
|
loc := geminiQuotaLocation()
|
|
localNow := now.In(loc)
|
|
return time.Date(localNow.Year(), localNow.Month(), localNow.Day(), 0, 0, 0, 0, loc)
|
|
}
|
|
|
|
func geminiDailyResetTime(now time.Time) time.Time {
|
|
loc := geminiQuotaLocation()
|
|
localNow := now.In(loc)
|
|
start := time.Date(localNow.Year(), localNow.Month(), localNow.Day(), 0, 0, 0, 0, loc)
|
|
reset := start.Add(24 * time.Hour)
|
|
if !reset.After(localNow) {
|
|
reset = reset.Add(24 * time.Hour)
|
|
}
|
|
return reset
|
|
}
|