Merge PR #142: feat: 多平台网关优化与用户系统增强
主要功能: - Gemini OAuth 配额系统优化:Google One tier 自动推断 - Antigravity 网关增强:Thinking Block 重试、Claude 模型 signature 透传 - 账号调度改进:临时不可调度功能、负载感知调度优化 - 前端用户体验:账号管理界面优化、使用教程改进 冲突解决:保留 handleUpstreamError 的 prefix 参数和日志记录能力, 同时合并 PR 的 thinking block 重试和 model fallback 功能。
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
// Package service provides business logic and domain services for the application.
|
||||
package service
|
||||
|
||||
import (
|
||||
@@ -29,6 +30,9 @@ type Account struct {
|
||||
RateLimitResetAt *time.Time
|
||||
OverloadUntil *time.Time
|
||||
|
||||
TempUnschedulableUntil *time.Time
|
||||
TempUnschedulableReason string
|
||||
|
||||
SessionWindowStart *time.Time
|
||||
SessionWindowEnd *time.Time
|
||||
SessionWindowStatus string
|
||||
@@ -39,6 +43,13 @@ type Account struct {
|
||||
Groups []*Group
|
||||
}
|
||||
|
||||
type TempUnschedulableRule struct {
|
||||
ErrorCode int `json:"error_code"`
|
||||
Keywords []string `json:"keywords"`
|
||||
DurationMinutes int `json:"duration_minutes"`
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
func (a *Account) IsActive() bool {
|
||||
return a.Status == StatusActive
|
||||
}
|
||||
@@ -54,6 +65,9 @@ func (a *Account) IsSchedulable() bool {
|
||||
if a.RateLimitResetAt != nil && now.Before(*a.RateLimitResetAt) {
|
||||
return false
|
||||
}
|
||||
if a.TempUnschedulableUntil != nil && now.Before(*a.TempUnschedulableUntil) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -92,10 +106,7 @@ func (a *Account) GeminiOAuthType() string {
|
||||
|
||||
func (a *Account) GeminiTierID() string {
|
||||
tierID := strings.TrimSpace(a.GetCredential("tier_id"))
|
||||
if tierID == "" {
|
||||
return ""
|
||||
}
|
||||
return strings.ToUpper(tierID)
|
||||
return tierID
|
||||
}
|
||||
|
||||
func (a *Account) IsGeminiCodeAssist() bool {
|
||||
@@ -163,6 +174,114 @@ func (a *Account) GetCredentialAsTime(key string) *time.Time {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (a *Account) IsTempUnschedulableEnabled() bool {
|
||||
if a.Credentials == nil {
|
||||
return false
|
||||
}
|
||||
raw, ok := a.Credentials["temp_unschedulable_enabled"]
|
||||
if !ok || raw == nil {
|
||||
return false
|
||||
}
|
||||
enabled, ok := raw.(bool)
|
||||
return ok && enabled
|
||||
}
|
||||
|
||||
func (a *Account) GetTempUnschedulableRules() []TempUnschedulableRule {
|
||||
if a.Credentials == nil {
|
||||
return nil
|
||||
}
|
||||
raw, ok := a.Credentials["temp_unschedulable_rules"]
|
||||
if !ok || raw == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
arr, ok := raw.([]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
rules := make([]TempUnschedulableRule, 0, len(arr))
|
||||
for _, item := range arr {
|
||||
entry, ok := item.(map[string]any)
|
||||
if !ok || entry == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
rule := TempUnschedulableRule{
|
||||
ErrorCode: parseTempUnschedInt(entry["error_code"]),
|
||||
Keywords: parseTempUnschedStrings(entry["keywords"]),
|
||||
DurationMinutes: parseTempUnschedInt(entry["duration_minutes"]),
|
||||
Description: parseTempUnschedString(entry["description"]),
|
||||
}
|
||||
|
||||
if rule.ErrorCode <= 0 || rule.DurationMinutes <= 0 || len(rule.Keywords) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
rules = append(rules, rule)
|
||||
}
|
||||
|
||||
return rules
|
||||
}
|
||||
|
||||
func parseTempUnschedString(value any) string {
|
||||
s, ok := value.(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(s)
|
||||
}
|
||||
|
||||
func parseTempUnschedStrings(value any) []string {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var raw []string
|
||||
switch v := value.(type) {
|
||||
case []string:
|
||||
raw = v
|
||||
case []any:
|
||||
raw = make([]string, 0, len(v))
|
||||
for _, item := range v {
|
||||
if s, ok := item.(string); ok {
|
||||
raw = append(raw, s)
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
out := make([]string, 0, len(raw))
|
||||
for _, item := range raw {
|
||||
s := strings.TrimSpace(item)
|
||||
if s != "" {
|
||||
out = append(out, s)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func parseTempUnschedInt(value any) int {
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
return v
|
||||
case int64:
|
||||
return int(v)
|
||||
case float64:
|
||||
return int(v)
|
||||
case json.Number:
|
||||
if i, err := v.Int64(); err == nil {
|
||||
return int(i)
|
||||
}
|
||||
case string:
|
||||
if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (a *Account) GetModelMapping() map[string]string {
|
||||
if a.Credentials == nil {
|
||||
return nil
|
||||
@@ -206,7 +325,7 @@ func (a *Account) GetMappedModel(requestedModel string) string {
|
||||
}
|
||||
|
||||
func (a *Account) GetBaseURL() string {
|
||||
if a.Type != AccountTypeApiKey {
|
||||
if a.Type != AccountTypeAPIKey {
|
||||
return ""
|
||||
}
|
||||
baseURL := a.GetCredential("base_url")
|
||||
@@ -229,7 +348,7 @@ func (a *Account) GetExtraString(key string) string {
|
||||
}
|
||||
|
||||
func (a *Account) IsCustomErrorCodesEnabled() bool {
|
||||
if a.Type != AccountTypeApiKey || a.Credentials == nil {
|
||||
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
|
||||
return false
|
||||
}
|
||||
if v, ok := a.Credentials["custom_error_codes_enabled"]; ok {
|
||||
@@ -301,14 +420,14 @@ func (a *Account) IsOpenAIOAuth() bool {
|
||||
}
|
||||
|
||||
func (a *Account) IsOpenAIApiKey() bool {
|
||||
return a.IsOpenAI() && a.Type == AccountTypeApiKey
|
||||
return a.IsOpenAI() && a.Type == AccountTypeAPIKey
|
||||
}
|
||||
|
||||
func (a *Account) GetOpenAIBaseURL() string {
|
||||
if !a.IsOpenAI() {
|
||||
return ""
|
||||
}
|
||||
if a.Type == AccountTypeApiKey {
|
||||
if a.Type == AccountTypeAPIKey {
|
||||
baseURL := a.GetCredential("base_url")
|
||||
if baseURL != "" {
|
||||
return baseURL
|
||||
|
||||
@@ -49,6 +49,8 @@ type AccountRepository interface {
|
||||
|
||||
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
|
||||
SetOverloaded(ctx context.Context, id int64, until time.Time) error
|
||||
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
|
||||
ClearTempUnschedulable(ctx context.Context, id int64) error
|
||||
ClearRateLimit(ctx context.Context, id int64) error
|
||||
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
|
||||
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
|
||||
|
||||
@@ -139,6 +139,14 @@ func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until tim
|
||||
panic("unexpected SetOverloaded call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||
panic("unexpected SetTempUnschedulable call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ClearTempUnschedulable(ctx context.Context, id int64) error {
|
||||
panic("unexpected ClearTempUnschedulable call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ClearRateLimit(ctx context.Context, id int64) error {
|
||||
panic("unexpected ClearRateLimit call")
|
||||
}
|
||||
|
||||
@@ -369,7 +369,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
|
||||
}
|
||||
|
||||
// For API Key accounts with model mapping, map the model
|
||||
if account.Type == AccountTypeApiKey {
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
mapping := account.GetModelMapping()
|
||||
if len(mapping) > 0 {
|
||||
if mappedModel, exists := mapping[testModelID]; exists {
|
||||
@@ -393,7 +393,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
|
||||
var err error
|
||||
|
||||
switch account.Type {
|
||||
case AccountTypeApiKey:
|
||||
case AccountTypeAPIKey:
|
||||
req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload)
|
||||
case AccountTypeOAuth:
|
||||
req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload)
|
||||
|
||||
@@ -12,16 +12,18 @@ import (
|
||||
)
|
||||
|
||||
type UsageLogRepository interface {
|
||||
Create(ctx context.Context, log *UsageLog) error
|
||||
// Create creates a usage log and returns whether it was actually inserted.
|
||||
// inserted is false when the insert was skipped due to conflict (idempotent retries).
|
||||
Create(ctx context.Context, log *UsageLog) (inserted bool, err error)
|
||||
GetByID(ctx context.Context, id int64) (*UsageLog, error)
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
|
||||
ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
|
||||
|
||||
@@ -32,10 +34,10 @@ type UsageLogRepository interface {
|
||||
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
|
||||
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error)
|
||||
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error)
|
||||
GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error)
|
||||
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
|
||||
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
|
||||
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
|
||||
GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error)
|
||||
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
|
||||
|
||||
// User dashboard stats
|
||||
GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error)
|
||||
@@ -51,7 +53,7 @@ type UsageLogRepository interface {
|
||||
|
||||
// Aggregated stats (optimized)
|
||||
GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error)
|
||||
GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error)
|
||||
@@ -105,6 +107,8 @@ type UsageProgress struct {
|
||||
ResetsAt *time.Time `json:"resets_at"` // 重置时间
|
||||
RemainingSeconds int `json:"remaining_seconds"` // 距重置剩余秒数
|
||||
WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量)
|
||||
UsedRequests int64 `json:"used_requests,omitempty"`
|
||||
LimitRequests int64 `json:"limit_requests,omitempty"`
|
||||
}
|
||||
|
||||
// AntigravityModelQuota Antigravity 单个模型的配额信息
|
||||
@@ -115,12 +119,16 @@ type AntigravityModelQuota struct {
|
||||
|
||||
// UsageInfo 账号使用量信息
|
||||
type UsageInfo struct {
|
||||
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
|
||||
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
|
||||
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
|
||||
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
|
||||
GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额
|
||||
GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额
|
||||
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
|
||||
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
|
||||
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
|
||||
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
|
||||
GeminiSharedDaily *UsageProgress `json:"gemini_shared_daily,omitempty"` // Gemini shared pool RPD (Google One / Code Assist)
|
||||
GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额
|
||||
GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额
|
||||
GeminiSharedMinute *UsageProgress `json:"gemini_shared_minute,omitempty"` // Gemini shared pool RPM (Google One / Code Assist)
|
||||
GeminiProMinute *UsageProgress `json:"gemini_pro_minute,omitempty"` // Gemini Pro RPM
|
||||
GeminiFlashMinute *UsageProgress `json:"gemini_flash_minute,omitempty"` // Gemini Flash RPM
|
||||
|
||||
// Antigravity 多模型配额
|
||||
AntigravityQuota map[string]*AntigravityModelQuota `json:"antigravity_quota,omitempty"`
|
||||
@@ -256,17 +264,44 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
|
||||
return usage, nil
|
||||
}
|
||||
|
||||
start := geminiDailyWindowStart(now)
|
||||
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
|
||||
dayStart := geminiDailyWindowStart(now)
|
||||
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
|
||||
}
|
||||
|
||||
totals := geminiAggregateUsage(stats)
|
||||
resetAt := geminiDailyResetTime(now)
|
||||
dayTotals := geminiAggregateUsage(stats)
|
||||
dailyResetAt := geminiDailyResetTime(now)
|
||||
|
||||
usage.GeminiProDaily = buildGeminiUsageProgress(totals.ProRequests, quota.ProRPD, resetAt, totals.ProTokens, totals.ProCost, now)
|
||||
usage.GeminiFlashDaily = buildGeminiUsageProgress(totals.FlashRequests, quota.FlashRPD, resetAt, totals.FlashTokens, totals.FlashCost, now)
|
||||
// Daily window (RPD)
|
||||
if quota.SharedRPD > 0 {
|
||||
totalReq := dayTotals.ProRequests + dayTotals.FlashRequests
|
||||
totalTokens := dayTotals.ProTokens + dayTotals.FlashTokens
|
||||
totalCost := dayTotals.ProCost + dayTotals.FlashCost
|
||||
usage.GeminiSharedDaily = buildGeminiUsageProgress(totalReq, quota.SharedRPD, dailyResetAt, totalTokens, totalCost, now)
|
||||
} else {
|
||||
usage.GeminiProDaily = buildGeminiUsageProgress(dayTotals.ProRequests, quota.ProRPD, dailyResetAt, dayTotals.ProTokens, dayTotals.ProCost, now)
|
||||
usage.GeminiFlashDaily = buildGeminiUsageProgress(dayTotals.FlashRequests, quota.FlashRPD, dailyResetAt, dayTotals.FlashTokens, dayTotals.FlashCost, now)
|
||||
}
|
||||
|
||||
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
|
||||
minuteStart := now.Truncate(time.Minute)
|
||||
minuteResetAt := minuteStart.Add(time.Minute)
|
||||
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err)
|
||||
}
|
||||
minuteTotals := geminiAggregateUsage(minuteStats)
|
||||
|
||||
if quota.SharedRPM > 0 {
|
||||
totalReq := minuteTotals.ProRequests + minuteTotals.FlashRequests
|
||||
totalTokens := minuteTotals.ProTokens + minuteTotals.FlashTokens
|
||||
totalCost := minuteTotals.ProCost + minuteTotals.FlashCost
|
||||
usage.GeminiSharedMinute = buildGeminiUsageProgress(totalReq, quota.SharedRPM, minuteResetAt, totalTokens, totalCost, now)
|
||||
} else {
|
||||
usage.GeminiProMinute = buildGeminiUsageProgress(minuteTotals.ProRequests, quota.ProRPM, minuteResetAt, minuteTotals.ProTokens, minuteTotals.ProCost, now)
|
||||
usage.GeminiFlashMinute = buildGeminiUsageProgress(minuteTotals.FlashRequests, quota.FlashRPM, minuteResetAt, minuteTotals.FlashTokens, minuteTotals.FlashCost, now)
|
||||
}
|
||||
|
||||
return usage, nil
|
||||
}
|
||||
@@ -506,6 +541,7 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn
|
||||
}
|
||||
|
||||
func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64, cost float64, now time.Time) *UsageProgress {
|
||||
// limit <= 0 means "no local quota window" (unknown or unlimited).
|
||||
if limit <= 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -519,6 +555,8 @@ func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64
|
||||
Utilization: utilization,
|
||||
ResetsAt: &resetCopy,
|
||||
RemainingSeconds: remainingSeconds,
|
||||
UsedRequests: used,
|
||||
LimitRequests: limit,
|
||||
WindowStats: &WindowStats{
|
||||
Requests: used,
|
||||
Tokens: tokens,
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
@@ -19,7 +20,7 @@ type AdminService interface {
|
||||
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
|
||||
DeleteUser(ctx context.Context, id int64) error
|
||||
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
|
||||
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error)
|
||||
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error)
|
||||
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
|
||||
|
||||
// Group management
|
||||
@@ -30,7 +31,7 @@ type AdminService interface {
|
||||
CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error)
|
||||
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error)
|
||||
DeleteGroup(ctx context.Context, id int64) error
|
||||
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error)
|
||||
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
|
||||
|
||||
// Account management
|
||||
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
|
||||
@@ -65,7 +66,7 @@ type AdminService interface {
|
||||
ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
|
||||
}
|
||||
|
||||
// Input types for admin operations
|
||||
// CreateUserInput represents input for creating a new user via admin operations.
|
||||
type CreateUserInput struct {
|
||||
Email string
|
||||
Password string
|
||||
@@ -122,18 +123,22 @@ type CreateAccountInput struct {
|
||||
Concurrency int
|
||||
Priority int
|
||||
GroupIDs []int64
|
||||
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
|
||||
// This should only be set when the caller has explicitly confirmed the risk.
|
||||
SkipMixedChannelCheck bool
|
||||
}
|
||||
|
||||
type UpdateAccountInput struct {
|
||||
Name string
|
||||
Type string // Account type: oauth, setup-token, apikey
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
ProxyID *int64
|
||||
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
||||
Priority *int // 使用指针区分"未提供"和"设置为0"
|
||||
Status string
|
||||
GroupIDs *[]int64
|
||||
Name string
|
||||
Type string // Account type: oauth, setup-token, apikey
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
ProxyID *int64
|
||||
Concurrency *int // 使用指针区分"未提供"和"设置为0"
|
||||
Priority *int // 使用指针区分"未提供"和"设置为0"
|
||||
Status string
|
||||
GroupIDs *[]int64
|
||||
SkipMixedChannelCheck bool // 跳过混合渠道检查(用户已确认风险)
|
||||
}
|
||||
|
||||
// BulkUpdateAccountsInput describes the payload for bulk updating accounts.
|
||||
@@ -147,6 +152,9 @@ type BulkUpdateAccountsInput struct {
|
||||
GroupIDs *[]int64
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
|
||||
// This should only be set when the caller has explicitly confirmed the risk.
|
||||
SkipMixedChannelCheck bool
|
||||
}
|
||||
|
||||
// BulkUpdateAccountResult captures the result for a single account update.
|
||||
@@ -220,7 +228,7 @@ type adminServiceImpl struct {
|
||||
groupRepo GroupRepository
|
||||
accountRepo AccountRepository
|
||||
proxyRepo ProxyRepository
|
||||
apiKeyRepo ApiKeyRepository
|
||||
apiKeyRepo APIKeyRepository
|
||||
redeemCodeRepo RedeemCodeRepository
|
||||
billingCacheService *BillingCacheService
|
||||
proxyProber ProxyExitInfoProber
|
||||
@@ -232,7 +240,7 @@ func NewAdminService(
|
||||
groupRepo GroupRepository,
|
||||
accountRepo AccountRepository,
|
||||
proxyRepo ProxyRepository,
|
||||
apiKeyRepo ApiKeyRepository,
|
||||
apiKeyRepo APIKeyRepository,
|
||||
redeemCodeRepo RedeemCodeRepository,
|
||||
billingCacheService *BillingCacheService,
|
||||
proxyProber ProxyExitInfoProber,
|
||||
@@ -430,7 +438,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error) {
|
||||
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
|
||||
if err != nil {
|
||||
@@ -583,7 +591,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error) {
|
||||
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error) {
|
||||
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
|
||||
keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params)
|
||||
if err != nil {
|
||||
@@ -620,6 +628,29 @@ func (s *adminServiceImpl) GetAccountsByIDs(ctx context.Context, ids []int64) ([
|
||||
}
|
||||
|
||||
func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) {
|
||||
// 绑定分组
|
||||
groupIDs := input.GroupIDs
|
||||
// 如果没有指定分组,自动绑定对应平台的默认分组
|
||||
if len(groupIDs) == 0 {
|
||||
defaultGroupName := input.Platform + "-default"
|
||||
groups, err := s.groupRepo.ListActiveByPlatform(ctx, input.Platform)
|
||||
if err == nil {
|
||||
for _, g := range groups {
|
||||
if g.Name == defaultGroupName {
|
||||
groupIDs = []int64{g.ID}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 检查混合渠道风险(除非用户已确认)
|
||||
if len(groupIDs) > 0 && !input.SkipMixedChannelCheck {
|
||||
if err := s.checkMixedChannelRisk(ctx, 0, input.Platform, groupIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
Name: input.Name,
|
||||
Platform: input.Platform,
|
||||
@@ -637,22 +668,6 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
|
||||
}
|
||||
|
||||
// 绑定分组
|
||||
groupIDs := input.GroupIDs
|
||||
// 如果没有指定分组,自动绑定对应平台的默认分组
|
||||
if len(groupIDs) == 0 {
|
||||
defaultGroupName := input.Platform + "-default"
|
||||
groups, err := s.groupRepo.ListActiveByPlatform(ctx, input.Platform)
|
||||
if err == nil {
|
||||
for _, g := range groups {
|
||||
if g.Name == defaultGroupName {
|
||||
groupIDs = []int64{g.ID}
|
||||
log.Printf("[CreateAccount] Auto-binding account %d to default group %s (ID: %d)", account.ID, defaultGroupName, g.ID)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(groupIDs) > 0 {
|
||||
if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil {
|
||||
return nil, err
|
||||
@@ -703,6 +718,13 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 检查混合渠道风险(除非用户已确认)
|
||||
if !input.SkipMixedChannelCheck {
|
||||
if err := s.checkMixedChannelRisk(ctx, account.ID, account.Platform, *input.GroupIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
@@ -731,6 +753,20 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Preload account platforms for mixed channel risk checks if group bindings are requested.
|
||||
platformByID := map[int64]string{}
|
||||
if input.GroupIDs != nil && !input.SkipMixedChannelCheck {
|
||||
accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, account := range accounts {
|
||||
if account != nil {
|
||||
platformByID[account.ID] = account.Platform
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare bulk updates for columns and JSONB fields.
|
||||
repoUpdates := AccountBulkUpdate{
|
||||
Credentials: input.Credentials,
|
||||
@@ -762,6 +798,29 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
|
||||
entry := BulkUpdateAccountResult{AccountID: accountID}
|
||||
|
||||
if input.GroupIDs != nil {
|
||||
// 检查混合渠道风险(除非用户已确认)
|
||||
if !input.SkipMixedChannelCheck {
|
||||
platform := platformByID[accountID]
|
||||
if platform == "" {
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err != nil {
|
||||
entry.Success = false
|
||||
entry.Error = err.Error()
|
||||
result.Failed++
|
||||
result.Results = append(result.Results, entry)
|
||||
continue
|
||||
}
|
||||
platform = account.Platform
|
||||
}
|
||||
if err := s.checkMixedChannelRisk(ctx, accountID, platform, *input.GroupIDs); err != nil {
|
||||
entry.Success = false
|
||||
entry.Error = err.Error()
|
||||
result.Failed++
|
||||
result.Results = append(result.Results, entry)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.accountRepo.BindGroups(ctx, accountID, *input.GroupIDs); err != nil {
|
||||
entry.Success = false
|
||||
entry.Error = err.Error()
|
||||
@@ -1006,3 +1065,77 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR
|
||||
Country: exitInfo.Country,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// checkMixedChannelRisk 检查分组中是否存在混合渠道(Antigravity + Anthropic)
|
||||
// 如果存在混合,返回错误提示用户确认
|
||||
func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
|
||||
// 判断当前账号的渠道类型(基于 platform 字段,而不是 type 字段)
|
||||
currentPlatform := getAccountPlatform(currentAccountPlatform)
|
||||
if currentPlatform == "" {
|
||||
// 不是 Antigravity 或 Anthropic,无需检查
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查每个分组中的其他账号
|
||||
for _, groupID := range groupIDs {
|
||||
accounts, err := s.accountRepo.ListByGroup(ctx, groupID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get accounts in group %d: %w", groupID, err)
|
||||
}
|
||||
|
||||
// 检查是否存在不同渠道的账号
|
||||
for _, account := range accounts {
|
||||
if currentAccountID > 0 && account.ID == currentAccountID {
|
||||
continue // 跳过当前账号
|
||||
}
|
||||
|
||||
otherPlatform := getAccountPlatform(account.Platform)
|
||||
if otherPlatform == "" {
|
||||
continue // 不是 Antigravity 或 Anthropic,跳过
|
||||
}
|
||||
|
||||
// 检测混合渠道
|
||||
if currentPlatform != otherPlatform {
|
||||
group, _ := s.groupRepo.GetByID(ctx, groupID)
|
||||
groupName := fmt.Sprintf("Group %d", groupID)
|
||||
if group != nil {
|
||||
groupName = group.Name
|
||||
}
|
||||
|
||||
return &MixedChannelError{
|
||||
GroupID: groupID,
|
||||
GroupName: groupName,
|
||||
CurrentPlatform: currentPlatform,
|
||||
OtherPlatform: otherPlatform,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getAccountPlatform 根据账号 platform 判断混合渠道检查用的平台标识
|
||||
func getAccountPlatform(accountPlatform string) string {
|
||||
switch strings.ToLower(strings.TrimSpace(accountPlatform)) {
|
||||
case PlatformAntigravity:
|
||||
return "Antigravity"
|
||||
case PlatformAnthropic, "claude":
|
||||
return "Anthropic"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// MixedChannelError 混合渠道错误
|
||||
type MixedChannelError struct {
|
||||
GroupID int64
|
||||
GroupName string
|
||||
CurrentPlatform string
|
||||
OtherPlatform string
|
||||
}
|
||||
|
||||
func (e *MixedChannelError) Error() string {
|
||||
return fmt.Sprintf("mixed_channel_warning: Group '%s' contains both %s and %s accounts. Using mixed channels in the same context may cause thinking block signature validation issues, which will fallback to non-thinking mode for historical messages.",
|
||||
e.GroupName, e.CurrentPlatform, e.OtherPlatform)
|
||||
}
|
||||
|
||||
@@ -81,6 +81,7 @@ type AntigravityGatewayService struct {
|
||||
tokenProvider *AntigravityTokenProvider
|
||||
rateLimitService *RateLimitService
|
||||
httpUpstream HTTPUpstream
|
||||
settingService *SettingService
|
||||
}
|
||||
|
||||
func NewAntigravityGatewayService(
|
||||
@@ -89,12 +90,14 @@ func NewAntigravityGatewayService(
|
||||
tokenProvider *AntigravityTokenProvider,
|
||||
rateLimitService *RateLimitService,
|
||||
httpUpstream HTTPUpstream,
|
||||
settingService *SettingService,
|
||||
) *AntigravityGatewayService {
|
||||
return &AntigravityGatewayService{
|
||||
accountRepo: accountRepo,
|
||||
tokenProvider: tokenProvider,
|
||||
rateLimitService: rateLimitService,
|
||||
httpUpstream: httpUpstream,
|
||||
settingService: settingService,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -324,6 +327,22 @@ func (s *AntigravityGatewayService) unwrapV1InternalResponse(body []byte) ([]byt
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// isModelNotFoundError 检测是否为模型不存在的 404 错误
|
||||
func isModelNotFoundError(statusCode int, body []byte) bool {
|
||||
if statusCode != 404 {
|
||||
return false
|
||||
}
|
||||
|
||||
bodyStr := strings.ToLower(string(body))
|
||||
keywords := []string{"model not found", "unknown model", "not found"}
|
||||
for _, keyword := range keywords {
|
||||
if strings.Contains(bodyStr, keyword) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return true // 404 without specific message also treated as model not found
|
||||
}
|
||||
|
||||
// Forward 转发 Claude 协议请求(Claude → Gemini 转换)
|
||||
func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
@@ -417,16 +436,56 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
// 处理错误响应
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
|
||||
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
// 优先检测 thinking block 的 signature 相关错误(400)并重试一次:
|
||||
// Antigravity /v1internal 链路在部分场景会对 thought/thinking signature 做严格校验,
|
||||
// 当历史消息携带的 signature 不合法时会直接 400;去除 thinking 后可继续完成请求。
|
||||
if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) {
|
||||
retryClaudeReq := claudeReq
|
||||
retryClaudeReq.Messages = append([]antigravity.ClaudeMessage(nil), claudeReq.Messages...)
|
||||
|
||||
stripped, stripErr := stripThinkingFromClaudeRequest(&retryClaudeReq)
|
||||
if stripErr == nil && stripped {
|
||||
log.Printf("Antigravity account %d: detected signature-related 400, retrying once without thinking blocks", account.ID)
|
||||
|
||||
retryGeminiBody, txErr := antigravity.TransformClaudeToGemini(&retryClaudeReq, projectID, mappedModel)
|
||||
if txErr == nil {
|
||||
retryReq, buildErr := antigravity.NewAPIRequest(ctx, action, accessToken, retryGeminiBody)
|
||||
if buildErr == nil {
|
||||
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
||||
if retryErr == nil {
|
||||
// Retry success: continue normal success flow with the new response.
|
||||
if retryResp.StatusCode < 400 {
|
||||
_ = resp.Body.Close()
|
||||
resp = retryResp
|
||||
respBody = nil
|
||||
} else {
|
||||
// Retry still errored: replace error context with retry response.
|
||||
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
|
||||
_ = retryResp.Body.Close()
|
||||
respBody = retryBody
|
||||
resp = retryResp
|
||||
}
|
||||
} else {
|
||||
log.Printf("Antigravity account %d: signature retry request failed: %v", account.ID, retryErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, s.writeMappedClaudeError(c, resp.StatusCode, respBody)
|
||||
// 处理错误响应(重试后仍失败或不触发重试)
|
||||
if resp.StatusCode >= 400 {
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
|
||||
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
|
||||
return nil, s.writeMappedClaudeError(c, resp.StatusCode, respBody)
|
||||
}
|
||||
}
|
||||
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
@@ -461,6 +520,122 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func isSignatureRelatedError(respBody []byte) bool {
|
||||
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
|
||||
if msg == "" {
|
||||
// Fallback: best-effort scan of the raw payload.
|
||||
msg = strings.ToLower(string(respBody))
|
||||
}
|
||||
|
||||
// Keep this intentionally broad: different upstreams may use "signature" or "thought_signature".
|
||||
return strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature")
|
||||
}
|
||||
|
||||
func extractAntigravityErrorMessage(body []byte) string {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Google-style: {"error": {"message": "..."}}
|
||||
if errObj, ok := payload["error"].(map[string]any); ok {
|
||||
if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" {
|
||||
return msg
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: top-level message
|
||||
if msg, ok := payload["message"].(string); ok && strings.TrimSpace(msg) != "" {
|
||||
return msg
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// stripThinkingFromClaudeRequest converts thinking blocks to text blocks in a Claude Messages request.
|
||||
// This preserves the thinking content while avoiding signature validation errors.
|
||||
// Note: redacted_thinking blocks are removed because they cannot be converted to text.
|
||||
// It also disables top-level `thinking` to prevent dummy-thought injection during retry.
|
||||
func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error) {
|
||||
if req == nil {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
changed := false
|
||||
if req.Thinking != nil {
|
||||
req.Thinking = nil
|
||||
changed = true
|
||||
}
|
||||
|
||||
for i := range req.Messages {
|
||||
raw := req.Messages[i].Content
|
||||
if len(raw) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// If content is a string, nothing to strip.
|
||||
var str string
|
||||
if json.Unmarshal(raw, &str) == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Otherwise treat as an array of blocks and convert thinking blocks to text.
|
||||
var blocks []map[string]any
|
||||
if err := json.Unmarshal(raw, &blocks); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
filtered := make([]map[string]any, 0, len(blocks))
|
||||
modifiedAny := false
|
||||
for _, block := range blocks {
|
||||
t, _ := block["type"].(string)
|
||||
switch t {
|
||||
case "thinking":
|
||||
// Convert thinking to text, skip if empty
|
||||
thinkingText, _ := block["thinking"].(string)
|
||||
if thinkingText != "" {
|
||||
filtered = append(filtered, map[string]any{
|
||||
"type": "text",
|
||||
"text": thinkingText,
|
||||
})
|
||||
}
|
||||
modifiedAny = true
|
||||
case "redacted_thinking":
|
||||
// Remove redacted_thinking (cannot convert encrypted content)
|
||||
modifiedAny = true
|
||||
case "":
|
||||
// Handle untyped block with "thinking" field
|
||||
if thinkingText, hasThinking := block["thinking"].(string); hasThinking {
|
||||
if thinkingText != "" {
|
||||
filtered = append(filtered, map[string]any{
|
||||
"type": "text",
|
||||
"text": thinkingText,
|
||||
})
|
||||
}
|
||||
modifiedAny = true
|
||||
} else {
|
||||
filtered = append(filtered, block)
|
||||
}
|
||||
default:
|
||||
filtered = append(filtered, block)
|
||||
}
|
||||
}
|
||||
|
||||
if !modifiedAny {
|
||||
continue
|
||||
}
|
||||
|
||||
newRaw, err := json.Marshal(filtered)
|
||||
if err != nil {
|
||||
return changed, err
|
||||
}
|
||||
req.Messages[i].Content = newRaw
|
||||
changed = true
|
||||
}
|
||||
|
||||
return changed, nil
|
||||
}
|
||||
|
||||
// ForwardGemini 转发 Gemini 协议请求
|
||||
func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
|
||||
startTime := time.Now()
|
||||
@@ -574,14 +749,40 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
if requestID != "" {
|
||||
c.Header("x-request-id", requestID)
|
||||
}
|
||||
|
||||
// 处理错误响应
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
|
||||
// 模型兜底:模型不存在且开启 fallback 时,自动用 fallback 模型重试一次
|
||||
if s.settingService != nil && s.settingService.IsModelFallbackEnabled(ctx) &&
|
||||
isModelNotFoundError(resp.StatusCode, respBody) {
|
||||
fallbackModel := s.settingService.GetFallbackModel(ctx, PlatformAntigravity)
|
||||
if fallbackModel != "" && fallbackModel != mappedModel {
|
||||
log.Printf("[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name)
|
||||
|
||||
// 关闭原始响应,释放连接(respBody 已读取到内存)
|
||||
_ = resp.Body.Close()
|
||||
|
||||
fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, body)
|
||||
if err == nil {
|
||||
fallbackReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, fallbackWrapped)
|
||||
if err == nil {
|
||||
fallbackResp, err := s.httpUpstream.Do(fallbackReq, proxyURL, account.ID, account.Concurrency)
|
||||
if err == nil && fallbackResp.StatusCode < 400 {
|
||||
resp = fallbackResp
|
||||
} else if fallbackResp != nil {
|
||||
_ = fallbackResp.Body.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// fallback 成功:继续按正常响应处理
|
||||
if resp.StatusCode < 400 {
|
||||
goto handleSuccess
|
||||
}
|
||||
|
||||
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
|
||||
|
||||
if s.shouldFailoverUpstreamError(resp.StatusCode) {
|
||||
@@ -589,6 +790,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
}
|
||||
|
||||
// 解包并返回错误
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
if requestID != "" {
|
||||
c.Header("x-request-id", requestID)
|
||||
}
|
||||
unwrapped, _ := s.unwrapV1InternalResponse(respBody)
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
if contentType == "" {
|
||||
@@ -598,6 +803,12 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
|
||||
return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
handleSuccess:
|
||||
requestID := resp.Header.Get("x-request-id")
|
||||
if requestID != "" {
|
||||
c.Header("x-request-id", requestID)
|
||||
}
|
||||
|
||||
var usage *ClaudeUsage
|
||||
var firstTokenMs *int
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ package service
|
||||
|
||||
import "time"
|
||||
|
||||
type ApiKey struct {
|
||||
type APIKey struct {
|
||||
ID int64
|
||||
UserID int64
|
||||
Key string
|
||||
@@ -15,6 +15,6 @@ type ApiKey struct {
|
||||
Group *Group
|
||||
}
|
||||
|
||||
func (k *ApiKey) IsActive() bool {
|
||||
func (k *APIKey) IsActive() bool {
|
||||
return k.Status == StatusActive
|
||||
}
|
||||
|
||||
@@ -14,39 +14,39 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrApiKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
|
||||
ErrAPIKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
|
||||
ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group")
|
||||
ErrApiKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
|
||||
ErrApiKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
|
||||
ErrApiKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
|
||||
ErrApiKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
|
||||
ErrAPIKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
|
||||
ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
|
||||
ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
|
||||
ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
|
||||
)
|
||||
|
||||
const (
|
||||
apiKeyMaxErrorsPerHour = 20
|
||||
)
|
||||
|
||||
type ApiKeyRepository interface {
|
||||
Create(ctx context.Context, key *ApiKey) error
|
||||
GetByID(ctx context.Context, id int64) (*ApiKey, error)
|
||||
type APIKeyRepository interface {
|
||||
Create(ctx context.Context, key *APIKey) error
|
||||
GetByID(ctx context.Context, id int64) (*APIKey, error)
|
||||
// GetOwnerID 仅获取 API Key 的所有者 ID,用于删除前的轻量级权限验证
|
||||
GetOwnerID(ctx context.Context, id int64) (int64, error)
|
||||
GetByKey(ctx context.Context, key string) (*ApiKey, error)
|
||||
Update(ctx context.Context, key *ApiKey) error
|
||||
GetByKey(ctx context.Context, key string) (*APIKey, error)
|
||||
Update(ctx context.Context, key *APIKey) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
|
||||
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
|
||||
VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error)
|
||||
CountByUserID(ctx context.Context, userID int64) (int64, error)
|
||||
ExistsByKey(ctx context.Context, key string) (bool, error)
|
||||
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
|
||||
SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error)
|
||||
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
|
||||
SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error)
|
||||
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
}
|
||||
|
||||
// ApiKeyCache defines cache operations for API key service
|
||||
type ApiKeyCache interface {
|
||||
// APIKeyCache defines cache operations for API key service
|
||||
type APIKeyCache interface {
|
||||
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
|
||||
IncrementCreateAttemptCount(ctx context.Context, userID int64) error
|
||||
DeleteCreateAttemptCount(ctx context.Context, userID int64) error
|
||||
@@ -55,40 +55,40 @@ type ApiKeyCache interface {
|
||||
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
|
||||
}
|
||||
|
||||
// CreateApiKeyRequest 创建API Key请求
|
||||
type CreateApiKeyRequest struct {
|
||||
// CreateAPIKeyRequest 创建API Key请求
|
||||
type CreateAPIKeyRequest struct {
|
||||
Name string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
CustomKey *string `json:"custom_key"` // 可选的自定义key
|
||||
}
|
||||
|
||||
// UpdateApiKeyRequest 更新API Key请求
|
||||
type UpdateApiKeyRequest struct {
|
||||
// UpdateAPIKeyRequest 更新API Key请求
|
||||
type UpdateAPIKeyRequest struct {
|
||||
Name *string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Status *string `json:"status"`
|
||||
}
|
||||
|
||||
// ApiKeyService API Key服务
|
||||
type ApiKeyService struct {
|
||||
apiKeyRepo ApiKeyRepository
|
||||
// APIKeyService API Key服务
|
||||
type APIKeyService struct {
|
||||
apiKeyRepo APIKeyRepository
|
||||
userRepo UserRepository
|
||||
groupRepo GroupRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
cache ApiKeyCache
|
||||
cache APIKeyCache
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewApiKeyService 创建API Key服务实例
|
||||
func NewApiKeyService(
|
||||
apiKeyRepo ApiKeyRepository,
|
||||
// NewAPIKeyService 创建API Key服务实例
|
||||
func NewAPIKeyService(
|
||||
apiKeyRepo APIKeyRepository,
|
||||
userRepo UserRepository,
|
||||
groupRepo GroupRepository,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
cache ApiKeyCache,
|
||||
cache APIKeyCache,
|
||||
cfg *config.Config,
|
||||
) *ApiKeyService {
|
||||
return &ApiKeyService{
|
||||
) *APIKeyService {
|
||||
return &APIKeyService{
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
userRepo: userRepo,
|
||||
groupRepo: groupRepo,
|
||||
@@ -99,7 +99,7 @@ func NewApiKeyService(
|
||||
}
|
||||
|
||||
// GenerateKey 生成随机API Key
|
||||
func (s *ApiKeyService) GenerateKey() (string, error) {
|
||||
func (s *APIKeyService) GenerateKey() (string, error) {
|
||||
// 生成32字节随机数据
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
@@ -107,7 +107,7 @@ func (s *ApiKeyService) GenerateKey() (string, error) {
|
||||
}
|
||||
|
||||
// 转换为十六进制字符串并添加前缀
|
||||
prefix := s.cfg.Default.ApiKeyPrefix
|
||||
prefix := s.cfg.Default.APIKeyPrefix
|
||||
if prefix == "" {
|
||||
prefix = "sk-"
|
||||
}
|
||||
@@ -117,10 +117,10 @@ func (s *ApiKeyService) GenerateKey() (string, error) {
|
||||
}
|
||||
|
||||
// ValidateCustomKey 验证自定义API Key格式
|
||||
func (s *ApiKeyService) ValidateCustomKey(key string) error {
|
||||
func (s *APIKeyService) ValidateCustomKey(key string) error {
|
||||
// 检查长度
|
||||
if len(key) < 16 {
|
||||
return ErrApiKeyTooShort
|
||||
return ErrAPIKeyTooShort
|
||||
}
|
||||
|
||||
// 检查字符:只允许字母、数字、下划线、连字符
|
||||
@@ -131,14 +131,14 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error {
|
||||
c == '_' || c == '-' {
|
||||
continue
|
||||
}
|
||||
return ErrApiKeyInvalidChars
|
||||
return ErrAPIKeyInvalidChars
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
|
||||
func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error {
|
||||
// checkAPIKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
|
||||
func (s *APIKeyService) checkAPIKeyRateLimit(ctx context.Context, userID int64) error {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -150,14 +150,14 @@ func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64)
|
||||
}
|
||||
|
||||
if count >= apiKeyMaxErrorsPerHour {
|
||||
return ErrApiKeyRateLimited
|
||||
return ErrAPIKeyRateLimited
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数
|
||||
func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) {
|
||||
// incrementAPIKeyErrorCount 增加用户创建自定义Key的错误计数
|
||||
func (s *APIKeyService) incrementAPIKeyErrorCount(ctx context.Context, userID int64) {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
@@ -168,7 +168,7 @@ func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID in
|
||||
// canUserBindGroup 检查用户是否可以绑定指定分组
|
||||
// 对于订阅类型分组:检查用户是否有有效订阅
|
||||
// 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑
|
||||
func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
|
||||
func (s *APIKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
|
||||
// 订阅类型分组:需要有效订阅
|
||||
if group.IsSubscriptionType() {
|
||||
_, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID)
|
||||
@@ -179,7 +179,7 @@ func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group
|
||||
}
|
||||
|
||||
// Create 创建API Key
|
||||
func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*ApiKey, error) {
|
||||
func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIKeyRequest) (*APIKey, error) {
|
||||
// 验证用户存在
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
@@ -204,7 +204,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
|
||||
// 判断是否使用自定义Key
|
||||
if req.CustomKey != nil && *req.CustomKey != "" {
|
||||
// 检查限流(仅对自定义key进行限流)
|
||||
if err := s.checkApiKeyRateLimit(ctx, userID); err != nil {
|
||||
if err := s.checkAPIKeyRateLimit(ctx, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -220,8 +220,8 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
|
||||
}
|
||||
if exists {
|
||||
// Key已存在,增加错误计数
|
||||
s.incrementApiKeyErrorCount(ctx, userID)
|
||||
return nil, ErrApiKeyExists
|
||||
s.incrementAPIKeyErrorCount(ctx, userID)
|
||||
return nil, ErrAPIKeyExists
|
||||
}
|
||||
|
||||
key = *req.CustomKey
|
||||
@@ -235,7 +235,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
|
||||
}
|
||||
|
||||
// 创建API Key记录
|
||||
apiKey := &ApiKey{
|
||||
apiKey := &APIKey{
|
||||
UserID: userID,
|
||||
Key: key,
|
||||
Name: req.Name,
|
||||
@@ -251,7 +251,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
|
||||
}
|
||||
|
||||
// List 获取用户的API Key列表
|
||||
func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
|
||||
func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list api keys: %w", err)
|
||||
@@ -259,7 +259,7 @@ func (s *ApiKeyService) List(ctx context.Context, userID int64, params paginatio
|
||||
return keys, pagination, nil
|
||||
}
|
||||
|
||||
func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
func (s *APIKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
if len(apiKeyIDs) == 0 {
|
||||
return []int64{}, nil
|
||||
}
|
||||
@@ -272,7 +272,7 @@ func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKe
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取API Key
|
||||
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
|
||||
func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error) {
|
||||
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
@@ -281,7 +281,7 @@ func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error)
|
||||
}
|
||||
|
||||
// GetByKey 根据Key字符串获取API Key(用于认证)
|
||||
func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, error) {
|
||||
func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
||||
// 尝试从Redis缓存获取
|
||||
cacheKey := fmt.Sprintf("apikey:%s", key)
|
||||
|
||||
@@ -301,7 +301,7 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, erro
|
||||
}
|
||||
|
||||
// Update 更新API Key
|
||||
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*ApiKey, error) {
|
||||
func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateAPIKeyRequest) (*APIKey, error) {
|
||||
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
@@ -353,8 +353,8 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req
|
||||
|
||||
// Delete 删除API Key
|
||||
// 优化:使用 GetOwnerID 替代 GetByID 进行权限验证,
|
||||
// 避免加载完整 ApiKey 对象及其关联数据(User、Group),提升删除操作的性能
|
||||
func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) error {
|
||||
// 避免加载完整 APIKey 对象及其关联数据(User、Group),提升删除操作的性能
|
||||
func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error {
|
||||
// 仅获取所有者 ID 用于权限验证,而非加载完整对象
|
||||
ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id)
|
||||
if err != nil {
|
||||
@@ -379,7 +379,7 @@ func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) erro
|
||||
}
|
||||
|
||||
// ValidateKey 验证API Key是否有效(用于认证中间件)
|
||||
func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *User, error) {
|
||||
func (s *APIKeyService) ValidateKey(ctx context.Context, key string) (*APIKey, *User, error) {
|
||||
// 获取API Key
|
||||
apiKey, err := s.GetByKey(ctx, key)
|
||||
if err != nil {
|
||||
@@ -406,7 +406,7 @@ func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *
|
||||
}
|
||||
|
||||
// IncrementUsage 增加API Key使用次数(可选:用于统计)
|
||||
func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
|
||||
func (s *APIKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
|
||||
// 使用Redis计数器
|
||||
if s.cache != nil {
|
||||
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
|
||||
@@ -423,7 +423,7 @@ func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
|
||||
// 返回用户可以选择的分组:
|
||||
// - 标准类型分组:公开的(非专属)或用户被明确允许的
|
||||
// - 订阅类型分组:用户有有效订阅的
|
||||
func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
|
||||
func (s *APIKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
|
||||
// 获取用户信息
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
@@ -460,7 +460,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
|
||||
}
|
||||
|
||||
// canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据)
|
||||
func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
|
||||
func (s *APIKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
|
||||
// 订阅类型分组:需要有效订阅
|
||||
if group.IsSubscriptionType() {
|
||||
return subscribedGroupIDs[group.ID]
|
||||
@@ -469,8 +469,8 @@ func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subsc
|
||||
return user.CanBindGroup(group.ID, group.IsExclusive)
|
||||
}
|
||||
|
||||
func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) {
|
||||
keys, err := s.apiKeyRepo.SearchApiKeys(ctx, userID, keyword, limit)
|
||||
func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
|
||||
keys, err := s.apiKeyRepo.SearchAPIKeys(ctx, userID, keyword, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("search api keys: %w", err)
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
//go:build unit
|
||||
|
||||
// API Key 服务删除方法的单元测试
|
||||
// 测试 ApiKeyService.Delete 方法在各种场景下的行为,
|
||||
// 测试 APIKeyService.Delete 方法在各种场景下的行为,
|
||||
// 包括权限验证、缓存清理和错误处理
|
||||
|
||||
package service
|
||||
@@ -16,12 +16,12 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// apiKeyRepoStub 是 ApiKeyRepository 接口的测试桩实现。
|
||||
// 用于隔离测试 ApiKeyService.Delete 方法,避免依赖真实数据库。
|
||||
// apiKeyRepoStub 是 APIKeyRepository 接口的测试桩实现。
|
||||
// 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。
|
||||
//
|
||||
// 设计说明:
|
||||
// - ownerID: 模拟 GetOwnerID 返回的所有者 ID
|
||||
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrApiKeyNotFound)
|
||||
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrAPIKeyNotFound)
|
||||
// - deleteErr: 模拟 Delete 返回的错误
|
||||
// - deletedIDs: 记录被调用删除的 API Key ID,用于断言验证
|
||||
type apiKeyRepoStub struct {
|
||||
@@ -33,11 +33,11 @@ type apiKeyRepoStub struct {
|
||||
|
||||
// 以下方法在本测试中不应被调用,使用 panic 确保测试失败时能快速定位问题
|
||||
|
||||
func (s *apiKeyRepoStub) Create(ctx context.Context, key *ApiKey) error {
|
||||
func (s *apiKeyRepoStub) Create(ctx context.Context, key *APIKey) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
|
||||
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
|
||||
panic("unexpected GetByID call")
|
||||
}
|
||||
|
||||
@@ -47,11 +47,11 @@ func (s *apiKeyRepoStub) GetOwnerID(ctx context.Context, id int64) (int64, error
|
||||
return s.ownerID, s.ownerErr
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*ApiKey, error) {
|
||||
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
||||
panic("unexpected GetByKey call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) Update(ctx context.Context, key *ApiKey) error {
|
||||
func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
@@ -64,7 +64,7 @@ func (s *apiKeyRepoStub) Delete(ctx context.Context, id int64) error {
|
||||
|
||||
// 以下是接口要求实现但本测试不关心的方法
|
||||
|
||||
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
|
||||
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByUserID call")
|
||||
}
|
||||
|
||||
@@ -80,12 +80,12 @@ func (s *apiKeyRepoStub) ExistsByKey(ctx context.Context, key string) (bool, err
|
||||
panic("unexpected ExistsByKey call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
|
||||
func (s *apiKeyRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByGroupID call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) {
|
||||
panic("unexpected SearchApiKeys call")
|
||||
func (s *apiKeyRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
|
||||
panic("unexpected SearchAPIKeys call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
@@ -96,7 +96,7 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int
|
||||
panic("unexpected CountByGroupID call")
|
||||
}
|
||||
|
||||
// apiKeyCacheStub 是 ApiKeyCache 接口的测试桩实现。
|
||||
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
|
||||
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
|
||||
//
|
||||
// 设计说明:
|
||||
@@ -142,7 +142,7 @@ func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string
|
||||
func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{ownerID: 1}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
|
||||
err := svc.Delete(context.Background(), 10, 2) // API Key ID=10, 调用者 userID=2
|
||||
require.ErrorIs(t, err, ErrInsufficientPerms)
|
||||
@@ -160,7 +160,7 @@ func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
|
||||
func TestApiKeyService_Delete_Success(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{ownerID: 7}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
|
||||
err := svc.Delete(context.Background(), 42, 7) // API Key ID=42, 调用者 userID=7
|
||||
require.NoError(t, err)
|
||||
@@ -170,17 +170,17 @@ func TestApiKeyService_Delete_Success(t *testing.T) {
|
||||
|
||||
// TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
|
||||
// 预期行为:
|
||||
// - GetOwnerID 返回 ErrApiKeyNotFound 错误
|
||||
// - 返回 ErrApiKeyNotFound 错误(被 fmt.Errorf 包装)
|
||||
// - GetOwnerID 返回 ErrAPIKeyNotFound 错误
|
||||
// - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装)
|
||||
// - Delete 方法不被调用
|
||||
// - 缓存不被清除
|
||||
func TestApiKeyService_Delete_NotFound(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{ownerErr: ErrApiKeyNotFound}
|
||||
repo := &apiKeyRepoStub{ownerErr: ErrAPIKeyNotFound}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
|
||||
err := svc.Delete(context.Background(), 99, 1)
|
||||
require.ErrorIs(t, err, ErrApiKeyNotFound)
|
||||
require.ErrorIs(t, err, ErrAPIKeyNotFound)
|
||||
require.Empty(t, repo.deletedIDs)
|
||||
require.Empty(t, cache.invalidated)
|
||||
}
|
||||
@@ -198,7 +198,7 @@ func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
|
||||
deleteErr: errors.New("delete failed"),
|
||||
}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
|
||||
err := svc.Delete(context.Background(), 3, 3) // API Key ID=3, 调用者 userID=3
|
||||
require.Error(t, err)
|
||||
|
||||
@@ -445,7 +445,7 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
|
||||
// CheckBillingEligibility 检查用户是否有资格发起请求
|
||||
// 余额模式:检查缓存余额 > 0
|
||||
// 订阅模式:检查缓存用量未超过限额(Group限额从参数传入)
|
||||
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *ApiKey, group *Group, subscription *UserSubscription) error {
|
||||
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *APIKey, group *Group, subscription *UserSubscription) error {
|
||||
// 简易模式:跳过所有计费检查
|
||||
if s.cfg.RunMode == config.RunModeSimple {
|
||||
return nil
|
||||
|
||||
@@ -430,7 +430,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
account := &Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeApiKey,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: credentials,
|
||||
Extra: extra,
|
||||
ProxyID: proxyID,
|
||||
@@ -455,7 +455,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
existing.Extra = mergeMap(existing.Extra, extra)
|
||||
existing.Name = defaultName(src.Name, src.ID)
|
||||
existing.Platform = PlatformAnthropic
|
||||
existing.Type = AccountTypeApiKey
|
||||
existing.Type = AccountTypeAPIKey
|
||||
existing.Credentials = mergeMap(existing.Credentials, credentials)
|
||||
if proxyID != nil {
|
||||
existing.ProxyID = proxyID
|
||||
@@ -674,7 +674,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
account := &Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeApiKey,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: credentials,
|
||||
Extra: extra,
|
||||
ProxyID: proxyID,
|
||||
@@ -699,7 +699,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
existing.Extra = mergeMap(existing.Extra, extra)
|
||||
existing.Name = defaultName(src.Name, src.ID)
|
||||
existing.Platform = PlatformOpenAI
|
||||
existing.Type = AccountTypeApiKey
|
||||
existing.Type = AccountTypeAPIKey
|
||||
existing.Credentials = mergeMap(existing.Credentials, credentials)
|
||||
if proxyID != nil {
|
||||
existing.ProxyID = proxyID
|
||||
@@ -893,7 +893,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
account := &Account{
|
||||
Name: defaultName(src.Name, src.ID),
|
||||
Platform: PlatformGemini,
|
||||
Type: AccountTypeApiKey,
|
||||
Type: AccountTypeAPIKey,
|
||||
Credentials: credentials,
|
||||
Extra: extra,
|
||||
ProxyID: proxyID,
|
||||
@@ -918,7 +918,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
|
||||
existing.Extra = mergeMap(existing.Extra, extra)
|
||||
existing.Name = defaultName(src.Name, src.ID)
|
||||
existing.Platform = PlatformGemini
|
||||
existing.Type = AccountTypeApiKey
|
||||
existing.Type = AccountTypeAPIKey
|
||||
existing.Credentials = mergeMap(existing.Credentials, credentials)
|
||||
if proxyID != nil {
|
||||
existing.ProxyID = proxyID
|
||||
|
||||
@@ -43,8 +43,8 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) {
|
||||
trend, err := s.usageRepo.GetApiKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||
func (s *DashboardService) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
||||
trend, err := s.usageRepo.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key usage trend: %w", err)
|
||||
}
|
||||
@@ -67,8 +67,8 @@ func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs [
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchApiKeyUsageStats(ctx, apiKeyIDs)
|
||||
func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ const (
|
||||
const (
|
||||
AccountTypeOAuth = "oauth" // OAuth类型账号(full scope: profile + inference)
|
||||
AccountTypeSetupToken = "setup-token" // Setup Token类型账号(inference only scope)
|
||||
AccountTypeApiKey = "apikey" // API Key类型账号
|
||||
AccountTypeAPIKey = "apikey" // API Key类型账号
|
||||
)
|
||||
|
||||
// Redeem type constants
|
||||
@@ -64,13 +64,13 @@ const (
|
||||
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
|
||||
|
||||
// 邮件服务设置
|
||||
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
|
||||
SettingKeySmtpPort = "smtp_port" // SMTP端口
|
||||
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
|
||||
SettingKeySmtpPassword = "smtp_password" // SMTP密码(加密存储)
|
||||
SettingKeySmtpFrom = "smtp_from" // 发件人地址
|
||||
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
|
||||
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
|
||||
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
|
||||
SettingKeySMTPPort = "smtp_port" // SMTP端口
|
||||
SettingKeySMTPUsername = "smtp_username" // SMTP用户名
|
||||
SettingKeySMTPPassword = "smtp_password" // SMTP密码(加密存储)
|
||||
SettingKeySMTPFrom = "smtp_from" // 发件人地址
|
||||
SettingKeySMTPFromName = "smtp_from_name" // 发件人名称
|
||||
SettingKeySMTPUseTLS = "smtp_use_tls" // 是否使用TLS
|
||||
|
||||
// Cloudflare Turnstile 设置
|
||||
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
|
||||
@@ -81,20 +81,27 @@ const (
|
||||
SettingKeySiteName = "site_name" // 网站名称
|
||||
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
|
||||
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
|
||||
SettingKeyApiBaseUrl = "api_base_url" // API端点地址(用于客户端配置和导入)
|
||||
SettingKeyAPIBaseURL = "api_base_url" // API端点地址(用于客户端配置和导入)
|
||||
SettingKeyContactInfo = "contact_info" // 客服联系方式
|
||||
SettingKeyDocUrl = "doc_url" // 文档链接
|
||||
SettingKeyDocURL = "doc_url" // 文档链接
|
||||
|
||||
// 默认配置
|
||||
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
|
||||
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
|
||||
|
||||
// 管理员 API Key
|
||||
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
|
||||
SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key(用于外部系统集成)
|
||||
|
||||
// Gemini 配额策略(JSON)
|
||||
SettingKeyGeminiQuotaPolicy = "gemini_quota_policy"
|
||||
|
||||
// Model fallback settings
|
||||
SettingKeyEnableModelFallback = "enable_model_fallback"
|
||||
SettingKeyFallbackModelAnthropic = "fallback_model_anthropic"
|
||||
SettingKeyFallbackModelOpenAI = "fallback_model_openai"
|
||||
SettingKeyFallbackModelGemini = "fallback_model_gemini"
|
||||
SettingKeyFallbackModelAntigravity = "fallback_model_antigravity"
|
||||
)
|
||||
|
||||
// Admin API Key prefix (distinct from user "sk-" keys)
|
||||
const AdminApiKeyPrefix = "admin-"
|
||||
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
|
||||
const AdminAPIKeyPrefix = "admin-"
|
||||
|
||||
@@ -40,8 +40,8 @@ const (
|
||||
maxVerifyCodeAttempts = 5
|
||||
)
|
||||
|
||||
// SmtpConfig SMTP配置
|
||||
type SmtpConfig struct {
|
||||
// SMTPConfig SMTP配置
|
||||
type SMTPConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
@@ -65,16 +65,16 @@ func NewEmailService(settingRepo SettingRepository, cache EmailCache) *EmailServ
|
||||
}
|
||||
}
|
||||
|
||||
// GetSmtpConfig 从数据库获取SMTP配置
|
||||
func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
|
||||
// GetSMTPConfig 从数据库获取SMTP配置
|
||||
func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) {
|
||||
keys := []string{
|
||||
SettingKeySmtpHost,
|
||||
SettingKeySmtpPort,
|
||||
SettingKeySmtpUsername,
|
||||
SettingKeySmtpPassword,
|
||||
SettingKeySmtpFrom,
|
||||
SettingKeySmtpFromName,
|
||||
SettingKeySmtpUseTLS,
|
||||
SettingKeySMTPHost,
|
||||
SettingKeySMTPPort,
|
||||
SettingKeySMTPUsername,
|
||||
SettingKeySMTPPassword,
|
||||
SettingKeySMTPFrom,
|
||||
SettingKeySMTPFromName,
|
||||
SettingKeySMTPUseTLS,
|
||||
}
|
||||
|
||||
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
@@ -82,34 +82,34 @@ func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
|
||||
return nil, fmt.Errorf("get smtp settings: %w", err)
|
||||
}
|
||||
|
||||
host := settings[SettingKeySmtpHost]
|
||||
host := settings[SettingKeySMTPHost]
|
||||
if host == "" {
|
||||
return nil, ErrEmailNotConfigured
|
||||
}
|
||||
|
||||
port := 587 // 默认端口
|
||||
if portStr := settings[SettingKeySmtpPort]; portStr != "" {
|
||||
if portStr := settings[SettingKeySMTPPort]; portStr != "" {
|
||||
if p, err := strconv.Atoi(portStr); err == nil {
|
||||
port = p
|
||||
}
|
||||
}
|
||||
|
||||
useTLS := settings[SettingKeySmtpUseTLS] == "true"
|
||||
useTLS := settings[SettingKeySMTPUseTLS] == "true"
|
||||
|
||||
return &SmtpConfig{
|
||||
return &SMTPConfig{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Username: settings[SettingKeySmtpUsername],
|
||||
Password: settings[SettingKeySmtpPassword],
|
||||
From: settings[SettingKeySmtpFrom],
|
||||
FromName: settings[SettingKeySmtpFromName],
|
||||
Username: settings[SettingKeySMTPUsername],
|
||||
Password: settings[SettingKeySMTPPassword],
|
||||
From: settings[SettingKeySMTPFrom],
|
||||
FromName: settings[SettingKeySMTPFromName],
|
||||
UseTLS: useTLS,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SendEmail 发送邮件(使用数据库中保存的配置)
|
||||
func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string) error {
|
||||
config, err := s.GetSmtpConfig(ctx)
|
||||
config, err := s.GetSMTPConfig(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -117,7 +117,7 @@ func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string)
|
||||
}
|
||||
|
||||
// SendEmailWithConfig 使用指定配置发送邮件
|
||||
func (s *EmailService) SendEmailWithConfig(config *SmtpConfig, to, subject, body string) error {
|
||||
func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body string) error {
|
||||
from := config.From
|
||||
if config.FromName != "" {
|
||||
from = fmt.Sprintf("%s <%s>", config.FromName, config.From)
|
||||
@@ -306,8 +306,8 @@ func (s *EmailService) buildVerifyCodeEmailBody(code, siteName string) string {
|
||||
`, siteName, code)
|
||||
}
|
||||
|
||||
// TestSmtpConnectionWithConfig 使用指定配置测试SMTP连接
|
||||
func (s *EmailService) TestSmtpConnectionWithConfig(config *SmtpConfig) error {
|
||||
// TestSMTPConnectionWithConfig 使用指定配置测试SMTP连接
|
||||
func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
|
||||
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
|
||||
|
||||
if config.UseTLS {
|
||||
|
||||
@@ -136,6 +136,12 @@ func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int6
|
||||
func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ClearTempUnschedulable(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForPlatform) ClearRateLimit(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
@@ -276,7 +282,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_GeminiOAuthPreference(
|
||||
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
@@ -617,7 +623,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
|
||||
t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) {
|
||||
repo := &mockAccountRepoForPlatform{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
|
||||
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
|
||||
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
)
|
||||
@@ -70,3 +71,224 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
|
||||
|
||||
return parsed, nil
|
||||
}
|
||||
|
||||
// FilterThinkingBlocks removes thinking blocks from request body
|
||||
// Returns filtered body or original body if filtering fails (fail-safe)
|
||||
// This prevents 400 errors from invalid thinking block signatures
|
||||
//
|
||||
// Strategy:
|
||||
// - When thinking.type != "enabled": Remove all thinking blocks
|
||||
// - When thinking.type == "enabled": Only remove thinking blocks without valid signatures
|
||||
// (blocks with missing/empty/dummy signatures that would cause 400 errors)
|
||||
func FilterThinkingBlocks(body []byte) []byte {
|
||||
return filterThinkingBlocksInternal(body, false)
|
||||
}
|
||||
|
||||
// FilterThinkingBlocksForRetry removes thinking blocks from HISTORICAL messages for retry scenarios.
|
||||
// This is used when upstream returns signature-related 400 errors.
|
||||
//
|
||||
// Key insight:
|
||||
// - User's thinking.type = "enabled" should be PRESERVED (user's intent)
|
||||
// - Only HISTORICAL assistant messages have thinking blocks with signatures
|
||||
// - These signatures may be invalid when switching accounts/platforms
|
||||
// - New responses will generate fresh thinking blocks without signature issues
|
||||
//
|
||||
// Strategy:
|
||||
// - Keep thinking.type = "enabled" (preserve user intent)
|
||||
// - Remove thinking/redacted_thinking blocks from historical assistant messages
|
||||
// - Ensure no message has empty content after filtering
|
||||
func FilterThinkingBlocksForRetry(body []byte) []byte {
|
||||
// Fast path: check for presence of thinking-related keys in messages
|
||||
if !bytes.Contains(body, []byte(`"type":"thinking"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type": "thinking"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) {
|
||||
return body
|
||||
}
|
||||
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
// DO NOT modify thinking.type - preserve user's intent to use thinking mode
|
||||
// The issue is with historical message signatures, not the thinking mode itself
|
||||
|
||||
messages, ok := req["messages"].([]any)
|
||||
if !ok {
|
||||
return body
|
||||
}
|
||||
|
||||
modified := false
|
||||
newMessages := make([]any, 0, len(messages))
|
||||
|
||||
for _, msg := range messages {
|
||||
msgMap, ok := msg.(map[string]any)
|
||||
if !ok {
|
||||
newMessages = append(newMessages, msg)
|
||||
continue
|
||||
}
|
||||
|
||||
role, _ := msgMap["role"].(string)
|
||||
content, ok := msgMap["content"].([]any)
|
||||
if !ok {
|
||||
// String content or other format - keep as is
|
||||
newMessages = append(newMessages, msg)
|
||||
continue
|
||||
}
|
||||
|
||||
newContent := make([]any, 0, len(content))
|
||||
modifiedThisMsg := false
|
||||
|
||||
for _, block := range content {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
newContent = append(newContent, block)
|
||||
continue
|
||||
}
|
||||
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
|
||||
// Remove thinking/redacted_thinking blocks from historical messages
|
||||
// These have signatures that may be invalid across different accounts
|
||||
if blockType == "thinking" || blockType == "redacted_thinking" {
|
||||
modifiedThisMsg = true
|
||||
continue
|
||||
}
|
||||
|
||||
newContent = append(newContent, block)
|
||||
}
|
||||
|
||||
if modifiedThisMsg {
|
||||
modified = true
|
||||
// Handle empty content after filtering
|
||||
if len(newContent) == 0 {
|
||||
// For assistant messages, skip entirely (remove from conversation)
|
||||
// For user messages, add placeholder to avoid empty content error
|
||||
if role == "user" {
|
||||
newContent = append(newContent, map[string]any{
|
||||
"type": "text",
|
||||
"text": "(content removed)",
|
||||
})
|
||||
msgMap["content"] = newContent
|
||||
newMessages = append(newMessages, msgMap)
|
||||
}
|
||||
// Skip assistant messages with empty content (don't append)
|
||||
continue
|
||||
}
|
||||
msgMap["content"] = newContent
|
||||
}
|
||||
newMessages = append(newMessages, msgMap)
|
||||
}
|
||||
|
||||
if modified {
|
||||
req["messages"] = newMessages
|
||||
}
|
||||
|
||||
newBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return newBody
|
||||
}
|
||||
|
||||
// filterThinkingBlocksInternal removes invalid thinking blocks from request
|
||||
// Strategy:
|
||||
// - When thinking.type != "enabled": Remove all thinking blocks
|
||||
// - When thinking.type == "enabled": Only remove thinking blocks without valid signatures
|
||||
func filterThinkingBlocksInternal(body []byte, _ bool) []byte {
|
||||
// Fast path: if body doesn't contain "thinking", skip parsing
|
||||
if !bytes.Contains(body, []byte(`"type":"thinking"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type": "thinking"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) &&
|
||||
!bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) &&
|
||||
!bytes.Contains(body, []byte(`"thinking":`)) &&
|
||||
!bytes.Contains(body, []byte(`"thinking" :`)) {
|
||||
return body
|
||||
}
|
||||
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
// Check if thinking is enabled
|
||||
thinkingEnabled := false
|
||||
if thinking, ok := req["thinking"].(map[string]any); ok {
|
||||
if thinkType, ok := thinking["type"].(string); ok && thinkType == "enabled" {
|
||||
thinkingEnabled = true
|
||||
}
|
||||
}
|
||||
|
||||
messages, ok := req["messages"].([]any)
|
||||
if !ok {
|
||||
return body
|
||||
}
|
||||
|
||||
filtered := false
|
||||
for _, msg := range messages {
|
||||
msgMap, ok := msg.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
role, _ := msgMap["role"].(string)
|
||||
content, ok := msgMap["content"].([]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
newContent := make([]any, 0, len(content))
|
||||
filteredThisMessage := false
|
||||
|
||||
for _, block := range content {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
newContent = append(newContent, block)
|
||||
continue
|
||||
}
|
||||
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
|
||||
if blockType == "thinking" || blockType == "redacted_thinking" {
|
||||
// When thinking is enabled and this is an assistant message,
|
||||
// only keep thinking blocks with valid signatures
|
||||
if thinkingEnabled && role == "assistant" {
|
||||
signature, _ := blockMap["signature"].(string)
|
||||
if signature != "" && signature != "skip_thought_signature_validator" {
|
||||
newContent = append(newContent, block)
|
||||
continue
|
||||
}
|
||||
}
|
||||
filtered = true
|
||||
filteredThisMessage = true
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle blocks without type discriminator but with "thinking" key
|
||||
if blockType == "" {
|
||||
if _, hasThinking := blockMap["thinking"]; hasThinking {
|
||||
filtered = true
|
||||
filteredThisMessage = true
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
newContent = append(newContent, block)
|
||||
}
|
||||
|
||||
if filteredThisMessage {
|
||||
msgMap["content"] = newContent
|
||||
}
|
||||
}
|
||||
|
||||
if !filtered {
|
||||
return body
|
||||
}
|
||||
|
||||
newBody, err := json.Marshal(req)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return newBody
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -38,3 +39,115 @@ func TestParseGatewayRequest_InvalidStreamType(t *testing.T) {
|
||||
_, err := ParseGatewayRequest(body)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestFilterThinkingBlocks(t *testing.T) {
|
||||
containsThinkingBlock := func(body []byte) bool {
|
||||
var req map[string]any
|
||||
if err := json.Unmarshal(body, &req); err != nil {
|
||||
return false
|
||||
}
|
||||
messages, ok := req["messages"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
for _, msg := range messages {
|
||||
msgMap, ok := msg.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
content, ok := msgMap["content"].([]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, block := range content {
|
||||
blockMap, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
blockType, _ := blockMap["type"].(string)
|
||||
if blockType == "thinking" {
|
||||
return true
|
||||
}
|
||||
if blockType == "" {
|
||||
if _, hasThinking := blockMap["thinking"]; hasThinking {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
shouldFilter bool
|
||||
expectError bool
|
||||
}{
|
||||
{
|
||||
name: "filters thinking blocks",
|
||||
input: `{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":[{"type":"text","text":"Hello"},{"type":"thinking","thinking":"internal","signature":"invalid"},{"type":"text","text":"World"}]}]}`,
|
||||
shouldFilter: true,
|
||||
},
|
||||
{
|
||||
name: "handles no thinking blocks",
|
||||
input: `{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}]}`,
|
||||
shouldFilter: false,
|
||||
},
|
||||
{
|
||||
name: "handles invalid JSON gracefully",
|
||||
input: `{invalid json`,
|
||||
shouldFilter: false,
|
||||
expectError: true,
|
||||
},
|
||||
{
|
||||
name: "handles multiple messages with thinking blocks",
|
||||
input: `{"messages":[{"role":"user","content":[{"type":"text","text":"A"}]},{"role":"assistant","content":[{"type":"thinking","thinking":"think"},{"type":"text","text":"B"}]}]}`,
|
||||
shouldFilter: true,
|
||||
},
|
||||
{
|
||||
name: "filters thinking blocks without type discriminator",
|
||||
input: `{"messages":[{"role":"assistant","content":[{"thinking":{"text":"internal"}},{"type":"text","text":"B"}]}]}`,
|
||||
shouldFilter: true,
|
||||
},
|
||||
{
|
||||
name: "does not filter tool_use input fields named thinking",
|
||||
input: `{"messages":[{"role":"user","content":[{"type":"tool_use","id":"t1","name":"foo","input":{"thinking":"keepme","x":1}},{"type":"text","text":"Hello"}]}]}`,
|
||||
shouldFilter: false,
|
||||
},
|
||||
{
|
||||
name: "handles empty messages array",
|
||||
input: `{"messages":[]}`,
|
||||
shouldFilter: false,
|
||||
},
|
||||
{
|
||||
name: "handles missing messages field",
|
||||
input: `{"model":"claude-3"}`,
|
||||
shouldFilter: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := FilterThinkingBlocks([]byte(tt.input))
|
||||
|
||||
if tt.expectError {
|
||||
// For invalid JSON, should return original
|
||||
require.Equal(t, tt.input, string(result))
|
||||
return
|
||||
}
|
||||
|
||||
if tt.shouldFilter {
|
||||
require.False(t, containsThinkingBlock(result))
|
||||
} else {
|
||||
// Ensure we don't rewrite JSON when no filtering is needed.
|
||||
require.Equal(t, tt.input, string(result))
|
||||
}
|
||||
|
||||
// Verify valid JSON returned (unless input was invalid)
|
||||
var parsed map[string]any
|
||||
err := json.Unmarshal(result, &parsed)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -543,7 +543,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
|
||||
for _, item := range available {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
if sessionHash != "" {
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
_ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
@@ -579,7 +579,7 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates
|
||||
for _, acc := range ordered {
|
||||
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
|
||||
if err == nil && result.Acquired {
|
||||
if sessionHash != "" {
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
_ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL)
|
||||
}
|
||||
return &AccountSelectionResult{
|
||||
@@ -710,7 +710,7 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
|
||||
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
|
||||
preferOAuth := platform == PlatformGemini
|
||||
// 1. 查询粘性会话
|
||||
if sessionHash != "" {
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
||||
if err == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
@@ -783,7 +783,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
|
||||
}
|
||||
|
||||
// 4. 建立粘性绑定
|
||||
if sessionHash != "" {
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
|
||||
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
||||
}
|
||||
@@ -799,7 +799,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
preferOAuth := nativePlatform == PlatformGemini
|
||||
|
||||
// 1. 查询粘性会话
|
||||
if sessionHash != "" {
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
|
||||
if err == nil && accountID > 0 {
|
||||
if _, excluded := excludedIDs[accountID]; !excluded {
|
||||
@@ -875,7 +875,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
|
||||
}
|
||||
|
||||
// 4. 建立粘性绑定
|
||||
if sessionHash != "" {
|
||||
if sessionHash != "" && s.cache != nil {
|
||||
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
|
||||
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
|
||||
}
|
||||
@@ -907,7 +907,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
|
||||
case AccountTypeOAuth, AccountTypeSetupToken:
|
||||
// Both oauth and setup-token use OAuth token flow
|
||||
return s.getOAuthToken(ctx, account)
|
||||
case AccountTypeApiKey:
|
||||
case AccountTypeAPIKey:
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if apiKey == "" {
|
||||
return "", "", errors.New("api_key not found in credentials")
|
||||
@@ -1045,7 +1045,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
|
||||
// 应用模型映射(仅对apikey类型账号)
|
||||
originalModel := reqModel
|
||||
if account.Type == AccountTypeApiKey {
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
// 替换请求体中的模型名
|
||||
@@ -1082,8 +1082,45 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
return nil, fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
|
||||
// 检查是否需要重试
|
||||
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
|
||||
// 优先检测thinking block签名错误(400)并重试一次
|
||||
if resp.StatusCode == 400 {
|
||||
respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
if readErr == nil {
|
||||
_ = resp.Body.Close()
|
||||
|
||||
if s.isThinkingBlockSignatureError(respBody) {
|
||||
log.Printf("Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID)
|
||||
|
||||
// 过滤thinking blocks并重试(使用更激进的过滤)
|
||||
filteredBody := FilterThinkingBlocksForRetry(body)
|
||||
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
|
||||
if buildErr == nil {
|
||||
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
||||
if retryErr == nil {
|
||||
// 使用重试后的响应,继续后续处理
|
||||
if retryResp.StatusCode < 400 {
|
||||
log.Printf("Account %d: signature error retry succeeded", account.ID)
|
||||
} else {
|
||||
log.Printf("Account %d: signature error retry returned status %d", account.ID, retryResp.StatusCode)
|
||||
}
|
||||
resp = retryResp
|
||||
break
|
||||
}
|
||||
log.Printf("Account %d: signature error retry failed: %v", account.ID, retryErr)
|
||||
} else {
|
||||
log.Printf("Account %d: signature error retry build request failed: %v", account.ID, buildErr)
|
||||
}
|
||||
// 重试失败,恢复原始响应体继续处理
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
break
|
||||
}
|
||||
// 不是thinking签名错误,恢复响应体
|
||||
resp.Body = io.NopCloser(bytes.NewReader(respBody))
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否需要通用重试(排除400,因为400已经在上面特殊处理过了)
|
||||
if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
|
||||
if attempt < maxRetries {
|
||||
log.Printf("Account %d: upstream error %d, retry %d/%d after %v",
|
||||
account.ID, resp.StatusCode, attempt, maxRetries, retryDelay)
|
||||
@@ -1096,6 +1133,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
}
|
||||
|
||||
// 不需要重试(成功或不可重试的错误),跳出循环
|
||||
// DEBUG: 输出响应 headers(用于检测 rate limit 信息)
|
||||
if account.Platform == PlatformGemini && resp.StatusCode < 400 {
|
||||
log.Printf("[DEBUG] Gemini API Response Headers for account %d:", account.ID)
|
||||
for k, v := range resp.Header {
|
||||
log.Printf("[DEBUG] %s: %v", k, v)
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
@@ -1119,7 +1163,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
if resp.StatusCode >= 400 {
|
||||
// 可选:对部分 400 触发 failover(默认关闭以保持语义)
|
||||
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
if readErr != nil {
|
||||
// ReadAll failed, fall back to normal error handling without consuming the stream
|
||||
return s.handleErrorResponse(ctx, resp, c, account)
|
||||
@@ -1179,7 +1223,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
|
||||
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
|
||||
// 确定目标URL
|
||||
targetURL := claudeAPIURL
|
||||
if account.Type == AccountTypeApiKey {
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
baseURL := account.GetBaseURL()
|
||||
targetURL = baseURL + "/v1/messages"
|
||||
}
|
||||
@@ -1243,10 +1287,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
|
||||
// 处理anthropic-beta header(OAuth账号需要特殊处理)
|
||||
if tokenType == "oauth" {
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
||||
// API-key:仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
|
||||
if requestNeedsBetaFeatures(body) {
|
||||
if beta := defaultApiKeyBetaHeader(body); beta != "" {
|
||||
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
|
||||
req.Header.Set("anthropic-beta", beta)
|
||||
}
|
||||
}
|
||||
@@ -1313,12 +1357,12 @@ func requestNeedsBetaFeatures(body []byte) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func defaultApiKeyBetaHeader(body []byte) string {
|
||||
func defaultAPIKeyBetaHeader(body []byte) string {
|
||||
modelID := gjson.GetBytes(body, "model").String()
|
||||
if strings.Contains(strings.ToLower(modelID), "haiku") {
|
||||
return claude.ApiKeyHaikuBetaHeader
|
||||
return claude.APIKeyHaikuBetaHeader
|
||||
}
|
||||
return claude.ApiKeyBetaHeader
|
||||
return claude.APIKeyBetaHeader
|
||||
}
|
||||
|
||||
func truncateForLog(b []byte, maxBytes int) string {
|
||||
@@ -1335,6 +1379,41 @@ func truncateForLog(b []byte, maxBytes int) string {
|
||||
return s
|
||||
}
|
||||
|
||||
// isThinkingBlockSignatureError 检测是否是thinking block相关错误
|
||||
// 这类错误可以通过过滤thinking blocks并重试来解决
|
||||
func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
|
||||
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
|
||||
if msg == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Log for debugging
|
||||
log.Printf("[SignatureCheck] Checking error message: %s", msg)
|
||||
|
||||
// 检测signature相关的错误(更宽松的匹配)
|
||||
// 例如: "Invalid `signature` in `thinking` block", "***.signature" 等
|
||||
if strings.Contains(msg, "signature") {
|
||||
log.Printf("[SignatureCheck] Detected signature error")
|
||||
return true
|
||||
}
|
||||
|
||||
// 检测 thinking block 顺序/类型错误
|
||||
// 例如: "Expected `thinking` or `redacted_thinking`, but found `text`"
|
||||
if strings.Contains(msg, "expected") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) {
|
||||
log.Printf("[SignatureCheck] Detected thinking block type error")
|
||||
return true
|
||||
}
|
||||
|
||||
// 检测空消息内容错误(可能是过滤 thinking blocks 后导致的)
|
||||
// 例如: "all messages must have non-empty content"
|
||||
if strings.Contains(msg, "non-empty content") || strings.Contains(msg, "empty content") {
|
||||
log.Printf("[SignatureCheck] Detected empty content error")
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
|
||||
// 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。
|
||||
// 默认保守:无法识别则不切换。
|
||||
@@ -1383,7 +1462,13 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
|
||||
// 处理上游错误,标记账号状态
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
shouldDisable := false
|
||||
if s.rateLimitService != nil {
|
||||
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
}
|
||||
if shouldDisable {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
|
||||
// 根据状态码返回适当的自定义错误响应(不透传上游详细信息)
|
||||
var errType, errMsg string
|
||||
@@ -1695,7 +1780,7 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
|
||||
// RecordUsageInput 记录使用量的输入参数
|
||||
type RecordUsageInput struct {
|
||||
Result *ForwardResult
|
||||
ApiKey *ApiKey
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription // 可选:订阅信息
|
||||
@@ -1704,7 +1789,7 @@ type RecordUsageInput struct {
|
||||
// RecordUsage 记录使用量并扣费(或更新订阅用量)
|
||||
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
|
||||
result := input.Result
|
||||
apiKey := input.ApiKey
|
||||
apiKey := input.APIKey
|
||||
user := input.User
|
||||
account := input.Account
|
||||
subscription := input.Subscription
|
||||
@@ -1741,7 +1826,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
durationMs := int(result.Duration.Milliseconds())
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
Model: result.Model,
|
||||
@@ -1771,7 +1856,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
usageLog.SubscriptionID = &subscription.ID
|
||||
}
|
||||
|
||||
if err := s.usageLogRepo.Create(ctx, usageLog); err != nil {
|
||||
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
|
||||
if err != nil {
|
||||
log.Printf("Create usage log failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -1781,10 +1867,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
return nil
|
||||
}
|
||||
|
||||
shouldBill := inserted || err != nil
|
||||
|
||||
// 根据计费类型执行扣费
|
||||
if isSubscriptionBilling {
|
||||
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
|
||||
if cost.TotalCost > 0 {
|
||||
if shouldBill && cost.TotalCost > 0 {
|
||||
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
|
||||
log.Printf("Increment subscription usage failed: %v", err)
|
||||
}
|
||||
@@ -1793,7 +1881,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
|
||||
}
|
||||
} else {
|
||||
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
|
||||
if cost.ActualCost > 0 {
|
||||
if shouldBill && cost.ActualCost > 0 {
|
||||
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
|
||||
log.Printf("Deduct balance failed: %v", err)
|
||||
}
|
||||
@@ -1826,7 +1914,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
}
|
||||
|
||||
// 应用模型映射(仅对 apikey 类型账号)
|
||||
if account.Type == AccountTypeApiKey {
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
if reqModel != "" {
|
||||
mappedModel := account.GetMappedModel(reqModel)
|
||||
if mappedModel != reqModel {
|
||||
@@ -1863,17 +1951,35 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
|
||||
return fmt.Errorf("upstream request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
// 读取响应体
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
if err != nil {
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
||||
return err
|
||||
}
|
||||
|
||||
// 检测 thinking block 签名错误(400)并重试一次(过滤 thinking blocks)
|
||||
if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) {
|
||||
log.Printf("Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID)
|
||||
|
||||
filteredBody := FilterThinkingBlocks(body)
|
||||
retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
|
||||
if buildErr == nil {
|
||||
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
|
||||
if retryErr == nil {
|
||||
resp = retryResp
|
||||
respBody, err = io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
if err != nil {
|
||||
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理错误响应
|
||||
if resp.StatusCode >= 400 {
|
||||
// 标记账号状态(429/529等)
|
||||
@@ -1912,7 +2018,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
|
||||
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
|
||||
// 确定目标 URL
|
||||
targetURL := claudeAPICountTokensURL
|
||||
if account.Type == AccountTypeApiKey {
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
baseURL := account.GetBaseURL()
|
||||
targetURL = baseURL + "/v1/messages/count_tokens"
|
||||
}
|
||||
@@ -1971,10 +2077,10 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
|
||||
// OAuth 账号:处理 anthropic-beta header
|
||||
if tokenType == "oauth" {
|
||||
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
|
||||
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
|
||||
// API-key:与 messages 同步的按需 beta 注入(默认关闭)
|
||||
if requestNeedsBetaFeatures(body) {
|
||||
if beta := defaultApiKeyBetaHeader(body); beta != "" {
|
||||
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
|
||||
req.Header.Set("anthropic-beta", beta)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -273,7 +273,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
|
||||
return 999
|
||||
}
|
||||
switch a.Type {
|
||||
case AccountTypeApiKey:
|
||||
case AccountTypeAPIKey:
|
||||
if strings.TrimSpace(a.GetCredential("api_key")) != "" {
|
||||
return 0
|
||||
}
|
||||
@@ -351,7 +351,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
|
||||
originalModel := req.Model
|
||||
mappedModel := req.Model
|
||||
if account.Type == AccountTypeApiKey {
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
mappedModel = account.GetMappedModel(req.Model)
|
||||
}
|
||||
|
||||
@@ -374,7 +374,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
}
|
||||
|
||||
switch account.Type {
|
||||
case AccountTypeApiKey:
|
||||
case AccountTypeAPIKey:
|
||||
buildReq = func(ctx context.Context) (*http.Request, string, error) {
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if strings.TrimSpace(apiKey) == "" {
|
||||
@@ -539,7 +539,14 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
tempMatched := false
|
||||
if s.rateLimitService != nil {
|
||||
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
|
||||
}
|
||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
if tempMatched {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
@@ -614,7 +621,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
}
|
||||
|
||||
mappedModel := originalModel
|
||||
if account.Type == AccountTypeApiKey {
|
||||
if account.Type == AccountTypeAPIKey {
|
||||
mappedModel = account.GetMappedModel(originalModel)
|
||||
}
|
||||
|
||||
@@ -636,7 +643,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
var buildReq func(ctx context.Context) (*http.Request, string, error)
|
||||
|
||||
switch account.Type {
|
||||
case AccountTypeApiKey:
|
||||
case AccountTypeAPIKey:
|
||||
buildReq = func(ctx context.Context) (*http.Request, string, error) {
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if strings.TrimSpace(apiKey) == "" {
|
||||
@@ -825,6 +832,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
|
||||
tempMatched := false
|
||||
if s.rateLimitService != nil {
|
||||
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
|
||||
}
|
||||
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
|
||||
|
||||
// Best-effort fallback for OAuth tokens missing AI Studio scopes when calling countTokens.
|
||||
@@ -842,6 +853,9 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
|
||||
}, nil
|
||||
}
|
||||
|
||||
if tempMatched {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
@@ -1614,6 +1628,15 @@ type UpstreamHTTPResult struct {
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Context, resp *http.Response, isOAuth bool) (*ClaudeUsage, error) {
|
||||
// Log response headers for debugging
|
||||
log.Printf("[GeminiAPI] ========== Response Headers ==========")
|
||||
for key, values := range resp.Header {
|
||||
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
|
||||
log.Printf("[GeminiAPI] %s: %v", key, values)
|
||||
}
|
||||
}
|
||||
log.Printf("[GeminiAPI] ========================================")
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1644,6 +1667,15 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
|
||||
}
|
||||
|
||||
func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, isOAuth bool) (*geminiNativeStreamResult, error) {
|
||||
// Log response headers for debugging
|
||||
log.Printf("[GeminiAPI] ========== Streaming Response Headers ==========")
|
||||
for key, values := range resp.Header {
|
||||
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
|
||||
log.Printf("[GeminiAPI] %s: %v", key, values)
|
||||
}
|
||||
}
|
||||
log.Printf("[GeminiAPI] ====================================================")
|
||||
|
||||
c.Status(resp.StatusCode)
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
@@ -1758,7 +1790,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
|
||||
}
|
||||
|
||||
switch account.Type {
|
||||
case AccountTypeApiKey:
|
||||
case AccountTypeAPIKey:
|
||||
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
|
||||
if apiKey == "" {
|
||||
return nil, errors.New("gemini api_key not configured")
|
||||
@@ -2177,10 +2209,12 @@ func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[str
|
||||
parts := make([]any, 0)
|
||||
switch content := mm["content"].(type) {
|
||||
case string:
|
||||
if strings.TrimSpace(content) != "" {
|
||||
parts = append(parts, map[string]any{"text": content})
|
||||
}
|
||||
// 字符串形式的 content,保留所有内容(包括空白)
|
||||
parts = append(parts, map[string]any{"text": content})
|
||||
case []any:
|
||||
// 如果只有一个 block,不过滤空白(让上游 API 报错)
|
||||
singleBlock := len(content) == 1
|
||||
|
||||
for _, block := range content {
|
||||
bm, ok := block.(map[string]any)
|
||||
if !ok {
|
||||
@@ -2189,8 +2223,12 @@ func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[str
|
||||
bt, _ := bm["type"].(string)
|
||||
switch bt {
|
||||
case "text":
|
||||
if text, ok := bm["text"].(string); ok && strings.TrimSpace(text) != "" {
|
||||
parts = append(parts, map[string]any{"text": text})
|
||||
if text, ok := bm["text"].(string); ok {
|
||||
// 单个 block 时保留所有内容(包括空白)
|
||||
// 多个 blocks 时过滤掉空白
|
||||
if singleBlock || strings.TrimSpace(text) != "" {
|
||||
parts = append(parts, map[string]any{"text": text})
|
||||
}
|
||||
}
|
||||
case "tool_use":
|
||||
id, _ := bm["id"].(string)
|
||||
|
||||
@@ -121,6 +121,12 @@ func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64,
|
||||
func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ClearTempUnschedulable(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
func (m *mockAccountRepoForGemini) ClearRateLimit(ctx context.Context, id int64) error { return nil }
|
||||
func (m *mockAccountRepoForGemini) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||
return nil
|
||||
@@ -275,7 +281,7 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OAuthPr
|
||||
|
||||
repo := &mockAccountRepoForGemini{
|
||||
accounts: []Account{
|
||||
{ID: 1, Platform: PlatformGemini, Type: AccountTypeApiKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
|
||||
{ID: 1, Platform: PlatformGemini, Type: AccountTypeAPIKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
|
||||
{ID: 2, Platform: PlatformGemini, Type: AccountTypeOAuth, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
|
||||
},
|
||||
accountsByID: map[int64]*Account{},
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
@@ -18,12 +19,23 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
TierAIPremium = "AI_PREMIUM"
|
||||
TierGoogleOneStandard = "GOOGLE_ONE_STANDARD"
|
||||
TierGoogleOneBasic = "GOOGLE_ONE_BASIC"
|
||||
TierFree = "FREE"
|
||||
TierGoogleOneUnknown = "GOOGLE_ONE_UNKNOWN"
|
||||
TierGoogleOneUnlimited = "GOOGLE_ONE_UNLIMITED"
|
||||
// Canonical tier IDs used by sub2api (2026-aligned).
|
||||
GeminiTierGoogleOneFree = "google_one_free"
|
||||
GeminiTierGoogleAIPro = "google_ai_pro"
|
||||
GeminiTierGoogleAIUltra = "google_ai_ultra"
|
||||
GeminiTierGCPStandard = "gcp_standard"
|
||||
GeminiTierGCPEnterprise = "gcp_enterprise"
|
||||
GeminiTierAIStudioFree = "aistudio_free"
|
||||
GeminiTierAIStudioPaid = "aistudio_paid"
|
||||
GeminiTierGoogleOneUnknown = "google_one_unknown"
|
||||
|
||||
// Legacy/compat tier IDs that may exist in historical data or upstream responses.
|
||||
legacyTierAIPremium = "AI_PREMIUM"
|
||||
legacyTierGoogleOneStandard = "GOOGLE_ONE_STANDARD"
|
||||
legacyTierGoogleOneBasic = "GOOGLE_ONE_BASIC"
|
||||
legacyTierFree = "FREE"
|
||||
legacyTierGoogleOneUnknown = "GOOGLE_ONE_UNKNOWN"
|
||||
legacyTierGoogleOneUnlimited = "GOOGLE_ONE_UNLIMITED"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -84,7 +96,7 @@ type GeminiAuthURLResult struct {
|
||||
State string `json:"state"`
|
||||
}
|
||||
|
||||
func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI, projectID, oauthType string) (*GeminiAuthURLResult, error) {
|
||||
func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI, projectID, oauthType, tierID string) (*GeminiAuthURLResult, error) {
|
||||
state, err := geminicli.GenerateState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate state: %w", err)
|
||||
@@ -109,14 +121,14 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
||||
|
||||
// OAuth client selection:
|
||||
// - code_assist: always use built-in Gemini CLI OAuth client (public), regardless of configured client_id/secret.
|
||||
// - google_one: same as code_assist, uses built-in client for personal Google accounts.
|
||||
// - google_one: uses configured OAuth client when provided; otherwise falls back to built-in client.
|
||||
// - ai_studio: requires a user-provided OAuth client.
|
||||
oauthCfg := geminicli.OAuthConfig{
|
||||
ClientID: s.cfg.Gemini.OAuth.ClientID,
|
||||
ClientSecret: s.cfg.Gemini.OAuth.ClientSecret,
|
||||
Scopes: s.cfg.Gemini.OAuth.Scopes,
|
||||
}
|
||||
if oauthType == "code_assist" || oauthType == "google_one" {
|
||||
if oauthType == "code_assist" {
|
||||
oauthCfg.ClientID = ""
|
||||
oauthCfg.ClientSecret = ""
|
||||
}
|
||||
@@ -127,6 +139,7 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
||||
ProxyURL: proxyURL,
|
||||
RedirectURI: redirectURI,
|
||||
ProjectID: strings.TrimSpace(projectID),
|
||||
TierID: canonicalGeminiTierIDForOAuthType(oauthType, tierID),
|
||||
OAuthType: oauthType,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
@@ -146,9 +159,9 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
|
||||
}
|
||||
|
||||
// Redirect URI strategy:
|
||||
// - code_assist: use Gemini CLI redirect URI (codeassist.google.com/authcode)
|
||||
// - ai_studio: use localhost callback for manual copy/paste flow
|
||||
if oauthType == "code_assist" {
|
||||
// - built-in Gemini CLI OAuth client: use upstream redirect URI (codeassist.google.com/authcode)
|
||||
// - custom OAuth client: use localhost callback for manual copy/paste flow
|
||||
if isBuiltinClient {
|
||||
redirectURI = geminicli.GeminiCLIRedirectURI
|
||||
} else {
|
||||
redirectURI = geminicli.AIStudioOAuthRedirectURI
|
||||
@@ -174,6 +187,9 @@ type GeminiExchangeCodeInput struct {
|
||||
Code string
|
||||
ProxyID *int64
|
||||
OAuthType string // "code_assist" 或 "ai_studio"
|
||||
// TierID is a user-selected tier to be used when auto detection is unavailable or fails.
|
||||
// If empty, the service will fall back to the tier stored in the OAuth session (if any).
|
||||
TierID string
|
||||
}
|
||||
|
||||
type GeminiTokenInfo struct {
|
||||
@@ -185,7 +201,7 @@ type GeminiTokenInfo struct {
|
||||
Scope string `json:"scope,omitempty"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio"
|
||||
TierID string `json:"tier_id,omitempty"` // Gemini Code Assist tier: LEGACY/PRO/ULTRA
|
||||
TierID string `json:"tier_id,omitempty"` // Canonical tier id (e.g. google_one_free, gcp_standard, aistudio_free)
|
||||
Extra map[string]any `json:"extra,omitempty"` // Drive metadata
|
||||
}
|
||||
|
||||
@@ -204,6 +220,90 @@ func validateTierID(tierID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func canonicalGeminiTierID(raw string) string {
|
||||
raw = strings.TrimSpace(raw)
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
lower := strings.ToLower(raw)
|
||||
switch lower {
|
||||
case GeminiTierGoogleOneFree,
|
||||
GeminiTierGoogleAIPro,
|
||||
GeminiTierGoogleAIUltra,
|
||||
GeminiTierGCPStandard,
|
||||
GeminiTierGCPEnterprise,
|
||||
GeminiTierAIStudioFree,
|
||||
GeminiTierAIStudioPaid,
|
||||
GeminiTierGoogleOneUnknown:
|
||||
return lower
|
||||
}
|
||||
|
||||
upper := strings.ToUpper(raw)
|
||||
switch upper {
|
||||
// Google One legacy tiers
|
||||
case legacyTierAIPremium:
|
||||
return GeminiTierGoogleAIPro
|
||||
case legacyTierGoogleOneUnlimited:
|
||||
return GeminiTierGoogleAIUltra
|
||||
case legacyTierFree, legacyTierGoogleOneBasic, legacyTierGoogleOneStandard:
|
||||
return GeminiTierGoogleOneFree
|
||||
case legacyTierGoogleOneUnknown:
|
||||
return GeminiTierGoogleOneUnknown
|
||||
|
||||
// Code Assist legacy tiers
|
||||
case "STANDARD", "PRO", "LEGACY":
|
||||
return GeminiTierGCPStandard
|
||||
case "ENTERPRISE", "ULTRA":
|
||||
return GeminiTierGCPEnterprise
|
||||
}
|
||||
|
||||
// Some Code Assist responses use kebab-case tier identifiers.
|
||||
switch lower {
|
||||
case "standard-tier", "pro-tier":
|
||||
return GeminiTierGCPStandard
|
||||
case "ultra-tier":
|
||||
return GeminiTierGCPEnterprise
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func canonicalGeminiTierIDForOAuthType(oauthType, tierID string) string {
|
||||
oauthType = strings.ToLower(strings.TrimSpace(oauthType))
|
||||
canonical := canonicalGeminiTierID(tierID)
|
||||
if canonical == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch oauthType {
|
||||
case "google_one":
|
||||
switch canonical {
|
||||
case GeminiTierGoogleOneFree, GeminiTierGoogleAIPro, GeminiTierGoogleAIUltra:
|
||||
return canonical
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
case "code_assist":
|
||||
switch canonical {
|
||||
case GeminiTierGCPStandard, GeminiTierGCPEnterprise:
|
||||
return canonical
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
case "ai_studio":
|
||||
switch canonical {
|
||||
case GeminiTierAIStudioFree, GeminiTierAIStudioPaid:
|
||||
return canonical
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
default:
|
||||
// Unknown oauth type: accept canonical tier.
|
||||
return canonical
|
||||
}
|
||||
}
|
||||
|
||||
// extractTierIDFromAllowedTiers extracts tierID from LoadCodeAssist response
|
||||
// Prioritizes IsDefault tier, falls back to first non-empty tier
|
||||
func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string {
|
||||
@@ -229,45 +329,61 @@ func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string
|
||||
|
||||
// inferGoogleOneTier infers Google One tier from Drive storage limit
|
||||
func inferGoogleOneTier(storageBytes int64) string {
|
||||
log.Printf("[GeminiOAuth] inferGoogleOneTier - input: %d bytes (%.2f TB)", storageBytes, float64(storageBytes)/float64(TB))
|
||||
|
||||
if storageBytes <= 0 {
|
||||
return TierGoogleOneUnknown
|
||||
log.Printf("[GeminiOAuth] inferGoogleOneTier - storageBytes <= 0, returning UNKNOWN")
|
||||
return GeminiTierGoogleOneUnknown
|
||||
}
|
||||
|
||||
if storageBytes > StorageTierUnlimited {
|
||||
return TierGoogleOneUnlimited
|
||||
log.Printf("[GeminiOAuth] inferGoogleOneTier - > %d bytes (100TB), returning UNLIMITED", StorageTierUnlimited)
|
||||
return GeminiTierGoogleAIUltra
|
||||
}
|
||||
if storageBytes >= StorageTierAIPremium {
|
||||
return TierAIPremium
|
||||
}
|
||||
if storageBytes >= StorageTierStandard {
|
||||
return TierGoogleOneStandard
|
||||
}
|
||||
if storageBytes >= StorageTierBasic {
|
||||
return TierGoogleOneBasic
|
||||
log.Printf("[GeminiOAuth] inferGoogleOneTier - >= %d bytes (2TB), returning google_ai_pro", StorageTierAIPremium)
|
||||
return GeminiTierGoogleAIPro
|
||||
}
|
||||
if storageBytes >= StorageTierFree {
|
||||
return TierFree
|
||||
log.Printf("[GeminiOAuth] inferGoogleOneTier - >= %d bytes (15GB), returning FREE", StorageTierFree)
|
||||
return GeminiTierGoogleOneFree
|
||||
}
|
||||
return TierGoogleOneUnknown
|
||||
|
||||
log.Printf("[GeminiOAuth] inferGoogleOneTier - < %d bytes (15GB), returning UNKNOWN", StorageTierFree)
|
||||
return GeminiTierGoogleOneUnknown
|
||||
}
|
||||
|
||||
// fetchGoogleOneTier fetches Google One tier from Drive API
|
||||
// FetchGoogleOneTier fetches Google One tier from Drive API.
|
||||
// Note: LoadCodeAssist API is NOT called for Google One accounts because:
|
||||
// 1. It's designed for GCP IAM (enterprise), not personal Google accounts
|
||||
// 2. Personal accounts will get 403/404 from cloudaicompanion.googleapis.com
|
||||
// 3. Google consumer (Google One) and enterprise (GCP) systems are physically isolated
|
||||
func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken, proxyURL string) (string, *geminicli.DriveStorageInfo, error) {
|
||||
log.Printf("[GeminiOAuth] Starting FetchGoogleOneTier (Google One personal account)")
|
||||
|
||||
// Use Drive API to infer tier from storage quota (requires drive.readonly scope)
|
||||
log.Printf("[GeminiOAuth] Calling Drive API for storage quota...")
|
||||
driveClient := geminicli.NewDriveClient()
|
||||
|
||||
storageInfo, err := driveClient.GetStorageQuota(ctx, accessToken, proxyURL)
|
||||
if err != nil {
|
||||
// Check if it's a 403 (scope not granted)
|
||||
if strings.Contains(err.Error(), "status 403") {
|
||||
fmt.Printf("[GeminiOAuth] Drive API scope not available: %v\n", err)
|
||||
return TierGoogleOneUnknown, nil, err
|
||||
log.Printf("[GeminiOAuth] Drive API scope not available (403): %v", err)
|
||||
return GeminiTierGoogleOneUnknown, nil, err
|
||||
}
|
||||
// Other errors
|
||||
fmt.Printf("[GeminiOAuth] Failed to fetch Drive storage: %v\n", err)
|
||||
return TierGoogleOneUnknown, nil, err
|
||||
log.Printf("[GeminiOAuth] Failed to fetch Drive storage: %v", err)
|
||||
return GeminiTierGoogleOneUnknown, nil, err
|
||||
}
|
||||
|
||||
log.Printf("[GeminiOAuth] Drive API response - Limit: %d bytes (%.2f TB), Usage: %d bytes (%.2f GB)",
|
||||
storageInfo.Limit, float64(storageInfo.Limit)/float64(TB),
|
||||
storageInfo.Usage, float64(storageInfo.Usage)/float64(GB))
|
||||
|
||||
tierID := inferGoogleOneTier(storageInfo.Limit)
|
||||
log.Printf("[GeminiOAuth] Inferred tier from storage: %s", tierID)
|
||||
|
||||
return tierID, storageInfo, nil
|
||||
}
|
||||
|
||||
@@ -326,11 +442,16 @@ func (s *GeminiOAuthService) RefreshAccountGoogleOneTier(
|
||||
}
|
||||
|
||||
func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) {
|
||||
log.Printf("[GeminiOAuth] ========== ExchangeCode START ==========")
|
||||
log.Printf("[GeminiOAuth] SessionID: %s", input.SessionID)
|
||||
|
||||
session, ok := s.sessionStore.Get(input.SessionID)
|
||||
if !ok {
|
||||
log.Printf("[GeminiOAuth] ERROR: Session not found or expired")
|
||||
return nil, fmt.Errorf("session not found or expired")
|
||||
}
|
||||
if strings.TrimSpace(input.State) == "" || input.State != session.State {
|
||||
log.Printf("[GeminiOAuth] ERROR: Invalid state")
|
||||
return nil, fmt.Errorf("invalid state")
|
||||
}
|
||||
|
||||
@@ -341,6 +462,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
log.Printf("[GeminiOAuth] ProxyURL: %s", proxyURL)
|
||||
|
||||
redirectURI := session.RedirectURI
|
||||
|
||||
@@ -349,6 +471,8 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
if oauthType == "" {
|
||||
oauthType = "code_assist"
|
||||
}
|
||||
log.Printf("[GeminiOAuth] OAuth Type: %s", oauthType)
|
||||
log.Printf("[GeminiOAuth] Project ID from session: %s", session.ProjectID)
|
||||
|
||||
// If the session was created for AI Studio OAuth, ensure a custom OAuth client is configured.
|
||||
if oauthType == "ai_studio" {
|
||||
@@ -374,8 +498,13 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
|
||||
tokenResp, err := s.oauthClient.ExchangeCode(ctx, oauthType, input.Code, session.CodeVerifier, redirectURI, proxyURL)
|
||||
if err != nil {
|
||||
log.Printf("[GeminiOAuth] ERROR: Failed to exchange code: %v", err)
|
||||
return nil, fmt.Errorf("failed to exchange code: %w", err)
|
||||
}
|
||||
log.Printf("[GeminiOAuth] Token exchange successful")
|
||||
log.Printf("[GeminiOAuth] Token scope: %s", tokenResp.Scope)
|
||||
log.Printf("[GeminiOAuth] Token expires_in: %d seconds", tokenResp.ExpiresIn)
|
||||
|
||||
sessionProjectID := strings.TrimSpace(session.ProjectID)
|
||||
s.sessionStore.Delete(input.SessionID)
|
||||
|
||||
@@ -391,43 +520,91 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
|
||||
projectID := sessionProjectID
|
||||
var tierID string
|
||||
fallbackTierID := canonicalGeminiTierIDForOAuthType(oauthType, input.TierID)
|
||||
if fallbackTierID == "" {
|
||||
fallbackTierID = canonicalGeminiTierIDForOAuthType(oauthType, session.TierID)
|
||||
}
|
||||
|
||||
log.Printf("[GeminiOAuth] ========== Account Type Detection START ==========")
|
||||
log.Printf("[GeminiOAuth] OAuth Type: %s", oauthType)
|
||||
|
||||
// 对于 code_assist 模式,project_id 是必需的,需要调用 Code Assist API
|
||||
// 对于 google_one 模式,使用个人 Google 账号,不需要 project_id,配额由 Google 网关自动识别
|
||||
// 对于 ai_studio 模式,project_id 是可选的(不影响使用 AI Studio API)
|
||||
switch oauthType {
|
||||
case "code_assist":
|
||||
log.Printf("[GeminiOAuth] Processing code_assist OAuth type")
|
||||
if projectID == "" {
|
||||
log.Printf("[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...")
|
||||
var err error
|
||||
projectID, tierID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
|
||||
if err != nil {
|
||||
// 记录警告但不阻断流程,允许后续补充 project_id
|
||||
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err)
|
||||
log.Printf("[GeminiOAuth] WARNING: Failed to fetch project_id: %v", err)
|
||||
} else {
|
||||
log.Printf("[GeminiOAuth] Successfully fetched project_id: %s, tier_id: %s", projectID, tierID)
|
||||
}
|
||||
} else {
|
||||
log.Printf("[GeminiOAuth] User provided project_id: %s, fetching tier_id...", projectID)
|
||||
// 用户手动填了 project_id,仍需调用 LoadCodeAssist 获取 tierID
|
||||
_, fetchedTierID, err := s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
|
||||
if err != nil {
|
||||
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch tierID: %v\n", err)
|
||||
log.Printf("[GeminiOAuth] WARNING: Failed to fetch tier_id: %v", err)
|
||||
} else {
|
||||
tierID = fetchedTierID
|
||||
log.Printf("[GeminiOAuth] Successfully fetched tier_id: %s", tierID)
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(projectID) == "" {
|
||||
log.Printf("[GeminiOAuth] ERROR: Missing project_id for Code Assist OAuth")
|
||||
return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project")
|
||||
}
|
||||
// tierID 缺失时使用默认值
|
||||
// Prefer auto-detected tier; fall back to user-selected tier.
|
||||
tierID = canonicalGeminiTierIDForOAuthType(oauthType, tierID)
|
||||
if tierID == "" {
|
||||
tierID = "LEGACY"
|
||||
if fallbackTierID != "" {
|
||||
tierID = fallbackTierID
|
||||
log.Printf("[GeminiOAuth] Using fallback tier_id from user/session: %s", tierID)
|
||||
} else {
|
||||
tierID = GeminiTierGCPStandard
|
||||
log.Printf("[GeminiOAuth] Using default tier_id: %s", tierID)
|
||||
}
|
||||
}
|
||||
log.Printf("[GeminiOAuth] Final code_assist result - project_id: %s, tier_id: %s", projectID, tierID)
|
||||
|
||||
case "google_one":
|
||||
log.Printf("[GeminiOAuth] Processing google_one OAuth type")
|
||||
log.Printf("[GeminiOAuth] Attempting to fetch Google One tier from Drive API...")
|
||||
// Attempt to fetch Drive storage tier
|
||||
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, tokenResp.AccessToken, proxyURL)
|
||||
var storageInfo *geminicli.DriveStorageInfo
|
||||
var err error
|
||||
tierID, storageInfo, err = s.FetchGoogleOneTier(ctx, tokenResp.AccessToken, proxyURL)
|
||||
if err != nil {
|
||||
// Log warning but don't block - use fallback
|
||||
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch Drive tier: %v\n", err)
|
||||
tierID = TierGoogleOneUnknown
|
||||
log.Printf("[GeminiOAuth] WARNING: Failed to fetch Drive tier: %v", err)
|
||||
tierID = ""
|
||||
} else {
|
||||
log.Printf("[GeminiOAuth] Successfully fetched Drive tier: %s", tierID)
|
||||
if storageInfo != nil {
|
||||
log.Printf("[GeminiOAuth] Drive storage - Limit: %d bytes (%.2f TB), Usage: %d bytes (%.2f GB)",
|
||||
storageInfo.Limit, float64(storageInfo.Limit)/float64(TB),
|
||||
storageInfo.Usage, float64(storageInfo.Usage)/float64(GB))
|
||||
}
|
||||
}
|
||||
tierID = canonicalGeminiTierIDForOAuthType(oauthType, tierID)
|
||||
if tierID == "" || tierID == GeminiTierGoogleOneUnknown {
|
||||
if fallbackTierID != "" {
|
||||
tierID = fallbackTierID
|
||||
log.Printf("[GeminiOAuth] Using fallback tier_id from user/session: %s", tierID)
|
||||
} else {
|
||||
tierID = GeminiTierGoogleOneFree
|
||||
log.Printf("[GeminiOAuth] Using default tier_id: %s", tierID)
|
||||
}
|
||||
}
|
||||
fmt.Printf("[GeminiOAuth] Google One tierID after normalization: %s\n", tierID)
|
||||
|
||||
// Store Drive info in extra field for caching
|
||||
if storageInfo != nil {
|
||||
@@ -447,12 +624,25 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
"drive_tier_updated_at": time.Now().Format(time.RFC3339),
|
||||
},
|
||||
}
|
||||
log.Printf("[GeminiOAuth] ========== ExchangeCode END (google_one with storage info) ==========")
|
||||
return tokenInfo, nil
|
||||
}
|
||||
}
|
||||
// ai_studio 模式不设置 tierID,保持为空
|
||||
|
||||
return &GeminiTokenInfo{
|
||||
case "ai_studio":
|
||||
// No automatic tier detection for AI Studio OAuth; rely on user selection.
|
||||
if fallbackTierID != "" {
|
||||
tierID = fallbackTierID
|
||||
} else {
|
||||
tierID = GeminiTierAIStudioFree
|
||||
}
|
||||
|
||||
default:
|
||||
log.Printf("[GeminiOAuth] Processing %s OAuth type (no tier detection)", oauthType)
|
||||
}
|
||||
|
||||
log.Printf("[GeminiOAuth] ========== Account Type Detection END ==========")
|
||||
|
||||
result := &GeminiTokenInfo{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
TokenType: tokenResp.TokenType,
|
||||
@@ -462,7 +652,10 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
|
||||
ProjectID: projectID,
|
||||
TierID: tierID,
|
||||
OAuthType: oauthType,
|
||||
}, nil
|
||||
}
|
||||
log.Printf("[GeminiOAuth] Final result - OAuth Type: %s, Project ID: %s, Tier ID: %s", result.OAuthType, result.ProjectID, result.TierID)
|
||||
log.Printf("[GeminiOAuth] ========== ExchangeCode END ==========")
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (s *GeminiOAuthService) RefreshToken(ctx context.Context, oauthType, refreshToken, proxyURL string) (*GeminiTokenInfo, error) {
|
||||
@@ -558,6 +751,17 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
// Backward compatibility for google_one:
|
||||
// - New behavior: when a custom OAuth client is configured, google_one will use it.
|
||||
// - Old behavior: google_one always used the built-in Gemini CLI OAuth client.
|
||||
// If an existing account was authorized with the built-in client, refreshing with the custom client
|
||||
// will fail with "unauthorized_client". Retry with the built-in client (code_assist path forces it).
|
||||
if err != nil && oauthType == "google_one" && strings.Contains(err.Error(), "unauthorized_client") && s.GetOAuthConfig().AIStudioOAuthEnabled {
|
||||
if alt, altErr := s.RefreshToken(ctx, "code_assist", refreshToken, proxyURL); altErr == nil {
|
||||
tokenInfo = alt
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
// Provide a more actionable error for common OAuth client mismatch issues.
|
||||
if strings.Contains(err.Error(), "unauthorized_client") {
|
||||
@@ -583,13 +787,14 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
|
||||
case "code_assist":
|
||||
// 先设置默认值或保留旧值,确保 tier_id 始终有值
|
||||
if existingTierID != "" {
|
||||
tokenInfo.TierID = existingTierID
|
||||
} else {
|
||||
tokenInfo.TierID = "LEGACY" // 默认值
|
||||
tokenInfo.TierID = canonicalGeminiTierIDForOAuthType(oauthType, existingTierID)
|
||||
}
|
||||
if tokenInfo.TierID == "" {
|
||||
tokenInfo.TierID = GeminiTierGCPStandard
|
||||
}
|
||||
|
||||
// 尝试自动探测 project_id 和 tier_id
|
||||
needDetect := strings.TrimSpace(tokenInfo.ProjectID) == "" || existingTierID == ""
|
||||
needDetect := strings.TrimSpace(tokenInfo.ProjectID) == "" || tokenInfo.TierID == ""
|
||||
if needDetect {
|
||||
projectID, tierID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL)
|
||||
if err != nil {
|
||||
@@ -598,9 +803,10 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
|
||||
if strings.TrimSpace(tokenInfo.ProjectID) == "" && projectID != "" {
|
||||
tokenInfo.ProjectID = projectID
|
||||
}
|
||||
// 只有当原来没有 tier_id 且探测成功时才更新
|
||||
if existingTierID == "" && tierID != "" {
|
||||
tokenInfo.TierID = tierID
|
||||
if tierID != "" {
|
||||
if canonical := canonicalGeminiTierIDForOAuthType(oauthType, tierID); canonical != "" {
|
||||
tokenInfo.TierID = canonical
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -609,6 +815,7 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
|
||||
return nil, fmt.Errorf("failed to auto-detect project_id: empty result")
|
||||
}
|
||||
case "google_one":
|
||||
canonicalExistingTier := canonicalGeminiTierIDForOAuthType(oauthType, existingTierID)
|
||||
// Check if tier cache is stale (> 24 hours)
|
||||
needsRefresh := true
|
||||
if account.Extra != nil {
|
||||
@@ -617,30 +824,37 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
|
||||
if time.Since(updatedAt) <= 24*time.Hour {
|
||||
needsRefresh = false
|
||||
// Use cached tier
|
||||
if existingTierID != "" {
|
||||
tokenInfo.TierID = existingTierID
|
||||
}
|
||||
tokenInfo.TierID = canonicalExistingTier
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if tokenInfo.TierID == "" {
|
||||
tokenInfo.TierID = canonicalExistingTier
|
||||
}
|
||||
|
||||
if needsRefresh {
|
||||
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, tokenInfo.AccessToken, proxyURL)
|
||||
if err == nil && storageInfo != nil {
|
||||
tokenInfo.TierID = tierID
|
||||
tokenInfo.Extra = map[string]any{
|
||||
"drive_storage_limit": storageInfo.Limit,
|
||||
"drive_storage_usage": storageInfo.Usage,
|
||||
"drive_tier_updated_at": time.Now().Format(time.RFC3339),
|
||||
if err == nil {
|
||||
if canonical := canonicalGeminiTierIDForOAuthType(oauthType, tierID); canonical != "" && canonical != GeminiTierGoogleOneUnknown {
|
||||
tokenInfo.TierID = canonical
|
||||
}
|
||||
if storageInfo != nil {
|
||||
tokenInfo.Extra = map[string]any{
|
||||
"drive_storage_limit": storageInfo.Limit,
|
||||
"drive_storage_usage": storageInfo.Usage,
|
||||
"drive_tier_updated_at": time.Now().Format(time.RFC3339),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if tokenInfo.TierID == "" || tokenInfo.TierID == GeminiTierGoogleOneUnknown {
|
||||
if canonicalExistingTier != "" {
|
||||
tokenInfo.TierID = canonicalExistingTier
|
||||
} else {
|
||||
// Fallback to cached or unknown
|
||||
if existingTierID != "" {
|
||||
tokenInfo.TierID = existingTierID
|
||||
} else {
|
||||
tokenInfo.TierID = TierGoogleOneUnknown
|
||||
}
|
||||
tokenInfo.TierID = GeminiTierGoogleOneFree
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -669,6 +883,9 @@ func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo)
|
||||
// Validate tier_id before storing
|
||||
if err := validateTierID(tokenInfo.TierID); err == nil {
|
||||
creds["tier_id"] = tokenInfo.TierID
|
||||
fmt.Printf("[GeminiOAuth] Storing tier_id: %s\n", tokenInfo.TierID)
|
||||
} else {
|
||||
fmt.Printf("[GeminiOAuth] Invalid tier_id %s: %v\n", tokenInfo.TierID, err)
|
||||
}
|
||||
// Silently skip invalid tier_id (don't block account creation)
|
||||
}
|
||||
@@ -698,7 +915,13 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
|
||||
// Extract tierID from response (works whether CloudAICompanionProject is set or not)
|
||||
tierID := "LEGACY"
|
||||
if loadResp != nil {
|
||||
tierID = extractTierIDFromAllowedTiers(loadResp.AllowedTiers)
|
||||
// First try to get tier from currentTier/paidTier fields
|
||||
if tier := loadResp.GetTier(); tier != "" {
|
||||
tierID = tier
|
||||
} else {
|
||||
// Fallback to extracting from allowedTiers
|
||||
tierID = extractTierIDFromAllowedTiers(loadResp.AllowedTiers)
|
||||
}
|
||||
}
|
||||
|
||||
// If LoadCodeAssist returned a project, use it
|
||||
|
||||
@@ -1,50 +1,129 @@
|
||||
package service
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
func TestInferGoogleOneTier(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
storageBytes int64
|
||||
expectedTier string
|
||||
}{
|
||||
{"Negative storage", -1, TierGoogleOneUnknown},
|
||||
{"Zero storage", 0, TierGoogleOneUnknown},
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
|
||||
)
|
||||
|
||||
// Free tier boundary (15GB)
|
||||
{"Below free tier", 10 * GB, TierGoogleOneUnknown},
|
||||
{"Just below free tier", StorageTierFree - 1, TierGoogleOneUnknown},
|
||||
{"Free tier (15GB)", StorageTierFree, TierFree},
|
||||
func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Basic tier boundary (100GB)
|
||||
{"Between free and basic", 50 * GB, TierFree},
|
||||
{"Just below basic tier", StorageTierBasic - 1, TierFree},
|
||||
{"Basic tier (100GB)", StorageTierBasic, TierGoogleOneBasic},
|
||||
type testCase struct {
|
||||
name string
|
||||
cfg *config.Config
|
||||
oauthType string
|
||||
projectID string
|
||||
wantClientID string
|
||||
wantRedirect string
|
||||
wantScope string
|
||||
wantProjectID string
|
||||
wantErrSubstr string
|
||||
}
|
||||
|
||||
// Standard tier boundary (200GB)
|
||||
{"Between basic and standard", 150 * GB, TierGoogleOneBasic},
|
||||
{"Just below standard tier", StorageTierStandard - 1, TierGoogleOneBasic},
|
||||
{"Standard tier (200GB)", StorageTierStandard, TierGoogleOneStandard},
|
||||
|
||||
// AI Premium tier boundary (2TB)
|
||||
{"Between standard and premium", 1 * TB, TierGoogleOneStandard},
|
||||
{"Just below AI Premium tier", StorageTierAIPremium - 1, TierGoogleOneStandard},
|
||||
{"AI Premium tier (2TB)", StorageTierAIPremium, TierAIPremium},
|
||||
|
||||
// Unlimited tier boundary (> 100TB)
|
||||
{"Between premium and unlimited", 50 * TB, TierAIPremium},
|
||||
{"At unlimited threshold (100TB)", StorageTierUnlimited, TierAIPremium},
|
||||
{"Unlimited tier (100TB+)", StorageTierUnlimited + 1, TierGoogleOneUnlimited},
|
||||
{"Unlimited tier (101TB+)", 101 * TB, TierGoogleOneUnlimited},
|
||||
{"Very large storage", 1000 * TB, TierGoogleOneUnlimited},
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "google_one uses built-in client when not configured and redirects to upstream",
|
||||
cfg: &config.Config{
|
||||
Gemini: config.GeminiConfig{
|
||||
OAuth: config.GeminiOAuthConfig{},
|
||||
},
|
||||
},
|
||||
oauthType: "google_one",
|
||||
wantClientID: geminicli.GeminiCLIOAuthClientID,
|
||||
wantRedirect: geminicli.GeminiCLIRedirectURI,
|
||||
wantScope: geminicli.DefaultCodeAssistScopes,
|
||||
wantProjectID: "",
|
||||
},
|
||||
{
|
||||
name: "google_one uses custom client when configured and redirects to localhost",
|
||||
cfg: &config.Config{
|
||||
Gemini: config.GeminiConfig{
|
||||
OAuth: config.GeminiOAuthConfig{
|
||||
ClientID: "custom-client-id",
|
||||
ClientSecret: "custom-client-secret",
|
||||
},
|
||||
},
|
||||
},
|
||||
oauthType: "google_one",
|
||||
wantClientID: "custom-client-id",
|
||||
wantRedirect: geminicli.AIStudioOAuthRedirectURI,
|
||||
wantScope: geminicli.DefaultGoogleOneScopes,
|
||||
wantProjectID: "",
|
||||
},
|
||||
{
|
||||
name: "code_assist always forces built-in client even when custom client configured",
|
||||
cfg: &config.Config{
|
||||
Gemini: config.GeminiConfig{
|
||||
OAuth: config.GeminiOAuthConfig{
|
||||
ClientID: "custom-client-id",
|
||||
ClientSecret: "custom-client-secret",
|
||||
},
|
||||
},
|
||||
},
|
||||
oauthType: "code_assist",
|
||||
projectID: "my-gcp-project",
|
||||
wantClientID: geminicli.GeminiCLIOAuthClientID,
|
||||
wantRedirect: geminicli.GeminiCLIRedirectURI,
|
||||
wantScope: geminicli.DefaultCodeAssistScopes,
|
||||
wantProjectID: "my-gcp-project",
|
||||
},
|
||||
{
|
||||
name: "ai_studio requires custom client",
|
||||
cfg: &config.Config{
|
||||
Gemini: config.GeminiConfig{
|
||||
OAuth: config.GeminiOAuthConfig{},
|
||||
},
|
||||
},
|
||||
oauthType: "ai_studio",
|
||||
wantErrSubstr: "AI Studio OAuth requires a custom OAuth Client",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := inferGoogleOneTier(tt.storageBytes)
|
||||
if result != tt.expectedTier {
|
||||
t.Errorf("inferGoogleOneTier(%d) = %s, want %s",
|
||||
tt.storageBytes, result, tt.expectedTier)
|
||||
t.Parallel()
|
||||
|
||||
svc := NewGeminiOAuthService(nil, nil, nil, tt.cfg)
|
||||
got, err := svc.GenerateAuthURL(context.Background(), nil, "https://example.com/auth/callback", tt.projectID, tt.oauthType, "")
|
||||
if tt.wantErrSubstr != "" {
|
||||
if err == nil {
|
||||
t.Fatalf("expected error containing %q, got nil", tt.wantErrSubstr)
|
||||
}
|
||||
if !strings.Contains(err.Error(), tt.wantErrSubstr) {
|
||||
t.Fatalf("expected error containing %q, got: %v", tt.wantErrSubstr, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("GenerateAuthURL returned error: %v", err)
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(got.AuthURL)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse auth_url: %v", err)
|
||||
}
|
||||
q := parsed.Query()
|
||||
|
||||
if gotState := q.Get("state"); gotState != got.State {
|
||||
t.Fatalf("state mismatch: query=%q result=%q", gotState, got.State)
|
||||
}
|
||||
if gotClientID := q.Get("client_id"); gotClientID != tt.wantClientID {
|
||||
t.Fatalf("client_id mismatch: got=%q want=%q", gotClientID, tt.wantClientID)
|
||||
}
|
||||
if gotRedirect := q.Get("redirect_uri"); gotRedirect != tt.wantRedirect {
|
||||
t.Fatalf("redirect_uri mismatch: got=%q want=%q", gotRedirect, tt.wantRedirect)
|
||||
}
|
||||
if gotScope := q.Get("scope"); gotScope != tt.wantScope {
|
||||
t.Fatalf("scope mismatch: got=%q want=%q", gotScope, tt.wantScope)
|
||||
}
|
||||
if gotProjectID := q.Get("project_id"); gotProjectID != tt.wantProjectID {
|
||||
t.Fatalf("project_id mismatch: got=%q want=%q", gotProjectID, tt.wantProjectID)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -20,13 +20,24 @@ const (
|
||||
geminiModelFlash geminiModelClass = "flash"
|
||||
)
|
||||
|
||||
type GeminiDailyQuota struct {
|
||||
ProRPD int64
|
||||
FlashRPD int64
|
||||
type GeminiQuota struct {
|
||||
// SharedRPD is a shared requests-per-day pool across models.
|
||||
// When SharedRPD > 0, callers should treat ProRPD/FlashRPD as not applicable for daily quota checks.
|
||||
SharedRPD int64 `json:"shared_rpd,omitempty"`
|
||||
// SharedRPM is a shared requests-per-minute pool across models.
|
||||
// When SharedRPM > 0, callers should treat ProRPM/FlashRPM as not applicable for minute quota checks.
|
||||
SharedRPM int64 `json:"shared_rpm,omitempty"`
|
||||
|
||||
// Per-model quotas (AI Studio / API key).
|
||||
// A value of -1 means "unlimited" (pay-as-you-go).
|
||||
ProRPD int64 `json:"pro_rpd,omitempty"`
|
||||
ProRPM int64 `json:"pro_rpm,omitempty"`
|
||||
FlashRPD int64 `json:"flash_rpd,omitempty"`
|
||||
FlashRPM int64 `json:"flash_rpm,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiTierPolicy struct {
|
||||
Quota GeminiDailyQuota
|
||||
Quota GeminiQuota
|
||||
Cooldown time.Duration
|
||||
}
|
||||
|
||||
@@ -45,10 +56,27 @@ type GeminiUsageTotals struct {
|
||||
|
||||
const geminiQuotaCacheTTL = time.Minute
|
||||
|
||||
type geminiQuotaOverrides struct {
|
||||
type geminiQuotaOverridesV1 struct {
|
||||
Tiers map[string]config.GeminiTierQuotaConfig `json:"tiers"`
|
||||
}
|
||||
|
||||
type geminiQuotaOverridesV2 struct {
|
||||
QuotaRules map[string]geminiQuotaRuleOverride `json:"quota_rules"`
|
||||
}
|
||||
|
||||
type geminiQuotaRuleOverride struct {
|
||||
SharedRPD *int64 `json:"shared_rpd,omitempty"`
|
||||
SharedRPM *int64 `json:"rpm,omitempty"`
|
||||
GeminiPro *geminiModelQuotaOverride `json:"gemini_pro,omitempty"`
|
||||
GeminiFlash *geminiModelQuotaOverride `json:"gemini_flash,omitempty"`
|
||||
Desc *string `json:"desc,omitempty"`
|
||||
}
|
||||
|
||||
type geminiModelQuotaOverride struct {
|
||||
RPD *int64 `json:"rpd,omitempty"`
|
||||
RPM *int64 `json:"rpm,omitempty"`
|
||||
}
|
||||
|
||||
type GeminiQuotaService struct {
|
||||
cfg *config.Config
|
||||
settingRepo SettingRepository
|
||||
@@ -82,11 +110,17 @@ func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
|
||||
if s.cfg != nil {
|
||||
policy.ApplyOverrides(s.cfg.Gemini.Quota.Tiers)
|
||||
if strings.TrimSpace(s.cfg.Gemini.Quota.Policy) != "" {
|
||||
var overrides geminiQuotaOverrides
|
||||
if err := json.Unmarshal([]byte(s.cfg.Gemini.Quota.Policy), &overrides); err != nil {
|
||||
log.Printf("gemini quota: parse config policy failed: %v", err)
|
||||
raw := []byte(s.cfg.Gemini.Quota.Policy)
|
||||
var overridesV2 geminiQuotaOverridesV2
|
||||
if err := json.Unmarshal(raw, &overridesV2); err == nil && len(overridesV2.QuotaRules) > 0 {
|
||||
policy.ApplyQuotaRulesOverrides(overridesV2.QuotaRules)
|
||||
} else {
|
||||
policy.ApplyOverrides(overrides.Tiers)
|
||||
var overridesV1 geminiQuotaOverridesV1
|
||||
if err := json.Unmarshal(raw, &overridesV1); err != nil {
|
||||
log.Printf("gemini quota: parse config policy failed: %v", err)
|
||||
} else {
|
||||
policy.ApplyOverrides(overridesV1.Tiers)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -96,11 +130,17 @@ func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
|
||||
if err != nil && !errors.Is(err, ErrSettingNotFound) {
|
||||
log.Printf("gemini quota: load setting failed: %v", err)
|
||||
} else if strings.TrimSpace(value) != "" {
|
||||
var overrides geminiQuotaOverrides
|
||||
if err := json.Unmarshal([]byte(value), &overrides); err != nil {
|
||||
log.Printf("gemini quota: parse setting failed: %v", err)
|
||||
raw := []byte(value)
|
||||
var overridesV2 geminiQuotaOverridesV2
|
||||
if err := json.Unmarshal(raw, &overridesV2); err == nil && len(overridesV2.QuotaRules) > 0 {
|
||||
policy.ApplyQuotaRulesOverrides(overridesV2.QuotaRules)
|
||||
} else {
|
||||
policy.ApplyOverrides(overrides.Tiers)
|
||||
var overridesV1 geminiQuotaOverridesV1
|
||||
if err := json.Unmarshal(raw, &overridesV1); err != nil {
|
||||
log.Printf("gemini quota: parse setting failed: %v", err)
|
||||
} else {
|
||||
policy.ApplyOverrides(overridesV1.Tiers)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -113,12 +153,20 @@ func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
|
||||
return policy
|
||||
}
|
||||
|
||||
func (s *GeminiQuotaService) QuotaForAccount(ctx context.Context, account *Account) (GeminiDailyQuota, bool) {
|
||||
if account == nil || !account.IsGeminiCodeAssist() {
|
||||
return GeminiDailyQuota{}, false
|
||||
func (s *GeminiQuotaService) QuotaForAccount(ctx context.Context, account *Account) (GeminiQuota, bool) {
|
||||
if account == nil || account.Platform != PlatformGemini {
|
||||
return GeminiQuota{}, false
|
||||
}
|
||||
|
||||
// Map (oauth_type + tier_id) to a canonical policy tier key.
|
||||
// This keeps the policy table stable even if upstream tier_id strings vary.
|
||||
tierKey := geminiQuotaTierKeyForAccount(account)
|
||||
if tierKey == "" {
|
||||
return GeminiQuota{}, false
|
||||
}
|
||||
|
||||
policy := s.Policy(ctx)
|
||||
return policy.QuotaForTier(account.GeminiTierID())
|
||||
return policy.QuotaForTier(tierKey)
|
||||
}
|
||||
|
||||
func (s *GeminiQuotaService) CooldownForTier(ctx context.Context, tierID string) time.Duration {
|
||||
@@ -126,12 +174,36 @@ func (s *GeminiQuotaService) CooldownForTier(ctx context.Context, tierID string)
|
||||
return policy.CooldownForTier(tierID)
|
||||
}
|
||||
|
||||
func (s *GeminiQuotaService) CooldownForAccount(ctx context.Context, account *Account) time.Duration {
|
||||
if s == nil || account == nil || account.Platform != PlatformGemini {
|
||||
return 5 * time.Minute
|
||||
}
|
||||
tierKey := geminiQuotaTierKeyForAccount(account)
|
||||
if strings.TrimSpace(tierKey) == "" {
|
||||
return 5 * time.Minute
|
||||
}
|
||||
return s.CooldownForTier(ctx, tierKey)
|
||||
}
|
||||
|
||||
func newGeminiQuotaPolicy() *GeminiQuotaPolicy {
|
||||
return &GeminiQuotaPolicy{
|
||||
tiers: map[string]GeminiTierPolicy{
|
||||
"LEGACY": {Quota: GeminiDailyQuota{ProRPD: 50, FlashRPD: 1500}, Cooldown: 30 * time.Minute},
|
||||
"PRO": {Quota: GeminiDailyQuota{ProRPD: 1500, FlashRPD: 4000}, Cooldown: 5 * time.Minute},
|
||||
"ULTRA": {Quota: GeminiDailyQuota{ProRPD: 2000, FlashRPD: 0}, Cooldown: 5 * time.Minute},
|
||||
// --- AI Studio / API Key (per-model) ---
|
||||
// aistudio_free:
|
||||
// - gemini_pro: 50 RPD / 2 RPM
|
||||
// - gemini_flash: 1500 RPD / 15 RPM
|
||||
GeminiTierAIStudioFree: {Quota: GeminiQuota{ProRPD: 50, ProRPM: 2, FlashRPD: 1500, FlashRPM: 15}, Cooldown: 30 * time.Minute},
|
||||
// aistudio_paid: -1 means "unlimited/pay-as-you-go" for RPD.
|
||||
GeminiTierAIStudioPaid: {Quota: GeminiQuota{ProRPD: -1, ProRPM: 1000, FlashRPD: -1, FlashRPM: 2000}, Cooldown: 5 * time.Minute},
|
||||
|
||||
// --- Google One (shared pool) ---
|
||||
GeminiTierGoogleOneFree: {Quota: GeminiQuota{SharedRPD: 1000, SharedRPM: 60}, Cooldown: 30 * time.Minute},
|
||||
GeminiTierGoogleAIPro: {Quota: GeminiQuota{SharedRPD: 1500, SharedRPM: 120}, Cooldown: 5 * time.Minute},
|
||||
GeminiTierGoogleAIUltra: {Quota: GeminiQuota{SharedRPD: 2000, SharedRPM: 120}, Cooldown: 5 * time.Minute},
|
||||
|
||||
// --- GCP Code Assist (shared pool) ---
|
||||
GeminiTierGCPStandard: {Quota: GeminiQuota{SharedRPD: 1500, SharedRPM: 120}, Cooldown: 5 * time.Minute},
|
||||
GeminiTierGCPEnterprise: {Quota: GeminiQuota{SharedRPD: 2000, SharedRPM: 120}, Cooldown: 5 * time.Minute},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -149,11 +221,22 @@ func (p *GeminiQuotaPolicy) ApplyOverrides(tiers map[string]config.GeminiTierQuo
|
||||
if !ok {
|
||||
policy = GeminiTierPolicy{Cooldown: 5 * time.Minute}
|
||||
}
|
||||
// Backward-compatible overrides:
|
||||
// - If the tier uses shared quota, interpret pro_rpd as shared_rpd.
|
||||
// - Otherwise apply per-model overrides.
|
||||
if override.ProRPD != nil {
|
||||
policy.Quota.ProRPD = clampGeminiQuotaInt64(*override.ProRPD)
|
||||
if policy.Quota.SharedRPD > 0 {
|
||||
policy.Quota.SharedRPD = clampGeminiQuotaInt64WithUnlimited(*override.ProRPD)
|
||||
} else {
|
||||
policy.Quota.ProRPD = clampGeminiQuotaInt64WithUnlimited(*override.ProRPD)
|
||||
}
|
||||
}
|
||||
if override.FlashRPD != nil {
|
||||
policy.Quota.FlashRPD = clampGeminiQuotaInt64(*override.FlashRPD)
|
||||
if policy.Quota.SharedRPD > 0 {
|
||||
// No separate flash RPD for shared tiers.
|
||||
} else {
|
||||
policy.Quota.FlashRPD = clampGeminiQuotaInt64WithUnlimited(*override.FlashRPD)
|
||||
}
|
||||
}
|
||||
if override.CooldownMinutes != nil {
|
||||
minutes := clampGeminiQuotaInt(*override.CooldownMinutes)
|
||||
@@ -163,10 +246,51 @@ func (p *GeminiQuotaPolicy) ApplyOverrides(tiers map[string]config.GeminiTierQuo
|
||||
}
|
||||
}
|
||||
|
||||
func (p *GeminiQuotaPolicy) QuotaForTier(tierID string) (GeminiDailyQuota, bool) {
|
||||
func (p *GeminiQuotaPolicy) ApplyQuotaRulesOverrides(rules map[string]geminiQuotaRuleOverride) {
|
||||
if p == nil || len(rules) == 0 {
|
||||
return
|
||||
}
|
||||
for rawID, override := range rules {
|
||||
tierID := normalizeGeminiTierID(rawID)
|
||||
if tierID == "" {
|
||||
continue
|
||||
}
|
||||
policy, ok := p.tiers[tierID]
|
||||
if !ok {
|
||||
policy = GeminiTierPolicy{Cooldown: 5 * time.Minute}
|
||||
}
|
||||
|
||||
if override.SharedRPD != nil {
|
||||
policy.Quota.SharedRPD = clampGeminiQuotaInt64WithUnlimited(*override.SharedRPD)
|
||||
}
|
||||
if override.SharedRPM != nil {
|
||||
policy.Quota.SharedRPM = clampGeminiQuotaRPM(*override.SharedRPM)
|
||||
}
|
||||
if override.GeminiPro != nil {
|
||||
if override.GeminiPro.RPD != nil {
|
||||
policy.Quota.ProRPD = clampGeminiQuotaInt64WithUnlimited(*override.GeminiPro.RPD)
|
||||
}
|
||||
if override.GeminiPro.RPM != nil {
|
||||
policy.Quota.ProRPM = clampGeminiQuotaRPM(*override.GeminiPro.RPM)
|
||||
}
|
||||
}
|
||||
if override.GeminiFlash != nil {
|
||||
if override.GeminiFlash.RPD != nil {
|
||||
policy.Quota.FlashRPD = clampGeminiQuotaInt64WithUnlimited(*override.GeminiFlash.RPD)
|
||||
}
|
||||
if override.GeminiFlash.RPM != nil {
|
||||
policy.Quota.FlashRPM = clampGeminiQuotaRPM(*override.GeminiFlash.RPM)
|
||||
}
|
||||
}
|
||||
|
||||
p.tiers[tierID] = policy
|
||||
}
|
||||
}
|
||||
|
||||
func (p *GeminiQuotaPolicy) QuotaForTier(tierID string) (GeminiQuota, bool) {
|
||||
policy, ok := p.policyForTier(tierID)
|
||||
if !ok {
|
||||
return GeminiDailyQuota{}, false
|
||||
return GeminiQuota{}, false
|
||||
}
|
||||
return policy.Quota, true
|
||||
}
|
||||
@@ -184,22 +308,43 @@ func (p *GeminiQuotaPolicy) policyForTier(tierID string) (GeminiTierPolicy, bool
|
||||
return GeminiTierPolicy{}, false
|
||||
}
|
||||
normalized := normalizeGeminiTierID(tierID)
|
||||
if normalized == "" {
|
||||
normalized = "LEGACY"
|
||||
}
|
||||
if policy, ok := p.tiers[normalized]; ok {
|
||||
return policy, true
|
||||
}
|
||||
policy, ok := p.tiers["LEGACY"]
|
||||
return policy, ok
|
||||
return GeminiTierPolicy{}, false
|
||||
}
|
||||
|
||||
func normalizeGeminiTierID(tierID string) string {
|
||||
return strings.ToUpper(strings.TrimSpace(tierID))
|
||||
tierID = strings.TrimSpace(tierID)
|
||||
if tierID == "" {
|
||||
return ""
|
||||
}
|
||||
// Prefer canonical mapping (handles legacy tier strings).
|
||||
if canonical := canonicalGeminiTierID(tierID); canonical != "" {
|
||||
return canonical
|
||||
}
|
||||
// Accept older policy keys that used uppercase names.
|
||||
switch strings.ToUpper(tierID) {
|
||||
case "AISTUDIO_FREE":
|
||||
return GeminiTierAIStudioFree
|
||||
case "AISTUDIO_PAID":
|
||||
return GeminiTierAIStudioPaid
|
||||
case "GOOGLE_ONE_FREE":
|
||||
return GeminiTierGoogleOneFree
|
||||
case "GOOGLE_AI_PRO":
|
||||
return GeminiTierGoogleAIPro
|
||||
case "GOOGLE_AI_ULTRA":
|
||||
return GeminiTierGoogleAIUltra
|
||||
case "GCP_STANDARD":
|
||||
return GeminiTierGCPStandard
|
||||
case "GCP_ENTERPRISE":
|
||||
return GeminiTierGCPEnterprise
|
||||
}
|
||||
return strings.ToLower(tierID)
|
||||
}
|
||||
|
||||
func clampGeminiQuotaInt64(value int64) int64 {
|
||||
if value < 0 {
|
||||
func clampGeminiQuotaInt64WithUnlimited(value int64) int64 {
|
||||
if value < -1 {
|
||||
return 0
|
||||
}
|
||||
return value
|
||||
@@ -212,11 +357,46 @@ func clampGeminiQuotaInt(value int) int {
|
||||
return value
|
||||
}
|
||||
|
||||
func clampGeminiQuotaRPM(value int64) int64 {
|
||||
if value < 0 {
|
||||
return 0
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func geminiCooldownForTier(tierID string) time.Duration {
|
||||
policy := newGeminiQuotaPolicy()
|
||||
return policy.CooldownForTier(tierID)
|
||||
}
|
||||
|
||||
func geminiQuotaTierKeyForAccount(account *Account) string {
|
||||
if account == nil || account.Platform != PlatformGemini {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Note: GeminiOAuthType() already defaults legacy (project_id present) to code_assist.
|
||||
oauthType := strings.ToLower(strings.TrimSpace(account.GeminiOAuthType()))
|
||||
rawTier := strings.TrimSpace(account.GeminiTierID())
|
||||
|
||||
// Prefer the canonical tier stored in credentials.
|
||||
if tierID := canonicalGeminiTierIDForOAuthType(oauthType, rawTier); tierID != "" && tierID != GeminiTierGoogleOneUnknown {
|
||||
return tierID
|
||||
}
|
||||
|
||||
// Fallback defaults when tier_id is missing or unknown.
|
||||
switch oauthType {
|
||||
case "google_one":
|
||||
return GeminiTierGoogleOneFree
|
||||
case "code_assist":
|
||||
return GeminiTierGCPStandard
|
||||
case "ai_studio":
|
||||
return GeminiTierAIStudioFree
|
||||
default:
|
||||
// API Key accounts (type=apikey) have empty oauth_type and are treated as AI Studio.
|
||||
return GeminiTierAIStudioFree
|
||||
}
|
||||
}
|
||||
|
||||
func geminiModelClassFromName(model string) geminiModelClass {
|
||||
name := strings.ToLower(strings.TrimSpace(model))
|
||||
if strings.Contains(name, "flash") || strings.Contains(name, "lite") {
|
||||
|
||||
@@ -487,7 +487,7 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Acco
|
||||
return "", "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
return accessToken, "oauth", nil
|
||||
case AccountTypeApiKey:
|
||||
case AccountTypeAPIKey:
|
||||
apiKey := account.GetOpenAIApiKey()
|
||||
if apiKey == "" {
|
||||
return "", "", errors.New("api_key not found in credentials")
|
||||
@@ -627,7 +627,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
|
||||
case AccountTypeOAuth:
|
||||
// OAuth accounts use ChatGPT internal API
|
||||
targetURL = chatgptCodexURL
|
||||
case AccountTypeApiKey:
|
||||
case AccountTypeAPIKey:
|
||||
// API Key accounts use Platform API or custom base URL
|
||||
baseURL := account.GetOpenAIBaseURL()
|
||||
if baseURL != "" {
|
||||
@@ -703,7 +703,13 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
|
||||
}
|
||||
|
||||
// Handle upstream error (mark account status)
|
||||
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
shouldDisable := false
|
||||
if s.rateLimitService != nil {
|
||||
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
|
||||
}
|
||||
if shouldDisable {
|
||||
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
|
||||
}
|
||||
|
||||
// Return appropriate error response
|
||||
var errType, errMsg string
|
||||
@@ -940,7 +946,7 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
|
||||
// OpenAIRecordUsageInput input for recording usage
|
||||
type OpenAIRecordUsageInput struct {
|
||||
Result *OpenAIForwardResult
|
||||
ApiKey *ApiKey
|
||||
APIKey *APIKey
|
||||
User *User
|
||||
Account *Account
|
||||
Subscription *UserSubscription
|
||||
@@ -949,7 +955,7 @@ type OpenAIRecordUsageInput struct {
|
||||
// RecordUsage records usage and deducts balance
|
||||
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
|
||||
result := input.Result
|
||||
apiKey := input.ApiKey
|
||||
apiKey := input.APIKey
|
||||
user := input.User
|
||||
account := input.Account
|
||||
subscription := input.Subscription
|
||||
@@ -991,7 +997,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
durationMs := int(result.Duration.Milliseconds())
|
||||
usageLog := &UsageLog{
|
||||
UserID: user.ID,
|
||||
ApiKeyID: apiKey.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
AccountID: account.ID,
|
||||
RequestID: result.RequestID,
|
||||
Model: result.Model,
|
||||
@@ -1020,22 +1026,23 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
|
||||
usageLog.SubscriptionID = &subscription.ID
|
||||
}
|
||||
|
||||
_ = s.usageLogRepo.Create(ctx, usageLog)
|
||||
|
||||
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
|
||||
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
|
||||
log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
|
||||
s.deferredService.ScheduleLastUsedUpdate(account.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
shouldBill := inserted || err != nil
|
||||
|
||||
// Deduct based on billing type
|
||||
if isSubscriptionBilling {
|
||||
if cost.TotalCost > 0 {
|
||||
if shouldBill && cost.TotalCost > 0 {
|
||||
_ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost)
|
||||
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
|
||||
}
|
||||
} else {
|
||||
if cost.ActualCost > 0 {
|
||||
if shouldBill && cost.ActualCost > 0 {
|
||||
_ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost)
|
||||
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
@@ -18,6 +19,7 @@ type RateLimitService struct {
|
||||
usageRepo UsageLogRepository
|
||||
cfg *config.Config
|
||||
geminiQuotaService *GeminiQuotaService
|
||||
tempUnschedCache TempUnschedCache
|
||||
usageCacheMu sync.RWMutex
|
||||
usageCache map[int64]*geminiUsageCacheEntry
|
||||
}
|
||||
@@ -31,12 +33,13 @@ type geminiUsageCacheEntry struct {
|
||||
const geminiPrecheckCacheTTL = time.Minute
|
||||
|
||||
// NewRateLimitService 创建RateLimitService实例
|
||||
func NewRateLimitService(accountRepo AccountRepository, usageRepo UsageLogRepository, cfg *config.Config, geminiQuotaService *GeminiQuotaService) *RateLimitService {
|
||||
func NewRateLimitService(accountRepo AccountRepository, usageRepo UsageLogRepository, cfg *config.Config, geminiQuotaService *GeminiQuotaService, tempUnschedCache TempUnschedCache) *RateLimitService {
|
||||
return &RateLimitService{
|
||||
accountRepo: accountRepo,
|
||||
usageRepo: usageRepo,
|
||||
cfg: cfg,
|
||||
geminiQuotaService: geminiQuotaService,
|
||||
tempUnschedCache: tempUnschedCache,
|
||||
usageCache: make(map[int64]*geminiUsageCacheEntry),
|
||||
}
|
||||
}
|
||||
@@ -51,38 +54,45 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
|
||||
return false
|
||||
}
|
||||
|
||||
tempMatched := s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
|
||||
|
||||
switch statusCode {
|
||||
case 401:
|
||||
// 认证失败:停止调度,记录错误
|
||||
s.handleAuthError(ctx, account, "Authentication failed (401): invalid or expired credentials")
|
||||
return true
|
||||
shouldDisable = true
|
||||
case 402:
|
||||
// 支付要求:余额不足或计费问题,停止调度
|
||||
s.handleAuthError(ctx, account, "Payment required (402): insufficient balance or billing issue")
|
||||
return true
|
||||
shouldDisable = true
|
||||
case 403:
|
||||
// 禁止访问:停止调度,记录错误
|
||||
s.handleAuthError(ctx, account, "Access forbidden (403): account may be suspended or lack permissions")
|
||||
return true
|
||||
shouldDisable = true
|
||||
case 429:
|
||||
s.handle429(ctx, account, headers)
|
||||
return false
|
||||
shouldDisable = false
|
||||
case 529:
|
||||
s.handle529(ctx, account)
|
||||
return false
|
||||
shouldDisable = false
|
||||
default:
|
||||
// 其他5xx错误:记录但不停止调度
|
||||
if statusCode >= 500 {
|
||||
log.Printf("Account %d received upstream error %d", account.ID, statusCode)
|
||||
}
|
||||
return false
|
||||
shouldDisable = false
|
||||
}
|
||||
|
||||
if tempMatched {
|
||||
return true
|
||||
}
|
||||
return shouldDisable
|
||||
}
|
||||
|
||||
// PreCheckUsage proactively checks local quota before dispatching a request.
|
||||
// Returns false when the account should be skipped.
|
||||
func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, requestedModel string) (bool, error) {
|
||||
if account == nil || !account.IsGeminiCodeAssist() || strings.TrimSpace(requestedModel) == "" {
|
||||
if account == nil || account.Platform != PlatformGemini {
|
||||
return true, nil
|
||||
}
|
||||
if s.usageRepo == nil || s.geminiQuotaService == nil {
|
||||
@@ -94,44 +104,99 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
|
||||
return true, nil
|
||||
}
|
||||
|
||||
var limit int64
|
||||
switch geminiModelClassFromName(requestedModel) {
|
||||
case geminiModelFlash:
|
||||
limit = quota.FlashRPD
|
||||
default:
|
||||
limit = quota.ProRPD
|
||||
}
|
||||
if limit <= 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
start := geminiDailyWindowStart(now)
|
||||
totals, ok := s.getGeminiUsageTotals(account.ID, start, now)
|
||||
if !ok {
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
|
||||
if err != nil {
|
||||
return true, err
|
||||
modelClass := geminiModelClassFromName(requestedModel)
|
||||
|
||||
// 1) Daily quota precheck (RPD; resets at PST midnight)
|
||||
{
|
||||
var limit int64
|
||||
if quota.SharedRPD > 0 {
|
||||
limit = quota.SharedRPD
|
||||
} else {
|
||||
switch modelClass {
|
||||
case geminiModelFlash:
|
||||
limit = quota.FlashRPD
|
||||
default:
|
||||
limit = quota.ProRPD
|
||||
}
|
||||
}
|
||||
|
||||
if limit > 0 {
|
||||
start := geminiDailyWindowStart(now)
|
||||
totals, ok := s.getGeminiUsageTotals(account.ID, start, now)
|
||||
if !ok {
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
totals = geminiAggregateUsage(stats)
|
||||
s.setGeminiUsageTotals(account.ID, start, now, totals)
|
||||
}
|
||||
|
||||
var used int64
|
||||
if quota.SharedRPD > 0 {
|
||||
used = totals.ProRequests + totals.FlashRequests
|
||||
} else {
|
||||
switch modelClass {
|
||||
case geminiModelFlash:
|
||||
used = totals.FlashRequests
|
||||
default:
|
||||
used = totals.ProRequests
|
||||
}
|
||||
}
|
||||
|
||||
if used >= limit {
|
||||
resetAt := geminiDailyResetTime(now)
|
||||
// NOTE:
|
||||
// - This is a local precheck to reduce upstream 429s.
|
||||
// - Do NOT mark the account as rate-limited here; rate_limit_reset_at should reflect real upstream 429s.
|
||||
log.Printf("[Gemini PreCheck] Account %d reached daily quota (%d/%d), skip until %v", account.ID, used, limit, resetAt)
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
totals = geminiAggregateUsage(stats)
|
||||
s.setGeminiUsageTotals(account.ID, start, now, totals)
|
||||
}
|
||||
|
||||
var used int64
|
||||
switch geminiModelClassFromName(requestedModel) {
|
||||
case geminiModelFlash:
|
||||
used = totals.FlashRequests
|
||||
default:
|
||||
used = totals.ProRequests
|
||||
}
|
||||
|
||||
if used >= limit {
|
||||
resetAt := geminiDailyResetTime(now)
|
||||
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
|
||||
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
|
||||
// 2) Minute quota precheck (RPM; fixed window current minute)
|
||||
{
|
||||
var limit int64
|
||||
if quota.SharedRPM > 0 {
|
||||
limit = quota.SharedRPM
|
||||
} else {
|
||||
switch modelClass {
|
||||
case geminiModelFlash:
|
||||
limit = quota.FlashRPM
|
||||
default:
|
||||
limit = quota.ProRPM
|
||||
}
|
||||
}
|
||||
|
||||
if limit > 0 {
|
||||
start := now.Truncate(time.Minute)
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
|
||||
if err != nil {
|
||||
return true, err
|
||||
}
|
||||
totals := geminiAggregateUsage(stats)
|
||||
|
||||
var used int64
|
||||
if quota.SharedRPM > 0 {
|
||||
used = totals.ProRequests + totals.FlashRequests
|
||||
} else {
|
||||
switch modelClass {
|
||||
case geminiModelFlash:
|
||||
used = totals.FlashRequests
|
||||
default:
|
||||
used = totals.ProRequests
|
||||
}
|
||||
}
|
||||
|
||||
if used >= limit {
|
||||
resetAt := start.Add(time.Minute)
|
||||
// Do not persist "rate limited" status from local precheck. See note above.
|
||||
log.Printf("[Gemini PreCheck] Account %d reached minute quota (%d/%d), skip until %v", account.ID, used, limit, resetAt)
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
log.Printf("[Gemini PreCheck] Account %d reached daily quota (%d/%d), rate limited until %v", account.ID, used, limit, resetAt)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
@@ -176,7 +241,10 @@ func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account)
|
||||
if account == nil {
|
||||
return 5 * time.Minute
|
||||
}
|
||||
return s.geminiQuotaService.CooldownForTier(ctx, account.GeminiTierID())
|
||||
if s.geminiQuotaService == nil {
|
||||
return 5 * time.Minute
|
||||
}
|
||||
return s.geminiQuotaService.CooldownForAccount(ctx, account)
|
||||
}
|
||||
|
||||
// handleAuthError 处理认证类错误(401/403),停止账号调度
|
||||
@@ -287,3 +355,183 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc
|
||||
func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) error {
|
||||
return s.accountRepo.ClearRateLimit(ctx, accountID)
|
||||
}
|
||||
|
||||
func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID int64) error {
|
||||
if err := s.accountRepo.ClearTempUnschedulable(ctx, accountID); err != nil {
|
||||
return err
|
||||
}
|
||||
if s.tempUnschedCache != nil {
|
||||
if err := s.tempUnschedCache.DeleteTempUnsched(ctx, accountID); err != nil {
|
||||
log.Printf("DeleteTempUnsched failed for account %d: %v", accountID, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *RateLimitService) GetTempUnschedStatus(ctx context.Context, accountID int64) (*TempUnschedState, error) {
|
||||
now := time.Now().Unix()
|
||||
if s.tempUnschedCache != nil {
|
||||
state, err := s.tempUnschedCache.GetTempUnsched(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if state != nil && state.UntilUnix > now {
|
||||
return state, nil
|
||||
}
|
||||
}
|
||||
|
||||
account, err := s.accountRepo.GetByID(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if account.TempUnschedulableUntil == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if account.TempUnschedulableUntil.Unix() <= now {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
state := &TempUnschedState{
|
||||
UntilUnix: account.TempUnschedulableUntil.Unix(),
|
||||
}
|
||||
|
||||
if account.TempUnschedulableReason != "" {
|
||||
var parsed TempUnschedState
|
||||
if err := json.Unmarshal([]byte(account.TempUnschedulableReason), &parsed); err == nil {
|
||||
if parsed.UntilUnix == 0 {
|
||||
parsed.UntilUnix = state.UntilUnix
|
||||
}
|
||||
state = &parsed
|
||||
} else {
|
||||
state.ErrorMessage = account.TempUnschedulableReason
|
||||
}
|
||||
}
|
||||
|
||||
if s.tempUnschedCache != nil {
|
||||
if err := s.tempUnschedCache.SetTempUnsched(ctx, accountID, state); err != nil {
|
||||
log.Printf("SetTempUnsched failed for account %d: %v", accountID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return state, nil
|
||||
}
|
||||
|
||||
func (s *RateLimitService) HandleTempUnschedulable(ctx context.Context, account *Account, statusCode int, responseBody []byte) bool {
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
if !account.ShouldHandleErrorCode(statusCode) {
|
||||
return false
|
||||
}
|
||||
return s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
|
||||
}
|
||||
|
||||
const tempUnschedBodyMaxBytes = 64 << 10
|
||||
const tempUnschedMessageMaxBytes = 2048
|
||||
|
||||
func (s *RateLimitService) tryTempUnschedulable(ctx context.Context, account *Account, statusCode int, responseBody []byte) bool {
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
if !account.IsTempUnschedulableEnabled() {
|
||||
return false
|
||||
}
|
||||
rules := account.GetTempUnschedulableRules()
|
||||
if len(rules) == 0 {
|
||||
return false
|
||||
}
|
||||
if statusCode <= 0 || len(responseBody) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
body := responseBody
|
||||
if len(body) > tempUnschedBodyMaxBytes {
|
||||
body = body[:tempUnschedBodyMaxBytes]
|
||||
}
|
||||
bodyLower := strings.ToLower(string(body))
|
||||
|
||||
for idx, rule := range rules {
|
||||
if rule.ErrorCode != statusCode || len(rule.Keywords) == 0 {
|
||||
continue
|
||||
}
|
||||
matchedKeyword := matchTempUnschedKeyword(bodyLower, rule.Keywords)
|
||||
if matchedKeyword == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if s.triggerTempUnschedulable(ctx, account, rule, idx, statusCode, matchedKeyword, responseBody) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func matchTempUnschedKeyword(bodyLower string, keywords []string) string {
|
||||
if bodyLower == "" {
|
||||
return ""
|
||||
}
|
||||
for _, keyword := range keywords {
|
||||
k := strings.TrimSpace(keyword)
|
||||
if k == "" {
|
||||
continue
|
||||
}
|
||||
if strings.Contains(bodyLower, strings.ToLower(k)) {
|
||||
return k
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *RateLimitService) triggerTempUnschedulable(ctx context.Context, account *Account, rule TempUnschedulableRule, ruleIndex int, statusCode int, matchedKeyword string, responseBody []byte) bool {
|
||||
if account == nil {
|
||||
return false
|
||||
}
|
||||
if rule.DurationMinutes <= 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
until := now.Add(time.Duration(rule.DurationMinutes) * time.Minute)
|
||||
|
||||
state := &TempUnschedState{
|
||||
UntilUnix: until.Unix(),
|
||||
TriggeredAtUnix: now.Unix(),
|
||||
StatusCode: statusCode,
|
||||
MatchedKeyword: matchedKeyword,
|
||||
RuleIndex: ruleIndex,
|
||||
ErrorMessage: truncateTempUnschedMessage(responseBody, tempUnschedMessageMaxBytes),
|
||||
}
|
||||
|
||||
reason := ""
|
||||
if raw, err := json.Marshal(state); err == nil {
|
||||
reason = string(raw)
|
||||
}
|
||||
if reason == "" {
|
||||
reason = strings.TrimSpace(state.ErrorMessage)
|
||||
}
|
||||
|
||||
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
|
||||
log.Printf("SetTempUnschedulable failed for account %d: %v", account.ID, err)
|
||||
return false
|
||||
}
|
||||
|
||||
if s.tempUnschedCache != nil {
|
||||
if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil {
|
||||
log.Printf("SetTempUnsched cache failed for account %d: %v", account.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("Account %d temp unschedulable until %v (rule %d, code %d)", account.ID, until, ruleIndex, statusCode)
|
||||
return true
|
||||
}
|
||||
|
||||
func truncateTempUnschedMessage(body []byte, maxBytes int) string {
|
||||
if maxBytes <= 0 || len(body) == 0 {
|
||||
return ""
|
||||
}
|
||||
if len(body) > maxBytes {
|
||||
body = body[:maxBytes]
|
||||
}
|
||||
return strings.TrimSpace(string(body))
|
||||
}
|
||||
|
||||
@@ -61,9 +61,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SettingKeySiteName,
|
||||
SettingKeySiteLogo,
|
||||
SettingKeySiteSubtitle,
|
||||
SettingKeyApiBaseUrl,
|
||||
SettingKeyAPIBaseURL,
|
||||
SettingKeyContactInfo,
|
||||
SettingKeyDocUrl,
|
||||
SettingKeyDocURL,
|
||||
}
|
||||
|
||||
settings, err := s.settingRepo.GetMultiple(ctx, keys)
|
||||
@@ -79,9 +79,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
|
||||
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
|
||||
SiteLogo: settings[SettingKeySiteLogo],
|
||||
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
|
||||
ApiBaseUrl: settings[SettingKeyApiBaseUrl],
|
||||
APIBaseURL: settings[SettingKeyAPIBaseURL],
|
||||
ContactInfo: settings[SettingKeyContactInfo],
|
||||
DocUrl: settings[SettingKeyDocUrl],
|
||||
DocURL: settings[SettingKeyDocURL],
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -94,15 +94,15 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
|
||||
|
||||
// 邮件服务设置(只有非空才更新密码)
|
||||
updates[SettingKeySmtpHost] = settings.SmtpHost
|
||||
updates[SettingKeySmtpPort] = strconv.Itoa(settings.SmtpPort)
|
||||
updates[SettingKeySmtpUsername] = settings.SmtpUsername
|
||||
if settings.SmtpPassword != "" {
|
||||
updates[SettingKeySmtpPassword] = settings.SmtpPassword
|
||||
updates[SettingKeySMTPHost] = settings.SMTPHost
|
||||
updates[SettingKeySMTPPort] = strconv.Itoa(settings.SMTPPort)
|
||||
updates[SettingKeySMTPUsername] = settings.SMTPUsername
|
||||
if settings.SMTPPassword != "" {
|
||||
updates[SettingKeySMTPPassword] = settings.SMTPPassword
|
||||
}
|
||||
updates[SettingKeySmtpFrom] = settings.SmtpFrom
|
||||
updates[SettingKeySmtpFromName] = settings.SmtpFromName
|
||||
updates[SettingKeySmtpUseTLS] = strconv.FormatBool(settings.SmtpUseTLS)
|
||||
updates[SettingKeySMTPFrom] = settings.SMTPFrom
|
||||
updates[SettingKeySMTPFromName] = settings.SMTPFromName
|
||||
updates[SettingKeySMTPUseTLS] = strconv.FormatBool(settings.SMTPUseTLS)
|
||||
|
||||
// Cloudflare Turnstile 设置(只有非空才更新密钥)
|
||||
updates[SettingKeyTurnstileEnabled] = strconv.FormatBool(settings.TurnstileEnabled)
|
||||
@@ -115,14 +115,21 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
|
||||
updates[SettingKeySiteName] = settings.SiteName
|
||||
updates[SettingKeySiteLogo] = settings.SiteLogo
|
||||
updates[SettingKeySiteSubtitle] = settings.SiteSubtitle
|
||||
updates[SettingKeyApiBaseUrl] = settings.ApiBaseUrl
|
||||
updates[SettingKeyAPIBaseURL] = settings.APIBaseURL
|
||||
updates[SettingKeyContactInfo] = settings.ContactInfo
|
||||
updates[SettingKeyDocUrl] = settings.DocUrl
|
||||
updates[SettingKeyDocURL] = settings.DocURL
|
||||
|
||||
// 默认配置
|
||||
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
|
||||
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
|
||||
|
||||
// Model fallback configuration
|
||||
updates[SettingKeyEnableModelFallback] = strconv.FormatBool(settings.EnableModelFallback)
|
||||
updates[SettingKeyFallbackModelAnthropic] = settings.FallbackModelAnthropic
|
||||
updates[SettingKeyFallbackModelOpenAI] = settings.FallbackModelOpenAI
|
||||
updates[SettingKeyFallbackModelGemini] = settings.FallbackModelGemini
|
||||
updates[SettingKeyFallbackModelAntigravity] = settings.FallbackModelAntigravity
|
||||
|
||||
return s.settingRepo.SetMultiple(ctx, updates)
|
||||
}
|
||||
|
||||
@@ -198,8 +205,14 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
|
||||
SettingKeySiteLogo: "",
|
||||
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
|
||||
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
|
||||
SettingKeySmtpPort: "587",
|
||||
SettingKeySmtpUseTLS: "false",
|
||||
SettingKeySMTPPort: "587",
|
||||
SettingKeySMTPUseTLS: "false",
|
||||
// Model fallback defaults
|
||||
SettingKeyEnableModelFallback: "false",
|
||||
SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022",
|
||||
SettingKeyFallbackModelOpenAI: "gpt-4o",
|
||||
SettingKeyFallbackModelGemini: "gemini-2.5-pro",
|
||||
SettingKeyFallbackModelAntigravity: "gemini-2.5-pro",
|
||||
}
|
||||
|
||||
return s.settingRepo.SetMultiple(ctx, defaults)
|
||||
@@ -210,26 +223,26 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
result := &SystemSettings{
|
||||
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
|
||||
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
|
||||
SmtpHost: settings[SettingKeySmtpHost],
|
||||
SmtpUsername: settings[SettingKeySmtpUsername],
|
||||
SmtpFrom: settings[SettingKeySmtpFrom],
|
||||
SmtpFromName: settings[SettingKeySmtpFromName],
|
||||
SmtpUseTLS: settings[SettingKeySmtpUseTLS] == "true",
|
||||
SMTPHost: settings[SettingKeySMTPHost],
|
||||
SMTPUsername: settings[SettingKeySMTPUsername],
|
||||
SMTPFrom: settings[SettingKeySMTPFrom],
|
||||
SMTPFromName: settings[SettingKeySMTPFromName],
|
||||
SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true",
|
||||
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
|
||||
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
|
||||
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
|
||||
SiteLogo: settings[SettingKeySiteLogo],
|
||||
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
|
||||
ApiBaseUrl: settings[SettingKeyApiBaseUrl],
|
||||
APIBaseURL: settings[SettingKeyAPIBaseURL],
|
||||
ContactInfo: settings[SettingKeyContactInfo],
|
||||
DocUrl: settings[SettingKeyDocUrl],
|
||||
DocURL: settings[SettingKeyDocURL],
|
||||
}
|
||||
|
||||
// 解析整数类型
|
||||
if port, err := strconv.Atoi(settings[SettingKeySmtpPort]); err == nil {
|
||||
result.SmtpPort = port
|
||||
if port, err := strconv.Atoi(settings[SettingKeySMTPPort]); err == nil {
|
||||
result.SMTPPort = port
|
||||
} else {
|
||||
result.SmtpPort = 587
|
||||
result.SMTPPort = 587
|
||||
}
|
||||
|
||||
if concurrency, err := strconv.Atoi(settings[SettingKeyDefaultConcurrency]); err == nil {
|
||||
@@ -246,9 +259,16 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
|
||||
}
|
||||
|
||||
// 敏感信息直接返回,方便测试连接时使用
|
||||
result.SmtpPassword = settings[SettingKeySmtpPassword]
|
||||
result.SMTPPassword = settings[SettingKeySMTPPassword]
|
||||
result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey]
|
||||
|
||||
// Model fallback settings
|
||||
result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true"
|
||||
result.FallbackModelAnthropic = s.getStringOrDefault(settings, SettingKeyFallbackModelAnthropic, "claude-3-5-sonnet-20241022")
|
||||
result.FallbackModelOpenAI = s.getStringOrDefault(settings, SettingKeyFallbackModelOpenAI, "gpt-4o")
|
||||
result.FallbackModelGemini = s.getStringOrDefault(settings, SettingKeyFallbackModelGemini, "gemini-2.5-pro")
|
||||
result.FallbackModelAntigravity = s.getStringOrDefault(settings, SettingKeyFallbackModelAntigravity, "gemini-2.5-pro")
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -278,28 +298,28 @@ func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string {
|
||||
return value
|
||||
}
|
||||
|
||||
// GenerateAdminApiKey 生成新的管理员 API Key
|
||||
func (s *SettingService) GenerateAdminApiKey(ctx context.Context) (string, error) {
|
||||
// GenerateAdminAPIKey 生成新的管理员 API Key
|
||||
func (s *SettingService) GenerateAdminAPIKey(ctx context.Context) (string, error) {
|
||||
// 生成 32 字节随机数 = 64 位十六进制字符
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", fmt.Errorf("generate random bytes: %w", err)
|
||||
}
|
||||
|
||||
key := AdminApiKeyPrefix + hex.EncodeToString(bytes)
|
||||
key := AdminAPIKeyPrefix + hex.EncodeToString(bytes)
|
||||
|
||||
// 存储到 settings 表
|
||||
if err := s.settingRepo.Set(ctx, SettingKeyAdminApiKey, key); err != nil {
|
||||
if err := s.settingRepo.Set(ctx, SettingKeyAdminAPIKey, key); err != nil {
|
||||
return "", fmt.Errorf("save admin api key: %w", err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// GetAdminApiKeyStatus 获取管理员 API Key 状态
|
||||
// GetAdminAPIKeyStatus 获取管理员 API Key 状态
|
||||
// 返回脱敏的 key、是否存在、错误
|
||||
func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) {
|
||||
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminApiKey)
|
||||
func (s *SettingService) GetAdminAPIKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) {
|
||||
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminAPIKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSettingNotFound) {
|
||||
return "", false, nil
|
||||
@@ -320,10 +340,10 @@ func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey st
|
||||
return maskedKey, true, nil
|
||||
}
|
||||
|
||||
// GetAdminApiKey 获取完整的管理员 API Key(仅供内部验证使用)
|
||||
// GetAdminAPIKey 获取完整的管理员 API Key(仅供内部验证使用)
|
||||
// 如果未配置返回空字符串和 nil 错误,只有数据库错误时才返回 error
|
||||
func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
|
||||
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminApiKey)
|
||||
func (s *SettingService) GetAdminAPIKey(ctx context.Context) (string, error) {
|
||||
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminAPIKey)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrSettingNotFound) {
|
||||
return "", nil // 未配置,返回空字符串
|
||||
@@ -333,7 +353,45 @@ func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// DeleteAdminApiKey 删除管理员 API Key
|
||||
func (s *SettingService) DeleteAdminApiKey(ctx context.Context) error {
|
||||
return s.settingRepo.Delete(ctx, SettingKeyAdminApiKey)
|
||||
// DeleteAdminAPIKey 删除管理员 API Key
|
||||
func (s *SettingService) DeleteAdminAPIKey(ctx context.Context) error {
|
||||
return s.settingRepo.Delete(ctx, SettingKeyAdminAPIKey)
|
||||
}
|
||||
|
||||
// IsModelFallbackEnabled 检查是否启用模型兜底机制
|
||||
func (s *SettingService) IsModelFallbackEnabled(ctx context.Context) bool {
|
||||
value, err := s.settingRepo.GetValue(ctx, SettingKeyEnableModelFallback)
|
||||
if err != nil {
|
||||
return false // Default: disabled
|
||||
}
|
||||
return value == "true"
|
||||
}
|
||||
|
||||
// GetFallbackModel 获取指定平台的兜底模型
|
||||
func (s *SettingService) GetFallbackModel(ctx context.Context, platform string) string {
|
||||
var key string
|
||||
var defaultModel string
|
||||
|
||||
switch platform {
|
||||
case PlatformAnthropic:
|
||||
key = SettingKeyFallbackModelAnthropic
|
||||
defaultModel = "claude-3-5-sonnet-20241022"
|
||||
case PlatformOpenAI:
|
||||
key = SettingKeyFallbackModelOpenAI
|
||||
defaultModel = "gpt-4o"
|
||||
case PlatformGemini:
|
||||
key = SettingKeyFallbackModelGemini
|
||||
defaultModel = "gemini-2.5-pro"
|
||||
case PlatformAntigravity:
|
||||
key = SettingKeyFallbackModelAntigravity
|
||||
defaultModel = "gemini-2.5-pro"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
|
||||
value, err := s.settingRepo.GetValue(ctx, key)
|
||||
if err != nil || value == "" {
|
||||
return defaultModel
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
@@ -4,13 +4,13 @@ type SystemSettings struct {
|
||||
RegistrationEnabled bool
|
||||
EmailVerifyEnabled bool
|
||||
|
||||
SmtpHost string
|
||||
SmtpPort int
|
||||
SmtpUsername string
|
||||
SmtpPassword string
|
||||
SmtpFrom string
|
||||
SmtpFromName string
|
||||
SmtpUseTLS bool
|
||||
SMTPHost string
|
||||
SMTPPort int
|
||||
SMTPUsername string
|
||||
SMTPPassword string
|
||||
SMTPFrom string
|
||||
SMTPFromName string
|
||||
SMTPUseTLS bool
|
||||
|
||||
TurnstileEnabled bool
|
||||
TurnstileSiteKey string
|
||||
@@ -19,12 +19,19 @@ type SystemSettings struct {
|
||||
SiteName string
|
||||
SiteLogo string
|
||||
SiteSubtitle string
|
||||
ApiBaseUrl string
|
||||
APIBaseURL string
|
||||
ContactInfo string
|
||||
DocUrl string
|
||||
DocURL string
|
||||
|
||||
DefaultConcurrency int
|
||||
DefaultBalance float64
|
||||
|
||||
// Model fallback configuration
|
||||
EnableModelFallback bool `json:"enable_model_fallback"`
|
||||
FallbackModelAnthropic string `json:"fallback_model_anthropic"`
|
||||
FallbackModelOpenAI string `json:"fallback_model_openai"`
|
||||
FallbackModelGemini string `json:"fallback_model_gemini"`
|
||||
FallbackModelAntigravity string `json:"fallback_model_antigravity"`
|
||||
}
|
||||
|
||||
type PublicSettings struct {
|
||||
@@ -35,8 +42,8 @@ type PublicSettings struct {
|
||||
SiteName string
|
||||
SiteLogo string
|
||||
SiteSubtitle string
|
||||
ApiBaseUrl string
|
||||
APIBaseURL string
|
||||
ContactInfo string
|
||||
DocUrl string
|
||||
DocURL string
|
||||
Version string
|
||||
}
|
||||
|
||||
22
backend/internal/service/temp_unsched.go
Normal file
22
backend/internal/service/temp_unsched.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// TempUnschedState 临时不可调度状态
|
||||
type TempUnschedState struct {
|
||||
UntilUnix int64 `json:"until_unix"` // 解除时间(Unix 时间戳)
|
||||
TriggeredAtUnix int64 `json:"triggered_at_unix"` // 触发时间(Unix 时间戳)
|
||||
StatusCode int `json:"status_code"` // 触发的错误码
|
||||
MatchedKeyword string `json:"matched_keyword"` // 匹配的关键词
|
||||
RuleIndex int `json:"rule_index"` // 触发的规则索引
|
||||
ErrorMessage string `json:"error_message"` // 错误消息
|
||||
}
|
||||
|
||||
// TempUnschedCache 临时不可调度缓存接口
|
||||
type TempUnschedCache interface {
|
||||
SetTempUnsched(ctx context.Context, accountID int64, state *TempUnschedState) error
|
||||
GetTempUnsched(ctx context.Context, accountID int64) (*TempUnschedState, error)
|
||||
DeleteTempUnsched(ctx context.Context, accountID int64) error
|
||||
}
|
||||
@@ -197,7 +197,7 @@ func TestClaudeTokenRefresher_CanRefresh(t *testing.T) {
|
||||
{
|
||||
name: "anthropic api-key - cannot refresh",
|
||||
platform: PlatformAnthropic,
|
||||
accType: AccountTypeApiKey,
|
||||
accType: AccountTypeAPIKey,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
|
||||
@@ -79,7 +79,7 @@ type ReleaseInfo struct {
|
||||
Name string `json:"name"`
|
||||
Body string `json:"body"`
|
||||
PublishedAt string `json:"published_at"`
|
||||
HtmlURL string `json:"html_url"`
|
||||
HTMLURL string `json:"html_url"`
|
||||
Assets []Asset `json:"assets,omitempty"`
|
||||
}
|
||||
|
||||
@@ -96,13 +96,13 @@ type GitHubRelease struct {
|
||||
Name string `json:"name"`
|
||||
Body string `json:"body"`
|
||||
PublishedAt string `json:"published_at"`
|
||||
HtmlUrl string `json:"html_url"`
|
||||
HTMLURL string `json:"html_url"`
|
||||
Assets []GitHubAsset `json:"assets"`
|
||||
}
|
||||
|
||||
type GitHubAsset struct {
|
||||
Name string `json:"name"`
|
||||
BrowserDownloadUrl string `json:"browser_download_url"`
|
||||
BrowserDownloadURL string `json:"browser_download_url"`
|
||||
Size int64 `json:"size"`
|
||||
}
|
||||
|
||||
@@ -285,7 +285,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er
|
||||
for i, a := range release.Assets {
|
||||
assets[i] = Asset{
|
||||
Name: a.Name,
|
||||
DownloadURL: a.BrowserDownloadUrl,
|
||||
DownloadURL: a.BrowserDownloadURL,
|
||||
Size: a.Size,
|
||||
}
|
||||
}
|
||||
@@ -298,7 +298,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er
|
||||
Name: release.Name,
|
||||
Body: release.Body,
|
||||
PublishedAt: release.PublishedAt,
|
||||
HtmlURL: release.HtmlUrl,
|
||||
HTMLURL: release.HTMLURL,
|
||||
Assets: assets,
|
||||
},
|
||||
Cached: false,
|
||||
|
||||
@@ -10,7 +10,7 @@ const (
|
||||
type UsageLog struct {
|
||||
ID int64
|
||||
UserID int64
|
||||
ApiKeyID int64
|
||||
APIKeyID int64
|
||||
AccountID int64
|
||||
RequestID string
|
||||
Model string
|
||||
@@ -42,7 +42,7 @@ type UsageLog struct {
|
||||
CreatedAt time.Time
|
||||
|
||||
User *User
|
||||
ApiKey *ApiKey
|
||||
APIKey *APIKey
|
||||
Account *Account
|
||||
Group *Group
|
||||
Subscription *UserSubscription
|
||||
|
||||
@@ -2,9 +2,11 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
dbent "github.com/Wei-Shaw/sub2api/ent"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
@@ -17,7 +19,7 @@ var (
|
||||
// CreateUsageLogRequest 创建使用日志请求
|
||||
type CreateUsageLogRequest struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
ApiKeyID int64 `json:"api_key_id"`
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
AccountID int64 `json:"account_id"`
|
||||
RequestID string `json:"request_id"`
|
||||
Model string `json:"model"`
|
||||
@@ -54,20 +56,34 @@ type UsageStats struct {
|
||||
type UsageService struct {
|
||||
usageRepo UsageLogRepository
|
||||
userRepo UserRepository
|
||||
entClient *dbent.Client
|
||||
}
|
||||
|
||||
// NewUsageService 创建使用统计服务实例
|
||||
func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository) *UsageService {
|
||||
func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository, entClient *dbent.Client) *UsageService {
|
||||
return &UsageService{
|
||||
usageRepo: usageRepo,
|
||||
userRepo: userRepo,
|
||||
entClient: entClient,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建使用日志
|
||||
func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*UsageLog, error) {
|
||||
// 使用数据库事务保证「使用日志插入」与「扣费」的原子性,避免重复扣费或漏扣风险。
|
||||
tx, err := s.entClient.Tx(ctx)
|
||||
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
|
||||
return nil, fmt.Errorf("begin transaction: %w", err)
|
||||
}
|
||||
|
||||
txCtx := ctx
|
||||
if err == nil {
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
txCtx = dbent.NewTxContext(ctx, tx)
|
||||
}
|
||||
|
||||
// 验证用户存在
|
||||
_, err := s.userRepo.GetByID(ctx, req.UserID)
|
||||
_, err = s.userRepo.GetByID(txCtx, req.UserID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
@@ -75,7 +91,7 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
|
||||
// 创建使用日志
|
||||
usageLog := &UsageLog{
|
||||
UserID: req.UserID,
|
||||
ApiKeyID: req.ApiKeyID,
|
||||
APIKeyID: req.APIKeyID,
|
||||
AccountID: req.AccountID,
|
||||
RequestID: req.RequestID,
|
||||
Model: req.Model,
|
||||
@@ -96,17 +112,24 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
|
||||
DurationMs: req.DurationMs,
|
||||
}
|
||||
|
||||
if err := s.usageRepo.Create(ctx, usageLog); err != nil {
|
||||
inserted, err := s.usageRepo.Create(txCtx, usageLog)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create usage log: %w", err)
|
||||
}
|
||||
|
||||
// 扣除用户余额
|
||||
if req.ActualCost > 0 {
|
||||
if err := s.userRepo.UpdateBalance(ctx, req.UserID, -req.ActualCost); err != nil {
|
||||
if inserted && req.ActualCost > 0 {
|
||||
if err := s.userRepo.UpdateBalance(txCtx, req.UserID, -req.ActualCost); err != nil {
|
||||
return nil, fmt.Errorf("update user balance: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if tx != nil {
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, fmt.Errorf("commit transaction: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return usageLog, nil
|
||||
}
|
||||
|
||||
@@ -128,9 +151,9 @@ func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagi
|
||||
return logs, pagination, nil
|
||||
}
|
||||
|
||||
// ListByApiKey 获取API Key的使用日志列表
|
||||
func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) {
|
||||
logs, pagination, err := s.usageRepo.ListByApiKey(ctx, apiKeyID, params)
|
||||
// ListByAPIKey 获取API Key的使用日志列表
|
||||
func (s *UsageService) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) {
|
||||
logs, pagination, err := s.usageRepo.ListByAPIKey(ctx, apiKeyID, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list usage logs: %w", err)
|
||||
}
|
||||
@@ -165,9 +188,9 @@ func (s *UsageService) GetStatsByUser(ctx context.Context, userID int64, startTi
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetStatsByApiKey 获取API Key的使用统计
|
||||
func (s *UsageService) GetStatsByApiKey(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*UsageStats, error) {
|
||||
stats, err := s.usageRepo.GetApiKeyStatsAggregated(ctx, apiKeyID, startTime, endTime)
|
||||
// GetStatsByAPIKey 获取API Key的使用统计
|
||||
func (s *UsageService) GetStatsByAPIKey(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*UsageStats, error) {
|
||||
stats, err := s.usageRepo.GetAPIKeyStatsAggregated(ctx, apiKeyID, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key stats: %w", err)
|
||||
}
|
||||
@@ -270,9 +293,9 @@ func (s *UsageService) GetUserModelStats(ctx context.Context, userID int64, star
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
// GetBatchApiKeyUsageStats returns today/total actual_cost for given api keys.
|
||||
func (s *UsageService) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchApiKeyUsageStats(ctx, apiKeyIDs)
|
||||
// GetBatchAPIKeyUsageStats returns today/total actual_cost for given api keys.
|
||||
func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
|
||||
}
|
||||
|
||||
@@ -21,7 +21,7 @@ type User struct {
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
ApiKeys []ApiKey
|
||||
APIKeys []APIKey
|
||||
Subscriptions []UserSubscription
|
||||
}
|
||||
|
||||
|
||||
@@ -56,6 +56,10 @@ func (s *UserAttributeService) CreateDefinition(ctx context.Context, input Creat
|
||||
Enabled: input.Enabled,
|
||||
}
|
||||
|
||||
if err := validateDefinitionPattern(def); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.defRepo.Create(ctx, def); err != nil {
|
||||
return nil, fmt.Errorf("create definition: %w", err)
|
||||
}
|
||||
@@ -108,6 +112,10 @@ func (s *UserAttributeService) UpdateDefinition(ctx context.Context, id int64, i
|
||||
def.Enabled = *input.Enabled
|
||||
}
|
||||
|
||||
if err := validateDefinitionPattern(def); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.defRepo.Update(ctx, def); err != nil {
|
||||
return nil, fmt.Errorf("update definition: %w", err)
|
||||
}
|
||||
@@ -231,7 +239,10 @@ func (s *UserAttributeService) validateValue(def *UserAttributeDefinition, value
|
||||
// Pattern validation
|
||||
if v.Pattern != nil && *v.Pattern != "" && value != "" {
|
||||
re, err := regexp.Compile(*v.Pattern)
|
||||
if err == nil && !re.MatchString(value) {
|
||||
if err != nil {
|
||||
return validationError(def.Name + " has an invalid pattern")
|
||||
}
|
||||
if !re.MatchString(value) {
|
||||
msg := def.Name + " format is invalid"
|
||||
if v.Message != nil && *v.Message != "" {
|
||||
msg = *v.Message
|
||||
@@ -293,3 +304,20 @@ func isValidAttributeType(t UserAttributeType) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func validateDefinitionPattern(def *UserAttributeDefinition) error {
|
||||
if def == nil {
|
||||
return nil
|
||||
}
|
||||
if def.Validation.Pattern == nil {
|
||||
return nil
|
||||
}
|
||||
pattern := strings.TrimSpace(*def.Validation.Pattern)
|
||||
if pattern == "" {
|
||||
return nil
|
||||
}
|
||||
if _, err := regexp.Compile(pattern); err != nil {
|
||||
return infraerrors.BadRequest("INVALID_ATTRIBUTE_PATTERN", fmt.Sprintf("invalid pattern for %s: %v", def.Name, err))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -75,7 +75,7 @@ var ProviderSet = wire.NewSet(
|
||||
// Core services
|
||||
NewAuthService,
|
||||
NewUserService,
|
||||
NewApiKeyService,
|
||||
NewAPIKeyService,
|
||||
NewGroupService,
|
||||
NewAccountService,
|
||||
NewProxyService,
|
||||
|
||||
Reference in New Issue
Block a user