Merge PR #142: feat: 多平台网关优化与用户系统增强

主要功能:
- Gemini OAuth 配额系统优化:Google One tier 自动推断
- Antigravity 网关增强:Thinking Block 重试、Claude 模型 signature 透传
- 账号调度改进:临时不可调度功能、负载感知调度优化
- 前端用户体验:账号管理界面优化、使用教程改进

冲突解决:保留 handleUpstreamError 的 prefix 参数和日志记录能力,
同时合并 PR 的 thinking block 重试和 model fallback 功能。
This commit is contained in:
shaw
2026-01-04 20:30:19 +08:00
183 changed files with 8223 additions and 3839 deletions

View File

@@ -1,3 +1,4 @@
// Package service provides business logic and domain services for the application.
package service
import (
@@ -29,6 +30,9 @@ type Account struct {
RateLimitResetAt *time.Time
OverloadUntil *time.Time
TempUnschedulableUntil *time.Time
TempUnschedulableReason string
SessionWindowStart *time.Time
SessionWindowEnd *time.Time
SessionWindowStatus string
@@ -39,6 +43,13 @@ type Account struct {
Groups []*Group
}
type TempUnschedulableRule struct {
ErrorCode int `json:"error_code"`
Keywords []string `json:"keywords"`
DurationMinutes int `json:"duration_minutes"`
Description string `json:"description"`
}
func (a *Account) IsActive() bool {
return a.Status == StatusActive
}
@@ -54,6 +65,9 @@ func (a *Account) IsSchedulable() bool {
if a.RateLimitResetAt != nil && now.Before(*a.RateLimitResetAt) {
return false
}
if a.TempUnschedulableUntil != nil && now.Before(*a.TempUnschedulableUntil) {
return false
}
return true
}
@@ -92,10 +106,7 @@ func (a *Account) GeminiOAuthType() string {
func (a *Account) GeminiTierID() string {
tierID := strings.TrimSpace(a.GetCredential("tier_id"))
if tierID == "" {
return ""
}
return strings.ToUpper(tierID)
return tierID
}
func (a *Account) IsGeminiCodeAssist() bool {
@@ -163,6 +174,114 @@ func (a *Account) GetCredentialAsTime(key string) *time.Time {
return nil
}
func (a *Account) IsTempUnschedulableEnabled() bool {
if a.Credentials == nil {
return false
}
raw, ok := a.Credentials["temp_unschedulable_enabled"]
if !ok || raw == nil {
return false
}
enabled, ok := raw.(bool)
return ok && enabled
}
func (a *Account) GetTempUnschedulableRules() []TempUnschedulableRule {
if a.Credentials == nil {
return nil
}
raw, ok := a.Credentials["temp_unschedulable_rules"]
if !ok || raw == nil {
return nil
}
arr, ok := raw.([]any)
if !ok {
return nil
}
rules := make([]TempUnschedulableRule, 0, len(arr))
for _, item := range arr {
entry, ok := item.(map[string]any)
if !ok || entry == nil {
continue
}
rule := TempUnschedulableRule{
ErrorCode: parseTempUnschedInt(entry["error_code"]),
Keywords: parseTempUnschedStrings(entry["keywords"]),
DurationMinutes: parseTempUnschedInt(entry["duration_minutes"]),
Description: parseTempUnschedString(entry["description"]),
}
if rule.ErrorCode <= 0 || rule.DurationMinutes <= 0 || len(rule.Keywords) == 0 {
continue
}
rules = append(rules, rule)
}
return rules
}
func parseTempUnschedString(value any) string {
s, ok := value.(string)
if !ok {
return ""
}
return strings.TrimSpace(s)
}
func parseTempUnschedStrings(value any) []string {
if value == nil {
return nil
}
var raw []string
switch v := value.(type) {
case []string:
raw = v
case []any:
raw = make([]string, 0, len(v))
for _, item := range v {
if s, ok := item.(string); ok {
raw = append(raw, s)
}
}
default:
return nil
}
out := make([]string, 0, len(raw))
for _, item := range raw {
s := strings.TrimSpace(item)
if s != "" {
out = append(out, s)
}
}
return out
}
func parseTempUnschedInt(value any) int {
switch v := value.(type) {
case int:
return v
case int64:
return int(v)
case float64:
return int(v)
case json.Number:
if i, err := v.Int64(); err == nil {
return int(i)
}
case string:
if i, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
return i
}
}
return 0
}
func (a *Account) GetModelMapping() map[string]string {
if a.Credentials == nil {
return nil
@@ -206,7 +325,7 @@ func (a *Account) GetMappedModel(requestedModel string) string {
}
func (a *Account) GetBaseURL() string {
if a.Type != AccountTypeApiKey {
if a.Type != AccountTypeAPIKey {
return ""
}
baseURL := a.GetCredential("base_url")
@@ -229,7 +348,7 @@ func (a *Account) GetExtraString(key string) string {
}
func (a *Account) IsCustomErrorCodesEnabled() bool {
if a.Type != AccountTypeApiKey || a.Credentials == nil {
if a.Type != AccountTypeAPIKey || a.Credentials == nil {
return false
}
if v, ok := a.Credentials["custom_error_codes_enabled"]; ok {
@@ -301,14 +420,14 @@ func (a *Account) IsOpenAIOAuth() bool {
}
func (a *Account) IsOpenAIApiKey() bool {
return a.IsOpenAI() && a.Type == AccountTypeApiKey
return a.IsOpenAI() && a.Type == AccountTypeAPIKey
}
func (a *Account) GetOpenAIBaseURL() string {
if !a.IsOpenAI() {
return ""
}
if a.Type == AccountTypeApiKey {
if a.Type == AccountTypeAPIKey {
baseURL := a.GetCredential("base_url")
if baseURL != "" {
return baseURL

View File

@@ -49,6 +49,8 @@ type AccountRepository interface {
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
ClearTempUnschedulable(ctx context.Context, id int64) error
ClearRateLimit(ctx context.Context, id int64) error
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error

View File

@@ -139,6 +139,14 @@ func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until tim
panic("unexpected SetOverloaded call")
}
func (s *accountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
panic("unexpected SetTempUnschedulable call")
}
func (s *accountRepoStub) ClearTempUnschedulable(ctx context.Context, id int64) error {
panic("unexpected ClearTempUnschedulable call")
}
func (s *accountRepoStub) ClearRateLimit(ctx context.Context, id int64) error {
panic("unexpected ClearRateLimit call")
}

View File

@@ -369,7 +369,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
}
// For API Key accounts with model mapping, map the model
if account.Type == AccountTypeApiKey {
if account.Type == AccountTypeAPIKey {
mapping := account.GetModelMapping()
if len(mapping) > 0 {
if mappedModel, exists := mapping[testModelID]; exists {
@@ -393,7 +393,7 @@ func (s *AccountTestService) testGeminiAccountConnection(c *gin.Context, account
var err error
switch account.Type {
case AccountTypeApiKey:
case AccountTypeAPIKey:
req, err = s.buildGeminiAPIKeyRequest(ctx, account, testModelID, payload)
case AccountTypeOAuth:
req, err = s.buildGeminiOAuthRequest(ctx, account, testModelID, payload)

View File

@@ -12,16 +12,18 @@ import (
)
type UsageLogRepository interface {
Create(ctx context.Context, log *UsageLog) error
// Create creates a usage log and returns whether it was actually inserted.
// inserted is false when the insert was skipped due to conflict (idempotent retries).
Create(ctx context.Context, log *UsageLog) (inserted bool, err error)
GetByID(ctx context.Context, id int64) (*UsageLog, error)
Delete(ctx context.Context, id int64) error
ListByUser(ctx context.Context, userID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByAccount(ctx context.Context, accountID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error)
ListByUserAndTimeRange(ctx context.Context, userID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByApiKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByAPIKeyAndTimeRange(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByAccountAndTimeRange(ctx context.Context, accountID int64, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
ListByModelAndTimeRange(ctx context.Context, modelName string, startTime, endTime time.Time) ([]UsageLog, *pagination.PaginationResult, error)
@@ -32,10 +34,10 @@ type UsageLogRepository interface {
GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error)
GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID int64) ([]usagestats.TrendDataPoint, error)
GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID int64) ([]usagestats.ModelStat, error)
GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error)
GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error)
GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error)
GetBatchUserUsageStats(ctx context.Context, userIDs []int64) (map[int64]*usagestats.BatchUserUsageStats, error)
GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error)
GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error)
// User dashboard stats
GetUserDashboardStats(ctx context.Context, userID int64) (*usagestats.UserDashboardStats, error)
@@ -51,7 +53,7 @@ type UsageLogRepository interface {
// Aggregated stats (optimized)
GetUserStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetApiKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetAPIKeyStatsAggregated(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetAccountStatsAggregated(ctx context.Context, accountID int64, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetModelStatsAggregated(ctx context.Context, modelName string, startTime, endTime time.Time) (*usagestats.UsageStats, error)
GetDailyStatsAggregated(ctx context.Context, userID int64, startTime, endTime time.Time) ([]map[string]any, error)
@@ -105,6 +107,8 @@ type UsageProgress struct {
ResetsAt *time.Time `json:"resets_at"` // 重置时间
RemainingSeconds int `json:"remaining_seconds"` // 距重置剩余秒数
WindowStats *WindowStats `json:"window_stats,omitempty"` // 窗口期统计(从窗口开始到当前的使用量)
UsedRequests int64 `json:"used_requests,omitempty"`
LimitRequests int64 `json:"limit_requests,omitempty"`
}
// AntigravityModelQuota Antigravity 单个模型的配额信息
@@ -115,12 +119,16 @@ type AntigravityModelQuota struct {
// UsageInfo 账号使用量信息
type UsageInfo struct {
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额
GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
GeminiSharedDaily *UsageProgress `json:"gemini_shared_daily,omitempty"` // Gemini shared pool RPD (Google One / Code Assist)
GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额
GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额
GeminiSharedMinute *UsageProgress `json:"gemini_shared_minute,omitempty"` // Gemini shared pool RPM (Google One / Code Assist)
GeminiProMinute *UsageProgress `json:"gemini_pro_minute,omitempty"` // Gemini Pro RPM
GeminiFlashMinute *UsageProgress `json:"gemini_flash_minute,omitempty"` // Gemini Flash RPM
// Antigravity 多模型配额
AntigravityQuota map[string]*AntigravityModelQuota `json:"antigravity_quota,omitempty"`
@@ -256,17 +264,44 @@ func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Accou
return usage, nil
}
start := geminiDailyWindowStart(now)
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
dayStart := geminiDailyWindowStart(now)
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, dayStart, now, 0, 0, account.ID)
if err != nil {
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
}
totals := geminiAggregateUsage(stats)
resetAt := geminiDailyResetTime(now)
dayTotals := geminiAggregateUsage(stats)
dailyResetAt := geminiDailyResetTime(now)
usage.GeminiProDaily = buildGeminiUsageProgress(totals.ProRequests, quota.ProRPD, resetAt, totals.ProTokens, totals.ProCost, now)
usage.GeminiFlashDaily = buildGeminiUsageProgress(totals.FlashRequests, quota.FlashRPD, resetAt, totals.FlashTokens, totals.FlashCost, now)
// Daily window (RPD)
if quota.SharedRPD > 0 {
totalReq := dayTotals.ProRequests + dayTotals.FlashRequests
totalTokens := dayTotals.ProTokens + dayTotals.FlashTokens
totalCost := dayTotals.ProCost + dayTotals.FlashCost
usage.GeminiSharedDaily = buildGeminiUsageProgress(totalReq, quota.SharedRPD, dailyResetAt, totalTokens, totalCost, now)
} else {
usage.GeminiProDaily = buildGeminiUsageProgress(dayTotals.ProRequests, quota.ProRPD, dailyResetAt, dayTotals.ProTokens, dayTotals.ProCost, now)
usage.GeminiFlashDaily = buildGeminiUsageProgress(dayTotals.FlashRequests, quota.FlashRPD, dailyResetAt, dayTotals.FlashTokens, dayTotals.FlashCost, now)
}
// Minute window (RPM) - fixed-window approximation: current minute [truncate(now), truncate(now)+1m)
minuteStart := now.Truncate(time.Minute)
minuteResetAt := minuteStart.Add(time.Minute)
minuteStats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, minuteStart, now, 0, 0, account.ID)
if err != nil {
return nil, fmt.Errorf("get gemini minute usage stats failed: %w", err)
}
minuteTotals := geminiAggregateUsage(minuteStats)
if quota.SharedRPM > 0 {
totalReq := minuteTotals.ProRequests + minuteTotals.FlashRequests
totalTokens := minuteTotals.ProTokens + minuteTotals.FlashTokens
totalCost := minuteTotals.ProCost + minuteTotals.FlashCost
usage.GeminiSharedMinute = buildGeminiUsageProgress(totalReq, quota.SharedRPM, minuteResetAt, totalTokens, totalCost, now)
} else {
usage.GeminiProMinute = buildGeminiUsageProgress(minuteTotals.ProRequests, quota.ProRPM, minuteResetAt, minuteTotals.ProTokens, minuteTotals.ProCost, now)
usage.GeminiFlashMinute = buildGeminiUsageProgress(minuteTotals.FlashRequests, quota.FlashRPM, minuteResetAt, minuteTotals.FlashTokens, minuteTotals.FlashCost, now)
}
return usage, nil
}
@@ -506,6 +541,7 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn
}
func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64, cost float64, now time.Time) *UsageProgress {
// limit <= 0 means "no local quota window" (unknown or unlimited).
if limit <= 0 {
return nil
}
@@ -519,6 +555,8 @@ func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64
Utilization: utilization,
ResetsAt: &resetCopy,
RemainingSeconds: remainingSeconds,
UsedRequests: used,
LimitRequests: limit,
WindowStats: &WindowStats{
Requests: used,
Tokens: tokens,

View File

@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"log"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
@@ -19,7 +20,7 @@ type AdminService interface {
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
DeleteUser(ctx context.Context, id int64) error
UpdateUserBalance(ctx context.Context, userID int64, balance float64, operation string, notes string) (*User, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error)
GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error)
GetUserUsageStats(ctx context.Context, userID int64, period string) (any, error)
// Group management
@@ -30,7 +31,7 @@ type AdminService interface {
CreateGroup(ctx context.Context, input *CreateGroupInput) (*Group, error)
UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error)
DeleteGroup(ctx context.Context, id int64) error
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error)
GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error)
// Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
@@ -65,7 +66,7 @@ type AdminService interface {
ExpireRedeemCode(ctx context.Context, id int64) (*RedeemCode, error)
}
// Input types for admin operations
// CreateUserInput represents input for creating a new user via admin operations.
type CreateUserInput struct {
Email string
Password string
@@ -122,18 +123,22 @@ type CreateAccountInput struct {
Concurrency int
Priority int
GroupIDs []int64
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
// This should only be set when the caller has explicitly confirmed the risk.
SkipMixedChannelCheck bool
}
type UpdateAccountInput struct {
Name string
Type string // Account type: oauth, setup-token, apikey
Credentials map[string]any
Extra map[string]any
ProxyID *int64
Concurrency *int // 使用指针区分"未提供"和"设置为0"
Priority *int // 使用指针区分"未提供"和"设置为0"
Status string
GroupIDs *[]int64
Name string
Type string // Account type: oauth, setup-token, apikey
Credentials map[string]any
Extra map[string]any
ProxyID *int64
Concurrency *int // 使用指针区分"未提供"和"设置为0"
Priority *int // 使用指针区分"未提供"和"设置为0"
Status string
GroupIDs *[]int64
SkipMixedChannelCheck bool // 跳过混合渠道检查(用户已确认风险)
}
// BulkUpdateAccountsInput describes the payload for bulk updating accounts.
@@ -147,6 +152,9 @@ type BulkUpdateAccountsInput struct {
GroupIDs *[]int64
Credentials map[string]any
Extra map[string]any
// SkipMixedChannelCheck skips the mixed channel risk check when binding groups.
// This should only be set when the caller has explicitly confirmed the risk.
SkipMixedChannelCheck bool
}
// BulkUpdateAccountResult captures the result for a single account update.
@@ -220,7 +228,7 @@ type adminServiceImpl struct {
groupRepo GroupRepository
accountRepo AccountRepository
proxyRepo ProxyRepository
apiKeyRepo ApiKeyRepository
apiKeyRepo APIKeyRepository
redeemCodeRepo RedeemCodeRepository
billingCacheService *BillingCacheService
proxyProber ProxyExitInfoProber
@@ -232,7 +240,7 @@ func NewAdminService(
groupRepo GroupRepository,
accountRepo AccountRepository,
proxyRepo ProxyRepository,
apiKeyRepo ApiKeyRepository,
apiKeyRepo APIKeyRepository,
redeemCodeRepo RedeemCodeRepository,
billingCacheService *BillingCacheService,
proxyProber ProxyExitInfoProber,
@@ -430,7 +438,7 @@ func (s *adminServiceImpl) UpdateUserBalance(ctx context.Context, userID int64,
return user, nil
}
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]ApiKey, int64, error) {
func (s *adminServiceImpl) GetUserAPIKeys(ctx context.Context, userID int64, page, pageSize int) ([]APIKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil {
@@ -583,7 +591,7 @@ func (s *adminServiceImpl) DeleteGroup(ctx context.Context, id int64) error {
return nil
}
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]ApiKey, int64, error) {
func (s *adminServiceImpl) GetGroupAPIKeys(ctx context.Context, groupID int64, page, pageSize int) ([]APIKey, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
keys, result, err := s.apiKeyRepo.ListByGroupID(ctx, groupID, params)
if err != nil {
@@ -620,6 +628,29 @@ func (s *adminServiceImpl) GetAccountsByIDs(ctx context.Context, ids []int64) ([
}
func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) {
// 绑定分组
groupIDs := input.GroupIDs
// 如果没有指定分组,自动绑定对应平台的默认分组
if len(groupIDs) == 0 {
defaultGroupName := input.Platform + "-default"
groups, err := s.groupRepo.ListActiveByPlatform(ctx, input.Platform)
if err == nil {
for _, g := range groups {
if g.Name == defaultGroupName {
groupIDs = []int64{g.ID}
break
}
}
}
}
// 检查混合渠道风险(除非用户已确认)
if len(groupIDs) > 0 && !input.SkipMixedChannelCheck {
if err := s.checkMixedChannelRisk(ctx, 0, input.Platform, groupIDs); err != nil {
return nil, err
}
}
account := &Account{
Name: input.Name,
Platform: input.Platform,
@@ -637,22 +668,6 @@ func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccou
}
// 绑定分组
groupIDs := input.GroupIDs
// 如果没有指定分组,自动绑定对应平台的默认分组
if len(groupIDs) == 0 {
defaultGroupName := input.Platform + "-default"
groups, err := s.groupRepo.ListActiveByPlatform(ctx, input.Platform)
if err == nil {
for _, g := range groups {
if g.Name == defaultGroupName {
groupIDs = []int64{g.ID}
log.Printf("[CreateAccount] Auto-binding account %d to default group %s (ID: %d)", account.ID, defaultGroupName, g.ID)
break
}
}
}
}
if len(groupIDs) > 0 {
if err := s.accountRepo.BindGroups(ctx, account.ID, groupIDs); err != nil {
return nil, err
@@ -703,6 +718,13 @@ func (s *adminServiceImpl) UpdateAccount(ctx context.Context, id int64, input *U
return nil, fmt.Errorf("get group: %w", err)
}
}
// 检查混合渠道风险(除非用户已确认)
if !input.SkipMixedChannelCheck {
if err := s.checkMixedChannelRisk(ctx, account.ID, account.Platform, *input.GroupIDs); err != nil {
return nil, err
}
}
}
if err := s.accountRepo.Update(ctx, account); err != nil {
@@ -731,6 +753,20 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
return result, nil
}
// Preload account platforms for mixed channel risk checks if group bindings are requested.
platformByID := map[int64]string{}
if input.GroupIDs != nil && !input.SkipMixedChannelCheck {
accounts, err := s.accountRepo.GetByIDs(ctx, input.AccountIDs)
if err != nil {
return nil, err
}
for _, account := range accounts {
if account != nil {
platformByID[account.ID] = account.Platform
}
}
}
// Prepare bulk updates for columns and JSONB fields.
repoUpdates := AccountBulkUpdate{
Credentials: input.Credentials,
@@ -762,6 +798,29 @@ func (s *adminServiceImpl) BulkUpdateAccounts(ctx context.Context, input *BulkUp
entry := BulkUpdateAccountResult{AccountID: accountID}
if input.GroupIDs != nil {
// 检查混合渠道风险(除非用户已确认)
if !input.SkipMixedChannelCheck {
platform := platformByID[accountID]
if platform == "" {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil {
entry.Success = false
entry.Error = err.Error()
result.Failed++
result.Results = append(result.Results, entry)
continue
}
platform = account.Platform
}
if err := s.checkMixedChannelRisk(ctx, accountID, platform, *input.GroupIDs); err != nil {
entry.Success = false
entry.Error = err.Error()
result.Failed++
result.Results = append(result.Results, entry)
continue
}
}
if err := s.accountRepo.BindGroups(ctx, accountID, *input.GroupIDs); err != nil {
entry.Success = false
entry.Error = err.Error()
@@ -1006,3 +1065,77 @@ func (s *adminServiceImpl) TestProxy(ctx context.Context, id int64) (*ProxyTestR
Country: exitInfo.Country,
}, nil
}
// checkMixedChannelRisk 检查分组中是否存在混合渠道Antigravity + Anthropic
// 如果存在混合,返回错误提示用户确认
func (s *adminServiceImpl) checkMixedChannelRisk(ctx context.Context, currentAccountID int64, currentAccountPlatform string, groupIDs []int64) error {
// 判断当前账号的渠道类型(基于 platform 字段,而不是 type 字段)
currentPlatform := getAccountPlatform(currentAccountPlatform)
if currentPlatform == "" {
// 不是 Antigravity 或 Anthropic无需检查
return nil
}
// 检查每个分组中的其他账号
for _, groupID := range groupIDs {
accounts, err := s.accountRepo.ListByGroup(ctx, groupID)
if err != nil {
return fmt.Errorf("get accounts in group %d: %w", groupID, err)
}
// 检查是否存在不同渠道的账号
for _, account := range accounts {
if currentAccountID > 0 && account.ID == currentAccountID {
continue // 跳过当前账号
}
otherPlatform := getAccountPlatform(account.Platform)
if otherPlatform == "" {
continue // 不是 Antigravity 或 Anthropic跳过
}
// 检测混合渠道
if currentPlatform != otherPlatform {
group, _ := s.groupRepo.GetByID(ctx, groupID)
groupName := fmt.Sprintf("Group %d", groupID)
if group != nil {
groupName = group.Name
}
return &MixedChannelError{
GroupID: groupID,
GroupName: groupName,
CurrentPlatform: currentPlatform,
OtherPlatform: otherPlatform,
}
}
}
}
return nil
}
// getAccountPlatform 根据账号 platform 判断混合渠道检查用的平台标识
func getAccountPlatform(accountPlatform string) string {
switch strings.ToLower(strings.TrimSpace(accountPlatform)) {
case PlatformAntigravity:
return "Antigravity"
case PlatformAnthropic, "claude":
return "Anthropic"
default:
return ""
}
}
// MixedChannelError 混合渠道错误
type MixedChannelError struct {
GroupID int64
GroupName string
CurrentPlatform string
OtherPlatform string
}
func (e *MixedChannelError) Error() string {
return fmt.Sprintf("mixed_channel_warning: Group '%s' contains both %s and %s accounts. Using mixed channels in the same context may cause thinking block signature validation issues, which will fallback to non-thinking mode for historical messages.",
e.GroupName, e.CurrentPlatform, e.OtherPlatform)
}

View File

@@ -81,6 +81,7 @@ type AntigravityGatewayService struct {
tokenProvider *AntigravityTokenProvider
rateLimitService *RateLimitService
httpUpstream HTTPUpstream
settingService *SettingService
}
func NewAntigravityGatewayService(
@@ -89,12 +90,14 @@ func NewAntigravityGatewayService(
tokenProvider *AntigravityTokenProvider,
rateLimitService *RateLimitService,
httpUpstream HTTPUpstream,
settingService *SettingService,
) *AntigravityGatewayService {
return &AntigravityGatewayService{
accountRepo: accountRepo,
tokenProvider: tokenProvider,
rateLimitService: rateLimitService,
httpUpstream: httpUpstream,
settingService: settingService,
}
}
@@ -324,6 +327,22 @@ func (s *AntigravityGatewayService) unwrapV1InternalResponse(body []byte) ([]byt
return body, nil
}
// isModelNotFoundError 检测是否为模型不存在的 404 错误
func isModelNotFoundError(statusCode int, body []byte) bool {
if statusCode != 404 {
return false
}
bodyStr := strings.ToLower(string(body))
keywords := []string{"model not found", "unknown model", "not found"}
for _, keyword := range keywords {
if strings.Contains(bodyStr, keyword) {
return true
}
}
return true // 404 without specific message also treated as model not found
}
// Forward 转发 Claude 协议请求Claude → Gemini 转换)
func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context, account *Account, body []byte) (*ForwardResult, error) {
startTime := time.Now()
@@ -417,16 +436,56 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
}
defer func() { _ = resp.Body.Close() }()
// 处理错误响应
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
if s.shouldFailoverUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
// 优先检测 thinking block 的 signature 相关错误400并重试一次
// Antigravity /v1internal 链路在部分场景会对 thought/thinking signature 做严格校验,
// 当历史消息携带的 signature 不合法时会直接 400去除 thinking 后可继续完成请求。
if resp.StatusCode == http.StatusBadRequest && isSignatureRelatedError(respBody) {
retryClaudeReq := claudeReq
retryClaudeReq.Messages = append([]antigravity.ClaudeMessage(nil), claudeReq.Messages...)
stripped, stripErr := stripThinkingFromClaudeRequest(&retryClaudeReq)
if stripErr == nil && stripped {
log.Printf("Antigravity account %d: detected signature-related 400, retrying once without thinking blocks", account.ID)
retryGeminiBody, txErr := antigravity.TransformClaudeToGemini(&retryClaudeReq, projectID, mappedModel)
if txErr == nil {
retryReq, buildErr := antigravity.NewAPIRequest(ctx, action, accessToken, retryGeminiBody)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if retryErr == nil {
// Retry success: continue normal success flow with the new response.
if retryResp.StatusCode < 400 {
_ = resp.Body.Close()
resp = retryResp
respBody = nil
} else {
// Retry still errored: replace error context with retry response.
retryBody, _ := io.ReadAll(io.LimitReader(retryResp.Body, 2<<20))
_ = retryResp.Body.Close()
respBody = retryBody
resp = retryResp
}
} else {
log.Printf("Antigravity account %d: signature retry request failed: %v", account.ID, retryErr)
}
}
}
}
}
return nil, s.writeMappedClaudeError(c, resp.StatusCode, respBody)
// 处理错误响应(重试后仍失败或不触发重试)
if resp.StatusCode >= 400 {
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
if s.shouldFailoverUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
return nil, s.writeMappedClaudeError(c, resp.StatusCode, respBody)
}
}
requestID := resp.Header.Get("x-request-id")
@@ -461,6 +520,122 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
}, nil
}
func isSignatureRelatedError(respBody []byte) bool {
msg := strings.ToLower(strings.TrimSpace(extractAntigravityErrorMessage(respBody)))
if msg == "" {
// Fallback: best-effort scan of the raw payload.
msg = strings.ToLower(string(respBody))
}
// Keep this intentionally broad: different upstreams may use "signature" or "thought_signature".
return strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature")
}
func extractAntigravityErrorMessage(body []byte) string {
var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil {
return ""
}
// Google-style: {"error": {"message": "..."}}
if errObj, ok := payload["error"].(map[string]any); ok {
if msg, ok := errObj["message"].(string); ok && strings.TrimSpace(msg) != "" {
return msg
}
}
// Fallback: top-level message
if msg, ok := payload["message"].(string); ok && strings.TrimSpace(msg) != "" {
return msg
}
return ""
}
// stripThinkingFromClaudeRequest converts thinking blocks to text blocks in a Claude Messages request.
// This preserves the thinking content while avoiding signature validation errors.
// Note: redacted_thinking blocks are removed because they cannot be converted to text.
// It also disables top-level `thinking` to prevent dummy-thought injection during retry.
func stripThinkingFromClaudeRequest(req *antigravity.ClaudeRequest) (bool, error) {
if req == nil {
return false, nil
}
changed := false
if req.Thinking != nil {
req.Thinking = nil
changed = true
}
for i := range req.Messages {
raw := req.Messages[i].Content
if len(raw) == 0 {
continue
}
// If content is a string, nothing to strip.
var str string
if json.Unmarshal(raw, &str) == nil {
continue
}
// Otherwise treat as an array of blocks and convert thinking blocks to text.
var blocks []map[string]any
if err := json.Unmarshal(raw, &blocks); err != nil {
continue
}
filtered := make([]map[string]any, 0, len(blocks))
modifiedAny := false
for _, block := range blocks {
t, _ := block["type"].(string)
switch t {
case "thinking":
// Convert thinking to text, skip if empty
thinkingText, _ := block["thinking"].(string)
if thinkingText != "" {
filtered = append(filtered, map[string]any{
"type": "text",
"text": thinkingText,
})
}
modifiedAny = true
case "redacted_thinking":
// Remove redacted_thinking (cannot convert encrypted content)
modifiedAny = true
case "":
// Handle untyped block with "thinking" field
if thinkingText, hasThinking := block["thinking"].(string); hasThinking {
if thinkingText != "" {
filtered = append(filtered, map[string]any{
"type": "text",
"text": thinkingText,
})
}
modifiedAny = true
} else {
filtered = append(filtered, block)
}
default:
filtered = append(filtered, block)
}
}
if !modifiedAny {
continue
}
newRaw, err := json.Marshal(filtered)
if err != nil {
return changed, err
}
req.Messages[i].Content = newRaw
changed = true
}
return changed, nil
}
// ForwardGemini 转发 Gemini 协议请求
func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Context, account *Account, originalModel string, action string, stream bool, body []byte) (*ForwardResult, error) {
startTime := time.Now()
@@ -574,14 +749,40 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
}
defer func() { _ = resp.Body.Close() }()
requestID := resp.Header.Get("x-request-id")
if requestID != "" {
c.Header("x-request-id", requestID)
}
// 处理错误响应
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
// 模型兜底:模型不存在且开启 fallback 时,自动用 fallback 模型重试一次
if s.settingService != nil && s.settingService.IsModelFallbackEnabled(ctx) &&
isModelNotFoundError(resp.StatusCode, respBody) {
fallbackModel := s.settingService.GetFallbackModel(ctx, PlatformAntigravity)
if fallbackModel != "" && fallbackModel != mappedModel {
log.Printf("[Antigravity] Model not found (%s), retrying with fallback model %s (account: %s)", mappedModel, fallbackModel, account.Name)
// 关闭原始响应释放连接respBody 已读取到内存)
_ = resp.Body.Close()
fallbackWrapped, err := s.wrapV1InternalRequest(projectID, fallbackModel, body)
if err == nil {
fallbackReq, err := antigravity.NewAPIRequest(ctx, upstreamAction, accessToken, fallbackWrapped)
if err == nil {
fallbackResp, err := s.httpUpstream.Do(fallbackReq, proxyURL, account.ID, account.Concurrency)
if err == nil && fallbackResp.StatusCode < 400 {
resp = fallbackResp
} else if fallbackResp != nil {
_ = fallbackResp.Body.Close()
}
}
}
}
}
// fallback 成功:继续按正常响应处理
if resp.StatusCode < 400 {
goto handleSuccess
}
s.handleUpstreamError(ctx, prefix, account, resp.StatusCode, resp.Header, respBody)
if s.shouldFailoverUpstreamError(resp.StatusCode) {
@@ -589,6 +790,10 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
}
// 解包并返回错误
requestID := resp.Header.Get("x-request-id")
if requestID != "" {
c.Header("x-request-id", requestID)
}
unwrapped, _ := s.unwrapV1InternalResponse(respBody)
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
@@ -598,6 +803,12 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
return nil, fmt.Errorf("antigravity upstream error: %d", resp.StatusCode)
}
handleSuccess:
requestID := resp.Header.Get("x-request-id")
if requestID != "" {
c.Header("x-request-id", requestID)
}
var usage *ClaudeUsage
var firstTokenMs *int

View File

@@ -2,7 +2,7 @@ package service
import "time"
type ApiKey struct {
type APIKey struct {
ID int64
UserID int64
Key string
@@ -15,6 +15,6 @@ type ApiKey struct {
Group *Group
}
func (k *ApiKey) IsActive() bool {
func (k *APIKey) IsActive() bool {
return k.Status == StatusActive
}

View File

@@ -14,39 +14,39 @@ import (
)
var (
ErrApiKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
ErrAPIKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group")
ErrApiKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
ErrApiKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
ErrApiKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
ErrApiKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
ErrAPIKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
)
const (
apiKeyMaxErrorsPerHour = 20
)
type ApiKeyRepository interface {
Create(ctx context.Context, key *ApiKey) error
GetByID(ctx context.Context, id int64) (*ApiKey, error)
type APIKeyRepository interface {
Create(ctx context.Context, key *APIKey) error
GetByID(ctx context.Context, id int64) (*APIKey, error)
// GetOwnerID 仅获取 API Key 的所有者 ID用于删除前的轻量级权限验证
GetOwnerID(ctx context.Context, id int64) (int64, error)
GetByKey(ctx context.Context, key string) (*ApiKey, error)
Update(ctx context.Context, key *ApiKey) error
GetByKey(ctx context.Context, key string) (*APIKey, error)
Update(ctx context.Context, key *APIKey) error
Delete(ctx context.Context, id int64) error
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error)
CountByUserID(ctx context.Context, userID int64) (int64, error)
ExistsByKey(ctx context.Context, key string) (bool, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error)
SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error)
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
}
// ApiKeyCache defines cache operations for API key service
type ApiKeyCache interface {
// APIKeyCache defines cache operations for API key service
type APIKeyCache interface {
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
IncrementCreateAttemptCount(ctx context.Context, userID int64) error
DeleteCreateAttemptCount(ctx context.Context, userID int64) error
@@ -55,40 +55,40 @@ type ApiKeyCache interface {
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
}
// CreateApiKeyRequest 创建API Key请求
type CreateApiKeyRequest struct {
// CreateAPIKeyRequest 创建API Key请求
type CreateAPIKeyRequest struct {
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
CustomKey *string `json:"custom_key"` // 可选的自定义key
}
// UpdateApiKeyRequest 更新API Key请求
type UpdateApiKeyRequest struct {
// UpdateAPIKeyRequest 更新API Key请求
type UpdateAPIKeyRequest struct {
Name *string `json:"name"`
GroupID *int64 `json:"group_id"`
Status *string `json:"status"`
}
// ApiKeyService API Key服务
type ApiKeyService struct {
apiKeyRepo ApiKeyRepository
// APIKeyService API Key服务
type APIKeyService struct {
apiKeyRepo APIKeyRepository
userRepo UserRepository
groupRepo GroupRepository
userSubRepo UserSubscriptionRepository
cache ApiKeyCache
cache APIKeyCache
cfg *config.Config
}
// NewApiKeyService 创建API Key服务实例
func NewApiKeyService(
apiKeyRepo ApiKeyRepository,
// NewAPIKeyService 创建API Key服务实例
func NewAPIKeyService(
apiKeyRepo APIKeyRepository,
userRepo UserRepository,
groupRepo GroupRepository,
userSubRepo UserSubscriptionRepository,
cache ApiKeyCache,
cache APIKeyCache,
cfg *config.Config,
) *ApiKeyService {
return &ApiKeyService{
) *APIKeyService {
return &APIKeyService{
apiKeyRepo: apiKeyRepo,
userRepo: userRepo,
groupRepo: groupRepo,
@@ -99,7 +99,7 @@ func NewApiKeyService(
}
// GenerateKey 生成随机API Key
func (s *ApiKeyService) GenerateKey() (string, error) {
func (s *APIKeyService) GenerateKey() (string, error) {
// 生成32字节随机数据
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
@@ -107,7 +107,7 @@ func (s *ApiKeyService) GenerateKey() (string, error) {
}
// 转换为十六进制字符串并添加前缀
prefix := s.cfg.Default.ApiKeyPrefix
prefix := s.cfg.Default.APIKeyPrefix
if prefix == "" {
prefix = "sk-"
}
@@ -117,10 +117,10 @@ func (s *ApiKeyService) GenerateKey() (string, error) {
}
// ValidateCustomKey 验证自定义API Key格式
func (s *ApiKeyService) ValidateCustomKey(key string) error {
func (s *APIKeyService) ValidateCustomKey(key string) error {
// 检查长度
if len(key) < 16 {
return ErrApiKeyTooShort
return ErrAPIKeyTooShort
}
// 检查字符:只允许字母、数字、下划线、连字符
@@ -131,14 +131,14 @@ func (s *ApiKeyService) ValidateCustomKey(key string) error {
c == '_' || c == '-' {
continue
}
return ErrApiKeyInvalidChars
return ErrAPIKeyInvalidChars
}
return nil
}
// checkApiKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64) error {
// checkAPIKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
func (s *APIKeyService) checkAPIKeyRateLimit(ctx context.Context, userID int64) error {
if s.cache == nil {
return nil
}
@@ -150,14 +150,14 @@ func (s *ApiKeyService) checkApiKeyRateLimit(ctx context.Context, userID int64)
}
if count >= apiKeyMaxErrorsPerHour {
return ErrApiKeyRateLimited
return ErrAPIKeyRateLimited
}
return nil
}
// incrementApiKeyErrorCount 增加用户创建自定义Key的错误计数
func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID int64) {
// incrementAPIKeyErrorCount 增加用户创建自定义Key的错误计数
func (s *APIKeyService) incrementAPIKeyErrorCount(ctx context.Context, userID int64) {
if s.cache == nil {
return
}
@@ -168,7 +168,7 @@ func (s *ApiKeyService) incrementApiKeyErrorCount(ctx context.Context, userID in
// canUserBindGroup 检查用户是否可以绑定指定分组
// 对于订阅类型分组:检查用户是否有有效订阅
// 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑
func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
func (s *APIKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
// 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() {
_, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID)
@@ -179,7 +179,7 @@ func (s *ApiKeyService) canUserBindGroup(ctx context.Context, user *User, group
}
// Create 创建API Key
func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiKeyRequest) (*ApiKey, error) {
func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIKeyRequest) (*APIKey, error) {
// 验证用户存在
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
@@ -204,7 +204,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
// 判断是否使用自定义Key
if req.CustomKey != nil && *req.CustomKey != "" {
// 检查限流仅对自定义key进行限流
if err := s.checkApiKeyRateLimit(ctx, userID); err != nil {
if err := s.checkAPIKeyRateLimit(ctx, userID); err != nil {
return nil, err
}
@@ -220,8 +220,8 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
}
if exists {
// Key已存在增加错误计数
s.incrementApiKeyErrorCount(ctx, userID)
return nil, ErrApiKeyExists
s.incrementAPIKeyErrorCount(ctx, userID)
return nil, ErrAPIKeyExists
}
key = *req.CustomKey
@@ -235,7 +235,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
}
// 创建API Key记录
apiKey := &ApiKey{
apiKey := &APIKey{
UserID: userID,
Key: key,
Name: req.Name,
@@ -251,7 +251,7 @@ func (s *ApiKeyService) Create(ctx context.Context, userID int64, req CreateApiK
}
// List 获取用户的API Key列表
func (s *ApiKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params)
if err != nil {
return nil, nil, fmt.Errorf("list api keys: %w", err)
@@ -259,7 +259,7 @@ func (s *ApiKeyService) List(ctx context.Context, userID int64, params paginatio
return keys, pagination, nil
}
func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
func (s *APIKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
if len(apiKeyIDs) == 0 {
return []int64{}, nil
}
@@ -272,7 +272,7 @@ func (s *ApiKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKe
}
// GetByID 根据ID获取API Key
func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
@@ -281,7 +281,7 @@ func (s *ApiKeyService) GetByID(ctx context.Context, id int64) (*ApiKey, error)
}
// GetByKey 根据Key字符串获取API Key用于认证
func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, error) {
func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) {
// 尝试从Redis缓存获取
cacheKey := fmt.Sprintf("apikey:%s", key)
@@ -301,7 +301,7 @@ func (s *ApiKeyService) GetByKey(ctx context.Context, key string) (*ApiKey, erro
}
// Update 更新API Key
func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateApiKeyRequest) (*ApiKey, error) {
func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateAPIKeyRequest) (*APIKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
@@ -353,8 +353,8 @@ func (s *ApiKeyService) Update(ctx context.Context, id int64, userID int64, req
// Delete 删除API Key
// 优化:使用 GetOwnerID 替代 GetByID 进行权限验证,
// 避免加载完整 ApiKey 对象及其关联数据User、Group提升删除操作的性能
func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) error {
// 避免加载完整 APIKey 对象及其关联数据User、Group提升删除操作的性能
func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error {
// 仅获取所有者 ID 用于权限验证,而非加载完整对象
ownerID, err := s.apiKeyRepo.GetOwnerID(ctx, id)
if err != nil {
@@ -379,7 +379,7 @@ func (s *ApiKeyService) Delete(ctx context.Context, id int64, userID int64) erro
}
// ValidateKey 验证API Key是否有效用于认证中间件
func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *User, error) {
func (s *APIKeyService) ValidateKey(ctx context.Context, key string) (*APIKey, *User, error) {
// 获取API Key
apiKey, err := s.GetByKey(ctx, key)
if err != nil {
@@ -406,7 +406,7 @@ func (s *ApiKeyService) ValidateKey(ctx context.Context, key string) (*ApiKey, *
}
// IncrementUsage 增加API Key使用次数可选用于统计
func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
func (s *APIKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 使用Redis计数器
if s.cache != nil {
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
@@ -423,7 +423,7 @@ func (s *ApiKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 返回用户可以选择的分组:
// - 标准类型分组:公开的(非专属)或用户被明确允许的
// - 订阅类型分组:用户有有效订阅的
func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
func (s *APIKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
// 获取用户信息
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
@@ -460,7 +460,7 @@ func (s *ApiKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([
}
// canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据)
func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
func (s *APIKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
// 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() {
return subscribedGroupIDs[group.ID]
@@ -469,8 +469,8 @@ func (s *ApiKeyService) canUserBindGroupInternal(user *User, group *Group, subsc
return user.CanBindGroup(group.ID, group.IsExclusive)
}
func (s *ApiKeyService) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) {
keys, err := s.apiKeyRepo.SearchApiKeys(ctx, userID, keyword, limit)
func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
keys, err := s.apiKeyRepo.SearchAPIKeys(ctx, userID, keyword, limit)
if err != nil {
return nil, fmt.Errorf("search api keys: %w", err)
}

View File

@@ -1,7 +1,7 @@
//go:build unit
// API Key 服务删除方法的单元测试
// 测试 ApiKeyService.Delete 方法在各种场景下的行为,
// 测试 APIKeyService.Delete 方法在各种场景下的行为,
// 包括权限验证、缓存清理和错误处理
package service
@@ -16,12 +16,12 @@ import (
"github.com/stretchr/testify/require"
)
// apiKeyRepoStub 是 ApiKeyRepository 接口的测试桩实现。
// 用于隔离测试 ApiKeyService.Delete 方法,避免依赖真实数据库。
// apiKeyRepoStub 是 APIKeyRepository 接口的测试桩实现。
// 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。
//
// 设计说明:
// - ownerID: 模拟 GetOwnerID 返回的所有者 ID
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrApiKeyNotFound
// - ownerErr: 模拟 GetOwnerID 返回的错误(如 ErrAPIKeyNotFound
// - deleteErr: 模拟 Delete 返回的错误
// - deletedIDs: 记录被调用删除的 API Key ID用于断言验证
type apiKeyRepoStub struct {
@@ -33,11 +33,11 @@ type apiKeyRepoStub struct {
// 以下方法在本测试中不应被调用,使用 panic 确保测试失败时能快速定位问题
func (s *apiKeyRepoStub) Create(ctx context.Context, key *ApiKey) error {
func (s *apiKeyRepoStub) Create(ctx context.Context, key *APIKey) error {
panic("unexpected Create call")
}
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*ApiKey, error) {
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
panic("unexpected GetByID call")
}
@@ -47,11 +47,11 @@ func (s *apiKeyRepoStub) GetOwnerID(ctx context.Context, id int64) (int64, error
return s.ownerID, s.ownerErr
}
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*ApiKey, error) {
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
panic("unexpected GetByKey call")
}
func (s *apiKeyRepoStub) Update(ctx context.Context, key *ApiKey) error {
func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error {
panic("unexpected Update call")
}
@@ -64,7 +64,7 @@ func (s *apiKeyRepoStub) Delete(ctx context.Context, id int64) error {
// 以下是接口要求实现但本测试不关心的方法
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByUserID call")
}
@@ -80,12 +80,12 @@ func (s *apiKeyRepoStub) ExistsByKey(ctx context.Context, key string) (bool, err
panic("unexpected ExistsByKey call")
}
func (s *apiKeyRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]ApiKey, *pagination.PaginationResult, error) {
func (s *apiKeyRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByGroupID call")
}
func (s *apiKeyRepoStub) SearchApiKeys(ctx context.Context, userID int64, keyword string, limit int) ([]ApiKey, error) {
panic("unexpected SearchApiKeys call")
func (s *apiKeyRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
panic("unexpected SearchAPIKeys call")
}
func (s *apiKeyRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
@@ -96,7 +96,7 @@ func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int
panic("unexpected CountByGroupID call")
}
// apiKeyCacheStub 是 ApiKeyCache 接口的测试桩实现。
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
//
// 设计说明:
@@ -142,7 +142,7 @@ func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string
func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
repo := &apiKeyRepoStub{ownerID: 1}
cache := &apiKeyCacheStub{}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 10, 2) // API Key ID=10, 调用者 userID=2
require.ErrorIs(t, err, ErrInsufficientPerms)
@@ -160,7 +160,7 @@ func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
func TestApiKeyService_Delete_Success(t *testing.T) {
repo := &apiKeyRepoStub{ownerID: 7}
cache := &apiKeyCacheStub{}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 42, 7) // API Key ID=42, 调用者 userID=7
require.NoError(t, err)
@@ -170,17 +170,17 @@ func TestApiKeyService_Delete_Success(t *testing.T) {
// TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
// 预期行为:
// - GetOwnerID 返回 ErrApiKeyNotFound 错误
// - 返回 ErrApiKeyNotFound 错误(被 fmt.Errorf 包装)
// - GetOwnerID 返回 ErrAPIKeyNotFound 错误
// - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装)
// - Delete 方法不被调用
// - 缓存不被清除
func TestApiKeyService_Delete_NotFound(t *testing.T) {
repo := &apiKeyRepoStub{ownerErr: ErrApiKeyNotFound}
repo := &apiKeyRepoStub{ownerErr: ErrAPIKeyNotFound}
cache := &apiKeyCacheStub{}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 99, 1)
require.ErrorIs(t, err, ErrApiKeyNotFound)
require.ErrorIs(t, err, ErrAPIKeyNotFound)
require.Empty(t, repo.deletedIDs)
require.Empty(t, cache.invalidated)
}
@@ -198,7 +198,7 @@ func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
deleteErr: errors.New("delete failed"),
}
cache := &apiKeyCacheStub{}
svc := &ApiKeyService{apiKeyRepo: repo, cache: cache}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 3, 3) // API Key ID=3, 调用者 userID=3
require.Error(t, err)

View File

@@ -445,7 +445,7 @@ func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID
// CheckBillingEligibility 检查用户是否有资格发起请求
// 余额模式:检查缓存余额 > 0
// 订阅模式检查缓存用量未超过限额Group限额从参数传入
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *ApiKey, group *Group, subscription *UserSubscription) error {
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *APIKey, group *Group, subscription *UserSubscription) error {
// 简易模式:跳过所有计费检查
if s.cfg.RunMode == config.RunModeSimple {
return nil

View File

@@ -430,7 +430,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformAnthropic,
Type: AccountTypeApiKey,
Type: AccountTypeAPIKey,
Credentials: credentials,
Extra: extra,
ProxyID: proxyID,
@@ -455,7 +455,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = PlatformAnthropic
existing.Type = AccountTypeApiKey
existing.Type = AccountTypeAPIKey
existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
@@ -674,7 +674,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformOpenAI,
Type: AccountTypeApiKey,
Type: AccountTypeAPIKey,
Credentials: credentials,
Extra: extra,
ProxyID: proxyID,
@@ -699,7 +699,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = PlatformOpenAI
existing.Type = AccountTypeApiKey
existing.Type = AccountTypeAPIKey
existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID
@@ -893,7 +893,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
account := &Account{
Name: defaultName(src.Name, src.ID),
Platform: PlatformGemini,
Type: AccountTypeApiKey,
Type: AccountTypeAPIKey,
Credentials: credentials,
Extra: extra,
ProxyID: proxyID,
@@ -918,7 +918,7 @@ func (s *CRSSyncService) SyncFromCRS(ctx context.Context, input SyncFromCRSInput
existing.Extra = mergeMap(existing.Extra, extra)
existing.Name = defaultName(src.Name, src.ID)
existing.Platform = PlatformGemini
existing.Type = AccountTypeApiKey
existing.Type = AccountTypeAPIKey
existing.Credentials = mergeMap(existing.Credentials, credentials)
if proxyID != nil {
existing.ProxyID = proxyID

View File

@@ -43,8 +43,8 @@ func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTi
return stats, nil
}
func (s *DashboardService) GetApiKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.ApiKeyUsageTrendPoint, error) {
trend, err := s.usageRepo.GetApiKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
func (s *DashboardService) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
trend, err := s.usageRepo.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
if err != nil {
return nil, fmt.Errorf("get api key usage trend: %w", err)
}
@@ -67,8 +67,8 @@ func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs [
return stats, nil
}
func (s *DashboardService) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchApiKeyUsageStats(ctx, apiKeyIDs)
func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
if err != nil {
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
}

View File

@@ -28,7 +28,7 @@ const (
const (
AccountTypeOAuth = "oauth" // OAuth类型账号full scope: profile + inference
AccountTypeSetupToken = "setup-token" // Setup Token类型账号inference only scope
AccountTypeApiKey = "apikey" // API Key类型账号
AccountTypeAPIKey = "apikey" // API Key类型账号
)
// Redeem type constants
@@ -64,13 +64,13 @@ const (
SettingKeyEmailVerifyEnabled = "email_verify_enabled" // 是否开启邮件验证
// 邮件服务设置
SettingKeySmtpHost = "smtp_host" // SMTP服务器地址
SettingKeySmtpPort = "smtp_port" // SMTP端口
SettingKeySmtpUsername = "smtp_username" // SMTP用户名
SettingKeySmtpPassword = "smtp_password" // SMTP密码加密存储
SettingKeySmtpFrom = "smtp_from" // 发件人地址
SettingKeySmtpFromName = "smtp_from_name" // 发件人名称
SettingKeySmtpUseTLS = "smtp_use_tls" // 是否使用TLS
SettingKeySMTPHost = "smtp_host" // SMTP服务器地址
SettingKeySMTPPort = "smtp_port" // SMTP端口
SettingKeySMTPUsername = "smtp_username" // SMTP用户名
SettingKeySMTPPassword = "smtp_password" // SMTP密码加密存储
SettingKeySMTPFrom = "smtp_from" // 发件人地址
SettingKeySMTPFromName = "smtp_from_name" // 发件人名称
SettingKeySMTPUseTLS = "smtp_use_tls" // 是否使用TLS
// Cloudflare Turnstile 设置
SettingKeyTurnstileEnabled = "turnstile_enabled" // 是否启用 Turnstile 验证
@@ -81,20 +81,27 @@ const (
SettingKeySiteName = "site_name" // 网站名称
SettingKeySiteLogo = "site_logo" // 网站Logo (base64)
SettingKeySiteSubtitle = "site_subtitle" // 网站副标题
SettingKeyApiBaseUrl = "api_base_url" // API端点地址用于客户端配置和导入
SettingKeyAPIBaseURL = "api_base_url" // API端点地址用于客户端配置和导入
SettingKeyContactInfo = "contact_info" // 客服联系方式
SettingKeyDocUrl = "doc_url" // 文档链接
SettingKeyDocURL = "doc_url" // 文档链接
// 默认配置
SettingKeyDefaultConcurrency = "default_concurrency" // 新用户默认并发量
SettingKeyDefaultBalance = "default_balance" // 新用户默认余额
// 管理员 API Key
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key用于外部系统集成
SettingKeyAdminAPIKey = "admin_api_key" // 全局管理员 API Key用于外部系统集成
// Gemini 配额策略JSON
SettingKeyGeminiQuotaPolicy = "gemini_quota_policy"
// Model fallback settings
SettingKeyEnableModelFallback = "enable_model_fallback"
SettingKeyFallbackModelAnthropic = "fallback_model_anthropic"
SettingKeyFallbackModelOpenAI = "fallback_model_openai"
SettingKeyFallbackModelGemini = "fallback_model_gemini"
SettingKeyFallbackModelAntigravity = "fallback_model_antigravity"
)
// Admin API Key prefix (distinct from user "sk-" keys)
const AdminApiKeyPrefix = "admin-"
// AdminAPIKeyPrefix is the prefix for admin API keys (distinct from user "sk-" keys).
const AdminAPIKeyPrefix = "admin-"

View File

@@ -40,8 +40,8 @@ const (
maxVerifyCodeAttempts = 5
)
// SmtpConfig SMTP配置
type SmtpConfig struct {
// SMTPConfig SMTP配置
type SMTPConfig struct {
Host string
Port int
Username string
@@ -65,16 +65,16 @@ func NewEmailService(settingRepo SettingRepository, cache EmailCache) *EmailServ
}
}
// GetSmtpConfig 从数据库获取SMTP配置
func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
// GetSMTPConfig 从数据库获取SMTP配置
func (s *EmailService) GetSMTPConfig(ctx context.Context) (*SMTPConfig, error) {
keys := []string{
SettingKeySmtpHost,
SettingKeySmtpPort,
SettingKeySmtpUsername,
SettingKeySmtpPassword,
SettingKeySmtpFrom,
SettingKeySmtpFromName,
SettingKeySmtpUseTLS,
SettingKeySMTPHost,
SettingKeySMTPPort,
SettingKeySMTPUsername,
SettingKeySMTPPassword,
SettingKeySMTPFrom,
SettingKeySMTPFromName,
SettingKeySMTPUseTLS,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
@@ -82,34 +82,34 @@ func (s *EmailService) GetSmtpConfig(ctx context.Context) (*SmtpConfig, error) {
return nil, fmt.Errorf("get smtp settings: %w", err)
}
host := settings[SettingKeySmtpHost]
host := settings[SettingKeySMTPHost]
if host == "" {
return nil, ErrEmailNotConfigured
}
port := 587 // 默认端口
if portStr := settings[SettingKeySmtpPort]; portStr != "" {
if portStr := settings[SettingKeySMTPPort]; portStr != "" {
if p, err := strconv.Atoi(portStr); err == nil {
port = p
}
}
useTLS := settings[SettingKeySmtpUseTLS] == "true"
useTLS := settings[SettingKeySMTPUseTLS] == "true"
return &SmtpConfig{
return &SMTPConfig{
Host: host,
Port: port,
Username: settings[SettingKeySmtpUsername],
Password: settings[SettingKeySmtpPassword],
From: settings[SettingKeySmtpFrom],
FromName: settings[SettingKeySmtpFromName],
Username: settings[SettingKeySMTPUsername],
Password: settings[SettingKeySMTPPassword],
From: settings[SettingKeySMTPFrom],
FromName: settings[SettingKeySMTPFromName],
UseTLS: useTLS,
}, nil
}
// SendEmail 发送邮件(使用数据库中保存的配置)
func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string) error {
config, err := s.GetSmtpConfig(ctx)
config, err := s.GetSMTPConfig(ctx)
if err != nil {
return err
}
@@ -117,7 +117,7 @@ func (s *EmailService) SendEmail(ctx context.Context, to, subject, body string)
}
// SendEmailWithConfig 使用指定配置发送邮件
func (s *EmailService) SendEmailWithConfig(config *SmtpConfig, to, subject, body string) error {
func (s *EmailService) SendEmailWithConfig(config *SMTPConfig, to, subject, body string) error {
from := config.From
if config.FromName != "" {
from = fmt.Sprintf("%s <%s>", config.FromName, config.From)
@@ -306,8 +306,8 @@ func (s *EmailService) buildVerifyCodeEmailBody(code, siteName string) string {
`, siteName, code)
}
// TestSmtpConnectionWithConfig 使用指定配置测试SMTP连接
func (s *EmailService) TestSmtpConnectionWithConfig(config *SmtpConfig) error {
// TestSMTPConnectionWithConfig 使用指定配置测试SMTP连接
func (s *EmailService) TestSMTPConnectionWithConfig(config *SMTPConfig) error {
addr := fmt.Sprintf("%s:%d", config.Host, config.Port)
if config.UseTLS {

View File

@@ -136,6 +136,12 @@ func (m *mockAccountRepoForPlatform) SetRateLimited(ctx context.Context, id int6
func (m *mockAccountRepoForPlatform) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return nil
}
func (m *mockAccountRepoForPlatform) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
return nil
}
func (m *mockAccountRepoForPlatform) ClearTempUnschedulable(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForPlatform) ClearRateLimit(ctx context.Context, id int64) error {
return nil
}
@@ -276,7 +282,7 @@ func TestGatewayService_SelectAccountForModelWithPlatform_GeminiOAuthPreference(
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
},
accountsByID: map[int64]*Account{},
@@ -617,7 +623,7 @@ func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeAPIKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
},
accountsByID: map[int64]*Account{},

View File

@@ -1,6 +1,7 @@
package service
import (
"bytes"
"encoding/json"
"fmt"
)
@@ -70,3 +71,224 @@ func ParseGatewayRequest(body []byte) (*ParsedRequest, error) {
return parsed, nil
}
// FilterThinkingBlocks removes thinking blocks from request body
// Returns filtered body or original body if filtering fails (fail-safe)
// This prevents 400 errors from invalid thinking block signatures
//
// Strategy:
// - When thinking.type != "enabled": Remove all thinking blocks
// - When thinking.type == "enabled": Only remove thinking blocks without valid signatures
// (blocks with missing/empty/dummy signatures that would cause 400 errors)
func FilterThinkingBlocks(body []byte) []byte {
return filterThinkingBlocksInternal(body, false)
}
// FilterThinkingBlocksForRetry removes thinking blocks from HISTORICAL messages for retry scenarios.
// This is used when upstream returns signature-related 400 errors.
//
// Key insight:
// - User's thinking.type = "enabled" should be PRESERVED (user's intent)
// - Only HISTORICAL assistant messages have thinking blocks with signatures
// - These signatures may be invalid when switching accounts/platforms
// - New responses will generate fresh thinking blocks without signature issues
//
// Strategy:
// - Keep thinking.type = "enabled" (preserve user intent)
// - Remove thinking/redacted_thinking blocks from historical assistant messages
// - Ensure no message has empty content after filtering
func FilterThinkingBlocksForRetry(body []byte) []byte {
// Fast path: check for presence of thinking-related keys in messages
if !bytes.Contains(body, []byte(`"type":"thinking"`)) &&
!bytes.Contains(body, []byte(`"type": "thinking"`)) &&
!bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) &&
!bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) {
return body
}
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return body
}
// DO NOT modify thinking.type - preserve user's intent to use thinking mode
// The issue is with historical message signatures, not the thinking mode itself
messages, ok := req["messages"].([]any)
if !ok {
return body
}
modified := false
newMessages := make([]any, 0, len(messages))
for _, msg := range messages {
msgMap, ok := msg.(map[string]any)
if !ok {
newMessages = append(newMessages, msg)
continue
}
role, _ := msgMap["role"].(string)
content, ok := msgMap["content"].([]any)
if !ok {
// String content or other format - keep as is
newMessages = append(newMessages, msg)
continue
}
newContent := make([]any, 0, len(content))
modifiedThisMsg := false
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
newContent = append(newContent, block)
continue
}
blockType, _ := blockMap["type"].(string)
// Remove thinking/redacted_thinking blocks from historical messages
// These have signatures that may be invalid across different accounts
if blockType == "thinking" || blockType == "redacted_thinking" {
modifiedThisMsg = true
continue
}
newContent = append(newContent, block)
}
if modifiedThisMsg {
modified = true
// Handle empty content after filtering
if len(newContent) == 0 {
// For assistant messages, skip entirely (remove from conversation)
// For user messages, add placeholder to avoid empty content error
if role == "user" {
newContent = append(newContent, map[string]any{
"type": "text",
"text": "(content removed)",
})
msgMap["content"] = newContent
newMessages = append(newMessages, msgMap)
}
// Skip assistant messages with empty content (don't append)
continue
}
msgMap["content"] = newContent
}
newMessages = append(newMessages, msgMap)
}
if modified {
req["messages"] = newMessages
}
newBody, err := json.Marshal(req)
if err != nil {
return body
}
return newBody
}
// filterThinkingBlocksInternal removes invalid thinking blocks from request
// Strategy:
// - When thinking.type != "enabled": Remove all thinking blocks
// - When thinking.type == "enabled": Only remove thinking blocks without valid signatures
func filterThinkingBlocksInternal(body []byte, _ bool) []byte {
// Fast path: if body doesn't contain "thinking", skip parsing
if !bytes.Contains(body, []byte(`"type":"thinking"`)) &&
!bytes.Contains(body, []byte(`"type": "thinking"`)) &&
!bytes.Contains(body, []byte(`"type":"redacted_thinking"`)) &&
!bytes.Contains(body, []byte(`"type": "redacted_thinking"`)) &&
!bytes.Contains(body, []byte(`"thinking":`)) &&
!bytes.Contains(body, []byte(`"thinking" :`)) {
return body
}
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return body
}
// Check if thinking is enabled
thinkingEnabled := false
if thinking, ok := req["thinking"].(map[string]any); ok {
if thinkType, ok := thinking["type"].(string); ok && thinkType == "enabled" {
thinkingEnabled = true
}
}
messages, ok := req["messages"].([]any)
if !ok {
return body
}
filtered := false
for _, msg := range messages {
msgMap, ok := msg.(map[string]any)
if !ok {
continue
}
role, _ := msgMap["role"].(string)
content, ok := msgMap["content"].([]any)
if !ok {
continue
}
newContent := make([]any, 0, len(content))
filteredThisMessage := false
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
newContent = append(newContent, block)
continue
}
blockType, _ := blockMap["type"].(string)
if blockType == "thinking" || blockType == "redacted_thinking" {
// When thinking is enabled and this is an assistant message,
// only keep thinking blocks with valid signatures
if thinkingEnabled && role == "assistant" {
signature, _ := blockMap["signature"].(string)
if signature != "" && signature != "skip_thought_signature_validator" {
newContent = append(newContent, block)
continue
}
}
filtered = true
filteredThisMessage = true
continue
}
// Handle blocks without type discriminator but with "thinking" key
if blockType == "" {
if _, hasThinking := blockMap["thinking"]; hasThinking {
filtered = true
filteredThisMessage = true
continue
}
}
newContent = append(newContent, block)
}
if filteredThisMessage {
msgMap["content"] = newContent
}
}
if !filtered {
return body
}
newBody, err := json.Marshal(req)
if err != nil {
return body
}
return newBody
}

View File

@@ -1,6 +1,7 @@
package service
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
@@ -38,3 +39,115 @@ func TestParseGatewayRequest_InvalidStreamType(t *testing.T) {
_, err := ParseGatewayRequest(body)
require.Error(t, err)
}
func TestFilterThinkingBlocks(t *testing.T) {
containsThinkingBlock := func(body []byte) bool {
var req map[string]any
if err := json.Unmarshal(body, &req); err != nil {
return false
}
messages, ok := req["messages"].([]any)
if !ok {
return false
}
for _, msg := range messages {
msgMap, ok := msg.(map[string]any)
if !ok {
continue
}
content, ok := msgMap["content"].([]any)
if !ok {
continue
}
for _, block := range content {
blockMap, ok := block.(map[string]any)
if !ok {
continue
}
blockType, _ := blockMap["type"].(string)
if blockType == "thinking" {
return true
}
if blockType == "" {
if _, hasThinking := blockMap["thinking"]; hasThinking {
return true
}
}
}
}
return false
}
tests := []struct {
name string
input string
shouldFilter bool
expectError bool
}{
{
name: "filters thinking blocks",
input: `{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":[{"type":"text","text":"Hello"},{"type":"thinking","thinking":"internal","signature":"invalid"},{"type":"text","text":"World"}]}]}`,
shouldFilter: true,
},
{
name: "handles no thinking blocks",
input: `{"model":"claude-3-5-sonnet-20241022","messages":[{"role":"user","content":[{"type":"text","text":"Hello"}]}]}`,
shouldFilter: false,
},
{
name: "handles invalid JSON gracefully",
input: `{invalid json`,
shouldFilter: false,
expectError: true,
},
{
name: "handles multiple messages with thinking blocks",
input: `{"messages":[{"role":"user","content":[{"type":"text","text":"A"}]},{"role":"assistant","content":[{"type":"thinking","thinking":"think"},{"type":"text","text":"B"}]}]}`,
shouldFilter: true,
},
{
name: "filters thinking blocks without type discriminator",
input: `{"messages":[{"role":"assistant","content":[{"thinking":{"text":"internal"}},{"type":"text","text":"B"}]}]}`,
shouldFilter: true,
},
{
name: "does not filter tool_use input fields named thinking",
input: `{"messages":[{"role":"user","content":[{"type":"tool_use","id":"t1","name":"foo","input":{"thinking":"keepme","x":1}},{"type":"text","text":"Hello"}]}]}`,
shouldFilter: false,
},
{
name: "handles empty messages array",
input: `{"messages":[]}`,
shouldFilter: false,
},
{
name: "handles missing messages field",
input: `{"model":"claude-3"}`,
shouldFilter: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := FilterThinkingBlocks([]byte(tt.input))
if tt.expectError {
// For invalid JSON, should return original
require.Equal(t, tt.input, string(result))
return
}
if tt.shouldFilter {
require.False(t, containsThinkingBlock(result))
} else {
// Ensure we don't rewrite JSON when no filtering is needed.
require.Equal(t, tt.input, string(result))
}
// Verify valid JSON returned (unless input was invalid)
var parsed map[string]any
err := json.Unmarshal(result, &parsed)
require.NoError(t, err)
})
}
}

View File

@@ -543,7 +543,7 @@ func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, gro
for _, item := range available {
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
@@ -579,7 +579,7 @@ func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates
for _, acc := range ordered {
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
if sessionHash != "" && s.cache != nil {
_ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL)
}
return &AccountSelectionResult{
@@ -710,7 +710,7 @@ func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
preferOAuth := platform == PlatformGemini
// 1. 查询粘性会话
if sessionHash != "" {
if sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
@@ -783,7 +783,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
}
// 4. 建立粘性绑定
if sessionHash != "" {
if sessionHash != "" && s.cache != nil {
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
}
@@ -799,7 +799,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
preferOAuth := nativePlatform == PlatformGemini
// 1. 查询粘性会话
if sessionHash != "" {
if sessionHash != "" && s.cache != nil {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
if err == nil && accountID > 0 {
if _, excluded := excludedIDs[accountID]; !excluded {
@@ -875,7 +875,7 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
}
// 4. 建立粘性绑定
if sessionHash != "" {
if sessionHash != "" && s.cache != nil {
if err := s.cache.SetSessionAccountID(ctx, sessionHash, selected.ID, stickySessionTTL); err != nil {
log.Printf("set session account failed: session=%s account_id=%d err=%v", sessionHash, selected.ID, err)
}
@@ -907,7 +907,7 @@ func (s *GatewayService) GetAccessToken(ctx context.Context, account *Account) (
case AccountTypeOAuth, AccountTypeSetupToken:
// Both oauth and setup-token use OAuth token flow
return s.getOAuthToken(ctx, account)
case AccountTypeApiKey:
case AccountTypeAPIKey:
apiKey := account.GetCredential("api_key")
if apiKey == "" {
return "", "", errors.New("api_key not found in credentials")
@@ -1045,7 +1045,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 应用模型映射仅对apikey类型账号
originalModel := reqModel
if account.Type == AccountTypeApiKey {
if account.Type == AccountTypeAPIKey {
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
// 替换请求体中的模型名
@@ -1082,8 +1082,45 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
return nil, fmt.Errorf("upstream request failed: %w", err)
}
// 检查是否需要重试
if resp.StatusCode >= 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
// 优先检测thinking block签名错误400并重试一次
if resp.StatusCode == 400 {
respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if readErr == nil {
_ = resp.Body.Close()
if s.isThinkingBlockSignatureError(respBody) {
log.Printf("Account %d: detected thinking block signature error, retrying with filtered thinking blocks", account.ID)
// 过滤thinking blocks并重试使用更激进的过滤
filteredBody := FilterThinkingBlocksForRetry(body)
retryReq, buildErr := s.buildUpstreamRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if retryErr == nil {
// 使用重试后的响应,继续后续处理
if retryResp.StatusCode < 400 {
log.Printf("Account %d: signature error retry succeeded", account.ID)
} else {
log.Printf("Account %d: signature error retry returned status %d", account.ID, retryResp.StatusCode)
}
resp = retryResp
break
}
log.Printf("Account %d: signature error retry failed: %v", account.ID, retryErr)
} else {
log.Printf("Account %d: signature error retry build request failed: %v", account.ID, buildErr)
}
// 重试失败,恢复原始响应体继续处理
resp.Body = io.NopCloser(bytes.NewReader(respBody))
break
}
// 不是thinking签名错误恢复响应体
resp.Body = io.NopCloser(bytes.NewReader(respBody))
}
}
// 检查是否需要通用重试排除400因为400已经在上面特殊处理过了
if resp.StatusCode >= 400 && resp.StatusCode != 400 && s.shouldRetryUpstreamError(account, resp.StatusCode) {
if attempt < maxRetries {
log.Printf("Account %d: upstream error %d, retry %d/%d after %v",
account.ID, resp.StatusCode, attempt, maxRetries, retryDelay)
@@ -1096,6 +1133,13 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
}
// 不需要重试(成功或不可重试的错误),跳出循环
// DEBUG: 输出响应 headers用于检测 rate limit 信息)
if account.Platform == PlatformGemini && resp.StatusCode < 400 {
log.Printf("[DEBUG] Gemini API Response Headers for account %d:", account.ID)
for k, v := range resp.Header {
log.Printf("[DEBUG] %s: %v", k, v)
}
}
break
}
defer func() { _ = resp.Body.Close() }()
@@ -1119,7 +1163,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
if resp.StatusCode >= 400 {
// 可选:对部分 400 触发 failover默认关闭以保持语义
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
respBody, readErr := io.ReadAll(resp.Body)
respBody, readErr := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
if readErr != nil {
// ReadAll failed, fall back to normal error handling without consuming the stream
return s.handleErrorResponse(ctx, resp, c, account)
@@ -1179,7 +1223,7 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标URL
targetURL := claudeAPIURL
if account.Type == AccountTypeApiKey {
if account.Type == AccountTypeAPIKey {
baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages"
}
@@ -1243,10 +1287,10 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// 处理anthropic-beta headerOAuth账号需要特殊处理
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultApiKeyBetaHeader(body); beta != "" {
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}
@@ -1313,12 +1357,12 @@ func requestNeedsBetaFeatures(body []byte) bool {
return false
}
func defaultApiKeyBetaHeader(body []byte) string {
func defaultAPIKeyBetaHeader(body []byte) string {
modelID := gjson.GetBytes(body, "model").String()
if strings.Contains(strings.ToLower(modelID), "haiku") {
return claude.ApiKeyHaikuBetaHeader
return claude.APIKeyHaikuBetaHeader
}
return claude.ApiKeyBetaHeader
return claude.APIKeyBetaHeader
}
func truncateForLog(b []byte, maxBytes int) string {
@@ -1335,6 +1379,41 @@ func truncateForLog(b []byte, maxBytes int) string {
return s
}
// isThinkingBlockSignatureError 检测是否是thinking block相关错误
// 这类错误可以通过过滤thinking blocks并重试来解决
func (s *GatewayService) isThinkingBlockSignatureError(respBody []byte) bool {
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
if msg == "" {
return false
}
// Log for debugging
log.Printf("[SignatureCheck] Checking error message: %s", msg)
// 检测signature相关的错误更宽松的匹配
// 例如: "Invalid `signature` in `thinking` block", "***.signature" 等
if strings.Contains(msg, "signature") {
log.Printf("[SignatureCheck] Detected signature error")
return true
}
// 检测 thinking block 顺序/类型错误
// 例如: "Expected `thinking` or `redacted_thinking`, but found `text`"
if strings.Contains(msg, "expected") && (strings.Contains(msg, "thinking") || strings.Contains(msg, "redacted_thinking")) {
log.Printf("[SignatureCheck] Detected thinking block type error")
return true
}
// 检测空消息内容错误(可能是过滤 thinking blocks 后导致的)
// 例如: "all messages must have non-empty content"
if strings.Contains(msg, "non-empty content") || strings.Contains(msg, "empty content") {
log.Printf("[SignatureCheck] Detected empty content error")
return true
}
return false
}
func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
// 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。
// 默认保守:无法识别则不切换。
@@ -1383,7 +1462,13 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
body, _ := io.ReadAll(resp.Body)
// 处理上游错误,标记账号状态
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
shouldDisable := false
if s.rateLimitService != nil {
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
if shouldDisable {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
// 根据状态码返回适当的自定义错误响应(不透传上游详细信息)
var errType, errMsg string
@@ -1695,7 +1780,7 @@ func (s *GatewayService) replaceModelInResponseBody(body []byte, fromModel, toMo
// RecordUsageInput 记录使用量的输入参数
type RecordUsageInput struct {
Result *ForwardResult
ApiKey *ApiKey
APIKey *APIKey
User *User
Account *Account
Subscription *UserSubscription // 可选:订阅信息
@@ -1704,7 +1789,7 @@ type RecordUsageInput struct {
// RecordUsage 记录使用量并扣费(或更新订阅用量)
func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInput) error {
result := input.Result
apiKey := input.ApiKey
apiKey := input.APIKey
user := input.User
account := input.Account
subscription := input.Subscription
@@ -1741,7 +1826,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
durationMs := int(result.Duration.Milliseconds())
usageLog := &UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
Model: result.Model,
@@ -1771,7 +1856,8 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
usageLog.SubscriptionID = &subscription.ID
}
if err := s.usageLogRepo.Create(ctx, usageLog); err != nil {
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
if err != nil {
log.Printf("Create usage log failed: %v", err)
}
@@ -1781,10 +1867,12 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
return nil
}
shouldBill := inserted || err != nil
// 根据计费类型执行扣费
if isSubscriptionBilling {
// 订阅模式:更新订阅用量(使用 TotalCost 原始费用,不考虑倍率)
if cost.TotalCost > 0 {
if shouldBill && cost.TotalCost > 0 {
if err := s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost); err != nil {
log.Printf("Increment subscription usage failed: %v", err)
}
@@ -1793,7 +1881,7 @@ func (s *GatewayService) RecordUsage(ctx context.Context, input *RecordUsageInpu
}
} else {
// 余额模式:扣除用户余额(使用 ActualCost 考虑倍率后的费用)
if cost.ActualCost > 0 {
if shouldBill && cost.ActualCost > 0 {
if err := s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost); err != nil {
log.Printf("Deduct balance failed: %v", err)
}
@@ -1826,7 +1914,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
}
// 应用模型映射(仅对 apikey 类型账号)
if account.Type == AccountTypeApiKey {
if account.Type == AccountTypeAPIKey {
if reqModel != "" {
mappedModel := account.GetMappedModel(reqModel)
if mappedModel != reqModel {
@@ -1863,17 +1951,35 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Request failed")
return fmt.Errorf("upstream request failed: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
// 读取响应体
respBody, err := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
return err
}
// 检测 thinking block 签名错误400并重试一次过滤 thinking blocks
if resp.StatusCode == 400 && s.isThinkingBlockSignatureError(respBody) {
log.Printf("Account %d: detected thinking block signature error on count_tokens, retrying with filtered thinking blocks", account.ID)
filteredBody := FilterThinkingBlocks(body)
retryReq, buildErr := s.buildCountTokensRequest(ctx, c, account, filteredBody, token, tokenType, reqModel)
if buildErr == nil {
retryResp, retryErr := s.httpUpstream.Do(retryReq, proxyURL, account.ID, account.Concurrency)
if retryErr == nil {
resp = retryResp
respBody, err = io.ReadAll(resp.Body)
_ = resp.Body.Close()
if err != nil {
s.countTokensError(c, http.StatusBadGateway, "upstream_error", "Failed to read response")
return err
}
}
}
}
// 处理错误响应
if resp.StatusCode >= 400 {
// 标记账号状态429/529等
@@ -1912,7 +2018,7 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Context, account *Account, body []byte, token, tokenType, modelID string) (*http.Request, error) {
// 确定目标 URL
targetURL := claudeAPICountTokensURL
if account.Type == AccountTypeApiKey {
if account.Type == AccountTypeAPIKey {
baseURL := account.GetBaseURL()
targetURL = baseURL + "/v1/messages/count_tokens"
}
@@ -1971,10 +2077,10 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForAPIKey && req.Header.Get("anthropic-beta") == "" {
// API-key与 messages 同步的按需 beta 注入(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultApiKeyBetaHeader(body); beta != "" {
if beta := defaultAPIKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}

View File

@@ -273,7 +273,7 @@ func (s *GeminiMessagesCompatService) SelectAccountForAIStudioEndpoints(ctx cont
return 999
}
switch a.Type {
case AccountTypeApiKey:
case AccountTypeAPIKey:
if strings.TrimSpace(a.GetCredential("api_key")) != "" {
return 0
}
@@ -351,7 +351,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
originalModel := req.Model
mappedModel := req.Model
if account.Type == AccountTypeApiKey {
if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(req.Model)
}
@@ -374,7 +374,7 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
}
switch account.Type {
case AccountTypeApiKey:
case AccountTypeAPIKey:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
apiKey := account.GetCredential("api_key")
if strings.TrimSpace(apiKey) == "" {
@@ -539,7 +539,14 @@ func (s *GeminiMessagesCompatService) Forward(ctx context.Context, c *gin.Contex
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
tempMatched := false
if s.rateLimitService != nil {
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
}
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
if tempMatched {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
@@ -614,7 +621,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}
mappedModel := originalModel
if account.Type == AccountTypeApiKey {
if account.Type == AccountTypeAPIKey {
mappedModel = account.GetMappedModel(originalModel)
}
@@ -636,7 +643,7 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
var buildReq func(ctx context.Context) (*http.Request, string, error)
switch account.Type {
case AccountTypeApiKey:
case AccountTypeAPIKey:
buildReq = func(ctx context.Context) (*http.Request, string, error) {
apiKey := account.GetCredential("api_key")
if strings.TrimSpace(apiKey) == "" {
@@ -825,6 +832,10 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
if resp.StatusCode >= 400 {
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
tempMatched := false
if s.rateLimitService != nil {
tempMatched = s.rateLimitService.HandleTempUnschedulable(ctx, account, resp.StatusCode, respBody)
}
s.handleGeminiUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
// Best-effort fallback for OAuth tokens missing AI Studio scopes when calling countTokens.
@@ -842,6 +853,9 @@ func (s *GeminiMessagesCompatService) ForwardNative(ctx context.Context, c *gin.
}, nil
}
if tempMatched {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
if s.shouldFailoverGeminiUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
@@ -1614,6 +1628,15 @@ type UpstreamHTTPResult struct {
}
func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Context, resp *http.Response, isOAuth bool) (*ClaudeUsage, error) {
// Log response headers for debugging
log.Printf("[GeminiAPI] ========== Response Headers ==========")
for key, values := range resp.Header {
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
log.Printf("[GeminiAPI] %s: %v", key, values)
}
}
log.Printf("[GeminiAPI] ========================================")
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
@@ -1644,6 +1667,15 @@ func (s *GeminiMessagesCompatService) handleNativeNonStreamingResponse(c *gin.Co
}
func (s *GeminiMessagesCompatService) handleNativeStreamingResponse(c *gin.Context, resp *http.Response, startTime time.Time, isOAuth bool) (*geminiNativeStreamResult, error) {
// Log response headers for debugging
log.Printf("[GeminiAPI] ========== Streaming Response Headers ==========")
for key, values := range resp.Header {
if strings.HasPrefix(strings.ToLower(key), "x-ratelimit") {
log.Printf("[GeminiAPI] %s: %v", key, values)
}
}
log.Printf("[GeminiAPI] ====================================================")
c.Status(resp.StatusCode)
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
@@ -1758,7 +1790,7 @@ func (s *GeminiMessagesCompatService) ForwardAIStudioGET(ctx context.Context, ac
}
switch account.Type {
case AccountTypeApiKey:
case AccountTypeAPIKey:
apiKey := strings.TrimSpace(account.GetCredential("api_key"))
if apiKey == "" {
return nil, errors.New("gemini api_key not configured")
@@ -2177,10 +2209,12 @@ func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[str
parts := make([]any, 0)
switch content := mm["content"].(type) {
case string:
if strings.TrimSpace(content) != "" {
parts = append(parts, map[string]any{"text": content})
}
// 字符串形式的 content保留所有内容包括空白
parts = append(parts, map[string]any{"text": content})
case []any:
// 如果只有一个 block不过滤空白让上游 API 报错)
singleBlock := len(content) == 1
for _, block := range content {
bm, ok := block.(map[string]any)
if !ok {
@@ -2189,8 +2223,12 @@ func convertClaudeMessagesToGeminiContents(messages any, toolUseIDToName map[str
bt, _ := bm["type"].(string)
switch bt {
case "text":
if text, ok := bm["text"].(string); ok && strings.TrimSpace(text) != "" {
parts = append(parts, map[string]any{"text": text})
if text, ok := bm["text"].(string); ok {
// 单个 block 时保留所有内容(包括空白)
// 多个 blocks 时过滤掉空白
if singleBlock || strings.TrimSpace(text) != "" {
parts = append(parts, map[string]any{"text": text})
}
}
case "tool_use":
id, _ := bm["id"].(string)

View File

@@ -121,6 +121,12 @@ func (m *mockAccountRepoForGemini) SetRateLimited(ctx context.Context, id int64,
func (m *mockAccountRepoForGemini) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
return nil
}
func (m *mockAccountRepoForGemini) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
return nil
}
func (m *mockAccountRepoForGemini) ClearTempUnschedulable(ctx context.Context, id int64) error {
return nil
}
func (m *mockAccountRepoForGemini) ClearRateLimit(ctx context.Context, id int64) error { return nil }
func (m *mockAccountRepoForGemini) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
return nil
@@ -275,7 +281,7 @@ func TestGeminiMessagesCompatService_SelectAccountForModelWithExclusions_OAuthPr
repo := &mockAccountRepoForGemini{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Type: AccountTypeApiKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
{ID: 1, Platform: PlatformGemini, Type: AccountTypeAPIKey, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
{ID: 2, Platform: PlatformGemini, Type: AccountTypeOAuth, Priority: 1, Status: StatusActive, Schedulable: true, LastUsedAt: nil},
},
accountsByID: map[int64]*Account{},

View File

@@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"io"
"log"
"net/http"
"regexp"
"strconv"
@@ -18,12 +19,23 @@ import (
)
const (
TierAIPremium = "AI_PREMIUM"
TierGoogleOneStandard = "GOOGLE_ONE_STANDARD"
TierGoogleOneBasic = "GOOGLE_ONE_BASIC"
TierFree = "FREE"
TierGoogleOneUnknown = "GOOGLE_ONE_UNKNOWN"
TierGoogleOneUnlimited = "GOOGLE_ONE_UNLIMITED"
// Canonical tier IDs used by sub2api (2026-aligned).
GeminiTierGoogleOneFree = "google_one_free"
GeminiTierGoogleAIPro = "google_ai_pro"
GeminiTierGoogleAIUltra = "google_ai_ultra"
GeminiTierGCPStandard = "gcp_standard"
GeminiTierGCPEnterprise = "gcp_enterprise"
GeminiTierAIStudioFree = "aistudio_free"
GeminiTierAIStudioPaid = "aistudio_paid"
GeminiTierGoogleOneUnknown = "google_one_unknown"
// Legacy/compat tier IDs that may exist in historical data or upstream responses.
legacyTierAIPremium = "AI_PREMIUM"
legacyTierGoogleOneStandard = "GOOGLE_ONE_STANDARD"
legacyTierGoogleOneBasic = "GOOGLE_ONE_BASIC"
legacyTierFree = "FREE"
legacyTierGoogleOneUnknown = "GOOGLE_ONE_UNKNOWN"
legacyTierGoogleOneUnlimited = "GOOGLE_ONE_UNLIMITED"
)
const (
@@ -84,7 +96,7 @@ type GeminiAuthURLResult struct {
State string `json:"state"`
}
func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI, projectID, oauthType string) (*GeminiAuthURLResult, error) {
func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64, redirectURI, projectID, oauthType, tierID string) (*GeminiAuthURLResult, error) {
state, err := geminicli.GenerateState()
if err != nil {
return nil, fmt.Errorf("failed to generate state: %w", err)
@@ -109,14 +121,14 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
// OAuth client selection:
// - code_assist: always use built-in Gemini CLI OAuth client (public), regardless of configured client_id/secret.
// - google_one: same as code_assist, uses built-in client for personal Google accounts.
// - google_one: uses configured OAuth client when provided; otherwise falls back to built-in client.
// - ai_studio: requires a user-provided OAuth client.
oauthCfg := geminicli.OAuthConfig{
ClientID: s.cfg.Gemini.OAuth.ClientID,
ClientSecret: s.cfg.Gemini.OAuth.ClientSecret,
Scopes: s.cfg.Gemini.OAuth.Scopes,
}
if oauthType == "code_assist" || oauthType == "google_one" {
if oauthType == "code_assist" {
oauthCfg.ClientID = ""
oauthCfg.ClientSecret = ""
}
@@ -127,6 +139,7 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
ProxyURL: proxyURL,
RedirectURI: redirectURI,
ProjectID: strings.TrimSpace(projectID),
TierID: canonicalGeminiTierIDForOAuthType(oauthType, tierID),
OAuthType: oauthType,
CreatedAt: time.Now(),
}
@@ -146,9 +159,9 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
}
// Redirect URI strategy:
// - code_assist: use Gemini CLI redirect URI (codeassist.google.com/authcode)
// - ai_studio: use localhost callback for manual copy/paste flow
if oauthType == "code_assist" {
// - built-in Gemini CLI OAuth client: use upstream redirect URI (codeassist.google.com/authcode)
// - custom OAuth client: use localhost callback for manual copy/paste flow
if isBuiltinClient {
redirectURI = geminicli.GeminiCLIRedirectURI
} else {
redirectURI = geminicli.AIStudioOAuthRedirectURI
@@ -174,6 +187,9 @@ type GeminiExchangeCodeInput struct {
Code string
ProxyID *int64
OAuthType string // "code_assist" 或 "ai_studio"
// TierID is a user-selected tier to be used when auto detection is unavailable or fails.
// If empty, the service will fall back to the tier stored in the OAuth session (if any).
TierID string
}
type GeminiTokenInfo struct {
@@ -185,7 +201,7 @@ type GeminiTokenInfo struct {
Scope string `json:"scope,omitempty"`
ProjectID string `json:"project_id,omitempty"`
OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio"
TierID string `json:"tier_id,omitempty"` // Gemini Code Assist tier: LEGACY/PRO/ULTRA
TierID string `json:"tier_id,omitempty"` // Canonical tier id (e.g. google_one_free, gcp_standard, aistudio_free)
Extra map[string]any `json:"extra,omitempty"` // Drive metadata
}
@@ -204,6 +220,90 @@ func validateTierID(tierID string) error {
return nil
}
func canonicalGeminiTierID(raw string) string {
raw = strings.TrimSpace(raw)
if raw == "" {
return ""
}
lower := strings.ToLower(raw)
switch lower {
case GeminiTierGoogleOneFree,
GeminiTierGoogleAIPro,
GeminiTierGoogleAIUltra,
GeminiTierGCPStandard,
GeminiTierGCPEnterprise,
GeminiTierAIStudioFree,
GeminiTierAIStudioPaid,
GeminiTierGoogleOneUnknown:
return lower
}
upper := strings.ToUpper(raw)
switch upper {
// Google One legacy tiers
case legacyTierAIPremium:
return GeminiTierGoogleAIPro
case legacyTierGoogleOneUnlimited:
return GeminiTierGoogleAIUltra
case legacyTierFree, legacyTierGoogleOneBasic, legacyTierGoogleOneStandard:
return GeminiTierGoogleOneFree
case legacyTierGoogleOneUnknown:
return GeminiTierGoogleOneUnknown
// Code Assist legacy tiers
case "STANDARD", "PRO", "LEGACY":
return GeminiTierGCPStandard
case "ENTERPRISE", "ULTRA":
return GeminiTierGCPEnterprise
}
// Some Code Assist responses use kebab-case tier identifiers.
switch lower {
case "standard-tier", "pro-tier":
return GeminiTierGCPStandard
case "ultra-tier":
return GeminiTierGCPEnterprise
}
return ""
}
func canonicalGeminiTierIDForOAuthType(oauthType, tierID string) string {
oauthType = strings.ToLower(strings.TrimSpace(oauthType))
canonical := canonicalGeminiTierID(tierID)
if canonical == "" {
return ""
}
switch oauthType {
case "google_one":
switch canonical {
case GeminiTierGoogleOneFree, GeminiTierGoogleAIPro, GeminiTierGoogleAIUltra:
return canonical
default:
return ""
}
case "code_assist":
switch canonical {
case GeminiTierGCPStandard, GeminiTierGCPEnterprise:
return canonical
default:
return ""
}
case "ai_studio":
switch canonical {
case GeminiTierAIStudioFree, GeminiTierAIStudioPaid:
return canonical
default:
return ""
}
default:
// Unknown oauth type: accept canonical tier.
return canonical
}
}
// extractTierIDFromAllowedTiers extracts tierID from LoadCodeAssist response
// Prioritizes IsDefault tier, falls back to first non-empty tier
func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string {
@@ -229,45 +329,61 @@ func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string
// inferGoogleOneTier infers Google One tier from Drive storage limit
func inferGoogleOneTier(storageBytes int64) string {
log.Printf("[GeminiOAuth] inferGoogleOneTier - input: %d bytes (%.2f TB)", storageBytes, float64(storageBytes)/float64(TB))
if storageBytes <= 0 {
return TierGoogleOneUnknown
log.Printf("[GeminiOAuth] inferGoogleOneTier - storageBytes <= 0, returning UNKNOWN")
return GeminiTierGoogleOneUnknown
}
if storageBytes > StorageTierUnlimited {
return TierGoogleOneUnlimited
log.Printf("[GeminiOAuth] inferGoogleOneTier - > %d bytes (100TB), returning UNLIMITED", StorageTierUnlimited)
return GeminiTierGoogleAIUltra
}
if storageBytes >= StorageTierAIPremium {
return TierAIPremium
}
if storageBytes >= StorageTierStandard {
return TierGoogleOneStandard
}
if storageBytes >= StorageTierBasic {
return TierGoogleOneBasic
log.Printf("[GeminiOAuth] inferGoogleOneTier - >= %d bytes (2TB), returning google_ai_pro", StorageTierAIPremium)
return GeminiTierGoogleAIPro
}
if storageBytes >= StorageTierFree {
return TierFree
log.Printf("[GeminiOAuth] inferGoogleOneTier - >= %d bytes (15GB), returning FREE", StorageTierFree)
return GeminiTierGoogleOneFree
}
return TierGoogleOneUnknown
log.Printf("[GeminiOAuth] inferGoogleOneTier - < %d bytes (15GB), returning UNKNOWN", StorageTierFree)
return GeminiTierGoogleOneUnknown
}
// fetchGoogleOneTier fetches Google One tier from Drive API
// FetchGoogleOneTier fetches Google One tier from Drive API.
// Note: LoadCodeAssist API is NOT called for Google One accounts because:
// 1. It's designed for GCP IAM (enterprise), not personal Google accounts
// 2. Personal accounts will get 403/404 from cloudaicompanion.googleapis.com
// 3. Google consumer (Google One) and enterprise (GCP) systems are physically isolated
func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken, proxyURL string) (string, *geminicli.DriveStorageInfo, error) {
log.Printf("[GeminiOAuth] Starting FetchGoogleOneTier (Google One personal account)")
// Use Drive API to infer tier from storage quota (requires drive.readonly scope)
log.Printf("[GeminiOAuth] Calling Drive API for storage quota...")
driveClient := geminicli.NewDriveClient()
storageInfo, err := driveClient.GetStorageQuota(ctx, accessToken, proxyURL)
if err != nil {
// Check if it's a 403 (scope not granted)
if strings.Contains(err.Error(), "status 403") {
fmt.Printf("[GeminiOAuth] Drive API scope not available: %v\n", err)
return TierGoogleOneUnknown, nil, err
log.Printf("[GeminiOAuth] Drive API scope not available (403): %v", err)
return GeminiTierGoogleOneUnknown, nil, err
}
// Other errors
fmt.Printf("[GeminiOAuth] Failed to fetch Drive storage: %v\n", err)
return TierGoogleOneUnknown, nil, err
log.Printf("[GeminiOAuth] Failed to fetch Drive storage: %v", err)
return GeminiTierGoogleOneUnknown, nil, err
}
log.Printf("[GeminiOAuth] Drive API response - Limit: %d bytes (%.2f TB), Usage: %d bytes (%.2f GB)",
storageInfo.Limit, float64(storageInfo.Limit)/float64(TB),
storageInfo.Usage, float64(storageInfo.Usage)/float64(GB))
tierID := inferGoogleOneTier(storageInfo.Limit)
log.Printf("[GeminiOAuth] Inferred tier from storage: %s", tierID)
return tierID, storageInfo, nil
}
@@ -326,11 +442,16 @@ func (s *GeminiOAuthService) RefreshAccountGoogleOneTier(
}
func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) {
log.Printf("[GeminiOAuth] ========== ExchangeCode START ==========")
log.Printf("[GeminiOAuth] SessionID: %s", input.SessionID)
session, ok := s.sessionStore.Get(input.SessionID)
if !ok {
log.Printf("[GeminiOAuth] ERROR: Session not found or expired")
return nil, fmt.Errorf("session not found or expired")
}
if strings.TrimSpace(input.State) == "" || input.State != session.State {
log.Printf("[GeminiOAuth] ERROR: Invalid state")
return nil, fmt.Errorf("invalid state")
}
@@ -341,6 +462,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
proxyURL = proxy.URL()
}
}
log.Printf("[GeminiOAuth] ProxyURL: %s", proxyURL)
redirectURI := session.RedirectURI
@@ -349,6 +471,8 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
if oauthType == "" {
oauthType = "code_assist"
}
log.Printf("[GeminiOAuth] OAuth Type: %s", oauthType)
log.Printf("[GeminiOAuth] Project ID from session: %s", session.ProjectID)
// If the session was created for AI Studio OAuth, ensure a custom OAuth client is configured.
if oauthType == "ai_studio" {
@@ -374,8 +498,13 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
tokenResp, err := s.oauthClient.ExchangeCode(ctx, oauthType, input.Code, session.CodeVerifier, redirectURI, proxyURL)
if err != nil {
log.Printf("[GeminiOAuth] ERROR: Failed to exchange code: %v", err)
return nil, fmt.Errorf("failed to exchange code: %w", err)
}
log.Printf("[GeminiOAuth] Token exchange successful")
log.Printf("[GeminiOAuth] Token scope: %s", tokenResp.Scope)
log.Printf("[GeminiOAuth] Token expires_in: %d seconds", tokenResp.ExpiresIn)
sessionProjectID := strings.TrimSpace(session.ProjectID)
s.sessionStore.Delete(input.SessionID)
@@ -391,43 +520,91 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
projectID := sessionProjectID
var tierID string
fallbackTierID := canonicalGeminiTierIDForOAuthType(oauthType, input.TierID)
if fallbackTierID == "" {
fallbackTierID = canonicalGeminiTierIDForOAuthType(oauthType, session.TierID)
}
log.Printf("[GeminiOAuth] ========== Account Type Detection START ==========")
log.Printf("[GeminiOAuth] OAuth Type: %s", oauthType)
// 对于 code_assist 模式project_id 是必需的,需要调用 Code Assist API
// 对于 google_one 模式,使用个人 Google 账号,不需要 project_id配额由 Google 网关自动识别
// 对于 ai_studio 模式project_id 是可选的(不影响使用 AI Studio API
switch oauthType {
case "code_assist":
log.Printf("[GeminiOAuth] Processing code_assist OAuth type")
if projectID == "" {
log.Printf("[GeminiOAuth] No project_id provided, attempting to fetch from LoadCodeAssist API...")
var err error
projectID, tierID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
if err != nil {
// 记录警告但不阻断流程,允许后续补充 project_id
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err)
log.Printf("[GeminiOAuth] WARNING: Failed to fetch project_id: %v", err)
} else {
log.Printf("[GeminiOAuth] Successfully fetched project_id: %s, tier_id: %s", projectID, tierID)
}
} else {
log.Printf("[GeminiOAuth] User provided project_id: %s, fetching tier_id...", projectID)
// 用户手动填了 project_id仍需调用 LoadCodeAssist 获取 tierID
_, fetchedTierID, err := s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
if err != nil {
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch tierID: %v\n", err)
log.Printf("[GeminiOAuth] WARNING: Failed to fetch tier_id: %v", err)
} else {
tierID = fetchedTierID
log.Printf("[GeminiOAuth] Successfully fetched tier_id: %s", tierID)
}
}
if strings.TrimSpace(projectID) == "" {
log.Printf("[GeminiOAuth] ERROR: Missing project_id for Code Assist OAuth")
return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project")
}
// tierID 缺失时使用默认值
// Prefer auto-detected tier; fall back to user-selected tier.
tierID = canonicalGeminiTierIDForOAuthType(oauthType, tierID)
if tierID == "" {
tierID = "LEGACY"
if fallbackTierID != "" {
tierID = fallbackTierID
log.Printf("[GeminiOAuth] Using fallback tier_id from user/session: %s", tierID)
} else {
tierID = GeminiTierGCPStandard
log.Printf("[GeminiOAuth] Using default tier_id: %s", tierID)
}
}
log.Printf("[GeminiOAuth] Final code_assist result - project_id: %s, tier_id: %s", projectID, tierID)
case "google_one":
log.Printf("[GeminiOAuth] Processing google_one OAuth type")
log.Printf("[GeminiOAuth] Attempting to fetch Google One tier from Drive API...")
// Attempt to fetch Drive storage tier
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, tokenResp.AccessToken, proxyURL)
var storageInfo *geminicli.DriveStorageInfo
var err error
tierID, storageInfo, err = s.FetchGoogleOneTier(ctx, tokenResp.AccessToken, proxyURL)
if err != nil {
// Log warning but don't block - use fallback
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch Drive tier: %v\n", err)
tierID = TierGoogleOneUnknown
log.Printf("[GeminiOAuth] WARNING: Failed to fetch Drive tier: %v", err)
tierID = ""
} else {
log.Printf("[GeminiOAuth] Successfully fetched Drive tier: %s", tierID)
if storageInfo != nil {
log.Printf("[GeminiOAuth] Drive storage - Limit: %d bytes (%.2f TB), Usage: %d bytes (%.2f GB)",
storageInfo.Limit, float64(storageInfo.Limit)/float64(TB),
storageInfo.Usage, float64(storageInfo.Usage)/float64(GB))
}
}
tierID = canonicalGeminiTierIDForOAuthType(oauthType, tierID)
if tierID == "" || tierID == GeminiTierGoogleOneUnknown {
if fallbackTierID != "" {
tierID = fallbackTierID
log.Printf("[GeminiOAuth] Using fallback tier_id from user/session: %s", tierID)
} else {
tierID = GeminiTierGoogleOneFree
log.Printf("[GeminiOAuth] Using default tier_id: %s", tierID)
}
}
fmt.Printf("[GeminiOAuth] Google One tierID after normalization: %s\n", tierID)
// Store Drive info in extra field for caching
if storageInfo != nil {
@@ -447,12 +624,25 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
"drive_tier_updated_at": time.Now().Format(time.RFC3339),
},
}
log.Printf("[GeminiOAuth] ========== ExchangeCode END (google_one with storage info) ==========")
return tokenInfo, nil
}
}
// ai_studio 模式不设置 tierID保持为空
return &GeminiTokenInfo{
case "ai_studio":
// No automatic tier detection for AI Studio OAuth; rely on user selection.
if fallbackTierID != "" {
tierID = fallbackTierID
} else {
tierID = GeminiTierAIStudioFree
}
default:
log.Printf("[GeminiOAuth] Processing %s OAuth type (no tier detection)", oauthType)
}
log.Printf("[GeminiOAuth] ========== Account Type Detection END ==========")
result := &GeminiTokenInfo{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
TokenType: tokenResp.TokenType,
@@ -462,7 +652,10 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
ProjectID: projectID,
TierID: tierID,
OAuthType: oauthType,
}, nil
}
log.Printf("[GeminiOAuth] Final result - OAuth Type: %s, Project ID: %s, Tier ID: %s", result.OAuthType, result.ProjectID, result.TierID)
log.Printf("[GeminiOAuth] ========== ExchangeCode END ==========")
return result, nil
}
func (s *GeminiOAuthService) RefreshToken(ctx context.Context, oauthType, refreshToken, proxyURL string) (*GeminiTokenInfo, error) {
@@ -558,6 +751,17 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
err = nil
}
}
// Backward compatibility for google_one:
// - New behavior: when a custom OAuth client is configured, google_one will use it.
// - Old behavior: google_one always used the built-in Gemini CLI OAuth client.
// If an existing account was authorized with the built-in client, refreshing with the custom client
// will fail with "unauthorized_client". Retry with the built-in client (code_assist path forces it).
if err != nil && oauthType == "google_one" && strings.Contains(err.Error(), "unauthorized_client") && s.GetOAuthConfig().AIStudioOAuthEnabled {
if alt, altErr := s.RefreshToken(ctx, "code_assist", refreshToken, proxyURL); altErr == nil {
tokenInfo = alt
err = nil
}
}
if err != nil {
// Provide a more actionable error for common OAuth client mismatch issues.
if strings.Contains(err.Error(), "unauthorized_client") {
@@ -583,13 +787,14 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
case "code_assist":
// 先设置默认值或保留旧值,确保 tier_id 始终有值
if existingTierID != "" {
tokenInfo.TierID = existingTierID
} else {
tokenInfo.TierID = "LEGACY" // 默认值
tokenInfo.TierID = canonicalGeminiTierIDForOAuthType(oauthType, existingTierID)
}
if tokenInfo.TierID == "" {
tokenInfo.TierID = GeminiTierGCPStandard
}
// 尝试自动探测 project_id 和 tier_id
needDetect := strings.TrimSpace(tokenInfo.ProjectID) == "" || existingTierID == ""
needDetect := strings.TrimSpace(tokenInfo.ProjectID) == "" || tokenInfo.TierID == ""
if needDetect {
projectID, tierID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL)
if err != nil {
@@ -598,9 +803,10 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
if strings.TrimSpace(tokenInfo.ProjectID) == "" && projectID != "" {
tokenInfo.ProjectID = projectID
}
// 只有当原来没有 tier_id 且探测成功时才更新
if existingTierID == "" && tierID != "" {
tokenInfo.TierID = tierID
if tierID != "" {
if canonical := canonicalGeminiTierIDForOAuthType(oauthType, tierID); canonical != "" {
tokenInfo.TierID = canonical
}
}
}
}
@@ -609,6 +815,7 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
return nil, fmt.Errorf("failed to auto-detect project_id: empty result")
}
case "google_one":
canonicalExistingTier := canonicalGeminiTierIDForOAuthType(oauthType, existingTierID)
// Check if tier cache is stale (> 24 hours)
needsRefresh := true
if account.Extra != nil {
@@ -617,30 +824,37 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
if time.Since(updatedAt) <= 24*time.Hour {
needsRefresh = false
// Use cached tier
if existingTierID != "" {
tokenInfo.TierID = existingTierID
}
tokenInfo.TierID = canonicalExistingTier
}
}
}
}
if tokenInfo.TierID == "" {
tokenInfo.TierID = canonicalExistingTier
}
if needsRefresh {
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, tokenInfo.AccessToken, proxyURL)
if err == nil && storageInfo != nil {
tokenInfo.TierID = tierID
tokenInfo.Extra = map[string]any{
"drive_storage_limit": storageInfo.Limit,
"drive_storage_usage": storageInfo.Usage,
"drive_tier_updated_at": time.Now().Format(time.RFC3339),
if err == nil {
if canonical := canonicalGeminiTierIDForOAuthType(oauthType, tierID); canonical != "" && canonical != GeminiTierGoogleOneUnknown {
tokenInfo.TierID = canonical
}
if storageInfo != nil {
tokenInfo.Extra = map[string]any{
"drive_storage_limit": storageInfo.Limit,
"drive_storage_usage": storageInfo.Usage,
"drive_tier_updated_at": time.Now().Format(time.RFC3339),
}
}
}
}
if tokenInfo.TierID == "" || tokenInfo.TierID == GeminiTierGoogleOneUnknown {
if canonicalExistingTier != "" {
tokenInfo.TierID = canonicalExistingTier
} else {
// Fallback to cached or unknown
if existingTierID != "" {
tokenInfo.TierID = existingTierID
} else {
tokenInfo.TierID = TierGoogleOneUnknown
}
tokenInfo.TierID = GeminiTierGoogleOneFree
}
}
}
@@ -669,6 +883,9 @@ func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo)
// Validate tier_id before storing
if err := validateTierID(tokenInfo.TierID); err == nil {
creds["tier_id"] = tokenInfo.TierID
fmt.Printf("[GeminiOAuth] Storing tier_id: %s\n", tokenInfo.TierID)
} else {
fmt.Printf("[GeminiOAuth] Invalid tier_id %s: %v\n", tokenInfo.TierID, err)
}
// Silently skip invalid tier_id (don't block account creation)
}
@@ -698,7 +915,13 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
// Extract tierID from response (works whether CloudAICompanionProject is set or not)
tierID := "LEGACY"
if loadResp != nil {
tierID = extractTierIDFromAllowedTiers(loadResp.AllowedTiers)
// First try to get tier from currentTier/paidTier fields
if tier := loadResp.GetTier(); tier != "" {
tierID = tier
} else {
// Fallback to extracting from allowedTiers
tierID = extractTierIDFromAllowedTiers(loadResp.AllowedTiers)
}
}
// If LoadCodeAssist returned a project, use it

View File

@@ -1,50 +1,129 @@
package service
import "testing"
import (
"context"
"net/url"
"strings"
"testing"
func TestInferGoogleOneTier(t *testing.T) {
tests := []struct {
name string
storageBytes int64
expectedTier string
}{
{"Negative storage", -1, TierGoogleOneUnknown},
{"Zero storage", 0, TierGoogleOneUnknown},
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/geminicli"
)
// Free tier boundary (15GB)
{"Below free tier", 10 * GB, TierGoogleOneUnknown},
{"Just below free tier", StorageTierFree - 1, TierGoogleOneUnknown},
{"Free tier (15GB)", StorageTierFree, TierFree},
func TestGeminiOAuthService_GenerateAuthURL_RedirectURIStrategy(t *testing.T) {
t.Parallel()
// Basic tier boundary (100GB)
{"Between free and basic", 50 * GB, TierFree},
{"Just below basic tier", StorageTierBasic - 1, TierFree},
{"Basic tier (100GB)", StorageTierBasic, TierGoogleOneBasic},
type testCase struct {
name string
cfg *config.Config
oauthType string
projectID string
wantClientID string
wantRedirect string
wantScope string
wantProjectID string
wantErrSubstr string
}
// Standard tier boundary (200GB)
{"Between basic and standard", 150 * GB, TierGoogleOneBasic},
{"Just below standard tier", StorageTierStandard - 1, TierGoogleOneBasic},
{"Standard tier (200GB)", StorageTierStandard, TierGoogleOneStandard},
// AI Premium tier boundary (2TB)
{"Between standard and premium", 1 * TB, TierGoogleOneStandard},
{"Just below AI Premium tier", StorageTierAIPremium - 1, TierGoogleOneStandard},
{"AI Premium tier (2TB)", StorageTierAIPremium, TierAIPremium},
// Unlimited tier boundary (> 100TB)
{"Between premium and unlimited", 50 * TB, TierAIPremium},
{"At unlimited threshold (100TB)", StorageTierUnlimited, TierAIPremium},
{"Unlimited tier (100TB+)", StorageTierUnlimited + 1, TierGoogleOneUnlimited},
{"Unlimited tier (101TB+)", 101 * TB, TierGoogleOneUnlimited},
{"Very large storage", 1000 * TB, TierGoogleOneUnlimited},
tests := []testCase{
{
name: "google_one uses built-in client when not configured and redirects to upstream",
cfg: &config.Config{
Gemini: config.GeminiConfig{
OAuth: config.GeminiOAuthConfig{},
},
},
oauthType: "google_one",
wantClientID: geminicli.GeminiCLIOAuthClientID,
wantRedirect: geminicli.GeminiCLIRedirectURI,
wantScope: geminicli.DefaultCodeAssistScopes,
wantProjectID: "",
},
{
name: "google_one uses custom client when configured and redirects to localhost",
cfg: &config.Config{
Gemini: config.GeminiConfig{
OAuth: config.GeminiOAuthConfig{
ClientID: "custom-client-id",
ClientSecret: "custom-client-secret",
},
},
},
oauthType: "google_one",
wantClientID: "custom-client-id",
wantRedirect: geminicli.AIStudioOAuthRedirectURI,
wantScope: geminicli.DefaultGoogleOneScopes,
wantProjectID: "",
},
{
name: "code_assist always forces built-in client even when custom client configured",
cfg: &config.Config{
Gemini: config.GeminiConfig{
OAuth: config.GeminiOAuthConfig{
ClientID: "custom-client-id",
ClientSecret: "custom-client-secret",
},
},
},
oauthType: "code_assist",
projectID: "my-gcp-project",
wantClientID: geminicli.GeminiCLIOAuthClientID,
wantRedirect: geminicli.GeminiCLIRedirectURI,
wantScope: geminicli.DefaultCodeAssistScopes,
wantProjectID: "my-gcp-project",
},
{
name: "ai_studio requires custom client",
cfg: &config.Config{
Gemini: config.GeminiConfig{
OAuth: config.GeminiOAuthConfig{},
},
},
oauthType: "ai_studio",
wantErrSubstr: "AI Studio OAuth requires a custom OAuth Client",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
result := inferGoogleOneTier(tt.storageBytes)
if result != tt.expectedTier {
t.Errorf("inferGoogleOneTier(%d) = %s, want %s",
tt.storageBytes, result, tt.expectedTier)
t.Parallel()
svc := NewGeminiOAuthService(nil, nil, nil, tt.cfg)
got, err := svc.GenerateAuthURL(context.Background(), nil, "https://example.com/auth/callback", tt.projectID, tt.oauthType, "")
if tt.wantErrSubstr != "" {
if err == nil {
t.Fatalf("expected error containing %q, got nil", tt.wantErrSubstr)
}
if !strings.Contains(err.Error(), tt.wantErrSubstr) {
t.Fatalf("expected error containing %q, got: %v", tt.wantErrSubstr, err)
}
return
}
if err != nil {
t.Fatalf("GenerateAuthURL returned error: %v", err)
}
parsed, err := url.Parse(got.AuthURL)
if err != nil {
t.Fatalf("failed to parse auth_url: %v", err)
}
q := parsed.Query()
if gotState := q.Get("state"); gotState != got.State {
t.Fatalf("state mismatch: query=%q result=%q", gotState, got.State)
}
if gotClientID := q.Get("client_id"); gotClientID != tt.wantClientID {
t.Fatalf("client_id mismatch: got=%q want=%q", gotClientID, tt.wantClientID)
}
if gotRedirect := q.Get("redirect_uri"); gotRedirect != tt.wantRedirect {
t.Fatalf("redirect_uri mismatch: got=%q want=%q", gotRedirect, tt.wantRedirect)
}
if gotScope := q.Get("scope"); gotScope != tt.wantScope {
t.Fatalf("scope mismatch: got=%q want=%q", gotScope, tt.wantScope)
}
if gotProjectID := q.Get("project_id"); gotProjectID != tt.wantProjectID {
t.Fatalf("project_id mismatch: got=%q want=%q", gotProjectID, tt.wantProjectID)
}
})
}

View File

@@ -20,13 +20,24 @@ const (
geminiModelFlash geminiModelClass = "flash"
)
type GeminiDailyQuota struct {
ProRPD int64
FlashRPD int64
type GeminiQuota struct {
// SharedRPD is a shared requests-per-day pool across models.
// When SharedRPD > 0, callers should treat ProRPD/FlashRPD as not applicable for daily quota checks.
SharedRPD int64 `json:"shared_rpd,omitempty"`
// SharedRPM is a shared requests-per-minute pool across models.
// When SharedRPM > 0, callers should treat ProRPM/FlashRPM as not applicable for minute quota checks.
SharedRPM int64 `json:"shared_rpm,omitempty"`
// Per-model quotas (AI Studio / API key).
// A value of -1 means "unlimited" (pay-as-you-go).
ProRPD int64 `json:"pro_rpd,omitempty"`
ProRPM int64 `json:"pro_rpm,omitempty"`
FlashRPD int64 `json:"flash_rpd,omitempty"`
FlashRPM int64 `json:"flash_rpm,omitempty"`
}
type GeminiTierPolicy struct {
Quota GeminiDailyQuota
Quota GeminiQuota
Cooldown time.Duration
}
@@ -45,10 +56,27 @@ type GeminiUsageTotals struct {
const geminiQuotaCacheTTL = time.Minute
type geminiQuotaOverrides struct {
type geminiQuotaOverridesV1 struct {
Tiers map[string]config.GeminiTierQuotaConfig `json:"tiers"`
}
type geminiQuotaOverridesV2 struct {
QuotaRules map[string]geminiQuotaRuleOverride `json:"quota_rules"`
}
type geminiQuotaRuleOverride struct {
SharedRPD *int64 `json:"shared_rpd,omitempty"`
SharedRPM *int64 `json:"rpm,omitempty"`
GeminiPro *geminiModelQuotaOverride `json:"gemini_pro,omitempty"`
GeminiFlash *geminiModelQuotaOverride `json:"gemini_flash,omitempty"`
Desc *string `json:"desc,omitempty"`
}
type geminiModelQuotaOverride struct {
RPD *int64 `json:"rpd,omitempty"`
RPM *int64 `json:"rpm,omitempty"`
}
type GeminiQuotaService struct {
cfg *config.Config
settingRepo SettingRepository
@@ -82,11 +110,17 @@ func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
if s.cfg != nil {
policy.ApplyOverrides(s.cfg.Gemini.Quota.Tiers)
if strings.TrimSpace(s.cfg.Gemini.Quota.Policy) != "" {
var overrides geminiQuotaOverrides
if err := json.Unmarshal([]byte(s.cfg.Gemini.Quota.Policy), &overrides); err != nil {
log.Printf("gemini quota: parse config policy failed: %v", err)
raw := []byte(s.cfg.Gemini.Quota.Policy)
var overridesV2 geminiQuotaOverridesV2
if err := json.Unmarshal(raw, &overridesV2); err == nil && len(overridesV2.QuotaRules) > 0 {
policy.ApplyQuotaRulesOverrides(overridesV2.QuotaRules)
} else {
policy.ApplyOverrides(overrides.Tiers)
var overridesV1 geminiQuotaOverridesV1
if err := json.Unmarshal(raw, &overridesV1); err != nil {
log.Printf("gemini quota: parse config policy failed: %v", err)
} else {
policy.ApplyOverrides(overridesV1.Tiers)
}
}
}
}
@@ -96,11 +130,17 @@ func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
if err != nil && !errors.Is(err, ErrSettingNotFound) {
log.Printf("gemini quota: load setting failed: %v", err)
} else if strings.TrimSpace(value) != "" {
var overrides geminiQuotaOverrides
if err := json.Unmarshal([]byte(value), &overrides); err != nil {
log.Printf("gemini quota: parse setting failed: %v", err)
raw := []byte(value)
var overridesV2 geminiQuotaOverridesV2
if err := json.Unmarshal(raw, &overridesV2); err == nil && len(overridesV2.QuotaRules) > 0 {
policy.ApplyQuotaRulesOverrides(overridesV2.QuotaRules)
} else {
policy.ApplyOverrides(overrides.Tiers)
var overridesV1 geminiQuotaOverridesV1
if err := json.Unmarshal(raw, &overridesV1); err != nil {
log.Printf("gemini quota: parse setting failed: %v", err)
} else {
policy.ApplyOverrides(overridesV1.Tiers)
}
}
}
}
@@ -113,12 +153,20 @@ func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
return policy
}
func (s *GeminiQuotaService) QuotaForAccount(ctx context.Context, account *Account) (GeminiDailyQuota, bool) {
if account == nil || !account.IsGeminiCodeAssist() {
return GeminiDailyQuota{}, false
func (s *GeminiQuotaService) QuotaForAccount(ctx context.Context, account *Account) (GeminiQuota, bool) {
if account == nil || account.Platform != PlatformGemini {
return GeminiQuota{}, false
}
// Map (oauth_type + tier_id) to a canonical policy tier key.
// This keeps the policy table stable even if upstream tier_id strings vary.
tierKey := geminiQuotaTierKeyForAccount(account)
if tierKey == "" {
return GeminiQuota{}, false
}
policy := s.Policy(ctx)
return policy.QuotaForTier(account.GeminiTierID())
return policy.QuotaForTier(tierKey)
}
func (s *GeminiQuotaService) CooldownForTier(ctx context.Context, tierID string) time.Duration {
@@ -126,12 +174,36 @@ func (s *GeminiQuotaService) CooldownForTier(ctx context.Context, tierID string)
return policy.CooldownForTier(tierID)
}
func (s *GeminiQuotaService) CooldownForAccount(ctx context.Context, account *Account) time.Duration {
if s == nil || account == nil || account.Platform != PlatformGemini {
return 5 * time.Minute
}
tierKey := geminiQuotaTierKeyForAccount(account)
if strings.TrimSpace(tierKey) == "" {
return 5 * time.Minute
}
return s.CooldownForTier(ctx, tierKey)
}
func newGeminiQuotaPolicy() *GeminiQuotaPolicy {
return &GeminiQuotaPolicy{
tiers: map[string]GeminiTierPolicy{
"LEGACY": {Quota: GeminiDailyQuota{ProRPD: 50, FlashRPD: 1500}, Cooldown: 30 * time.Minute},
"PRO": {Quota: GeminiDailyQuota{ProRPD: 1500, FlashRPD: 4000}, Cooldown: 5 * time.Minute},
"ULTRA": {Quota: GeminiDailyQuota{ProRPD: 2000, FlashRPD: 0}, Cooldown: 5 * time.Minute},
// --- AI Studio / API Key (per-model) ---
// aistudio_free:
// - gemini_pro: 50 RPD / 2 RPM
// - gemini_flash: 1500 RPD / 15 RPM
GeminiTierAIStudioFree: {Quota: GeminiQuota{ProRPD: 50, ProRPM: 2, FlashRPD: 1500, FlashRPM: 15}, Cooldown: 30 * time.Minute},
// aistudio_paid: -1 means "unlimited/pay-as-you-go" for RPD.
GeminiTierAIStudioPaid: {Quota: GeminiQuota{ProRPD: -1, ProRPM: 1000, FlashRPD: -1, FlashRPM: 2000}, Cooldown: 5 * time.Minute},
// --- Google One (shared pool) ---
GeminiTierGoogleOneFree: {Quota: GeminiQuota{SharedRPD: 1000, SharedRPM: 60}, Cooldown: 30 * time.Minute},
GeminiTierGoogleAIPro: {Quota: GeminiQuota{SharedRPD: 1500, SharedRPM: 120}, Cooldown: 5 * time.Minute},
GeminiTierGoogleAIUltra: {Quota: GeminiQuota{SharedRPD: 2000, SharedRPM: 120}, Cooldown: 5 * time.Minute},
// --- GCP Code Assist (shared pool) ---
GeminiTierGCPStandard: {Quota: GeminiQuota{SharedRPD: 1500, SharedRPM: 120}, Cooldown: 5 * time.Minute},
GeminiTierGCPEnterprise: {Quota: GeminiQuota{SharedRPD: 2000, SharedRPM: 120}, Cooldown: 5 * time.Minute},
},
}
}
@@ -149,11 +221,22 @@ func (p *GeminiQuotaPolicy) ApplyOverrides(tiers map[string]config.GeminiTierQuo
if !ok {
policy = GeminiTierPolicy{Cooldown: 5 * time.Minute}
}
// Backward-compatible overrides:
// - If the tier uses shared quota, interpret pro_rpd as shared_rpd.
// - Otherwise apply per-model overrides.
if override.ProRPD != nil {
policy.Quota.ProRPD = clampGeminiQuotaInt64(*override.ProRPD)
if policy.Quota.SharedRPD > 0 {
policy.Quota.SharedRPD = clampGeminiQuotaInt64WithUnlimited(*override.ProRPD)
} else {
policy.Quota.ProRPD = clampGeminiQuotaInt64WithUnlimited(*override.ProRPD)
}
}
if override.FlashRPD != nil {
policy.Quota.FlashRPD = clampGeminiQuotaInt64(*override.FlashRPD)
if policy.Quota.SharedRPD > 0 {
// No separate flash RPD for shared tiers.
} else {
policy.Quota.FlashRPD = clampGeminiQuotaInt64WithUnlimited(*override.FlashRPD)
}
}
if override.CooldownMinutes != nil {
minutes := clampGeminiQuotaInt(*override.CooldownMinutes)
@@ -163,10 +246,51 @@ func (p *GeminiQuotaPolicy) ApplyOverrides(tiers map[string]config.GeminiTierQuo
}
}
func (p *GeminiQuotaPolicy) QuotaForTier(tierID string) (GeminiDailyQuota, bool) {
func (p *GeminiQuotaPolicy) ApplyQuotaRulesOverrides(rules map[string]geminiQuotaRuleOverride) {
if p == nil || len(rules) == 0 {
return
}
for rawID, override := range rules {
tierID := normalizeGeminiTierID(rawID)
if tierID == "" {
continue
}
policy, ok := p.tiers[tierID]
if !ok {
policy = GeminiTierPolicy{Cooldown: 5 * time.Minute}
}
if override.SharedRPD != nil {
policy.Quota.SharedRPD = clampGeminiQuotaInt64WithUnlimited(*override.SharedRPD)
}
if override.SharedRPM != nil {
policy.Quota.SharedRPM = clampGeminiQuotaRPM(*override.SharedRPM)
}
if override.GeminiPro != nil {
if override.GeminiPro.RPD != nil {
policy.Quota.ProRPD = clampGeminiQuotaInt64WithUnlimited(*override.GeminiPro.RPD)
}
if override.GeminiPro.RPM != nil {
policy.Quota.ProRPM = clampGeminiQuotaRPM(*override.GeminiPro.RPM)
}
}
if override.GeminiFlash != nil {
if override.GeminiFlash.RPD != nil {
policy.Quota.FlashRPD = clampGeminiQuotaInt64WithUnlimited(*override.GeminiFlash.RPD)
}
if override.GeminiFlash.RPM != nil {
policy.Quota.FlashRPM = clampGeminiQuotaRPM(*override.GeminiFlash.RPM)
}
}
p.tiers[tierID] = policy
}
}
func (p *GeminiQuotaPolicy) QuotaForTier(tierID string) (GeminiQuota, bool) {
policy, ok := p.policyForTier(tierID)
if !ok {
return GeminiDailyQuota{}, false
return GeminiQuota{}, false
}
return policy.Quota, true
}
@@ -184,22 +308,43 @@ func (p *GeminiQuotaPolicy) policyForTier(tierID string) (GeminiTierPolicy, bool
return GeminiTierPolicy{}, false
}
normalized := normalizeGeminiTierID(tierID)
if normalized == "" {
normalized = "LEGACY"
}
if policy, ok := p.tiers[normalized]; ok {
return policy, true
}
policy, ok := p.tiers["LEGACY"]
return policy, ok
return GeminiTierPolicy{}, false
}
func normalizeGeminiTierID(tierID string) string {
return strings.ToUpper(strings.TrimSpace(tierID))
tierID = strings.TrimSpace(tierID)
if tierID == "" {
return ""
}
// Prefer canonical mapping (handles legacy tier strings).
if canonical := canonicalGeminiTierID(tierID); canonical != "" {
return canonical
}
// Accept older policy keys that used uppercase names.
switch strings.ToUpper(tierID) {
case "AISTUDIO_FREE":
return GeminiTierAIStudioFree
case "AISTUDIO_PAID":
return GeminiTierAIStudioPaid
case "GOOGLE_ONE_FREE":
return GeminiTierGoogleOneFree
case "GOOGLE_AI_PRO":
return GeminiTierGoogleAIPro
case "GOOGLE_AI_ULTRA":
return GeminiTierGoogleAIUltra
case "GCP_STANDARD":
return GeminiTierGCPStandard
case "GCP_ENTERPRISE":
return GeminiTierGCPEnterprise
}
return strings.ToLower(tierID)
}
func clampGeminiQuotaInt64(value int64) int64 {
if value < 0 {
func clampGeminiQuotaInt64WithUnlimited(value int64) int64 {
if value < -1 {
return 0
}
return value
@@ -212,11 +357,46 @@ func clampGeminiQuotaInt(value int) int {
return value
}
func clampGeminiQuotaRPM(value int64) int64 {
if value < 0 {
return 0
}
return value
}
func geminiCooldownForTier(tierID string) time.Duration {
policy := newGeminiQuotaPolicy()
return policy.CooldownForTier(tierID)
}
func geminiQuotaTierKeyForAccount(account *Account) string {
if account == nil || account.Platform != PlatformGemini {
return ""
}
// Note: GeminiOAuthType() already defaults legacy (project_id present) to code_assist.
oauthType := strings.ToLower(strings.TrimSpace(account.GeminiOAuthType()))
rawTier := strings.TrimSpace(account.GeminiTierID())
// Prefer the canonical tier stored in credentials.
if tierID := canonicalGeminiTierIDForOAuthType(oauthType, rawTier); tierID != "" && tierID != GeminiTierGoogleOneUnknown {
return tierID
}
// Fallback defaults when tier_id is missing or unknown.
switch oauthType {
case "google_one":
return GeminiTierGoogleOneFree
case "code_assist":
return GeminiTierGCPStandard
case "ai_studio":
return GeminiTierAIStudioFree
default:
// API Key accounts (type=apikey) have empty oauth_type and are treated as AI Studio.
return GeminiTierAIStudioFree
}
}
func geminiModelClassFromName(model string) geminiModelClass {
name := strings.ToLower(strings.TrimSpace(model))
if strings.Contains(name, "flash") || strings.Contains(name, "lite") {

View File

@@ -487,7 +487,7 @@ func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Acco
return "", "", errors.New("access_token not found in credentials")
}
return accessToken, "oauth", nil
case AccountTypeApiKey:
case AccountTypeAPIKey:
apiKey := account.GetOpenAIApiKey()
if apiKey == "" {
return "", "", errors.New("api_key not found in credentials")
@@ -627,7 +627,7 @@ func (s *OpenAIGatewayService) buildUpstreamRequest(ctx context.Context, c *gin.
case AccountTypeOAuth:
// OAuth accounts use ChatGPT internal API
targetURL = chatgptCodexURL
case AccountTypeApiKey:
case AccountTypeAPIKey:
// API Key accounts use Platform API or custom base URL
baseURL := account.GetOpenAIBaseURL()
if baseURL != "" {
@@ -703,7 +703,13 @@ func (s *OpenAIGatewayService) handleErrorResponse(ctx context.Context, resp *ht
}
// Handle upstream error (mark account status)
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
shouldDisable := false
if s.rateLimitService != nil {
shouldDisable = s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, body)
}
if shouldDisable {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
// Return appropriate error response
var errType, errMsg string
@@ -940,7 +946,7 @@ func (s *OpenAIGatewayService) replaceModelInResponseBody(body []byte, fromModel
// OpenAIRecordUsageInput input for recording usage
type OpenAIRecordUsageInput struct {
Result *OpenAIForwardResult
ApiKey *ApiKey
APIKey *APIKey
User *User
Account *Account
Subscription *UserSubscription
@@ -949,7 +955,7 @@ type OpenAIRecordUsageInput struct {
// RecordUsage records usage and deducts balance
func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRecordUsageInput) error {
result := input.Result
apiKey := input.ApiKey
apiKey := input.APIKey
user := input.User
account := input.Account
subscription := input.Subscription
@@ -991,7 +997,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
durationMs := int(result.Duration.Milliseconds())
usageLog := &UsageLog{
UserID: user.ID,
ApiKeyID: apiKey.ID,
APIKeyID: apiKey.ID,
AccountID: account.ID,
RequestID: result.RequestID,
Model: result.Model,
@@ -1020,22 +1026,23 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
usageLog.SubscriptionID = &subscription.ID
}
_ = s.usageLogRepo.Create(ctx, usageLog)
inserted, err := s.usageLogRepo.Create(ctx, usageLog)
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
log.Printf("[SIMPLE MODE] Usage recorded (not billed): user=%d, tokens=%d", usageLog.UserID, usageLog.TotalTokens())
s.deferredService.ScheduleLastUsedUpdate(account.ID)
return nil
}
shouldBill := inserted || err != nil
// Deduct based on billing type
if isSubscriptionBilling {
if cost.TotalCost > 0 {
if shouldBill && cost.TotalCost > 0 {
_ = s.userSubRepo.IncrementUsage(ctx, subscription.ID, cost.TotalCost)
s.billingCacheService.QueueUpdateSubscriptionUsage(user.ID, *apiKey.GroupID, cost.TotalCost)
}
} else {
if cost.ActualCost > 0 {
if shouldBill && cost.ActualCost > 0 {
_ = s.userRepo.DeductBalance(ctx, user.ID, cost.ActualCost)
s.billingCacheService.QueueDeductBalance(user.ID, cost.ActualCost)
}

View File

@@ -2,6 +2,7 @@ package service
import (
"context"
"encoding/json"
"log"
"net/http"
"strconv"
@@ -18,6 +19,7 @@ type RateLimitService struct {
usageRepo UsageLogRepository
cfg *config.Config
geminiQuotaService *GeminiQuotaService
tempUnschedCache TempUnschedCache
usageCacheMu sync.RWMutex
usageCache map[int64]*geminiUsageCacheEntry
}
@@ -31,12 +33,13 @@ type geminiUsageCacheEntry struct {
const geminiPrecheckCacheTTL = time.Minute
// NewRateLimitService 创建RateLimitService实例
func NewRateLimitService(accountRepo AccountRepository, usageRepo UsageLogRepository, cfg *config.Config, geminiQuotaService *GeminiQuotaService) *RateLimitService {
func NewRateLimitService(accountRepo AccountRepository, usageRepo UsageLogRepository, cfg *config.Config, geminiQuotaService *GeminiQuotaService, tempUnschedCache TempUnschedCache) *RateLimitService {
return &RateLimitService{
accountRepo: accountRepo,
usageRepo: usageRepo,
cfg: cfg,
geminiQuotaService: geminiQuotaService,
tempUnschedCache: tempUnschedCache,
usageCache: make(map[int64]*geminiUsageCacheEntry),
}
}
@@ -51,38 +54,45 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
return false
}
tempMatched := s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
switch statusCode {
case 401:
// 认证失败:停止调度,记录错误
s.handleAuthError(ctx, account, "Authentication failed (401): invalid or expired credentials")
return true
shouldDisable = true
case 402:
// 支付要求:余额不足或计费问题,停止调度
s.handleAuthError(ctx, account, "Payment required (402): insufficient balance or billing issue")
return true
shouldDisable = true
case 403:
// 禁止访问:停止调度,记录错误
s.handleAuthError(ctx, account, "Access forbidden (403): account may be suspended or lack permissions")
return true
shouldDisable = true
case 429:
s.handle429(ctx, account, headers)
return false
shouldDisable = false
case 529:
s.handle529(ctx, account)
return false
shouldDisable = false
default:
// 其他5xx错误记录但不停止调度
if statusCode >= 500 {
log.Printf("Account %d received upstream error %d", account.ID, statusCode)
}
return false
shouldDisable = false
}
if tempMatched {
return true
}
return shouldDisable
}
// PreCheckUsage proactively checks local quota before dispatching a request.
// Returns false when the account should be skipped.
func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, requestedModel string) (bool, error) {
if account == nil || !account.IsGeminiCodeAssist() || strings.TrimSpace(requestedModel) == "" {
if account == nil || account.Platform != PlatformGemini {
return true, nil
}
if s.usageRepo == nil || s.geminiQuotaService == nil {
@@ -94,44 +104,99 @@ func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account,
return true, nil
}
var limit int64
switch geminiModelClassFromName(requestedModel) {
case geminiModelFlash:
limit = quota.FlashRPD
default:
limit = quota.ProRPD
}
if limit <= 0 {
return true, nil
}
now := time.Now()
start := geminiDailyWindowStart(now)
totals, ok := s.getGeminiUsageTotals(account.ID, start, now)
if !ok {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
if err != nil {
return true, err
modelClass := geminiModelClassFromName(requestedModel)
// 1) Daily quota precheck (RPD; resets at PST midnight)
{
var limit int64
if quota.SharedRPD > 0 {
limit = quota.SharedRPD
} else {
switch modelClass {
case geminiModelFlash:
limit = quota.FlashRPD
default:
limit = quota.ProRPD
}
}
if limit > 0 {
start := geminiDailyWindowStart(now)
totals, ok := s.getGeminiUsageTotals(account.ID, start, now)
if !ok {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
if err != nil {
return true, err
}
totals = geminiAggregateUsage(stats)
s.setGeminiUsageTotals(account.ID, start, now, totals)
}
var used int64
if quota.SharedRPD > 0 {
used = totals.ProRequests + totals.FlashRequests
} else {
switch modelClass {
case geminiModelFlash:
used = totals.FlashRequests
default:
used = totals.ProRequests
}
}
if used >= limit {
resetAt := geminiDailyResetTime(now)
// NOTE:
// - This is a local precheck to reduce upstream 429s.
// - Do NOT mark the account as rate-limited here; rate_limit_reset_at should reflect real upstream 429s.
log.Printf("[Gemini PreCheck] Account %d reached daily quota (%d/%d), skip until %v", account.ID, used, limit, resetAt)
return false, nil
}
}
totals = geminiAggregateUsage(stats)
s.setGeminiUsageTotals(account.ID, start, now, totals)
}
var used int64
switch geminiModelClassFromName(requestedModel) {
case geminiModelFlash:
used = totals.FlashRequests
default:
used = totals.ProRequests
}
if used >= limit {
resetAt := geminiDailyResetTime(now)
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
// 2) Minute quota precheck (RPM; fixed window current minute)
{
var limit int64
if quota.SharedRPM > 0 {
limit = quota.SharedRPM
} else {
switch modelClass {
case geminiModelFlash:
limit = quota.FlashRPM
default:
limit = quota.ProRPM
}
}
if limit > 0 {
start := now.Truncate(time.Minute)
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
if err != nil {
return true, err
}
totals := geminiAggregateUsage(stats)
var used int64
if quota.SharedRPM > 0 {
used = totals.ProRequests + totals.FlashRequests
} else {
switch modelClass {
case geminiModelFlash:
used = totals.FlashRequests
default:
used = totals.ProRequests
}
}
if used >= limit {
resetAt := start.Add(time.Minute)
// Do not persist "rate limited" status from local precheck. See note above.
log.Printf("[Gemini PreCheck] Account %d reached minute quota (%d/%d), skip until %v", account.ID, used, limit, resetAt)
return false, nil
}
}
log.Printf("[Gemini PreCheck] Account %d reached daily quota (%d/%d), rate limited until %v", account.ID, used, limit, resetAt)
return false, nil
}
return true, nil
@@ -176,7 +241,10 @@ func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account)
if account == nil {
return 5 * time.Minute
}
return s.geminiQuotaService.CooldownForTier(ctx, account.GeminiTierID())
if s.geminiQuotaService == nil {
return 5 * time.Minute
}
return s.geminiQuotaService.CooldownForAccount(ctx, account)
}
// handleAuthError 处理认证类错误(401/403),停止账号调度
@@ -287,3 +355,183 @@ func (s *RateLimitService) UpdateSessionWindow(ctx context.Context, account *Acc
func (s *RateLimitService) ClearRateLimit(ctx context.Context, accountID int64) error {
return s.accountRepo.ClearRateLimit(ctx, accountID)
}
func (s *RateLimitService) ClearTempUnschedulable(ctx context.Context, accountID int64) error {
if err := s.accountRepo.ClearTempUnschedulable(ctx, accountID); err != nil {
return err
}
if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.DeleteTempUnsched(ctx, accountID); err != nil {
log.Printf("DeleteTempUnsched failed for account %d: %v", accountID, err)
}
}
return nil
}
func (s *RateLimitService) GetTempUnschedStatus(ctx context.Context, accountID int64) (*TempUnschedState, error) {
now := time.Now().Unix()
if s.tempUnschedCache != nil {
state, err := s.tempUnschedCache.GetTempUnsched(ctx, accountID)
if err != nil {
return nil, err
}
if state != nil && state.UntilUnix > now {
return state, nil
}
}
account, err := s.accountRepo.GetByID(ctx, accountID)
if err != nil {
return nil, err
}
if account.TempUnschedulableUntil == nil {
return nil, nil
}
if account.TempUnschedulableUntil.Unix() <= now {
return nil, nil
}
state := &TempUnschedState{
UntilUnix: account.TempUnschedulableUntil.Unix(),
}
if account.TempUnschedulableReason != "" {
var parsed TempUnschedState
if err := json.Unmarshal([]byte(account.TempUnschedulableReason), &parsed); err == nil {
if parsed.UntilUnix == 0 {
parsed.UntilUnix = state.UntilUnix
}
state = &parsed
} else {
state.ErrorMessage = account.TempUnschedulableReason
}
}
if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.SetTempUnsched(ctx, accountID, state); err != nil {
log.Printf("SetTempUnsched failed for account %d: %v", accountID, err)
}
}
return state, nil
}
func (s *RateLimitService) HandleTempUnschedulable(ctx context.Context, account *Account, statusCode int, responseBody []byte) bool {
if account == nil {
return false
}
if !account.ShouldHandleErrorCode(statusCode) {
return false
}
return s.tryTempUnschedulable(ctx, account, statusCode, responseBody)
}
const tempUnschedBodyMaxBytes = 64 << 10
const tempUnschedMessageMaxBytes = 2048
func (s *RateLimitService) tryTempUnschedulable(ctx context.Context, account *Account, statusCode int, responseBody []byte) bool {
if account == nil {
return false
}
if !account.IsTempUnschedulableEnabled() {
return false
}
rules := account.GetTempUnschedulableRules()
if len(rules) == 0 {
return false
}
if statusCode <= 0 || len(responseBody) == 0 {
return false
}
body := responseBody
if len(body) > tempUnschedBodyMaxBytes {
body = body[:tempUnschedBodyMaxBytes]
}
bodyLower := strings.ToLower(string(body))
for idx, rule := range rules {
if rule.ErrorCode != statusCode || len(rule.Keywords) == 0 {
continue
}
matchedKeyword := matchTempUnschedKeyword(bodyLower, rule.Keywords)
if matchedKeyword == "" {
continue
}
if s.triggerTempUnschedulable(ctx, account, rule, idx, statusCode, matchedKeyword, responseBody) {
return true
}
}
return false
}
func matchTempUnschedKeyword(bodyLower string, keywords []string) string {
if bodyLower == "" {
return ""
}
for _, keyword := range keywords {
k := strings.TrimSpace(keyword)
if k == "" {
continue
}
if strings.Contains(bodyLower, strings.ToLower(k)) {
return k
}
}
return ""
}
func (s *RateLimitService) triggerTempUnschedulable(ctx context.Context, account *Account, rule TempUnschedulableRule, ruleIndex int, statusCode int, matchedKeyword string, responseBody []byte) bool {
if account == nil {
return false
}
if rule.DurationMinutes <= 0 {
return false
}
now := time.Now()
until := now.Add(time.Duration(rule.DurationMinutes) * time.Minute)
state := &TempUnschedState{
UntilUnix: until.Unix(),
TriggeredAtUnix: now.Unix(),
StatusCode: statusCode,
MatchedKeyword: matchedKeyword,
RuleIndex: ruleIndex,
ErrorMessage: truncateTempUnschedMessage(responseBody, tempUnschedMessageMaxBytes),
}
reason := ""
if raw, err := json.Marshal(state); err == nil {
reason = string(raw)
}
if reason == "" {
reason = strings.TrimSpace(state.ErrorMessage)
}
if err := s.accountRepo.SetTempUnschedulable(ctx, account.ID, until, reason); err != nil {
log.Printf("SetTempUnschedulable failed for account %d: %v", account.ID, err)
return false
}
if s.tempUnschedCache != nil {
if err := s.tempUnschedCache.SetTempUnsched(ctx, account.ID, state); err != nil {
log.Printf("SetTempUnsched cache failed for account %d: %v", account.ID, err)
}
}
log.Printf("Account %d temp unschedulable until %v (rule %d, code %d)", account.ID, until, ruleIndex, statusCode)
return true
}
func truncateTempUnschedMessage(body []byte, maxBytes int) string {
if maxBytes <= 0 || len(body) == 0 {
return ""
}
if len(body) > maxBytes {
body = body[:maxBytes]
}
return strings.TrimSpace(string(body))
}

View File

@@ -61,9 +61,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SettingKeySiteName,
SettingKeySiteLogo,
SettingKeySiteSubtitle,
SettingKeyApiBaseUrl,
SettingKeyAPIBaseURL,
SettingKeyContactInfo,
SettingKeyDocUrl,
SettingKeyDocURL,
}
settings, err := s.settingRepo.GetMultiple(ctx, keys)
@@ -79,9 +79,9 @@ func (s *SettingService) GetPublicSettings(ctx context.Context) (*PublicSettings
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
ApiBaseUrl: settings[SettingKeyApiBaseUrl],
APIBaseURL: settings[SettingKeyAPIBaseURL],
ContactInfo: settings[SettingKeyContactInfo],
DocUrl: settings[SettingKeyDocUrl],
DocURL: settings[SettingKeyDocURL],
}, nil
}
@@ -94,15 +94,15 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeyEmailVerifyEnabled] = strconv.FormatBool(settings.EmailVerifyEnabled)
// 邮件服务设置(只有非空才更新密码)
updates[SettingKeySmtpHost] = settings.SmtpHost
updates[SettingKeySmtpPort] = strconv.Itoa(settings.SmtpPort)
updates[SettingKeySmtpUsername] = settings.SmtpUsername
if settings.SmtpPassword != "" {
updates[SettingKeySmtpPassword] = settings.SmtpPassword
updates[SettingKeySMTPHost] = settings.SMTPHost
updates[SettingKeySMTPPort] = strconv.Itoa(settings.SMTPPort)
updates[SettingKeySMTPUsername] = settings.SMTPUsername
if settings.SMTPPassword != "" {
updates[SettingKeySMTPPassword] = settings.SMTPPassword
}
updates[SettingKeySmtpFrom] = settings.SmtpFrom
updates[SettingKeySmtpFromName] = settings.SmtpFromName
updates[SettingKeySmtpUseTLS] = strconv.FormatBool(settings.SmtpUseTLS)
updates[SettingKeySMTPFrom] = settings.SMTPFrom
updates[SettingKeySMTPFromName] = settings.SMTPFromName
updates[SettingKeySMTPUseTLS] = strconv.FormatBool(settings.SMTPUseTLS)
// Cloudflare Turnstile 设置(只有非空才更新密钥)
updates[SettingKeyTurnstileEnabled] = strconv.FormatBool(settings.TurnstileEnabled)
@@ -115,14 +115,21 @@ func (s *SettingService) UpdateSettings(ctx context.Context, settings *SystemSet
updates[SettingKeySiteName] = settings.SiteName
updates[SettingKeySiteLogo] = settings.SiteLogo
updates[SettingKeySiteSubtitle] = settings.SiteSubtitle
updates[SettingKeyApiBaseUrl] = settings.ApiBaseUrl
updates[SettingKeyAPIBaseURL] = settings.APIBaseURL
updates[SettingKeyContactInfo] = settings.ContactInfo
updates[SettingKeyDocUrl] = settings.DocUrl
updates[SettingKeyDocURL] = settings.DocURL
// 默认配置
updates[SettingKeyDefaultConcurrency] = strconv.Itoa(settings.DefaultConcurrency)
updates[SettingKeyDefaultBalance] = strconv.FormatFloat(settings.DefaultBalance, 'f', 8, 64)
// Model fallback configuration
updates[SettingKeyEnableModelFallback] = strconv.FormatBool(settings.EnableModelFallback)
updates[SettingKeyFallbackModelAnthropic] = settings.FallbackModelAnthropic
updates[SettingKeyFallbackModelOpenAI] = settings.FallbackModelOpenAI
updates[SettingKeyFallbackModelGemini] = settings.FallbackModelGemini
updates[SettingKeyFallbackModelAntigravity] = settings.FallbackModelAntigravity
return s.settingRepo.SetMultiple(ctx, updates)
}
@@ -198,8 +205,14 @@ func (s *SettingService) InitializeDefaultSettings(ctx context.Context) error {
SettingKeySiteLogo: "",
SettingKeyDefaultConcurrency: strconv.Itoa(s.cfg.Default.UserConcurrency),
SettingKeyDefaultBalance: strconv.FormatFloat(s.cfg.Default.UserBalance, 'f', 8, 64),
SettingKeySmtpPort: "587",
SettingKeySmtpUseTLS: "false",
SettingKeySMTPPort: "587",
SettingKeySMTPUseTLS: "false",
// Model fallback defaults
SettingKeyEnableModelFallback: "false",
SettingKeyFallbackModelAnthropic: "claude-3-5-sonnet-20241022",
SettingKeyFallbackModelOpenAI: "gpt-4o",
SettingKeyFallbackModelGemini: "gemini-2.5-pro",
SettingKeyFallbackModelAntigravity: "gemini-2.5-pro",
}
return s.settingRepo.SetMultiple(ctx, defaults)
@@ -210,26 +223,26 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
result := &SystemSettings{
RegistrationEnabled: settings[SettingKeyRegistrationEnabled] == "true",
EmailVerifyEnabled: settings[SettingKeyEmailVerifyEnabled] == "true",
SmtpHost: settings[SettingKeySmtpHost],
SmtpUsername: settings[SettingKeySmtpUsername],
SmtpFrom: settings[SettingKeySmtpFrom],
SmtpFromName: settings[SettingKeySmtpFromName],
SmtpUseTLS: settings[SettingKeySmtpUseTLS] == "true",
SMTPHost: settings[SettingKeySMTPHost],
SMTPUsername: settings[SettingKeySMTPUsername],
SMTPFrom: settings[SettingKeySMTPFrom],
SMTPFromName: settings[SettingKeySMTPFromName],
SMTPUseTLS: settings[SettingKeySMTPUseTLS] == "true",
TurnstileEnabled: settings[SettingKeyTurnstileEnabled] == "true",
TurnstileSiteKey: settings[SettingKeyTurnstileSiteKey],
SiteName: s.getStringOrDefault(settings, SettingKeySiteName, "Sub2API"),
SiteLogo: settings[SettingKeySiteLogo],
SiteSubtitle: s.getStringOrDefault(settings, SettingKeySiteSubtitle, "Subscription to API Conversion Platform"),
ApiBaseUrl: settings[SettingKeyApiBaseUrl],
APIBaseURL: settings[SettingKeyAPIBaseURL],
ContactInfo: settings[SettingKeyContactInfo],
DocUrl: settings[SettingKeyDocUrl],
DocURL: settings[SettingKeyDocURL],
}
// 解析整数类型
if port, err := strconv.Atoi(settings[SettingKeySmtpPort]); err == nil {
result.SmtpPort = port
if port, err := strconv.Atoi(settings[SettingKeySMTPPort]); err == nil {
result.SMTPPort = port
} else {
result.SmtpPort = 587
result.SMTPPort = 587
}
if concurrency, err := strconv.Atoi(settings[SettingKeyDefaultConcurrency]); err == nil {
@@ -246,9 +259,16 @@ func (s *SettingService) parseSettings(settings map[string]string) *SystemSettin
}
// 敏感信息直接返回,方便测试连接时使用
result.SmtpPassword = settings[SettingKeySmtpPassword]
result.SMTPPassword = settings[SettingKeySMTPPassword]
result.TurnstileSecretKey = settings[SettingKeyTurnstileSecretKey]
// Model fallback settings
result.EnableModelFallback = settings[SettingKeyEnableModelFallback] == "true"
result.FallbackModelAnthropic = s.getStringOrDefault(settings, SettingKeyFallbackModelAnthropic, "claude-3-5-sonnet-20241022")
result.FallbackModelOpenAI = s.getStringOrDefault(settings, SettingKeyFallbackModelOpenAI, "gpt-4o")
result.FallbackModelGemini = s.getStringOrDefault(settings, SettingKeyFallbackModelGemini, "gemini-2.5-pro")
result.FallbackModelAntigravity = s.getStringOrDefault(settings, SettingKeyFallbackModelAntigravity, "gemini-2.5-pro")
return result
}
@@ -278,28 +298,28 @@ func (s *SettingService) GetTurnstileSecretKey(ctx context.Context) string {
return value
}
// GenerateAdminApiKey 生成新的管理员 API Key
func (s *SettingService) GenerateAdminApiKey(ctx context.Context) (string, error) {
// GenerateAdminAPIKey 生成新的管理员 API Key
func (s *SettingService) GenerateAdminAPIKey(ctx context.Context) (string, error) {
// 生成 32 字节随机数 = 64 位十六进制字符
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("generate random bytes: %w", err)
}
key := AdminApiKeyPrefix + hex.EncodeToString(bytes)
key := AdminAPIKeyPrefix + hex.EncodeToString(bytes)
// 存储到 settings 表
if err := s.settingRepo.Set(ctx, SettingKeyAdminApiKey, key); err != nil {
if err := s.settingRepo.Set(ctx, SettingKeyAdminAPIKey, key); err != nil {
return "", fmt.Errorf("save admin api key: %w", err)
}
return key, nil
}
// GetAdminApiKeyStatus 获取管理员 API Key 状态
// GetAdminAPIKeyStatus 获取管理员 API Key 状态
// 返回脱敏的 key、是否存在、错误
func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) {
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminApiKey)
func (s *SettingService) GetAdminAPIKeyStatus(ctx context.Context) (maskedKey string, exists bool, err error) {
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminAPIKey)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
return "", false, nil
@@ -320,10 +340,10 @@ func (s *SettingService) GetAdminApiKeyStatus(ctx context.Context) (maskedKey st
return maskedKey, true, nil
}
// GetAdminApiKey 获取完整的管理员 API Key仅供内部验证使用
// GetAdminAPIKey 获取完整的管理员 API Key仅供内部验证使用
// 如果未配置返回空字符串和 nil 错误,只有数据库错误时才返回 error
func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminApiKey)
func (s *SettingService) GetAdminAPIKey(ctx context.Context) (string, error) {
key, err := s.settingRepo.GetValue(ctx, SettingKeyAdminAPIKey)
if err != nil {
if errors.Is(err, ErrSettingNotFound) {
return "", nil // 未配置,返回空字符串
@@ -333,7 +353,45 @@ func (s *SettingService) GetAdminApiKey(ctx context.Context) (string, error) {
return key, nil
}
// DeleteAdminApiKey 删除管理员 API Key
func (s *SettingService) DeleteAdminApiKey(ctx context.Context) error {
return s.settingRepo.Delete(ctx, SettingKeyAdminApiKey)
// DeleteAdminAPIKey 删除管理员 API Key
func (s *SettingService) DeleteAdminAPIKey(ctx context.Context) error {
return s.settingRepo.Delete(ctx, SettingKeyAdminAPIKey)
}
// IsModelFallbackEnabled 检查是否启用模型兜底机制
func (s *SettingService) IsModelFallbackEnabled(ctx context.Context) bool {
value, err := s.settingRepo.GetValue(ctx, SettingKeyEnableModelFallback)
if err != nil {
return false // Default: disabled
}
return value == "true"
}
// GetFallbackModel 获取指定平台的兜底模型
func (s *SettingService) GetFallbackModel(ctx context.Context, platform string) string {
var key string
var defaultModel string
switch platform {
case PlatformAnthropic:
key = SettingKeyFallbackModelAnthropic
defaultModel = "claude-3-5-sonnet-20241022"
case PlatformOpenAI:
key = SettingKeyFallbackModelOpenAI
defaultModel = "gpt-4o"
case PlatformGemini:
key = SettingKeyFallbackModelGemini
defaultModel = "gemini-2.5-pro"
case PlatformAntigravity:
key = SettingKeyFallbackModelAntigravity
defaultModel = "gemini-2.5-pro"
default:
return ""
}
value, err := s.settingRepo.GetValue(ctx, key)
if err != nil || value == "" {
return defaultModel
}
return value
}

View File

@@ -4,13 +4,13 @@ type SystemSettings struct {
RegistrationEnabled bool
EmailVerifyEnabled bool
SmtpHost string
SmtpPort int
SmtpUsername string
SmtpPassword string
SmtpFrom string
SmtpFromName string
SmtpUseTLS bool
SMTPHost string
SMTPPort int
SMTPUsername string
SMTPPassword string
SMTPFrom string
SMTPFromName string
SMTPUseTLS bool
TurnstileEnabled bool
TurnstileSiteKey string
@@ -19,12 +19,19 @@ type SystemSettings struct {
SiteName string
SiteLogo string
SiteSubtitle string
ApiBaseUrl string
APIBaseURL string
ContactInfo string
DocUrl string
DocURL string
DefaultConcurrency int
DefaultBalance float64
// Model fallback configuration
EnableModelFallback bool `json:"enable_model_fallback"`
FallbackModelAnthropic string `json:"fallback_model_anthropic"`
FallbackModelOpenAI string `json:"fallback_model_openai"`
FallbackModelGemini string `json:"fallback_model_gemini"`
FallbackModelAntigravity string `json:"fallback_model_antigravity"`
}
type PublicSettings struct {
@@ -35,8 +42,8 @@ type PublicSettings struct {
SiteName string
SiteLogo string
SiteSubtitle string
ApiBaseUrl string
APIBaseURL string
ContactInfo string
DocUrl string
DocURL string
Version string
}

View File

@@ -0,0 +1,22 @@
package service
import (
"context"
)
// TempUnschedState 临时不可调度状态
type TempUnschedState struct {
UntilUnix int64 `json:"until_unix"` // 解除时间Unix 时间戳)
TriggeredAtUnix int64 `json:"triggered_at_unix"` // 触发时间Unix 时间戳)
StatusCode int `json:"status_code"` // 触发的错误码
MatchedKeyword string `json:"matched_keyword"` // 匹配的关键词
RuleIndex int `json:"rule_index"` // 触发的规则索引
ErrorMessage string `json:"error_message"` // 错误消息
}
// TempUnschedCache 临时不可调度缓存接口
type TempUnschedCache interface {
SetTempUnsched(ctx context.Context, accountID int64, state *TempUnschedState) error
GetTempUnsched(ctx context.Context, accountID int64) (*TempUnschedState, error)
DeleteTempUnsched(ctx context.Context, accountID int64) error
}

View File

@@ -197,7 +197,7 @@ func TestClaudeTokenRefresher_CanRefresh(t *testing.T) {
{
name: "anthropic api-key - cannot refresh",
platform: PlatformAnthropic,
accType: AccountTypeApiKey,
accType: AccountTypeAPIKey,
want: false,
},
{

View File

@@ -79,7 +79,7 @@ type ReleaseInfo struct {
Name string `json:"name"`
Body string `json:"body"`
PublishedAt string `json:"published_at"`
HtmlURL string `json:"html_url"`
HTMLURL string `json:"html_url"`
Assets []Asset `json:"assets,omitempty"`
}
@@ -96,13 +96,13 @@ type GitHubRelease struct {
Name string `json:"name"`
Body string `json:"body"`
PublishedAt string `json:"published_at"`
HtmlUrl string `json:"html_url"`
HTMLURL string `json:"html_url"`
Assets []GitHubAsset `json:"assets"`
}
type GitHubAsset struct {
Name string `json:"name"`
BrowserDownloadUrl string `json:"browser_download_url"`
BrowserDownloadURL string `json:"browser_download_url"`
Size int64 `json:"size"`
}
@@ -285,7 +285,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er
for i, a := range release.Assets {
assets[i] = Asset{
Name: a.Name,
DownloadURL: a.BrowserDownloadUrl,
DownloadURL: a.BrowserDownloadURL,
Size: a.Size,
}
}
@@ -298,7 +298,7 @@ func (s *UpdateService) fetchLatestRelease(ctx context.Context) (*UpdateInfo, er
Name: release.Name,
Body: release.Body,
PublishedAt: release.PublishedAt,
HtmlURL: release.HtmlUrl,
HTMLURL: release.HTMLURL,
Assets: assets,
},
Cached: false,

View File

@@ -10,7 +10,7 @@ const (
type UsageLog struct {
ID int64
UserID int64
ApiKeyID int64
APIKeyID int64
AccountID int64
RequestID string
Model string
@@ -42,7 +42,7 @@ type UsageLog struct {
CreatedAt time.Time
User *User
ApiKey *ApiKey
APIKey *APIKey
Account *Account
Group *Group
Subscription *UserSubscription

View File

@@ -2,9 +2,11 @@ package service
import (
"context"
"errors"
"fmt"
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
@@ -17,7 +19,7 @@ var (
// CreateUsageLogRequest 创建使用日志请求
type CreateUsageLogRequest struct {
UserID int64 `json:"user_id"`
ApiKeyID int64 `json:"api_key_id"`
APIKeyID int64 `json:"api_key_id"`
AccountID int64 `json:"account_id"`
RequestID string `json:"request_id"`
Model string `json:"model"`
@@ -54,20 +56,34 @@ type UsageStats struct {
type UsageService struct {
usageRepo UsageLogRepository
userRepo UserRepository
entClient *dbent.Client
}
// NewUsageService 创建使用统计服务实例
func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository) *UsageService {
func NewUsageService(usageRepo UsageLogRepository, userRepo UserRepository, entClient *dbent.Client) *UsageService {
return &UsageService{
usageRepo: usageRepo,
userRepo: userRepo,
entClient: entClient,
}
}
// Create 创建使用日志
func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*UsageLog, error) {
// 使用数据库事务保证「使用日志插入」与「扣费」的原子性,避免重复扣费或漏扣风险。
tx, err := s.entClient.Tx(ctx)
if err != nil && !errors.Is(err, dbent.ErrTxStarted) {
return nil, fmt.Errorf("begin transaction: %w", err)
}
txCtx := ctx
if err == nil {
defer func() { _ = tx.Rollback() }()
txCtx = dbent.NewTxContext(ctx, tx)
}
// 验证用户存在
_, err := s.userRepo.GetByID(ctx, req.UserID)
_, err = s.userRepo.GetByID(txCtx, req.UserID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
@@ -75,7 +91,7 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
// 创建使用日志
usageLog := &UsageLog{
UserID: req.UserID,
ApiKeyID: req.ApiKeyID,
APIKeyID: req.APIKeyID,
AccountID: req.AccountID,
RequestID: req.RequestID,
Model: req.Model,
@@ -96,17 +112,24 @@ func (s *UsageService) Create(ctx context.Context, req CreateUsageLogRequest) (*
DurationMs: req.DurationMs,
}
if err := s.usageRepo.Create(ctx, usageLog); err != nil {
inserted, err := s.usageRepo.Create(txCtx, usageLog)
if err != nil {
return nil, fmt.Errorf("create usage log: %w", err)
}
// 扣除用户余额
if req.ActualCost > 0 {
if err := s.userRepo.UpdateBalance(ctx, req.UserID, -req.ActualCost); err != nil {
if inserted && req.ActualCost > 0 {
if err := s.userRepo.UpdateBalance(txCtx, req.UserID, -req.ActualCost); err != nil {
return nil, fmt.Errorf("update user balance: %w", err)
}
}
if tx != nil {
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("commit transaction: %w", err)
}
}
return usageLog, nil
}
@@ -128,9 +151,9 @@ func (s *UsageService) ListByUser(ctx context.Context, userID int64, params pagi
return logs, pagination, nil
}
// ListByApiKey 获取API Key的使用日志列表
func (s *UsageService) ListByApiKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByApiKey(ctx, apiKeyID, params)
// ListByAPIKey 获取API Key的使用日志列表
func (s *UsageService) ListByAPIKey(ctx context.Context, apiKeyID int64, params pagination.PaginationParams) ([]UsageLog, *pagination.PaginationResult, error) {
logs, pagination, err := s.usageRepo.ListByAPIKey(ctx, apiKeyID, params)
if err != nil {
return nil, nil, fmt.Errorf("list usage logs: %w", err)
}
@@ -165,9 +188,9 @@ func (s *UsageService) GetStatsByUser(ctx context.Context, userID int64, startTi
}, nil
}
// GetStatsByApiKey 获取API Key的使用统计
func (s *UsageService) GetStatsByApiKey(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*UsageStats, error) {
stats, err := s.usageRepo.GetApiKeyStatsAggregated(ctx, apiKeyID, startTime, endTime)
// GetStatsByAPIKey 获取API Key的使用统计
func (s *UsageService) GetStatsByAPIKey(ctx context.Context, apiKeyID int64, startTime, endTime time.Time) (*UsageStats, error) {
stats, err := s.usageRepo.GetAPIKeyStatsAggregated(ctx, apiKeyID, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("get api key stats: %w", err)
}
@@ -270,9 +293,9 @@ func (s *UsageService) GetUserModelStats(ctx context.Context, userID int64, star
return stats, nil
}
// GetBatchApiKeyUsageStats returns today/total actual_cost for given api keys.
func (s *UsageService) GetBatchApiKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchApiKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchApiKeyUsageStats(ctx, apiKeyIDs)
// GetBatchAPIKeyUsageStats returns today/total actual_cost for given api keys.
func (s *UsageService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs)
if err != nil {
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
}

View File

@@ -21,7 +21,7 @@ type User struct {
CreatedAt time.Time
UpdatedAt time.Time
ApiKeys []ApiKey
APIKeys []APIKey
Subscriptions []UserSubscription
}

View File

@@ -56,6 +56,10 @@ func (s *UserAttributeService) CreateDefinition(ctx context.Context, input Creat
Enabled: input.Enabled,
}
if err := validateDefinitionPattern(def); err != nil {
return nil, err
}
if err := s.defRepo.Create(ctx, def); err != nil {
return nil, fmt.Errorf("create definition: %w", err)
}
@@ -108,6 +112,10 @@ func (s *UserAttributeService) UpdateDefinition(ctx context.Context, id int64, i
def.Enabled = *input.Enabled
}
if err := validateDefinitionPattern(def); err != nil {
return nil, err
}
if err := s.defRepo.Update(ctx, def); err != nil {
return nil, fmt.Errorf("update definition: %w", err)
}
@@ -231,7 +239,10 @@ func (s *UserAttributeService) validateValue(def *UserAttributeDefinition, value
// Pattern validation
if v.Pattern != nil && *v.Pattern != "" && value != "" {
re, err := regexp.Compile(*v.Pattern)
if err == nil && !re.MatchString(value) {
if err != nil {
return validationError(def.Name + " has an invalid pattern")
}
if !re.MatchString(value) {
msg := def.Name + " format is invalid"
if v.Message != nil && *v.Message != "" {
msg = *v.Message
@@ -293,3 +304,20 @@ func isValidAttributeType(t UserAttributeType) bool {
}
return false
}
func validateDefinitionPattern(def *UserAttributeDefinition) error {
if def == nil {
return nil
}
if def.Validation.Pattern == nil {
return nil
}
pattern := strings.TrimSpace(*def.Validation.Pattern)
if pattern == "" {
return nil
}
if _, err := regexp.Compile(pattern); err != nil {
return infraerrors.BadRequest("INVALID_ATTRIBUTE_PATTERN", fmt.Sprintf("invalid pattern for %s: %v", def.Name, err))
}
return nil
}

View File

@@ -75,7 +75,7 @@ var ProviderSet = wire.NewSet(
// Core services
NewAuthService,
NewUserService,
NewApiKeyService,
NewAPIKeyService,
NewGroupService,
NewAccountService,
NewProxyService,