merge: 合并 test 分支到 test-dev,解决冲突

解决的冲突文件:
- wire_gen.go: 合并 ConcurrencyService/CRSSyncService 参数和 userAttributeHandler
- gateway_handler.go: 合并 pkg/errors 和 antigravity 导入
- gateway_service.go: 合并 validateUpstreamBaseURL 和 GetAvailableModels
- config.example.yaml: 合并 billing/turnstile 配置和额外 gateway 选项

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
yangjianbo
2026-01-03 11:36:31 +08:00
176 changed files with 27680 additions and 1952 deletions

View File

@@ -3,6 +3,7 @@ package service
import (
"encoding/json"
"strconv"
"strings"
"time"
)
@@ -78,6 +79,36 @@ func (a *Account) IsGemini() bool {
return a.Platform == PlatformGemini
}
func (a *Account) GeminiOAuthType() string {
if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth {
return ""
}
oauthType := strings.TrimSpace(a.GetCredential("oauth_type"))
if oauthType == "" && strings.TrimSpace(a.GetCredential("project_id")) != "" {
return "code_assist"
}
return oauthType
}
func (a *Account) GeminiTierID() string {
tierID := strings.TrimSpace(a.GetCredential("tier_id"))
if tierID == "" {
return ""
}
return strings.ToUpper(tierID)
}
func (a *Account) IsGeminiCodeAssist() bool {
if a.Platform != PlatformGemini || a.Type != AccountTypeOAuth {
return false
}
oauthType := a.GeminiOAuthType()
if oauthType == "" {
return strings.TrimSpace(a.GetCredential("project_id")) != ""
}
return oauthType == "code_assist"
}
func (a *Account) CanGetUsage() bool {
return a.Type == AccountTypeOAuth
}

View File

@@ -5,7 +5,7 @@ import (
"fmt"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
@@ -17,6 +17,9 @@ var (
type AccountRepository interface {
Create(ctx context.Context, account *Account) error
GetByID(ctx context.Context, id int64) (*Account, error)
// GetByIDs fetches accounts by IDs in a single query.
// It should return all accounts found (missing IDs are ignored).
GetByIDs(ctx context.Context, ids []int64) ([]*Account, error)
// ExistsByID 检查账号是否存在,仅返回布尔值,用于删除前的轻量级存在性检查
ExistsByID(ctx context.Context, id int64) (bool, error)
// GetByCRSAccountID finds an account previously synced from CRS.

View File

@@ -40,6 +40,10 @@ func (s *accountRepoStub) GetByID(ctx context.Context, id int64) (*Account, erro
panic("unexpected GetByID call")
}
func (s *accountRepoStub) GetByIDs(ctx context.Context, ids []int64) ([]*Account, error) {
panic("unexpected GetByIDs call")
}
// ExistsByID 返回预设的存在性检查结果。
// 这是 Delete 方法调用的第一个仓储方法,用于验证账号是否存在。
func (s *accountRepoStub) ExistsByID(ctx context.Context, id int64) (bool, error) {

View File

@@ -93,10 +93,12 @@ type UsageProgress struct {
// UsageInfo 账号使用量信息
type UsageInfo struct {
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
UpdatedAt *time.Time `json:"updated_at,omitempty"` // 更新时间
FiveHour *UsageProgress `json:"five_hour"` // 5小时窗口
SevenDay *UsageProgress `json:"seven_day,omitempty"` // 7天窗口
SevenDaySonnet *UsageProgress `json:"seven_day_sonnet,omitempty"` // 7天Sonnet窗口
GeminiProDaily *UsageProgress `json:"gemini_pro_daily,omitempty"` // Gemini Pro 日配额
GeminiFlashDaily *UsageProgress `json:"gemini_flash_daily,omitempty"` // Gemini Flash 日配额
}
// ClaudeUsageResponse Anthropic API返回的usage结构
@@ -122,17 +124,19 @@ type ClaudeUsageFetcher interface {
// AccountUsageService 账号使用量查询服务
type AccountUsageService struct {
accountRepo AccountRepository
usageLogRepo UsageLogRepository
usageFetcher ClaudeUsageFetcher
accountRepo AccountRepository
usageLogRepo UsageLogRepository
usageFetcher ClaudeUsageFetcher
geminiQuotaService *GeminiQuotaService
}
// NewAccountUsageService 创建AccountUsageService实例
func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLogRepository, usageFetcher ClaudeUsageFetcher) *AccountUsageService {
func NewAccountUsageService(accountRepo AccountRepository, usageLogRepo UsageLogRepository, usageFetcher ClaudeUsageFetcher, geminiQuotaService *GeminiQuotaService) *AccountUsageService {
return &AccountUsageService{
accountRepo: accountRepo,
usageLogRepo: usageLogRepo,
usageFetcher: usageFetcher,
accountRepo: accountRepo,
usageLogRepo: usageLogRepo,
usageFetcher: usageFetcher,
geminiQuotaService: geminiQuotaService,
}
}
@@ -146,6 +150,10 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return nil, fmt.Errorf("get account failed: %w", err)
}
if account.Platform == PlatformGemini {
return s.getGeminiUsage(ctx, account)
}
// 只有oauth类型账号可以通过API获取usage有profile scope
if account.CanGetUsage() {
var apiResp *ClaudeUsageResponse
@@ -192,6 +200,36 @@ func (s *AccountUsageService) GetUsage(ctx context.Context, accountID int64) (*U
return nil, fmt.Errorf("account type %s does not support usage query", account.Type)
}
func (s *AccountUsageService) getGeminiUsage(ctx context.Context, account *Account) (*UsageInfo, error) {
now := time.Now()
usage := &UsageInfo{
UpdatedAt: &now,
}
if s.geminiQuotaService == nil || s.usageLogRepo == nil {
return usage, nil
}
quota, ok := s.geminiQuotaService.QuotaForAccount(ctx, account)
if !ok {
return usage, nil
}
start := geminiDailyWindowStart(now)
stats, err := s.usageLogRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
if err != nil {
return nil, fmt.Errorf("get gemini usage stats failed: %w", err)
}
totals := geminiAggregateUsage(stats)
resetAt := geminiDailyResetTime(now)
usage.GeminiProDaily = buildGeminiUsageProgress(totals.ProRequests, quota.ProRPD, resetAt, totals.ProTokens, totals.ProCost, now)
usage.GeminiFlashDaily = buildGeminiUsageProgress(totals.FlashRequests, quota.FlashRPD, resetAt, totals.FlashTokens, totals.FlashCost, now)
return usage, nil
}
// addWindowStats 为 usage 数据添加窗口期统计
// 使用独立缓存1 分钟),与 API 缓存分离
func (s *AccountUsageService) addWindowStats(ctx context.Context, account *Account, usage *UsageInfo) {
@@ -388,3 +426,25 @@ func (s *AccountUsageService) estimateSetupTokenUsage(account *Account) *UsageIn
// Setup Token无法获取7d数据
return info
}
func buildGeminiUsageProgress(used, limit int64, resetAt time.Time, tokens int64, cost float64, now time.Time) *UsageProgress {
if limit <= 0 {
return nil
}
utilization := (float64(used) / float64(limit)) * 100
remainingSeconds := int(resetAt.Sub(now).Seconds())
if remainingSeconds < 0 {
remainingSeconds = 0
}
resetCopy := resetAt
return &UsageProgress{
Utilization: utilization,
ResetsAt: &resetCopy,
RemainingSeconds: remainingSeconds,
WindowStats: &WindowStats{
Requests: used,
Tokens: tokens,
Cost: cost,
},
}
}

View File

@@ -13,7 +13,7 @@ import (
// AdminService interface defines admin management operations
type AdminService interface {
// User management
ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]User, int64, error)
ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error)
GetUser(ctx context.Context, id int64) (*User, error)
CreateUser(ctx context.Context, input *CreateUserInput) (*User, error)
UpdateUser(ctx context.Context, id int64, input *UpdateUserInput) (*User, error)
@@ -35,6 +35,7 @@ type AdminService interface {
// Account management
ListAccounts(ctx context.Context, page, pageSize int, platform, accountType, status, search string) ([]Account, int64, error)
GetAccount(ctx context.Context, id int64) (*Account, error)
GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error)
CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error)
UpdateAccount(ctx context.Context, id int64, input *UpdateAccountInput) (*Account, error)
DeleteAccount(ctx context.Context, id int64) error
@@ -69,7 +70,6 @@ type CreateUserInput struct {
Email string
Password string
Username string
Wechat string
Notes string
Balance float64
Concurrency int
@@ -80,7 +80,6 @@ type UpdateUserInput struct {
Email string
Password string
Username *string
Wechat *string
Notes *string
Balance *float64 // 使用指针区分"未提供"和"设置为0"
Concurrency *int // 使用指针区分"未提供"和"设置为0"
@@ -251,9 +250,9 @@ func NewAdminService(
}
// User management implementations
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, status, role, search string) ([]User, int64, error) {
func (s *adminServiceImpl) ListUsers(ctx context.Context, page, pageSize int, filters UserListFilters) ([]User, int64, error) {
params := pagination.PaginationParams{Page: page, PageSize: pageSize}
users, result, err := s.userRepo.ListWithFilters(ctx, params, status, role, search)
users, result, err := s.userRepo.ListWithFilters(ctx, params, filters)
if err != nil {
return nil, 0, err
}
@@ -268,7 +267,6 @@ func (s *adminServiceImpl) CreateUser(ctx context.Context, input *CreateUserInpu
user := &User{
Email: input.Email,
Username: input.Username,
Wechat: input.Wechat,
Notes: input.Notes,
Role: RoleUser, // Always create as regular user, never admin
Balance: input.Balance,
@@ -310,9 +308,6 @@ func (s *adminServiceImpl) UpdateUser(ctx context.Context, id int64, input *Upda
if input.Username != nil {
user.Username = *input.Username
}
if input.Wechat != nil {
user.Wechat = *input.Wechat
}
if input.Notes != nil {
user.Notes = *input.Notes
}
@@ -488,6 +483,11 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
subscriptionType = SubscriptionTypeStandard
}
// 限额字段0 和 nil 都表示"无限制"
dailyLimit := normalizeLimit(input.DailyLimitUSD)
weeklyLimit := normalizeLimit(input.WeeklyLimitUSD)
monthlyLimit := normalizeLimit(input.MonthlyLimitUSD)
group := &Group{
Name: input.Name,
Description: input.Description,
@@ -496,9 +496,9 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
IsExclusive: input.IsExclusive,
Status: StatusActive,
SubscriptionType: subscriptionType,
DailyLimitUSD: input.DailyLimitUSD,
WeeklyLimitUSD: input.WeeklyLimitUSD,
MonthlyLimitUSD: input.MonthlyLimitUSD,
DailyLimitUSD: dailyLimit,
WeeklyLimitUSD: weeklyLimit,
MonthlyLimitUSD: monthlyLimit,
}
if err := s.groupRepo.Create(ctx, group); err != nil {
return nil, err
@@ -506,6 +506,14 @@ func (s *adminServiceImpl) CreateGroup(ctx context.Context, input *CreateGroupIn
return group, nil
}
// normalizeLimit 将 0 或负数转换为 nil表示无限制
func normalizeLimit(limit *float64) *float64 {
if limit == nil || *limit <= 0 {
return nil
}
return limit
}
func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *UpdateGroupInput) (*Group, error) {
group, err := s.groupRepo.GetByID(ctx, id)
if err != nil {
@@ -535,15 +543,15 @@ func (s *adminServiceImpl) UpdateGroup(ctx context.Context, id int64, input *Upd
if input.SubscriptionType != "" {
group.SubscriptionType = input.SubscriptionType
}
// 限额字段支持设置为nil清除限额或具体值
// 限额字段0 和 nil 都表示"无限制",正数表示具体限额
if input.DailyLimitUSD != nil {
group.DailyLimitUSD = input.DailyLimitUSD
group.DailyLimitUSD = normalizeLimit(input.DailyLimitUSD)
}
if input.WeeklyLimitUSD != nil {
group.WeeklyLimitUSD = input.WeeklyLimitUSD
group.WeeklyLimitUSD = normalizeLimit(input.WeeklyLimitUSD)
}
if input.MonthlyLimitUSD != nil {
group.MonthlyLimitUSD = input.MonthlyLimitUSD
group.MonthlyLimitUSD = normalizeLimit(input.MonthlyLimitUSD)
}
if err := s.groupRepo.Update(ctx, group); err != nil {
@@ -598,6 +606,19 @@ func (s *adminServiceImpl) GetAccount(ctx context.Context, id int64) (*Account,
return s.accountRepo.GetByID(ctx, id)
}
func (s *adminServiceImpl) GetAccountsByIDs(ctx context.Context, ids []int64) ([]*Account, error) {
if len(ids) == 0 {
return []*Account{}, nil
}
accounts, err := s.accountRepo.GetByIDs(ctx, ids)
if err != nil {
return nil, fmt.Errorf("failed to get accounts by IDs: %w", err)
}
return accounts, nil
}
func (s *adminServiceImpl) CreateAccount(ctx context.Context, input *CreateAccountInput) (*Account, error) {
account := &Account{
Name: input.Name,

View File

@@ -18,7 +18,6 @@ func TestAdminService_CreateUser_Success(t *testing.T) {
Email: "user@test.com",
Password: "strong-pass",
Username: "tester",
Wechat: "wx",
Notes: "note",
Balance: 12.5,
Concurrency: 7,
@@ -31,7 +30,6 @@ func TestAdminService_CreateUser_Success(t *testing.T) {
require.Equal(t, int64(10), user.ID)
require.Equal(t, input.Email, user.Email)
require.Equal(t, input.Username, user.Username)
require.Equal(t, input.Wechat, user.Wechat)
require.Equal(t, input.Notes, user.Notes)
require.Equal(t, input.Balance, user.Balance)
require.Equal(t, input.Concurrency, user.Concurrency)

View File

@@ -66,7 +66,7 @@ func (s *userRepoStub) List(ctx context.Context, params pagination.PaginationPar
panic("unexpected List call")
}
func (s *userRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]User, *pagination.PaginationResult, error) {
func (s *userRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}

View File

@@ -25,7 +25,7 @@ const (
antigravityRetryMaxDelay = 16 * time.Second
)
// Antigravity 直接支持的模型
// Antigravity 直接支持的模型(精确匹配透传)
var antigravitySupportedModels = map[string]bool{
"claude-opus-4-5-thinking": true,
"claude-sonnet-4-5": true,
@@ -36,23 +36,26 @@ var antigravitySupportedModels = map[string]bool{
"gemini-3-flash": true,
"gemini-3-pro-low": true,
"gemini-3-pro-high": true,
"gemini-3-pro-preview": true,
"gemini-3-pro-image": true,
}
// Antigravity 系统默认模型映射表(不支持 → 支持
var antigravityModelMapping = map[string]string{
"claude-3-5-sonnet-20241022": "claude-sonnet-4-5",
"claude-3-5-sonnet-20240620": "claude-sonnet-4-5",
"claude-sonnet-4-5-20250929": "claude-sonnet-4-5-thinking",
"claude-opus-4": "claude-opus-4-5-thinking",
"claude-opus-4-5-20251101": "claude-opus-4-5-thinking",
"claude-haiku-4": "gemini-3-flash",
"claude-haiku-4-5": "gemini-3-flash",
"claude-3-haiku-20240307": "gemini-3-flash",
"claude-haiku-4-5-20251001": "gemini-3-flash",
// 生图模型:官方名 → Antigravity 内部名
"gemini-3-pro-image-preview": "gemini-3-pro-image",
// Antigravity 前缀映射表(按前缀长度降序排列,确保最长匹配优先
// 用于处理模型版本号变化(如 -20251111, -thinking, -preview 等后缀)
var antigravityPrefixMapping = []struct {
prefix string
target string
}{
// 长前缀优先
{"gemini-3-pro-image", "gemini-3-pro-image"}, // gemini-3-pro-image-preview 等
{"claude-3-5-sonnet", "claude-sonnet-4-5"}, // 旧版 claude-3-5-sonnet-xxx
{"claude-sonnet-4-5", "claude-sonnet-4-5"}, // claude-sonnet-4-5-xxx
{"claude-haiku-4-5", "claude-sonnet-4-5"}, // claude-haiku-4-5-xxx → sonnet
{"claude-opus-4-5", "claude-opus-4-5-thinking"},
{"claude-3-haiku", "claude-sonnet-4-5"}, // 旧版 claude-3-haiku-xxx → sonnet
{"claude-sonnet-4", "claude-sonnet-4-5"},
{"claude-haiku-4", "claude-sonnet-4-5"}, // → sonnet
{"claude-opus-4", "claude-opus-4-5-thinking"},
{"gemini-3-pro", "gemini-3-pro-high"}, // gemini-3-pro, gemini-3-pro-preview 等
}
// AntigravityGatewayService 处理 Antigravity 平台的 API 转发
@@ -84,24 +87,27 @@ func (s *AntigravityGatewayService) GetTokenProvider() *AntigravityTokenProvider
}
// getMappedModel 获取映射后的模型名
// 逻辑:账户映射 → 直接支持透传 → 前缀映射 → gemini透传 → 默认值
func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedModel string) string {
// 1. 优先使用账户级映射(复用现有方法
// 1. 账户级映射(用户自定义优先
if mapped := account.GetMappedModel(requestedModel); mapped != requestedModel {
return mapped
}
// 2. 系统默认映射
if mapped, ok := antigravityModelMapping[requestedModel]; ok {
return mapped
}
// 3. Gemini 模型透传
if strings.HasPrefix(requestedModel, "gemini-") {
// 2. 直接支持的模型透传
if antigravitySupportedModels[requestedModel] {
return requestedModel
}
// 4. Claude 前缀透传直接支持的模型
if antigravitySupportedModels[requestedModel] {
// 3. 前缀映射(处理版本号变化,如 -20251111, -thinking, -preview
for _, pm := range antigravityPrefixMapping {
if strings.HasPrefix(requestedModel, pm.prefix) {
return pm.target
}
}
// 4. Gemini 模型透传(未匹配到前缀的 gemini 模型)
if strings.HasPrefix(requestedModel, "gemini-") {
return requestedModel
}
@@ -110,24 +116,10 @@ func (s *AntigravityGatewayService) getMappedModel(account *Account, requestedMo
}
// IsModelSupported 检查模型是否被支持
// 所有 claude- 和 gemini- 前缀的模型都能通过映射或透传支持
func (s *AntigravityGatewayService) IsModelSupported(requestedModel string) bool {
// 直接支持的模型
if antigravitySupportedModels[requestedModel] {
return true
}
// 可映射的模型
if _, ok := antigravityModelMapping[requestedModel]; ok {
return true
}
// Gemini 前缀透传
if strings.HasPrefix(requestedModel, "gemini-") {
return true
}
// Claude 模型支持(通过默认映射)
if strings.HasPrefix(requestedModel, "claude-") {
return true
}
return false
return strings.HasPrefix(requestedModel, "claude-") ||
strings.HasPrefix(requestedModel, "gemini-")
}
// TestConnectionResult 测试连接结果
@@ -358,6 +350,15 @@ func (s *AntigravityGatewayService) Forward(ctx context.Context, c *gin.Context,
return nil, fmt.Errorf("transform request: %w", err)
}
// 调试:记录转换后的请求体(仅记录前 2000 字符)
if bodyJSON, err := json.Marshal(geminiBody); err == nil {
truncated := string(bodyJSON)
if len(truncated) > 2000 {
truncated = truncated[:2000] + "..."
}
log.Printf("[Debug] Transformed Gemini request: %s", truncated)
}
// 构建上游 action
action := "generateContent"
if claudeReq.Stream {
@@ -466,8 +467,19 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
}
switch action {
case "generateContent", "streamGenerateContent", "countTokens":
case "generateContent", "streamGenerateContent":
// ok
case "countTokens":
// 直接返回空值,不透传上游
c.JSON(http.StatusOK, map[string]any{"totalTokens": 0})
return &ForwardResult{
RequestID: "",
Usage: ClaudeUsage{},
Model: originalModel,
Stream: false,
Duration: time.Since(time.Now()),
FirstTokenMs: nil,
}, nil
default:
return nil, s.writeGoogleError(c, http.StatusNotFound, "Unsupported action: "+action)
}
@@ -522,18 +534,6 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
sleepAntigravityBackoff(attempt)
continue
}
if action == "countTokens" {
estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
return &ForwardResult{
RequestID: "",
Usage: ClaudeUsage{},
Model: originalModel,
Stream: false,
Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil
}
return nil, s.writeGoogleError(c, http.StatusBadGateway, "Upstream request failed after retries")
}
@@ -550,18 +550,6 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
if resp.StatusCode == 429 {
s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
}
if action == "countTokens" {
estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
return &ForwardResult{
RequestID: "",
Usage: ClaudeUsage{},
Model: originalModel,
Stream: false,
Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil
}
resp = &http.Response{
StatusCode: resp.StatusCode,
Header: resp.Header.Clone(),
@@ -584,19 +572,6 @@ func (s *AntigravityGatewayService) ForwardGemini(ctx context.Context, c *gin.Co
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 2<<20))
s.handleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
if action == "countTokens" {
estimated := estimateGeminiCountTokens(body)
c.JSON(http.StatusOK, map[string]any{"totalTokens": estimated})
return &ForwardResult{
RequestID: requestID,
Usage: ClaudeUsage{},
Model: originalModel,
Stream: false,
Duration: time.Since(startTime),
FirstTokenMs: nil,
}, nil
}
if s.shouldFailoverUpstreamError(resp.StatusCode) {
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}

View File

@@ -104,34 +104,34 @@ func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
expected: "claude-opus-4-5-thinking",
},
{
name: "系统映射 - claude-haiku-4 → gemini-3-flash",
name: "系统映射 - claude-haiku-4 → claude-sonnet-4-5",
requestedModel: "claude-haiku-4",
accountMapping: nil,
expected: "gemini-3-flash",
expected: "claude-sonnet-4-5",
},
{
name: "系统映射 - claude-haiku-4-5 → gemini-3-flash",
name: "系统映射 - claude-haiku-4-5 → claude-sonnet-4-5",
requestedModel: "claude-haiku-4-5",
accountMapping: nil,
expected: "gemini-3-flash",
expected: "claude-sonnet-4-5",
},
{
name: "系统映射 - claude-3-haiku-20240307 → gemini-3-flash",
name: "系统映射 - claude-3-haiku-20240307 → claude-sonnet-4-5",
requestedModel: "claude-3-haiku-20240307",
accountMapping: nil,
expected: "gemini-3-flash",
expected: "claude-sonnet-4-5",
},
{
name: "系统映射 - claude-haiku-4-5-20251001 → gemini-3-flash",
name: "系统映射 - claude-haiku-4-5-20251001 → claude-sonnet-4-5",
requestedModel: "claude-haiku-4-5-20251001",
accountMapping: nil,
expected: "gemini-3-flash",
expected: "claude-sonnet-4-5",
},
{
name: "系统映射 - claude-sonnet-4-5-20250929",
requestedModel: "claude-sonnet-4-5-20250929",
accountMapping: nil,
expected: "claude-sonnet-4-5-thinking",
expected: "claude-sonnet-4-5",
},
// 3. Gemini 透传

View File

@@ -2,6 +2,7 @@ package service
import (
"context"
"fmt"
"time"
)
@@ -28,7 +29,7 @@ func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool {
}
// NeedsRefresh 检查账户是否需要刷新
// Antigravity 使用固定的10分钟刷新窗口,忽略全局配置
// Antigravity 使用固定的15分钟刷新窗口,忽略全局配置
func (r *AntigravityTokenRefresher) NeedsRefresh(account *Account, _ time.Duration) bool {
if !r.CanRefresh(account) {
return false
@@ -37,7 +38,13 @@ func (r *AntigravityTokenRefresher) NeedsRefresh(account *Account, _ time.Durati
if expiresAt == nil {
return false
}
return time.Until(*expiresAt) < antigravityRefreshWindow
timeUntilExpiry := time.Until(*expiresAt)
needsRefresh := timeUntilExpiry < antigravityRefreshWindow
if needsRefresh {
fmt.Printf("[AntigravityTokenRefresher] Account %d needs refresh: expires_at=%s, time_until_expiry=%v, window=%v\n",
account.ID, expiresAt.Format("2006-01-02 15:04:05"), timeUntilExpiry, antigravityRefreshWindow)
}
return needsRefresh
}
// Refresh 执行 token 刷新

View File

@@ -8,7 +8,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
)

View File

@@ -8,7 +8,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/golang-jwt/jwt/v5"
"golang.org/x/crypto/bcrypt"

View File

@@ -9,7 +9,7 @@ import (
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// 错误定义

View File

@@ -18,6 +18,11 @@ type ConcurrencyCache interface {
ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
// 账号等待队列(账号级)
IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error)
DecrementAccountWaitCount(ctx context.Context, accountID int64) error
GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error)
// 用户槽位管理
// 键格式: concurrency:user:{userID}(有序集合,成员为 requestID
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error)
@@ -27,6 +32,12 @@ type ConcurrencyCache interface {
// 等待队列计数(只在首次创建时设置 TTL
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
DecrementWaitCount(ctx context.Context, userID int64) error
// 批量负载查询(只读)
GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error)
// 清理过期槽位(后台任务)
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
}
// generateRequestID generates a unique request ID for concurrency slot tracking
@@ -61,6 +72,18 @@ type AcquireResult struct {
ReleaseFunc func() // Must be called when done (typically via defer)
}
type AccountWithConcurrency struct {
ID int64
MaxConcurrency int
}
type AccountLoadInfo struct {
AccountID int64
CurrentConcurrency int
WaitingCount int
LoadRate int // 0-100+ (percent)
}
// AcquireAccountSlot attempts to acquire a concurrency slot for an account.
// If the account is at max concurrency, it waits until a slot is available or timeout.
// Returns a release function that MUST be called when the request completes.
@@ -177,6 +200,42 @@ func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int6
}
}
// IncrementAccountWaitCount increments the wait queue counter for an account.
func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
if s.cache == nil {
return true, nil
}
result, err := s.cache.IncrementAccountWaitCount(ctx, accountID, maxWait)
if err != nil {
log.Printf("Warning: increment wait count failed for account %d: %v", accountID, err)
return true, nil
}
return result, nil
}
// DecrementAccountWaitCount decrements the wait queue counter for an account.
func (s *ConcurrencyService) DecrementAccountWaitCount(ctx context.Context, accountID int64) {
if s.cache == nil {
return
}
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.cache.DecrementAccountWaitCount(bgCtx, accountID); err != nil {
log.Printf("Warning: decrement wait count failed for account %d: %v", accountID, err)
}
}
// GetAccountWaitingCount gets current wait queue count for an account.
func (s *ConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
if s.cache == nil {
return 0, nil
}
return s.cache.GetAccountWaitingCount(ctx, accountID)
}
// CalculateMaxWait calculates the maximum wait queue size for a user
// maxWait = userConcurrency + defaultExtraWaitSlots
func CalculateMaxWait(userConcurrency int) int {
@@ -186,6 +245,57 @@ func CalculateMaxWait(userConcurrency int) int {
return userConcurrency + defaultExtraWaitSlots
}
// GetAccountsLoadBatch returns load info for multiple accounts.
func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
if s.cache == nil {
return map[int64]*AccountLoadInfo{}, nil
}
return s.cache.GetAccountsLoadBatch(ctx, accounts)
}
// CleanupExpiredAccountSlots removes expired slots for one account (background task).
func (s *ConcurrencyService) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
if s.cache == nil {
return nil
}
return s.cache.CleanupExpiredAccountSlots(ctx, accountID)
}
// StartSlotCleanupWorker starts a background cleanup worker for expired account slots.
func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepository, interval time.Duration) {
if s == nil || s.cache == nil || accountRepo == nil || interval <= 0 {
return
}
runCleanup := func() {
listCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
accounts, err := accountRepo.ListSchedulable(listCtx)
cancel()
if err != nil {
log.Printf("Warning: list schedulable accounts failed: %v", err)
return
}
for _, account := range accounts {
accountCtx, accountCancel := context.WithTimeout(context.Background(), 2*time.Second)
err := s.cache.CleanupExpiredAccountSlots(accountCtx, account.ID)
accountCancel()
if err != nil {
log.Printf("Warning: cleanup expired slots failed for account %d: %v", account.ID, err)
}
}
}
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
runCleanup()
for range ticker.C {
runCleanup()
}
}()
}
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
// Returns a map of accountID -> current concurrency count
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {

View File

@@ -91,6 +91,9 @@ const (
// 管理员 API Key
SettingKeyAdminApiKey = "admin_api_key" // 全局管理员 API Key用于外部系统集成
// Gemini 配额策略JSON
SettingKeyGeminiQuotaPolicy = "gemini_quota_policy"
)
// Admin API Key prefix (distinct from user "sk-" keys)

View File

@@ -10,7 +10,7 @@ import (
"strconv"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
var (

View File

@@ -32,6 +32,16 @@ func (m *mockAccountRepoForPlatform) GetByID(ctx context.Context, id int64) (*Ac
return nil, errors.New("account not found")
}
func (m *mockAccountRepoForPlatform) GetByIDs(ctx context.Context, ids []int64) ([]*Account, error) {
var result []*Account
for _, id := range ids {
if acc, ok := m.accountsByID[id]; ok {
result = append(result, acc)
}
}
return result, nil
}
func (m *mockAccountRepoForPlatform) ExistsByID(ctx context.Context, id int64) (bool, error) {
if m.accountsByID == nil {
return false, nil
@@ -261,6 +271,34 @@ func TestGatewayService_SelectAccountForModelWithPlatform_PriorityAndLastUsed(t
require.Equal(t, int64(2), acc.ID, "同优先级应选择最久未用的账户")
}
func TestGatewayService_SelectAccountForModelWithPlatform_GeminiOAuthPreference(t *testing.T) {
ctx := context.Background()
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountForModelWithPlatform(ctx, nil, "", "gemini-2.5-pro", nil, PlatformGemini)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "同优先级且未使用时应优先选择OAuth账户")
}
// TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts 测试无可用账户
func TestGatewayService_SelectAccountForModelWithPlatform_NoAvailableAccounts(t *testing.T) {
ctx := context.Background()
@@ -576,6 +614,32 @@ func TestGatewayService_isModelSupportedByAccount(t *testing.T) {
func TestGatewayService_selectAccountWithMixedScheduling(t *testing.T) {
ctx := context.Background()
t.Run("混合调度-Gemini优先选择OAuth账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeApiKey},
{ID: 2, Platform: PlatformGemini, Priority: 1, Status: StatusActive, Schedulable: true, Type: AccountTypeOAuth},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: testConfig(),
}
acc, err := svc.selectAccountWithMixedScheduling(ctx, nil, "", "gemini-2.5-pro", nil, PlatformGemini)
require.NoError(t, err)
require.NotNil(t, acc)
require.Equal(t, int64(2), acc.ID, "同优先级且未使用时应优先选择OAuth账户")
})
t.Run("混合调度-包含启用mixed_scheduling的antigravity账户", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
@@ -783,3 +847,160 @@ func TestAccount_IsMixedSchedulingEnabled(t *testing.T) {
})
}
}
// mockConcurrencyService for testing
type mockConcurrencyService struct {
accountLoads map[int64]*AccountLoadInfo
accountWaitCounts map[int64]int
acquireResults map[int64]bool
}
func (m *mockConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
if m.accountLoads == nil {
return map[int64]*AccountLoadInfo{}, nil
}
result := make(map[int64]*AccountLoadInfo)
for _, acc := range accounts {
if load, ok := m.accountLoads[acc.ID]; ok {
result[acc.ID] = load
} else {
result[acc.ID] = &AccountLoadInfo{
AccountID: acc.ID,
CurrentConcurrency: 0,
WaitingCount: 0,
LoadRate: 0,
}
}
}
return result, nil
}
func (m *mockConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
if m.accountWaitCounts == nil {
return 0, nil
}
return m.accountWaitCounts[accountID], nil
}
// TestGatewayService_SelectAccountWithLoadAwareness tests load-aware account selection
func TestGatewayService_SelectAccountWithLoadAwareness(t *testing.T) {
ctx := context.Background()
t.Run("禁用负载批量查询-降级到传统选择", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil, // No concurrency service
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(1), result.Account.ID, "应选择优先级最高的账号")
})
t.Run("无ConcurrencyService-降级到传统选择", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = true
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "应选择优先级最高的账号")
})
t.Run("排除账号-不选择被排除的账号", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{
{ID: 1, Platform: PlatformAnthropic, Priority: 1, Status: StatusActive, Schedulable: true, Concurrency: 5},
{ID: 2, Platform: PlatformAnthropic, Priority: 2, Status: StatusActive, Schedulable: true, Concurrency: 5},
},
accountsByID: map[int64]*Account{},
}
for i := range repo.accounts {
repo.accountsByID[repo.accounts[i].ID] = &repo.accounts[i]
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
excludedIDs := map[int64]struct{}{1: {}}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", excludedIDs)
require.NoError(t, err)
require.NotNil(t, result)
require.NotNil(t, result.Account)
require.Equal(t, int64(2), result.Account.ID, "不应选择被排除的账号")
})
t.Run("无可用账号-返回错误", func(t *testing.T) {
repo := &mockAccountRepoForPlatform{
accounts: []Account{},
accountsByID: map[int64]*Account{},
}
cache := &mockGatewayCacheForPlatform{}
cfg := testConfig()
cfg.Gateway.Scheduling.LoadBatchEnabled = false
svc := &GatewayService{
accountRepo: repo,
cache: cache,
cfg: cfg,
concurrencyService: nil,
}
result, err := svc.SelectAccountWithLoadAwareness(ctx, nil, "", "claude-3-5-sonnet-20241022", nil)
require.Error(t, err)
require.Nil(t, result)
require.Contains(t, err.Error(), "no available accounts")
})
}

View File

@@ -13,6 +13,7 @@ import (
"log"
"net/http"
"regexp"
"sort"
"strings"
"time"
@@ -21,6 +22,7 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/Wei-Shaw/sub2api/internal/util/responseheaders"
"github.com/Wei-Shaw/sub2api/internal/util/urlvalidator"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/gin-gonic/gin"
@@ -68,6 +70,20 @@ type GatewayCache interface {
RefreshSessionTTL(ctx context.Context, sessionHash string, ttl time.Duration) error
}
type AccountWaitPlan struct {
AccountID int64
MaxConcurrency int
Timeout time.Duration
MaxWaiting int
}
type AccountSelectionResult struct {
Account *Account
Acquired bool
ReleaseFunc func()
WaitPlan *AccountWaitPlan // nil means no wait allowed
}
// ClaudeUsage 表示Claude API返回的usage信息
type ClaudeUsage struct {
InputTokens int `json:"input_tokens"`
@@ -110,6 +126,7 @@ type GatewayService struct {
identityService *IdentityService
httpUpstream HTTPUpstream
deferredService *DeferredService
concurrencyService *ConcurrencyService
}
// NewGatewayService creates a new GatewayService
@@ -121,6 +138,7 @@ func NewGatewayService(
userSubRepo UserSubscriptionRepository,
cache GatewayCache,
cfg *config.Config,
concurrencyService *ConcurrencyService,
billingService *BillingService,
rateLimitService *RateLimitService,
billingCacheService *BillingCacheService,
@@ -136,6 +154,7 @@ func NewGatewayService(
userSubRepo: userSubRepo,
cache: cache,
cfg: cfg,
concurrencyService: concurrencyService,
billingService: billingService,
rateLimitService: rateLimitService,
billingCacheService: billingCacheService,
@@ -185,6 +204,14 @@ func (s *GatewayService) GenerateSessionHash(parsed *ParsedRequest) string {
return ""
}
// BindStickySession sets session -> account binding with standard TTL.
func (s *GatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error {
if sessionHash == "" || accountID <= 0 {
return nil
}
return s.cache.SetSessionAccountID(ctx, sessionHash, accountID, stickySessionTTL)
}
func (s *GatewayService) extractCacheableContent(parsed *ParsedRequest) string {
if parsed == nil {
return ""
@@ -334,8 +361,354 @@ func (s *GatewayService) SelectAccountForModelWithExclusions(ctx context.Context
return s.selectAccountForModelWithPlatform(ctx, groupID, sessionHash, requestedModel, excludedIDs, platform)
}
// SelectAccountWithLoadAwareness selects account with load-awareness and wait plan.
func (s *GatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
cfg := s.schedulingConfig()
var stickyAccountID int64
if sessionHash != "" && s.cache != nil {
if accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash); err == nil {
stickyAccountID = accountID
}
}
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
if err != nil {
return nil, err
}
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err == nil && result.Acquired {
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
}
platform, hasForcePlatform, err := s.resolvePlatform(ctx, groupID)
if err != nil {
return nil, err
}
preferOAuth := platform == PlatformGemini
accounts, useMixed, err := s.listSchedulableAccounts(ctx, groupID, platform, hasForcePlatform)
if err != nil {
return nil, err
}
if len(accounts) == 0 {
return nil, errors.New("no available accounts")
}
isExcluded := func(accountID int64) bool {
if excludedIDs == nil {
return false
}
_, excluded := excludedIDs[accountID]
return excluded
}
// ============ Layer 1: 粘性会话优先 ============
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err == nil && s.isAccountAllowedForPlatform(account, platform, useMixed) &&
account.IsSchedulable() &&
(requestedModel == "" || s.isModelSupportedByAccount(account, requestedModel)) {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
_ = s.cache.RefreshSessionTTL(ctx, sessionHash, stickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
}
}
// ============ Layer 2: 负载感知选择 ============
candidates := make([]*Account, 0, len(accounts))
for i := range accounts {
acc := &accounts[i]
if isExcluded(acc.ID) {
continue
}
if !s.isAccountAllowedForPlatform(acc, platform, useMixed) {
continue
}
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
candidates = append(candidates, acc)
}
if len(candidates) == 0 {
return nil, errors.New("no available accounts")
}
accountLoads := make([]AccountWithConcurrency, 0, len(candidates))
for _, acc := range candidates {
accountLoads = append(accountLoads, AccountWithConcurrency{
ID: acc.ID,
MaxConcurrency: acc.Concurrency,
})
}
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
if err != nil {
if result, ok := s.tryAcquireByLegacyOrder(ctx, candidates, sessionHash, preferOAuth); ok {
return result, nil
}
} else {
type accountWithLoad struct {
account *Account
loadInfo *AccountLoadInfo
}
var available []accountWithLoad
for _, acc := range candidates {
loadInfo := loadMap[acc.ID]
if loadInfo == nil {
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
}
if loadInfo.LoadRate < 100 {
available = append(available, accountWithLoad{
account: acc,
loadInfo: loadInfo,
})
}
}
if len(available) > 0 {
sort.SliceStable(available, func(i, j int) bool {
a, b := available[i], available[j]
if a.account.Priority != b.account.Priority {
return a.account.Priority < b.account.Priority
}
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
}
switch {
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
return true
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
return false
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
if preferOAuth && a.account.Type != b.account.Type {
return a.account.Type == AccountTypeOAuth
}
return false
default:
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
for _, item := range available {
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, sessionHash, item.account.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: item.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
}
}
}
// ============ Layer 3: 兜底排队 ============
sortAccountsByPriorityAndLastUsed(candidates, preferOAuth)
for _, acc := range candidates {
return &AccountSelectionResult{
Account: acc,
WaitPlan: &AccountWaitPlan{
AccountID: acc.ID,
MaxConcurrency: acc.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
}
return nil, errors.New("no available accounts")
}
func (s *GatewayService) tryAcquireByLegacyOrder(ctx context.Context, candidates []*Account, sessionHash string, preferOAuth bool) (*AccountSelectionResult, bool) {
ordered := append([]*Account(nil), candidates...)
sortAccountsByPriorityAndLastUsed(ordered, preferOAuth)
for _, acc := range ordered {
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, sessionHash, acc.ID, stickySessionTTL)
}
return &AccountSelectionResult{
Account: acc,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, true
}
}
return nil, false
}
func (s *GatewayService) schedulingConfig() config.GatewaySchedulingConfig {
if s.cfg != nil {
return s.cfg.Gateway.Scheduling
}
return config.GatewaySchedulingConfig{
StickySessionMaxWaiting: 3,
StickySessionWaitTimeout: 45 * time.Second,
FallbackWaitTimeout: 30 * time.Second,
FallbackMaxWaiting: 100,
LoadBatchEnabled: true,
SlotCleanupInterval: 30 * time.Second,
}
}
func (s *GatewayService) resolvePlatform(ctx context.Context, groupID *int64) (string, bool, error) {
forcePlatform, hasForcePlatform := ctx.Value(ctxkey.ForcePlatform).(string)
if hasForcePlatform && forcePlatform != "" {
return forcePlatform, true, nil
}
if groupID != nil {
group, err := s.groupRepo.GetByID(ctx, *groupID)
if err != nil {
return "", false, fmt.Errorf("get group failed: %w", err)
}
return group.Platform, false, nil
}
return PlatformAnthropic, false, nil
}
func (s *GatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64, platform string, hasForcePlatform bool) ([]Account, bool, error) {
useMixed := (platform == PlatformAnthropic || platform == PlatformGemini) && !hasForcePlatform
if useMixed {
platforms := []string{platform, PlatformAntigravity}
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatforms(ctx, *groupID, platforms)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatforms(ctx, platforms)
}
if err != nil {
return nil, useMixed, err
}
filtered := make([]Account, 0, len(accounts))
for _, acc := range accounts {
if acc.Platform == PlatformAntigravity && !acc.IsMixedSchedulingEnabled() {
continue
}
filtered = append(filtered, acc)
}
return filtered, useMixed, nil
}
var accounts []Account
var err error
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
} else if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, platform)
if err == nil && len(accounts) == 0 && hasForcePlatform {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
}
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, platform)
}
if err != nil {
return nil, useMixed, err
}
return accounts, useMixed, nil
}
func (s *GatewayService) isAccountAllowedForPlatform(account *Account, platform string, useMixed bool) bool {
if account == nil {
return false
}
if useMixed {
if account.Platform == platform {
return true
}
return account.Platform == PlatformAntigravity && account.IsMixedSchedulingEnabled()
}
return account.Platform == platform
}
func (s *GatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
if s.concurrencyService == nil {
return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil
}
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
}
func sortAccountsByPriorityAndLastUsed(accounts []*Account, preferOAuth bool) {
sort.SliceStable(accounts, func(i, j int) bool {
a, b := accounts[i], accounts[j]
if a.Priority != b.Priority {
return a.Priority < b.Priority
}
switch {
case a.LastUsedAt == nil && b.LastUsedAt != nil:
return true
case a.LastUsedAt != nil && b.LastUsedAt == nil:
return false
case a.LastUsedAt == nil && b.LastUsedAt == nil:
if preferOAuth && a.Type != b.Type {
return a.Type == AccountTypeOAuth
}
return false
default:
return a.LastUsedAt.Before(*b.LastUsedAt)
}
})
}
// selectAccountForModelWithPlatform 选择单平台账户(完全隔离)
func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, platform string) (*Account, error) {
preferOAuth := platform == PlatformGemini
// 1. 查询粘性会话
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, sessionHash)
@@ -391,7 +764,9 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
// keep selected (both never used)
if preferOAuth && acc.Type != selected.Type && acc.Type == AccountTypeOAuth {
selected = acc
}
default:
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc
@@ -421,6 +796,7 @@ func (s *GatewayService) selectAccountForModelWithPlatform(ctx context.Context,
// 查询原生平台账户 + 启用 mixed_scheduling 的 antigravity 账户
func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}, nativePlatform string) (*Account, error) {
platforms := []string{nativePlatform, PlatformAntigravity}
preferOAuth := nativePlatform == PlatformGemini
// 1. 查询粘性会话
if sessionHash != "" {
@@ -480,7 +856,9 @@ func (s *GatewayService) selectAccountWithMixedScheduling(ctx context.Context, g
case acc.LastUsedAt != nil && selected.LastUsedAt == nil:
// keep selected (never used is preferred)
case acc.LastUsedAt == nil && selected.LastUsedAt == nil:
// keep selected (both never used)
if preferOAuth && acc.Platform == PlatformGemini && selected.Platform == PlatformGemini && acc.Type != selected.Type && acc.Type == AccountTypeOAuth {
selected = acc
}
default:
if acc.LastUsedAt.Before(*selected.LastUsedAt) {
selected = acc
@@ -517,24 +895,10 @@ func (s *GatewayService) isModelSupportedByAccount(account *Account, requestedMo
}
// IsAntigravityModelSupported 检查 Antigravity 平台是否支持指定模型
// 所有 claude- 和 gemini- 前缀的模型都能通过映射或透传支持
func IsAntigravityModelSupported(requestedModel string) bool {
// 直接支持的模型
if antigravitySupportedModels[requestedModel] {
return true
}
// 可映射的模型
if _, ok := antigravityModelMapping[requestedModel]; ok {
return true
}
// Gemini 前缀透传
if strings.HasPrefix(requestedModel, "gemini-") {
return true
}
// Claude 模型支持(通过默认映射到 claude-sonnet-4-5
if strings.HasPrefix(requestedModel, "claude-") {
return true
}
return false
return strings.HasPrefix(requestedModel, "claude-") ||
strings.HasPrefix(requestedModel, "gemini-")
}
// GetAccessToken 获取账号凭证
@@ -686,6 +1050,30 @@ func (s *GatewayService) Forward(ctx context.Context, c *gin.Context, account *A
// 处理错误响应(不可重试的错误)
if resp.StatusCode >= 400 {
// 可选:对部分 400 触发 failover默认关闭以保持语义
if resp.StatusCode == 400 && s.cfg != nil && s.cfg.Gateway.FailoverOn400 {
respBody, readErr := io.ReadAll(resp.Body)
if readErr != nil {
// ReadAll failed, fall back to normal error handling without consuming the stream
return s.handleErrorResponse(ctx, resp, c, account)
}
_ = resp.Body.Close()
resp.Body = io.NopCloser(bytes.NewReader(respBody))
if s.shouldFailoverOn400(respBody) {
if s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
"Account %d: 400 error, attempting failover: %s",
account.ID,
truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
)
} else {
log.Printf("Account %d: 400 error, attempting failover", account.ID)
}
s.handleFailoverSideEffects(ctx, resp, account)
return nil, &UpstreamFailoverError{StatusCode: resp.StatusCode}
}
}
return s.handleErrorResponse(ctx, resp, c, account)
}
@@ -794,6 +1182,13 @@ func (s *GatewayService) buildUpstreamRequest(ctx context.Context, c *gin.Contex
// 处理anthropic-beta headerOAuth账号需要特殊处理
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
// API-key仅在请求显式使用 beta 特性且客户端未提供时,按需补齐(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultApiKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}
}
return req, nil
@@ -846,6 +1241,83 @@ func (s *GatewayService) getBetaHeader(modelID string, clientBetaHeader string)
return claude.DefaultBetaHeader
}
func requestNeedsBetaFeatures(body []byte) bool {
tools := gjson.GetBytes(body, "tools")
if tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 {
return true
}
if strings.EqualFold(gjson.GetBytes(body, "thinking.type").String(), "enabled") {
return true
}
return false
}
func defaultApiKeyBetaHeader(body []byte) string {
modelID := gjson.GetBytes(body, "model").String()
if strings.Contains(strings.ToLower(modelID), "haiku") {
return claude.ApiKeyHaikuBetaHeader
}
return claude.ApiKeyBetaHeader
}
func truncateForLog(b []byte, maxBytes int) string {
if maxBytes <= 0 {
maxBytes = 2048
}
if len(b) > maxBytes {
b = b[:maxBytes]
}
s := string(b)
// 保持一行,避免污染日志格式
s = strings.ReplaceAll(s, "\n", "\\n")
s = strings.ReplaceAll(s, "\r", "\\r")
return s
}
func (s *GatewayService) shouldFailoverOn400(respBody []byte) bool {
// 只对“可能是兼容性差异导致”的 400 允许切换,避免无意义重试。
// 默认保守:无法识别则不切换。
msg := strings.ToLower(strings.TrimSpace(extractUpstreamErrorMessage(respBody)))
if msg == "" {
return false
}
// 缺少/错误的 beta header换账号/链路可能成功(尤其是混合调度时)。
// 更精确匹配 beta 相关的兼容性问题,避免误触发切换。
if strings.Contains(msg, "anthropic-beta") ||
strings.Contains(msg, "beta feature") ||
strings.Contains(msg, "requires beta") {
return true
}
// thinking/tool streaming 等兼容性约束(常见于中间转换链路)
if strings.Contains(msg, "thinking") || strings.Contains(msg, "thought_signature") || strings.Contains(msg, "signature") {
return true
}
if strings.Contains(msg, "tool_use") || strings.Contains(msg, "tool_result") || strings.Contains(msg, "tools") {
return true
}
return false
}
func extractUpstreamErrorMessage(body []byte) string {
// Claude 风格:{"type":"error","error":{"type":"...","message":"..."}}
if m := gjson.GetBytes(body, "error.message").String(); strings.TrimSpace(m) != "" {
inner := strings.TrimSpace(m)
// 有些上游会把完整 JSON 作为字符串塞进 message
if strings.HasPrefix(inner, "{") {
if innerMsg := gjson.Get(inner, "error.message").String(); strings.TrimSpace(innerMsg) != "" {
return innerMsg
}
}
return m
}
// 兜底:尝试顶层 message
return gjson.GetBytes(body, "message").String()
}
func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Response, c *gin.Context, account *Account) (*ForwardResult, error) {
body, _ := io.ReadAll(resp.Body)
@@ -858,6 +1330,16 @@ func (s *GatewayService) handleErrorResponse(ctx context.Context, resp *http.Res
switch resp.StatusCode {
case 400:
// 仅记录上游错误摘要(避免输出请求内容);需要时可通过配置打开
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
"Upstream 400 error (account=%d platform=%s type=%s): %s",
account.ID,
account.Platform,
account.Type,
truncateForLog(body, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
)
}
c.Data(http.StatusBadRequest, "application/json", body)
return nil, fmt.Errorf("upstream error: %d", resp.StatusCode)
case 401:
@@ -1271,10 +1753,9 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
body := parsed.Body
reqModel := parsed.Model
// Antigravity 账户不支持 count_tokens 转发,返回估算
// 参考 Antigravity-Manager 和 proxycast 实现
// Antigravity 账户不支持 count_tokens 转发,直接返回空
if account.Platform == PlatformAntigravity {
c.JSON(http.StatusOK, gin.H{"input_tokens": 100})
c.JSON(http.StatusOK, gin.H{"input_tokens": 0})
return nil
}
@@ -1332,6 +1813,18 @@ func (s *GatewayService) ForwardCountTokens(ctx context.Context, c *gin.Context,
// 标记账号状态429/529等
s.rateLimitService.HandleUpstreamError(ctx, account, resp.StatusCode, resp.Header, respBody)
// 记录上游错误摘要便于排障(不回显请求内容)
if s.cfg != nil && s.cfg.Gateway.LogUpstreamErrorBody {
log.Printf(
"count_tokens upstream error %d (account=%d platform=%s type=%s): %s",
resp.StatusCode,
account.ID,
account.Platform,
account.Type,
truncateForLog(respBody, s.cfg.Gateway.LogUpstreamErrorBodyMaxBytes),
)
}
// 返回简化的错误响应
errMsg := "Upstream request failed"
switch resp.StatusCode {
@@ -1418,6 +1911,13 @@ func (s *GatewayService) buildCountTokensRequest(ctx context.Context, c *gin.Con
// OAuth 账号:处理 anthropic-beta header
if tokenType == "oauth" {
req.Header.Set("anthropic-beta", s.getBetaHeader(modelID, c.GetHeader("anthropic-beta")))
} else if s.cfg != nil && s.cfg.Gateway.InjectBetaForApiKey && req.Header.Get("anthropic-beta") == "" {
// API-key与 messages 同步的按需 beta 注入(默认关闭)
if requestNeedsBetaFeatures(body) {
if beta := defaultApiKeyBetaHeader(body); beta != "" {
req.Header.Set("anthropic-beta", beta)
}
}
}
return req, nil
@@ -1445,3 +1945,58 @@ func (s *GatewayService) validateUpstreamBaseURL(raw string) (string, error) {
}
return normalized, nil
}
// GetAvailableModels returns the list of models available for a group
// It aggregates model_mapping keys from all schedulable accounts in the group
func (s *GatewayService) GetAvailableModels(ctx context.Context, groupID *int64, platform string) []string {
var accounts []Account
var err error
if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupID(ctx, *groupID)
} else {
accounts, err = s.accountRepo.ListSchedulable(ctx)
}
if err != nil || len(accounts) == 0 {
return nil
}
// Filter by platform if specified
if platform != "" {
filtered := make([]Account, 0)
for _, acc := range accounts {
if acc.Platform == platform {
filtered = append(filtered, acc)
}
}
accounts = filtered
}
// Collect unique models from all accounts
modelSet := make(map[string]struct{})
hasAnyMapping := false
for _, acc := range accounts {
mapping := acc.GetModelMapping()
if len(mapping) > 0 {
hasAnyMapping = true
for model := range mapping {
modelSet[model] = struct{}{}
}
}
}
// If no account has model_mapping, return nil (use default)
if !hasAnyMapping {
return nil
}
// Convert to slice
models := make([]string, 0, len(modelSet))
for model := range modelSet {
models = append(models, model)
}
return models
}

View File

@@ -122,8 +122,20 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
valid = true
}
if valid {
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
return account, nil
usable := true
if s.rateLimitService != nil && requestedModel != "" {
ok, err := s.rateLimitService.PreCheckUsage(ctx, account, requestedModel)
if err != nil {
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", account.ID, err)
}
if !ok {
usable = false
}
}
if usable {
_ = s.cache.RefreshSessionTTL(ctx, cacheKey, geminiStickySessionTTL)
return account, nil
}
}
}
}
@@ -163,6 +175,15 @@ func (s *GeminiMessagesCompatService) SelectAccountForModelWithExclusions(ctx co
if requestedModel != "" && !s.isModelSupportedByAccount(acc, requestedModel) {
continue
}
if s.rateLimitService != nil && requestedModel != "" {
ok, err := s.rateLimitService.PreCheckUsage(ctx, acc, requestedModel)
if err != nil {
log.Printf("[Gemini PreCheck] Account %d precheck error: %v", acc.ID, err)
}
if !ok {
continue
}
}
if selected == nil {
selected = acc
continue
@@ -1939,13 +1960,44 @@ func (s *GeminiMessagesCompatService) handleGeminiUpstreamError(ctx context.Cont
if statusCode != 429 {
return
}
oauthType := account.GeminiOAuthType()
tierID := account.GeminiTierID()
projectID := strings.TrimSpace(account.GetCredential("project_id"))
isCodeAssist := account.IsGeminiCodeAssist()
resetAt := ParseGeminiRateLimitResetTime(body)
if resetAt == nil {
ra := time.Now().Add(5 * time.Minute)
// 根据账号类型使用不同的默认重置时间
var ra time.Time
if isCodeAssist {
// Code Assist: fallback cooldown by tier
cooldown := geminiCooldownForTier(tierID)
if s.rateLimitService != nil {
cooldown = s.rateLimitService.GeminiCooldown(ctx, account)
}
ra = time.Now().Add(cooldown)
log.Printf("[Gemini 429] Account %d (Code Assist, tier=%s, project=%s) rate limited, cooldown=%v", account.ID, tierID, projectID, time.Until(ra).Truncate(time.Second))
} else {
// API Key / AI Studio OAuth: PST 午夜
if ts := nextGeminiDailyResetUnix(); ts != nil {
ra = time.Unix(*ts, 0)
log.Printf("[Gemini 429] Account %d (API Key/AI Studio, type=%s) rate limited, reset at PST midnight (%v)", account.ID, account.Type, ra)
} else {
// 兜底5 分钟
ra = time.Now().Add(5 * time.Minute)
log.Printf("[Gemini 429] Account %d rate limited, fallback to 5min", account.ID)
}
}
_ = s.accountRepo.SetRateLimited(ctx, account.ID, ra)
return
}
_ = s.accountRepo.SetRateLimited(ctx, account.ID, time.Unix(*resetAt, 0))
// 使用解析到的重置时间
resetTime := time.Unix(*resetAt, 0)
_ = s.accountRepo.SetRateLimited(ctx, account.ID, resetTime)
log.Printf("[Gemini 429] Account %d rate limited until %v (oauth_type=%s, tier=%s)",
account.ID, resetTime, oauthType, tierID)
}
// ParseGeminiRateLimitResetTime 解析 Gemini 格式的 429 响应,返回重置时间的 Unix 时间戳
@@ -2001,16 +2053,7 @@ func looksLikeGeminiDailyQuota(message string) bool {
}
func nextGeminiDailyResetUnix() *int64 {
loc, err := time.LoadLocation("America/Los_Angeles")
if err != nil {
// Fallback: PST without DST.
loc = time.FixedZone("PST", -8*3600)
}
now := time.Now().In(loc)
reset := time.Date(now.Year(), now.Month(), now.Day(), 0, 5, 0, 0, loc)
if !reset.After(now) {
reset = reset.Add(24 * time.Hour)
}
reset := geminiDailyResetTime(time.Now())
ts := reset.Unix()
return &ts
}
@@ -2298,16 +2341,46 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
if !ok {
continue
}
name, _ := tm["name"].(string)
desc, _ := tm["description"].(string)
params := tm["input_schema"]
var name, desc string
var params any
// 检查是否为 custom 类型工具 (MCP)
toolType, _ := tm["type"].(string)
if toolType == "custom" {
// Custom 格式: 从 custom 字段获取 description 和 input_schema
custom, ok := tm["custom"].(map[string]any)
if !ok {
continue
}
name, _ = tm["name"].(string)
desc, _ = custom["description"].(string)
params = custom["input_schema"]
} else {
// 标准格式: 从顶层字段获取
name, _ = tm["name"].(string)
desc, _ = tm["description"].(string)
params = tm["input_schema"]
}
if name == "" {
continue
}
// 为 nil params 提供默认值
if params == nil {
params = map[string]any{
"type": "object",
"properties": map[string]any{},
}
}
// 清理 JSON Schema
cleanedParams := cleanToolSchema(params)
funcDecls = append(funcDecls, map[string]any{
"name": name,
"description": desc,
"parameters": params,
"parameters": cleanedParams,
})
}
@@ -2321,6 +2394,41 @@ func convertClaudeToolsToGeminiTools(tools any) []any {
}
}
// cleanToolSchema 清理工具的 JSON Schema移除 Gemini 不支持的字段
func cleanToolSchema(schema any) any {
if schema == nil {
return nil
}
switch v := schema.(type) {
case map[string]any:
cleaned := make(map[string]any)
for key, value := range v {
// 跳过不支持的字段
if key == "$schema" || key == "$id" || key == "$ref" ||
key == "additionalProperties" || key == "minLength" ||
key == "maxLength" || key == "minItems" || key == "maxItems" {
continue
}
// 递归清理嵌套对象
cleaned[key] = cleanToolSchema(value)
}
// 规范化 type 字段为大写
if typeVal, ok := cleaned["type"].(string); ok {
cleaned["type"] = strings.ToUpper(typeVal)
}
return cleaned
case []any:
cleaned := make([]any, len(v))
for i, item := range v {
cleaned[i] = cleanToolSchema(item)
}
return cleaned
default:
return v
}
}
func convertClaudeGenerationConfig(req map[string]any) map[string]any {
out := make(map[string]any)
if mt, ok := asInt(req["max_tokens"]); ok && mt > 0 {

View File

@@ -0,0 +1,128 @@
package service
import (
"testing"
)
// TestConvertClaudeToolsToGeminiTools_CustomType 测试custom类型工具转换
func TestConvertClaudeToolsToGeminiTools_CustomType(t *testing.T) {
tests := []struct {
name string
tools any
expectedLen int
description string
}{
{
name: "Standard tools",
tools: []any{
map[string]any{
"name": "get_weather",
"description": "Get weather info",
"input_schema": map[string]any{"type": "object"},
},
},
expectedLen: 1,
description: "标准工具格式应该正常转换",
},
{
name: "Custom type tool (MCP format)",
tools: []any{
map[string]any{
"type": "custom",
"name": "mcp_tool",
"custom": map[string]any{
"description": "MCP tool description",
"input_schema": map[string]any{"type": "object"},
},
},
},
expectedLen: 1,
description: "Custom类型工具应该从custom字段读取",
},
{
name: "Mixed standard and custom tools",
tools: []any{
map[string]any{
"name": "standard_tool",
"description": "Standard",
"input_schema": map[string]any{"type": "object"},
},
map[string]any{
"type": "custom",
"name": "custom_tool",
"custom": map[string]any{
"description": "Custom",
"input_schema": map[string]any{"type": "object"},
},
},
},
expectedLen: 1,
description: "混合工具应该都能正确转换",
},
{
name: "Custom tool without custom field",
tools: []any{
map[string]any{
"type": "custom",
"name": "invalid_custom",
// 缺少 custom 字段
},
},
expectedLen: 0, // 应该被跳过
description: "缺少custom字段的custom工具应该被跳过",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := convertClaudeToolsToGeminiTools(tt.tools)
if tt.expectedLen == 0 {
if result != nil {
t.Errorf("%s: expected nil result, got %v", tt.description, result)
}
return
}
if result == nil {
t.Fatalf("%s: expected non-nil result", tt.description)
}
if len(result) != 1 {
t.Errorf("%s: expected 1 tool declaration, got %d", tt.description, len(result))
return
}
toolDecl, ok := result[0].(map[string]any)
if !ok {
t.Fatalf("%s: result[0] is not map[string]any", tt.description)
}
funcDecls, ok := toolDecl["functionDeclarations"].([]any)
if !ok {
t.Fatalf("%s: functionDeclarations is not []any", tt.description)
}
toolsArr, _ := tt.tools.([]any)
expectedFuncCount := 0
for _, tool := range toolsArr {
toolMap, _ := tool.(map[string]any)
if toolMap["name"] != "" {
// 检查是否为有效的custom工具
if toolMap["type"] == "custom" {
if toolMap["custom"] != nil {
expectedFuncCount++
}
} else {
expectedFuncCount++
}
}
}
if len(funcDecls) != expectedFuncCount {
t.Errorf("%s: expected %d function declarations, got %d",
tt.description, expectedFuncCount, len(funcDecls))
}
})
}
}

View File

@@ -25,6 +25,16 @@ func (m *mockAccountRepoForGemini) GetByID(ctx context.Context, id int64) (*Acco
return nil, errors.New("account not found")
}
func (m *mockAccountRepoForGemini) GetByIDs(ctx context.Context, ids []int64) ([]*Account, error) {
var result []*Account
for _, id := range ids {
if acc, ok := m.accountsByID[id]; ok {
result = append(result, acc)
}
}
return result, nil
}
func (m *mockAccountRepoForGemini) ExistsByID(ctx context.Context, id int64) (bool, error) {
if m.accountsByID == nil {
return false, nil

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"io"
"net/http"
"regexp"
"strconv"
"strings"
"time"
@@ -16,6 +17,26 @@ import (
"github.com/Wei-Shaw/sub2api/internal/pkg/httpclient"
)
const (
TierAIPremium = "AI_PREMIUM"
TierGoogleOneStandard = "GOOGLE_ONE_STANDARD"
TierGoogleOneBasic = "GOOGLE_ONE_BASIC"
TierFree = "FREE"
TierGoogleOneUnknown = "GOOGLE_ONE_UNKNOWN"
TierGoogleOneUnlimited = "GOOGLE_ONE_UNLIMITED"
)
const (
GB = 1024 * 1024 * 1024
TB = 1024 * GB
StorageTierUnlimited = 100 * TB // 100TB
StorageTierAIPremium = 2 * TB // 2TB
StorageTierStandard = 200 * GB // 200GB
StorageTierBasic = 100 * GB // 100GB
StorageTierFree = 15 * GB // 15GB
)
type GeminiOAuthService struct {
sessionStore *geminicli.SessionStore
proxyRepo ProxyRepository
@@ -88,13 +109,14 @@ func (s *GeminiOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64
// OAuth client selection:
// - code_assist: always use built-in Gemini CLI OAuth client (public), regardless of configured client_id/secret.
// - google_one: same as code_assist, uses built-in client for personal Google accounts.
// - ai_studio: requires a user-provided OAuth client.
oauthCfg := geminicli.OAuthConfig{
ClientID: s.cfg.Gemini.OAuth.ClientID,
ClientSecret: s.cfg.Gemini.OAuth.ClientSecret,
Scopes: s.cfg.Gemini.OAuth.Scopes,
}
if oauthType == "code_assist" {
if oauthType == "code_assist" || oauthType == "google_one" {
oauthCfg.ClientID = ""
oauthCfg.ClientSecret = ""
}
@@ -155,14 +177,152 @@ type GeminiExchangeCodeInput struct {
}
type GeminiTokenInfo struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
ExpiresAt int64 `json:"expires_at"`
TokenType string `json:"token_type"`
Scope string `json:"scope,omitempty"`
ProjectID string `json:"project_id,omitempty"`
OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio"
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
ExpiresAt int64 `json:"expires_at"`
TokenType string `json:"token_type"`
Scope string `json:"scope,omitempty"`
ProjectID string `json:"project_id,omitempty"`
OAuthType string `json:"oauth_type,omitempty"` // "code_assist" 或 "ai_studio"
TierID string `json:"tier_id,omitempty"` // Gemini Code Assist tier: LEGACY/PRO/ULTRA
Extra map[string]any `json:"extra,omitempty"` // Drive metadata
}
// validateTierID validates tier_id format and length
func validateTierID(tierID string) error {
if tierID == "" {
return nil // Empty is allowed
}
if len(tierID) > 64 {
return fmt.Errorf("tier_id exceeds maximum length of 64 characters")
}
// Allow alphanumeric, underscore, hyphen, and slash (for tier paths)
if !regexp.MustCompile(`^[a-zA-Z0-9_/-]+$`).MatchString(tierID) {
return fmt.Errorf("tier_id contains invalid characters")
}
return nil
}
// extractTierIDFromAllowedTiers extracts tierID from LoadCodeAssist response
// Prioritizes IsDefault tier, falls back to first non-empty tier
func extractTierIDFromAllowedTiers(allowedTiers []geminicli.AllowedTier) string {
tierID := "LEGACY"
// First pass: look for default tier
for _, tier := range allowedTiers {
if tier.IsDefault && strings.TrimSpace(tier.ID) != "" {
tierID = strings.TrimSpace(tier.ID)
break
}
}
// Second pass: if still LEGACY, take first non-empty tier
if tierID == "LEGACY" {
for _, tier := range allowedTiers {
if strings.TrimSpace(tier.ID) != "" {
tierID = strings.TrimSpace(tier.ID)
break
}
}
}
return tierID
}
// inferGoogleOneTier infers Google One tier from Drive storage limit
func inferGoogleOneTier(storageBytes int64) string {
if storageBytes <= 0 {
return TierGoogleOneUnknown
}
if storageBytes > StorageTierUnlimited {
return TierGoogleOneUnlimited
}
if storageBytes >= StorageTierAIPremium {
return TierAIPremium
}
if storageBytes >= StorageTierStandard {
return TierGoogleOneStandard
}
if storageBytes >= StorageTierBasic {
return TierGoogleOneBasic
}
if storageBytes >= StorageTierFree {
return TierFree
}
return TierGoogleOneUnknown
}
// fetchGoogleOneTier fetches Google One tier from Drive API
func (s *GeminiOAuthService) FetchGoogleOneTier(ctx context.Context, accessToken, proxyURL string) (string, *geminicli.DriveStorageInfo, error) {
driveClient := geminicli.NewDriveClient()
storageInfo, err := driveClient.GetStorageQuota(ctx, accessToken, proxyURL)
if err != nil {
// Check if it's a 403 (scope not granted)
if strings.Contains(err.Error(), "status 403") {
fmt.Printf("[GeminiOAuth] Drive API scope not available: %v\n", err)
return TierGoogleOneUnknown, nil, err
}
// Other errors
fmt.Printf("[GeminiOAuth] Failed to fetch Drive storage: %v\n", err)
return TierGoogleOneUnknown, nil, err
}
tierID := inferGoogleOneTier(storageInfo.Limit)
return tierID, storageInfo, nil
}
// RefreshAccountGoogleOneTier 刷新单个账号的 Google One Tier
func (s *GeminiOAuthService) RefreshAccountGoogleOneTier(
ctx context.Context,
account *Account,
) (tierID string, extra map[string]any, credentials map[string]any, err error) {
if account == nil {
return "", nil, nil, fmt.Errorf("account is nil")
}
// 验证账号类型
oauthType, ok := account.Credentials["oauth_type"].(string)
if !ok || oauthType != "google_one" {
return "", nil, nil, fmt.Errorf("not a google_one OAuth account")
}
// 获取 access_token
accessToken, ok := account.Credentials["access_token"].(string)
if !ok || accessToken == "" {
return "", nil, nil, fmt.Errorf("missing access_token")
}
// 获取 proxy URL
var proxyURL string
if account.ProxyID != nil && account.Proxy != nil {
proxyURL = account.Proxy.URL()
}
// 调用 Drive API
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, accessToken, proxyURL)
if err != nil {
return "", nil, nil, err
}
// 构建 extra 数据(保留原有 extra 字段)
extra = make(map[string]any)
for k, v := range account.Extra {
extra[k] = v
}
if storageInfo != nil {
extra["drive_storage_limit"] = storageInfo.Limit
extra["drive_storage_usage"] = storageInfo.Usage
extra["drive_tier_updated_at"] = time.Now().Format(time.RFC3339)
}
// 构建 credentials 数据
credentials = make(map[string]any)
for k, v := range account.Credentials {
credentials[k] = v
}
credentials["tier_id"] = tierID
return tierID, extra, credentials, nil
}
func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExchangeCodeInput) (*GeminiTokenInfo, error) {
@@ -219,26 +379,78 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
sessionProjectID := strings.TrimSpace(session.ProjectID)
s.sessionStore.Delete(input.SessionID)
// 计算过期时间减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
// 计算过期时间减去 5 分钟安全时间窗口考虑网络延迟和时钟偏差
// 同时设置下界保护,防止 expires_in 过小导致过去时间(引发刷新风暴)
const safetyWindow = 300 // 5 minutes
const minTTL = 30 // minimum 30 seconds
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - safetyWindow
minExpiresAt := time.Now().Unix() + minTTL
if expiresAt < minExpiresAt {
expiresAt = minExpiresAt
}
projectID := sessionProjectID
var tierID string
// 对于 code_assist 模式project_id 是必需的
// 对于 code_assist 模式project_id 是必需的,需要调用 Code Assist API
// 对于 google_one 模式,使用个人 Google 账号,不需要 project_id配额由 Google 网关自动识别
// 对于 ai_studio 模式project_id 是可选的(不影响使用 AI Studio API
if oauthType == "code_assist" {
switch oauthType {
case "code_assist":
if projectID == "" {
var err error
projectID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
projectID, tierID, err = s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
if err != nil {
// 记录警告但不阻断流程,允许后续补充 project_id
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch project_id during token exchange: %v\n", err)
}
} else {
// 用户手动填了 project_id仍需调用 LoadCodeAssist 获取 tierID
_, fetchedTierID, err := s.fetchProjectID(ctx, tokenResp.AccessToken, proxyURL)
if err != nil {
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch tierID: %v\n", err)
} else {
tierID = fetchedTierID
}
}
if strings.TrimSpace(projectID) == "" {
return nil, fmt.Errorf("missing project_id for Code Assist OAuth: please fill Project ID (optional field) and regenerate the auth URL, or ensure your Google account has an ACTIVE GCP project")
}
// tierID 缺失时使用默认值
if tierID == "" {
tierID = "LEGACY"
}
case "google_one":
// Attempt to fetch Drive storage tier
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, tokenResp.AccessToken, proxyURL)
if err != nil {
// Log warning but don't block - use fallback
fmt.Printf("[GeminiOAuth] Warning: Failed to fetch Drive tier: %v\n", err)
tierID = TierGoogleOneUnknown
}
// Store Drive info in extra field for caching
if storageInfo != nil {
tokenInfo := &GeminiTokenInfo{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
TokenType: tokenResp.TokenType,
ExpiresIn: tokenResp.ExpiresIn,
ExpiresAt: expiresAt,
Scope: tokenResp.Scope,
ProjectID: projectID,
TierID: tierID,
OAuthType: oauthType,
Extra: map[string]any{
"drive_storage_limit": storageInfo.Limit,
"drive_storage_usage": storageInfo.Usage,
"drive_tier_updated_at": time.Now().Format(time.RFC3339),
},
}
return tokenInfo, nil
}
}
// ai_studio 模式不设置 tierID保持为空
return &GeminiTokenInfo{
AccessToken: tokenResp.AccessToken,
@@ -248,6 +460,7 @@ func (s *GeminiOAuthService) ExchangeCode(ctx context.Context, input *GeminiExch
ExpiresAt: expiresAt,
Scope: tokenResp.Scope,
ProjectID: projectID,
TierID: tierID,
OAuthType: oauthType,
}, nil
}
@@ -266,8 +479,15 @@ func (s *GeminiOAuthService) RefreshToken(ctx context.Context, oauthType, refres
tokenResp, err := s.oauthClient.RefreshToken(ctx, oauthType, refreshToken, proxyURL)
if err == nil {
// 计算过期时间减去 5 分钟安全时间窗口,考虑网络延迟和时钟偏差
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
// 计算过期时间减去 5 分钟安全时间窗口考虑网络延迟和时钟偏差
// 同时设置下界保护,防止 expires_in 过小导致过去时间(引发刷新风暴)
const safetyWindow = 300 // 5 minutes
const minTTL = 30 // minimum 30 seconds
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - safetyWindow
minExpiresAt := time.Now().Unix() + minTTL
if expiresAt < minExpiresAt {
expiresAt = minExpiresAt
}
return &GeminiTokenInfo{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
@@ -354,18 +574,75 @@ func (s *GeminiOAuthService) RefreshAccountToken(ctx context.Context, account *A
tokenInfo.ProjectID = existingProjectID
}
// 尝试从账号凭证获取 tierID向后兼容
existingTierID := strings.TrimSpace(account.GetCredential("tier_id"))
// For Code Assist, project_id is required. Auto-detect if missing.
// For AI Studio OAuth, project_id is optional and should not block refresh.
if oauthType == "code_assist" && strings.TrimSpace(tokenInfo.ProjectID) == "" {
projectID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL)
if err != nil {
return nil, fmt.Errorf("failed to auto-detect project_id: %w", err)
switch oauthType {
case "code_assist":
// 先设置默认值或保留旧值,确保 tier_id 始终有值
if existingTierID != "" {
tokenInfo.TierID = existingTierID
} else {
tokenInfo.TierID = "LEGACY" // 默认值
}
projectID = strings.TrimSpace(projectID)
if projectID == "" {
// 尝试自动探测 project_id 和 tier_id
needDetect := strings.TrimSpace(tokenInfo.ProjectID) == "" || existingTierID == ""
if needDetect {
projectID, tierID, err := s.fetchProjectID(ctx, tokenInfo.AccessToken, proxyURL)
if err != nil {
fmt.Printf("[GeminiOAuth] Warning: failed to auto-detect project/tier: %v\n", err)
} else {
if strings.TrimSpace(tokenInfo.ProjectID) == "" && projectID != "" {
tokenInfo.ProjectID = projectID
}
// 只有当原来没有 tier_id 且探测成功时才更新
if existingTierID == "" && tierID != "" {
tokenInfo.TierID = tierID
}
}
}
if strings.TrimSpace(tokenInfo.ProjectID) == "" {
return nil, fmt.Errorf("failed to auto-detect project_id: empty result")
}
tokenInfo.ProjectID = projectID
case "google_one":
// Check if tier cache is stale (> 24 hours)
needsRefresh := true
if account.Extra != nil {
if updatedAtStr, ok := account.Extra["drive_tier_updated_at"].(string); ok {
if updatedAt, err := time.Parse(time.RFC3339, updatedAtStr); err == nil {
if time.Since(updatedAt) <= 24*time.Hour {
needsRefresh = false
// Use cached tier
if existingTierID != "" {
tokenInfo.TierID = existingTierID
}
}
}
}
}
if needsRefresh {
tierID, storageInfo, err := s.FetchGoogleOneTier(ctx, tokenInfo.AccessToken, proxyURL)
if err == nil && storageInfo != nil {
tokenInfo.TierID = tierID
tokenInfo.Extra = map[string]any{
"drive_storage_limit": storageInfo.Limit,
"drive_storage_usage": storageInfo.Usage,
"drive_tier_updated_at": time.Now().Format(time.RFC3339),
}
} else {
// Fallback to cached or unknown
if existingTierID != "" {
tokenInfo.TierID = existingTierID
} else {
tokenInfo.TierID = TierGoogleOneUnknown
}
}
}
}
return tokenInfo, nil
@@ -388,9 +665,22 @@ func (s *GeminiOAuthService) BuildAccountCredentials(tokenInfo *GeminiTokenInfo)
if tokenInfo.ProjectID != "" {
creds["project_id"] = tokenInfo.ProjectID
}
if tokenInfo.TierID != "" {
// Validate tier_id before storing
if err := validateTierID(tokenInfo.TierID); err == nil {
creds["tier_id"] = tokenInfo.TierID
}
// Silently skip invalid tier_id (don't block account creation)
}
if tokenInfo.OAuthType != "" {
creds["oauth_type"] = tokenInfo.OAuthType
}
// Store extra metadata (Drive info) if present
if len(tokenInfo.Extra) > 0 {
for k, v := range tokenInfo.Extra {
creds[k] = v
}
}
return creds
}
@@ -398,33 +688,22 @@ func (s *GeminiOAuthService) Stop() {
s.sessionStore.Stop()
}
func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, error) {
func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, proxyURL string) (string, string, error) {
if s.codeAssist == nil {
return "", errors.New("code assist client not configured")
return "", "", errors.New("code assist client not configured")
}
loadResp, loadErr := s.codeAssist.LoadCodeAssist(ctx, accessToken, proxyURL, nil)
if loadErr == nil && loadResp != nil && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" {
return strings.TrimSpace(loadResp.CloudAICompanionProject), nil
}
// Pick tier from allowedTiers; if no default tier is marked, pick the first non-empty tier ID.
// Extract tierID from response (works whether CloudAICompanionProject is set or not)
tierID := "LEGACY"
if loadResp != nil {
for _, tier := range loadResp.AllowedTiers {
if tier.IsDefault && strings.TrimSpace(tier.ID) != "" {
tierID = strings.TrimSpace(tier.ID)
break
}
}
if strings.TrimSpace(tierID) == "" || tierID == "LEGACY" {
for _, tier := range loadResp.AllowedTiers {
if strings.TrimSpace(tier.ID) != "" {
tierID = strings.TrimSpace(tier.ID)
break
}
}
}
tierID = extractTierIDFromAllowedTiers(loadResp.AllowedTiers)
}
// If LoadCodeAssist returned a project, use it
if loadErr == nil && loadResp != nil && strings.TrimSpace(loadResp.CloudAICompanionProject) != "" {
return strings.TrimSpace(loadResp.CloudAICompanionProject), tierID, nil
}
req := &geminicli.OnboardUserRequest{
@@ -443,39 +722,39 @@ func (s *GeminiOAuthService) fetchProjectID(ctx context.Context, accessToken, pr
// If Code Assist onboarding fails (e.g. INVALID_ARGUMENT), fallback to Cloud Resource Manager projects.
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), nil
return strings.TrimSpace(fallback), tierID, nil
}
return "", err
return "", tierID, err
}
if resp.Done {
if resp.Response != nil && resp.Response.CloudAICompanionProject != nil {
switch v := resp.Response.CloudAICompanionProject.(type) {
case string:
return strings.TrimSpace(v), nil
return strings.TrimSpace(v), tierID, nil
case map[string]any:
if id, ok := v["id"].(string); ok {
return strings.TrimSpace(id), nil
return strings.TrimSpace(id), tierID, nil
}
}
}
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), nil
return strings.TrimSpace(fallback), tierID, nil
}
return "", errors.New("onboardUser completed but no project_id returned")
return "", tierID, errors.New("onboardUser completed but no project_id returned")
}
time.Sleep(2 * time.Second)
}
fallback, fbErr := fetchProjectIDFromResourceManager(ctx, accessToken, proxyURL)
if fbErr == nil && strings.TrimSpace(fallback) != "" {
return strings.TrimSpace(fallback), nil
return strings.TrimSpace(fallback), tierID, nil
}
if loadErr != nil {
return "", fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts)
return "", tierID, fmt.Errorf("loadCodeAssist failed (%v) and onboardUser timeout after %d attempts", loadErr, maxAttempts)
}
return "", fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts)
return "", tierID, fmt.Errorf("onboardUser timeout after %d attempts", maxAttempts)
}
type googleCloudProject struct {

View File

@@ -0,0 +1,51 @@
package service
import "testing"
func TestInferGoogleOneTier(t *testing.T) {
tests := []struct {
name string
storageBytes int64
expectedTier string
}{
{"Negative storage", -1, TierGoogleOneUnknown},
{"Zero storage", 0, TierGoogleOneUnknown},
// Free tier boundary (15GB)
{"Below free tier", 10 * GB, TierGoogleOneUnknown},
{"Just below free tier", StorageTierFree - 1, TierGoogleOneUnknown},
{"Free tier (15GB)", StorageTierFree, TierFree},
// Basic tier boundary (100GB)
{"Between free and basic", 50 * GB, TierFree},
{"Just below basic tier", StorageTierBasic - 1, TierFree},
{"Basic tier (100GB)", StorageTierBasic, TierGoogleOneBasic},
// Standard tier boundary (200GB)
{"Between basic and standard", 150 * GB, TierGoogleOneBasic},
{"Just below standard tier", StorageTierStandard - 1, TierGoogleOneBasic},
{"Standard tier (200GB)", StorageTierStandard, TierGoogleOneStandard},
// AI Premium tier boundary (2TB)
{"Between standard and premium", 1 * TB, TierGoogleOneStandard},
{"Just below AI Premium tier", StorageTierAIPremium - 1, TierGoogleOneStandard},
{"AI Premium tier (2TB)", StorageTierAIPremium, TierAIPremium},
// Unlimited tier boundary (> 100TB)
{"Between premium and unlimited", 50 * TB, TierAIPremium},
{"At unlimited threshold (100TB)", StorageTierUnlimited, TierAIPremium},
{"Unlimited tier (100TB+)", StorageTierUnlimited + 1, TierGoogleOneUnlimited},
{"Unlimited tier (101TB+)", 101 * TB, TierGoogleOneUnlimited},
{"Very large storage", 1000 * TB, TierGoogleOneUnlimited},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := inferGoogleOneTier(tt.storageBytes)
if result != tt.expectedTier {
t.Errorf("inferGoogleOneTier(%d) = %s, want %s",
tt.storageBytes, result, tt.expectedTier)
}
})
}
}

View File

@@ -0,0 +1,268 @@
package service
import (
"context"
"encoding/json"
"errors"
"log"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
)
type geminiModelClass string
const (
geminiModelPro geminiModelClass = "pro"
geminiModelFlash geminiModelClass = "flash"
)
type GeminiDailyQuota struct {
ProRPD int64
FlashRPD int64
}
type GeminiTierPolicy struct {
Quota GeminiDailyQuota
Cooldown time.Duration
}
type GeminiQuotaPolicy struct {
tiers map[string]GeminiTierPolicy
}
type GeminiUsageTotals struct {
ProRequests int64
FlashRequests int64
ProTokens int64
FlashTokens int64
ProCost float64
FlashCost float64
}
const geminiQuotaCacheTTL = time.Minute
type geminiQuotaOverrides struct {
Tiers map[string]config.GeminiTierQuotaConfig `json:"tiers"`
}
type GeminiQuotaService struct {
cfg *config.Config
settingRepo SettingRepository
mu sync.Mutex
cachedAt time.Time
policy *GeminiQuotaPolicy
}
func NewGeminiQuotaService(cfg *config.Config, settingRepo SettingRepository) *GeminiQuotaService {
return &GeminiQuotaService{
cfg: cfg,
settingRepo: settingRepo,
}
}
func (s *GeminiQuotaService) Policy(ctx context.Context) *GeminiQuotaPolicy {
if s == nil {
return newGeminiQuotaPolicy()
}
now := time.Now()
s.mu.Lock()
if s.policy != nil && now.Sub(s.cachedAt) < geminiQuotaCacheTTL {
policy := s.policy
s.mu.Unlock()
return policy
}
s.mu.Unlock()
policy := newGeminiQuotaPolicy()
if s.cfg != nil {
policy.ApplyOverrides(s.cfg.Gemini.Quota.Tiers)
if strings.TrimSpace(s.cfg.Gemini.Quota.Policy) != "" {
var overrides geminiQuotaOverrides
if err := json.Unmarshal([]byte(s.cfg.Gemini.Quota.Policy), &overrides); err != nil {
log.Printf("gemini quota: parse config policy failed: %v", err)
} else {
policy.ApplyOverrides(overrides.Tiers)
}
}
}
if s.settingRepo != nil {
value, err := s.settingRepo.GetValue(ctx, SettingKeyGeminiQuotaPolicy)
if err != nil && !errors.Is(err, ErrSettingNotFound) {
log.Printf("gemini quota: load setting failed: %v", err)
} else if strings.TrimSpace(value) != "" {
var overrides geminiQuotaOverrides
if err := json.Unmarshal([]byte(value), &overrides); err != nil {
log.Printf("gemini quota: parse setting failed: %v", err)
} else {
policy.ApplyOverrides(overrides.Tiers)
}
}
}
s.mu.Lock()
s.policy = policy
s.cachedAt = now
s.mu.Unlock()
return policy
}
func (s *GeminiQuotaService) QuotaForAccount(ctx context.Context, account *Account) (GeminiDailyQuota, bool) {
if account == nil || !account.IsGeminiCodeAssist() {
return GeminiDailyQuota{}, false
}
policy := s.Policy(ctx)
return policy.QuotaForTier(account.GeminiTierID())
}
func (s *GeminiQuotaService) CooldownForTier(ctx context.Context, tierID string) time.Duration {
policy := s.Policy(ctx)
return policy.CooldownForTier(tierID)
}
func newGeminiQuotaPolicy() *GeminiQuotaPolicy {
return &GeminiQuotaPolicy{
tiers: map[string]GeminiTierPolicy{
"LEGACY": {Quota: GeminiDailyQuota{ProRPD: 50, FlashRPD: 1500}, Cooldown: 30 * time.Minute},
"PRO": {Quota: GeminiDailyQuota{ProRPD: 1500, FlashRPD: 4000}, Cooldown: 5 * time.Minute},
"ULTRA": {Quota: GeminiDailyQuota{ProRPD: 2000, FlashRPD: 0}, Cooldown: 5 * time.Minute},
},
}
}
func (p *GeminiQuotaPolicy) ApplyOverrides(tiers map[string]config.GeminiTierQuotaConfig) {
if p == nil || len(tiers) == 0 {
return
}
for rawID, override := range tiers {
tierID := normalizeGeminiTierID(rawID)
if tierID == "" {
continue
}
policy, ok := p.tiers[tierID]
if !ok {
policy = GeminiTierPolicy{Cooldown: 5 * time.Minute}
}
if override.ProRPD != nil {
policy.Quota.ProRPD = clampGeminiQuotaInt64(*override.ProRPD)
}
if override.FlashRPD != nil {
policy.Quota.FlashRPD = clampGeminiQuotaInt64(*override.FlashRPD)
}
if override.CooldownMinutes != nil {
minutes := clampGeminiQuotaInt(*override.CooldownMinutes)
policy.Cooldown = time.Duration(minutes) * time.Minute
}
p.tiers[tierID] = policy
}
}
func (p *GeminiQuotaPolicy) QuotaForTier(tierID string) (GeminiDailyQuota, bool) {
policy, ok := p.policyForTier(tierID)
if !ok {
return GeminiDailyQuota{}, false
}
return policy.Quota, true
}
func (p *GeminiQuotaPolicy) CooldownForTier(tierID string) time.Duration {
policy, ok := p.policyForTier(tierID)
if ok && policy.Cooldown > 0 {
return policy.Cooldown
}
return 5 * time.Minute
}
func (p *GeminiQuotaPolicy) policyForTier(tierID string) (GeminiTierPolicy, bool) {
if p == nil {
return GeminiTierPolicy{}, false
}
normalized := normalizeGeminiTierID(tierID)
if normalized == "" {
normalized = "LEGACY"
}
if policy, ok := p.tiers[normalized]; ok {
return policy, true
}
policy, ok := p.tiers["LEGACY"]
return policy, ok
}
func normalizeGeminiTierID(tierID string) string {
return strings.ToUpper(strings.TrimSpace(tierID))
}
func clampGeminiQuotaInt64(value int64) int64 {
if value < 0 {
return 0
}
return value
}
func clampGeminiQuotaInt(value int) int {
if value < 0 {
return 0
}
return value
}
func geminiCooldownForTier(tierID string) time.Duration {
policy := newGeminiQuotaPolicy()
return policy.CooldownForTier(tierID)
}
func geminiModelClassFromName(model string) geminiModelClass {
name := strings.ToLower(strings.TrimSpace(model))
if strings.Contains(name, "flash") || strings.Contains(name, "lite") {
return geminiModelFlash
}
return geminiModelPro
}
func geminiAggregateUsage(stats []usagestats.ModelStat) GeminiUsageTotals {
var totals GeminiUsageTotals
for _, stat := range stats {
switch geminiModelClassFromName(stat.Model) {
case geminiModelFlash:
totals.FlashRequests += stat.Requests
totals.FlashTokens += stat.TotalTokens
totals.FlashCost += stat.ActualCost
default:
totals.ProRequests += stat.Requests
totals.ProTokens += stat.TotalTokens
totals.ProCost += stat.ActualCost
}
}
return totals
}
func geminiQuotaLocation() *time.Location {
loc, err := time.LoadLocation("America/Los_Angeles")
if err != nil {
return time.FixedZone("PST", -8*3600)
}
return loc
}
func geminiDailyWindowStart(now time.Time) time.Time {
loc := geminiQuotaLocation()
localNow := now.In(loc)
return time.Date(localNow.Year(), localNow.Month(), localNow.Day(), 0, 0, 0, 0, loc)
}
func geminiDailyResetTime(now time.Time) time.Time {
loc := geminiQuotaLocation()
localNow := now.In(loc)
start := time.Date(localNow.Year(), localNow.Month(), localNow.Day(), 0, 0, 0, 0, loc)
reset := start.Add(24 * time.Hour)
if !reset.After(localNow) {
reset = reset.Add(24 * time.Hour)
}
return reset
}

View File

@@ -112,17 +112,21 @@ func (p *GeminiTokenProvider) GetAccessToken(ctx context.Context, account *Accou
}
}
detected, err := p.geminiOAuthService.fetchProjectID(ctx, accessToken, proxyURL)
detected, tierID, err := p.geminiOAuthService.fetchProjectID(ctx, accessToken, proxyURL)
if err != nil {
log.Printf("[GeminiTokenProvider] Auto-detect project_id failed: %v, fallback to AI Studio API mode", err)
return accessToken, nil
}
detected = strings.TrimSpace(detected)
tierID = strings.TrimSpace(tierID)
if detected != "" {
if account.Credentials == nil {
account.Credentials = make(map[string]any)
}
account.Credentials["project_id"] = detected
if tierID != "" {
account.Credentials["tier_id"] = tierID
}
_ = p.accountRepo.Update(ctx, account)
}
}

View File

@@ -4,7 +4,7 @@ import (
"context"
"fmt"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)

View File

@@ -13,6 +13,7 @@ import (
"log"
"net/http"
"regexp"
"sort"
"strconv"
"strings"
"time"
@@ -82,6 +83,7 @@ type OpenAIGatewayService struct {
userSubRepo UserSubscriptionRepository
cache GatewayCache
cfg *config.Config
concurrencyService *ConcurrencyService
billingService *BillingService
rateLimitService *RateLimitService
billingCacheService *BillingCacheService
@@ -97,6 +99,7 @@ func NewOpenAIGatewayService(
userSubRepo UserSubscriptionRepository,
cache GatewayCache,
cfg *config.Config,
concurrencyService *ConcurrencyService,
billingService *BillingService,
rateLimitService *RateLimitService,
billingCacheService *BillingCacheService,
@@ -110,6 +113,7 @@ func NewOpenAIGatewayService(
userSubRepo: userSubRepo,
cache: cache,
cfg: cfg,
concurrencyService: concurrencyService,
billingService: billingService,
rateLimitService: rateLimitService,
billingCacheService: billingCacheService,
@@ -128,6 +132,14 @@ func (s *OpenAIGatewayService) GenerateSessionHash(c *gin.Context) string {
return hex.EncodeToString(hash[:])
}
// BindStickySession sets session -> account binding with standard TTL.
func (s *OpenAIGatewayService) BindStickySession(ctx context.Context, sessionHash string, accountID int64) error {
if sessionHash == "" || accountID <= 0 {
return nil
}
return s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, accountID, openaiStickySessionTTL)
}
// SelectAccount selects an OpenAI account with sticky session support
func (s *OpenAIGatewayService) SelectAccount(ctx context.Context, groupID *int64, sessionHash string) (*Account, error) {
return s.SelectAccountForModel(ctx, groupID, sessionHash, "")
@@ -220,6 +232,254 @@ func (s *OpenAIGatewayService) SelectAccountForModelWithExclusions(ctx context.C
return selected, nil
}
// SelectAccountWithLoadAwareness selects an account with load-awareness and wait plan.
func (s *OpenAIGatewayService) SelectAccountWithLoadAwareness(ctx context.Context, groupID *int64, sessionHash string, requestedModel string, excludedIDs map[int64]struct{}) (*AccountSelectionResult, error) {
cfg := s.schedulingConfig()
var stickyAccountID int64
if sessionHash != "" && s.cache != nil {
if accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash); err == nil {
stickyAccountID = accountID
}
}
if s.concurrencyService == nil || !cfg.LoadBatchEnabled {
account, err := s.SelectAccountForModelWithExclusions(ctx, groupID, sessionHash, requestedModel, excludedIDs)
if err != nil {
return nil, err
}
result, err := s.tryAcquireAccountSlot(ctx, account.ID, account.Concurrency)
if err == nil && result.Acquired {
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
if stickyAccountID > 0 && stickyAccountID == account.ID && s.concurrencyService != nil {
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, account.ID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: account.ID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
}
accounts, err := s.listSchedulableAccounts(ctx, groupID)
if err != nil {
return nil, err
}
if len(accounts) == 0 {
return nil, errors.New("no available accounts")
}
isExcluded := func(accountID int64) bool {
if excludedIDs == nil {
return false
}
_, excluded := excludedIDs[accountID]
return excluded
}
// ============ Layer 1: Sticky session ============
if sessionHash != "" {
accountID, err := s.cache.GetSessionAccountID(ctx, "openai:"+sessionHash)
if err == nil && accountID > 0 && !isExcluded(accountID) {
account, err := s.accountRepo.GetByID(ctx, accountID)
if err == nil && account.IsSchedulable() && account.IsOpenAI() &&
(requestedModel == "" || account.IsModelSupported(requestedModel)) {
result, err := s.tryAcquireAccountSlot(ctx, accountID, account.Concurrency)
if err == nil && result.Acquired {
_ = s.cache.RefreshSessionTTL(ctx, "openai:"+sessionHash, openaiStickySessionTTL)
return &AccountSelectionResult{
Account: account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
waitingCount, _ := s.concurrencyService.GetAccountWaitingCount(ctx, accountID)
if waitingCount < cfg.StickySessionMaxWaiting {
return &AccountSelectionResult{
Account: account,
WaitPlan: &AccountWaitPlan{
AccountID: accountID,
MaxConcurrency: account.Concurrency,
Timeout: cfg.StickySessionWaitTimeout,
MaxWaiting: cfg.StickySessionMaxWaiting,
},
}, nil
}
}
}
}
// ============ Layer 2: Load-aware selection ============
candidates := make([]*Account, 0, len(accounts))
for i := range accounts {
acc := &accounts[i]
if isExcluded(acc.ID) {
continue
}
if requestedModel != "" && !acc.IsModelSupported(requestedModel) {
continue
}
candidates = append(candidates, acc)
}
if len(candidates) == 0 {
return nil, errors.New("no available accounts")
}
accountLoads := make([]AccountWithConcurrency, 0, len(candidates))
for _, acc := range candidates {
accountLoads = append(accountLoads, AccountWithConcurrency{
ID: acc.ID,
MaxConcurrency: acc.Concurrency,
})
}
loadMap, err := s.concurrencyService.GetAccountsLoadBatch(ctx, accountLoads)
if err != nil {
ordered := append([]*Account(nil), candidates...)
sortAccountsByPriorityAndLastUsed(ordered, false)
for _, acc := range ordered {
result, err := s.tryAcquireAccountSlot(ctx, acc.ID, acc.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, acc.ID, openaiStickySessionTTL)
}
return &AccountSelectionResult{
Account: acc,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
}
} else {
type accountWithLoad struct {
account *Account
loadInfo *AccountLoadInfo
}
var available []accountWithLoad
for _, acc := range candidates {
loadInfo := loadMap[acc.ID]
if loadInfo == nil {
loadInfo = &AccountLoadInfo{AccountID: acc.ID}
}
if loadInfo.LoadRate < 100 {
available = append(available, accountWithLoad{
account: acc,
loadInfo: loadInfo,
})
}
}
if len(available) > 0 {
sort.SliceStable(available, func(i, j int) bool {
a, b := available[i], available[j]
if a.account.Priority != b.account.Priority {
return a.account.Priority < b.account.Priority
}
if a.loadInfo.LoadRate != b.loadInfo.LoadRate {
return a.loadInfo.LoadRate < b.loadInfo.LoadRate
}
switch {
case a.account.LastUsedAt == nil && b.account.LastUsedAt != nil:
return true
case a.account.LastUsedAt != nil && b.account.LastUsedAt == nil:
return false
case a.account.LastUsedAt == nil && b.account.LastUsedAt == nil:
return false
default:
return a.account.LastUsedAt.Before(*b.account.LastUsedAt)
}
})
for _, item := range available {
result, err := s.tryAcquireAccountSlot(ctx, item.account.ID, item.account.Concurrency)
if err == nil && result.Acquired {
if sessionHash != "" {
_ = s.cache.SetSessionAccountID(ctx, "openai:"+sessionHash, item.account.ID, openaiStickySessionTTL)
}
return &AccountSelectionResult{
Account: item.account,
Acquired: true,
ReleaseFunc: result.ReleaseFunc,
}, nil
}
}
}
}
// ============ Layer 3: Fallback wait ============
sortAccountsByPriorityAndLastUsed(candidates, false)
for _, acc := range candidates {
return &AccountSelectionResult{
Account: acc,
WaitPlan: &AccountWaitPlan{
AccountID: acc.ID,
MaxConcurrency: acc.Concurrency,
Timeout: cfg.FallbackWaitTimeout,
MaxWaiting: cfg.FallbackMaxWaiting,
},
}, nil
}
return nil, errors.New("no available accounts")
}
func (s *OpenAIGatewayService) listSchedulableAccounts(ctx context.Context, groupID *int64) ([]Account, error) {
var accounts []Account
var err error
if s.cfg != nil && s.cfg.RunMode == config.RunModeSimple {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
} else if groupID != nil {
accounts, err = s.accountRepo.ListSchedulableByGroupIDAndPlatform(ctx, *groupID, PlatformOpenAI)
} else {
accounts, err = s.accountRepo.ListSchedulableByPlatform(ctx, PlatformOpenAI)
}
if err != nil {
return nil, fmt.Errorf("query accounts failed: %w", err)
}
return accounts, nil
}
func (s *OpenAIGatewayService) tryAcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
if s.concurrencyService == nil {
return &AcquireResult{Acquired: true, ReleaseFunc: func() {}}, nil
}
return s.concurrencyService.AcquireAccountSlot(ctx, accountID, maxConcurrency)
}
func (s *OpenAIGatewayService) schedulingConfig() config.GatewaySchedulingConfig {
if s.cfg != nil {
return s.cfg.Gateway.Scheduling
}
return config.GatewaySchedulingConfig{
StickySessionMaxWaiting: 3,
StickySessionWaitTimeout: 45 * time.Second,
FallbackWaitTimeout: 30 * time.Second,
FallbackMaxWaiting: 100,
LoadBatchEnabled: true,
SlotCleanupInterval: 30 * time.Second,
}
}
// GetAccessToken gets the access token for an OpenAI account
func (s *OpenAIGatewayService) GetAccessToken(ctx context.Context, account *Account) (string, string, error) {
switch account.Type {

View File

@@ -4,7 +4,7 @@ import (
"context"
"fmt"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)

View File

@@ -5,6 +5,8 @@ import (
"log"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
@@ -12,15 +14,30 @@ import (
// RateLimitService 处理限流和过载状态管理
type RateLimitService struct {
accountRepo AccountRepository
cfg *config.Config
accountRepo AccountRepository
usageRepo UsageLogRepository
cfg *config.Config
geminiQuotaService *GeminiQuotaService
usageCacheMu sync.RWMutex
usageCache map[int64]*geminiUsageCacheEntry
}
type geminiUsageCacheEntry struct {
windowStart time.Time
cachedAt time.Time
totals GeminiUsageTotals
}
const geminiPrecheckCacheTTL = time.Minute
// NewRateLimitService 创建RateLimitService实例
func NewRateLimitService(accountRepo AccountRepository, cfg *config.Config) *RateLimitService {
func NewRateLimitService(accountRepo AccountRepository, usageRepo UsageLogRepository, cfg *config.Config, geminiQuotaService *GeminiQuotaService) *RateLimitService {
return &RateLimitService{
accountRepo: accountRepo,
cfg: cfg,
accountRepo: accountRepo,
usageRepo: usageRepo,
cfg: cfg,
geminiQuotaService: geminiQuotaService,
usageCache: make(map[int64]*geminiUsageCacheEntry),
}
}
@@ -62,6 +79,106 @@ func (s *RateLimitService) HandleUpstreamError(ctx context.Context, account *Acc
}
}
// PreCheckUsage proactively checks local quota before dispatching a request.
// Returns false when the account should be skipped.
func (s *RateLimitService) PreCheckUsage(ctx context.Context, account *Account, requestedModel string) (bool, error) {
if account == nil || !account.IsGeminiCodeAssist() || strings.TrimSpace(requestedModel) == "" {
return true, nil
}
if s.usageRepo == nil || s.geminiQuotaService == nil {
return true, nil
}
quota, ok := s.geminiQuotaService.QuotaForAccount(ctx, account)
if !ok {
return true, nil
}
var limit int64
switch geminiModelClassFromName(requestedModel) {
case geminiModelFlash:
limit = quota.FlashRPD
default:
limit = quota.ProRPD
}
if limit <= 0 {
return true, nil
}
now := time.Now()
start := geminiDailyWindowStart(now)
totals, ok := s.getGeminiUsageTotals(account.ID, start, now)
if !ok {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, start, now, 0, 0, account.ID)
if err != nil {
return true, err
}
totals = geminiAggregateUsage(stats)
s.setGeminiUsageTotals(account.ID, start, now, totals)
}
var used int64
switch geminiModelClassFromName(requestedModel) {
case geminiModelFlash:
used = totals.FlashRequests
default:
used = totals.ProRequests
}
if used >= limit {
resetAt := geminiDailyResetTime(now)
if err := s.accountRepo.SetRateLimited(ctx, account.ID, resetAt); err != nil {
log.Printf("SetRateLimited failed for account %d: %v", account.ID, err)
}
log.Printf("[Gemini PreCheck] Account %d reached daily quota (%d/%d), rate limited until %v", account.ID, used, limit, resetAt)
return false, nil
}
return true, nil
}
func (s *RateLimitService) getGeminiUsageTotals(accountID int64, windowStart, now time.Time) (GeminiUsageTotals, bool) {
s.usageCacheMu.RLock()
defer s.usageCacheMu.RUnlock()
if s.usageCache == nil {
return GeminiUsageTotals{}, false
}
entry, ok := s.usageCache[accountID]
if !ok || entry == nil {
return GeminiUsageTotals{}, false
}
if !entry.windowStart.Equal(windowStart) {
return GeminiUsageTotals{}, false
}
if now.Sub(entry.cachedAt) >= geminiPrecheckCacheTTL {
return GeminiUsageTotals{}, false
}
return entry.totals, true
}
func (s *RateLimitService) setGeminiUsageTotals(accountID int64, windowStart, now time.Time, totals GeminiUsageTotals) {
s.usageCacheMu.Lock()
defer s.usageCacheMu.Unlock()
if s.usageCache == nil {
s.usageCache = make(map[int64]*geminiUsageCacheEntry)
}
s.usageCache[accountID] = &geminiUsageCacheEntry{
windowStart: windowStart,
cachedAt: now,
totals: totals,
}
}
// GeminiCooldown returns the fallback cooldown duration for Gemini 429s based on tier.
func (s *RateLimitService) GeminiCooldown(ctx context.Context, account *Account) time.Duration {
if account == nil {
return 5 * time.Minute
}
return s.geminiQuotaService.CooldownForTier(ctx, account.GeminiTierID())
}
// handleAuthError 处理认证类错误(401/403),停止账号调度
func (s *RateLimitService) handleAuthError(ctx context.Context, account *Account, errorMsg string) {
if err := s.accountRepo.SetError(ctx, account.ID, errorMsg); err != nil {

View File

@@ -10,7 +10,7 @@ import (
"time"
dbent "github.com/Wei-Shaw/sub2api/ent"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)

View File

@@ -9,7 +9,7 @@ import (
"strconv"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
var (

View File

@@ -6,7 +6,7 @@ import (
"log"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
@@ -490,6 +490,7 @@ func (s *SubscriptionService) CheckAndResetWindows(ctx context.Context, sub *Use
}
// CheckUsageLimits 检查使用限额(返回错误如果超限)
// 用于中间件的快速预检查additionalCost 通常为 0
func (s *SubscriptionService) CheckUsageLimits(ctx context.Context, sub *UserSubscription, group *Group, additionalCost float64) error {
if !sub.CheckDailyLimit(group, additionalCost) {
return ErrDailyLimitExceeded

View File

@@ -5,12 +5,13 @@ import (
"fmt"
"log"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
var (
ErrTurnstileVerificationFailed = infraerrors.BadRequest("TURNSTILE_VERIFICATION_FAILED", "turnstile verification failed")
ErrTurnstileNotConfigured = infraerrors.ServiceUnavailable("TURNSTILE_NOT_CONFIGURED", "turnstile not configured")
ErrTurnstileInvalidSecretKey = infraerrors.BadRequest("TURNSTILE_INVALID_SECRET_KEY", "invalid turnstile secret key")
)
// TurnstileVerifier 验证 Turnstile token 的接口
@@ -83,3 +84,22 @@ func (s *TurnstileService) VerifyToken(ctx context.Context, token string, remote
func (s *TurnstileService) IsEnabled(ctx context.Context) bool {
return s.settingService.IsTurnstileEnabled(ctx)
}
// ValidateSecretKey 验证 Turnstile Secret Key 是否有效
func (s *TurnstileService) ValidateSecretKey(ctx context.Context, secretKey string) error {
// 发送一个测试token的验证请求来检查secret_key是否有效
result, err := s.verifier.VerifyToken(ctx, secretKey, "test-validation", "")
if err != nil {
return fmt.Errorf("validate secret key: %w", err)
}
// 检查是否有 invalid-input-secret 错误
for _, code := range result.ErrorCodes {
if code == "invalid-input-secret" {
return ErrTurnstileInvalidSecretKey
}
}
// 其他错误(如 invalid-input-response说明 secret key 是有效的
return nil
}

View File

@@ -5,7 +5,7 @@ import (
"fmt"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
)

View File

@@ -10,7 +10,6 @@ type User struct {
ID int64
Email string
Username string
Wechat string
Notes string
PasswordHash string
Role string

View File

@@ -0,0 +1,125 @@
package service
import (
"context"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// Error definitions for user attribute operations
var (
ErrAttributeDefinitionNotFound = infraerrors.NotFound("ATTRIBUTE_DEFINITION_NOT_FOUND", "attribute definition not found")
ErrAttributeKeyExists = infraerrors.Conflict("ATTRIBUTE_KEY_EXISTS", "attribute key already exists")
ErrInvalidAttributeType = infraerrors.BadRequest("INVALID_ATTRIBUTE_TYPE", "invalid attribute type")
ErrAttributeValidationFailed = infraerrors.BadRequest("ATTRIBUTE_VALIDATION_FAILED", "attribute value validation failed")
)
// UserAttributeType represents supported attribute types
type UserAttributeType string
const (
AttributeTypeText UserAttributeType = "text"
AttributeTypeTextarea UserAttributeType = "textarea"
AttributeTypeNumber UserAttributeType = "number"
AttributeTypeEmail UserAttributeType = "email"
AttributeTypeURL UserAttributeType = "url"
AttributeTypeDate UserAttributeType = "date"
AttributeTypeSelect UserAttributeType = "select"
AttributeTypeMultiSelect UserAttributeType = "multi_select"
)
// UserAttributeOption represents a select option for select/multi_select types
type UserAttributeOption struct {
Value string `json:"value"`
Label string `json:"label"`
}
// UserAttributeValidation represents validation rules for an attribute
type UserAttributeValidation struct {
MinLength *int `json:"min_length,omitempty"`
MaxLength *int `json:"max_length,omitempty"`
Min *int `json:"min,omitempty"`
Max *int `json:"max,omitempty"`
Pattern *string `json:"pattern,omitempty"`
Message *string `json:"message,omitempty"`
}
// UserAttributeDefinition represents a custom attribute definition
type UserAttributeDefinition struct {
ID int64
Key string
Name string
Description string
Type UserAttributeType
Options []UserAttributeOption
Required bool
Validation UserAttributeValidation
Placeholder string
DisplayOrder int
Enabled bool
CreatedAt time.Time
UpdatedAt time.Time
}
// UserAttributeValue represents a user's attribute value
type UserAttributeValue struct {
ID int64
UserID int64
AttributeID int64
Value string
CreatedAt time.Time
UpdatedAt time.Time
}
// CreateAttributeDefinitionInput for creating new definition
type CreateAttributeDefinitionInput struct {
Key string
Name string
Description string
Type UserAttributeType
Options []UserAttributeOption
Required bool
Validation UserAttributeValidation
Placeholder string
Enabled bool
}
// UpdateAttributeDefinitionInput for updating definition
type UpdateAttributeDefinitionInput struct {
Name *string
Description *string
Type *UserAttributeType
Options *[]UserAttributeOption
Required *bool
Validation *UserAttributeValidation
Placeholder *string
Enabled *bool
}
// UpdateUserAttributeInput for updating a single attribute value
type UpdateUserAttributeInput struct {
AttributeID int64
Value string
}
// UserAttributeDefinitionRepository interface for attribute definition persistence
type UserAttributeDefinitionRepository interface {
Create(ctx context.Context, def *UserAttributeDefinition) error
GetByID(ctx context.Context, id int64) (*UserAttributeDefinition, error)
GetByKey(ctx context.Context, key string) (*UserAttributeDefinition, error)
Update(ctx context.Context, def *UserAttributeDefinition) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, enabledOnly bool) ([]UserAttributeDefinition, error)
UpdateDisplayOrders(ctx context.Context, orders map[int64]int) error
ExistsByKey(ctx context.Context, key string) (bool, error)
}
// UserAttributeValueRepository interface for user attribute value persistence
type UserAttributeValueRepository interface {
GetByUserID(ctx context.Context, userID int64) ([]UserAttributeValue, error)
GetByUserIDs(ctx context.Context, userIDs []int64) ([]UserAttributeValue, error)
UpsertBatch(ctx context.Context, userID int64, values []UpdateUserAttributeInput) error
DeleteByAttributeID(ctx context.Context, attributeID int64) error
DeleteByUserID(ctx context.Context, userID int64) error
}

View File

@@ -0,0 +1,295 @@
package service
import (
"context"
"encoding/json"
"fmt"
"regexp"
"strconv"
"strings"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
)
// UserAttributeService handles attribute management
type UserAttributeService struct {
defRepo UserAttributeDefinitionRepository
valueRepo UserAttributeValueRepository
}
// NewUserAttributeService creates a new service instance
func NewUserAttributeService(
defRepo UserAttributeDefinitionRepository,
valueRepo UserAttributeValueRepository,
) *UserAttributeService {
return &UserAttributeService{
defRepo: defRepo,
valueRepo: valueRepo,
}
}
// CreateDefinition creates a new attribute definition
func (s *UserAttributeService) CreateDefinition(ctx context.Context, input CreateAttributeDefinitionInput) (*UserAttributeDefinition, error) {
// Validate type
if !isValidAttributeType(input.Type) {
return nil, ErrInvalidAttributeType
}
// Check if key exists
exists, err := s.defRepo.ExistsByKey(ctx, input.Key)
if err != nil {
return nil, fmt.Errorf("check key exists: %w", err)
}
if exists {
return nil, ErrAttributeKeyExists
}
def := &UserAttributeDefinition{
Key: input.Key,
Name: input.Name,
Description: input.Description,
Type: input.Type,
Options: input.Options,
Required: input.Required,
Validation: input.Validation,
Placeholder: input.Placeholder,
Enabled: input.Enabled,
}
if err := s.defRepo.Create(ctx, def); err != nil {
return nil, fmt.Errorf("create definition: %w", err)
}
return def, nil
}
// GetDefinition retrieves a definition by ID
func (s *UserAttributeService) GetDefinition(ctx context.Context, id int64) (*UserAttributeDefinition, error) {
return s.defRepo.GetByID(ctx, id)
}
// ListDefinitions lists all definitions
func (s *UserAttributeService) ListDefinitions(ctx context.Context, enabledOnly bool) ([]UserAttributeDefinition, error) {
return s.defRepo.List(ctx, enabledOnly)
}
// UpdateDefinition updates an existing definition
func (s *UserAttributeService) UpdateDefinition(ctx context.Context, id int64, input UpdateAttributeDefinitionInput) (*UserAttributeDefinition, error) {
def, err := s.defRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
if input.Name != nil {
def.Name = *input.Name
}
if input.Description != nil {
def.Description = *input.Description
}
if input.Type != nil {
if !isValidAttributeType(*input.Type) {
return nil, ErrInvalidAttributeType
}
def.Type = *input.Type
}
if input.Options != nil {
def.Options = *input.Options
}
if input.Required != nil {
def.Required = *input.Required
}
if input.Validation != nil {
def.Validation = *input.Validation
}
if input.Placeholder != nil {
def.Placeholder = *input.Placeholder
}
if input.Enabled != nil {
def.Enabled = *input.Enabled
}
if err := s.defRepo.Update(ctx, def); err != nil {
return nil, fmt.Errorf("update definition: %w", err)
}
return def, nil
}
// DeleteDefinition soft-deletes a definition and hard-deletes associated values
func (s *UserAttributeService) DeleteDefinition(ctx context.Context, id int64) error {
// Check if definition exists
_, err := s.defRepo.GetByID(ctx, id)
if err != nil {
return err
}
// First delete all values (hard delete)
if err := s.valueRepo.DeleteByAttributeID(ctx, id); err != nil {
return fmt.Errorf("delete values: %w", err)
}
// Then soft-delete the definition
if err := s.defRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete definition: %w", err)
}
return nil
}
// ReorderDefinitions updates display order for multiple definitions
func (s *UserAttributeService) ReorderDefinitions(ctx context.Context, orders map[int64]int) error {
return s.defRepo.UpdateDisplayOrders(ctx, orders)
}
// GetUserAttributes retrieves all attribute values for a user
func (s *UserAttributeService) GetUserAttributes(ctx context.Context, userID int64) ([]UserAttributeValue, error) {
return s.valueRepo.GetByUserID(ctx, userID)
}
// GetBatchUserAttributes retrieves attribute values for multiple users
// Returns a map of userID -> map of attributeID -> value
func (s *UserAttributeService) GetBatchUserAttributes(ctx context.Context, userIDs []int64) (map[int64]map[int64]string, error) {
values, err := s.valueRepo.GetByUserIDs(ctx, userIDs)
if err != nil {
return nil, err
}
result := make(map[int64]map[int64]string)
for _, v := range values {
if result[v.UserID] == nil {
result[v.UserID] = make(map[int64]string)
}
result[v.UserID][v.AttributeID] = v.Value
}
return result, nil
}
// UpdateUserAttributes batch updates attribute values for a user
func (s *UserAttributeService) UpdateUserAttributes(ctx context.Context, userID int64, inputs []UpdateUserAttributeInput) error {
// Validate all values before updating
defs, err := s.defRepo.List(ctx, true)
if err != nil {
return fmt.Errorf("list definitions: %w", err)
}
defMap := make(map[int64]*UserAttributeDefinition, len(defs))
for i := range defs {
defMap[defs[i].ID] = &defs[i]
}
for _, input := range inputs {
def, ok := defMap[input.AttributeID]
if !ok {
return ErrAttributeDefinitionNotFound
}
if err := s.validateValue(def, input.Value); err != nil {
return err
}
}
return s.valueRepo.UpsertBatch(ctx, userID, inputs)
}
// validateValue validates a value against its definition
func (s *UserAttributeService) validateValue(def *UserAttributeDefinition, value string) error {
// Skip validation for empty non-required fields
if value == "" && !def.Required {
return nil
}
// Required check
if def.Required && value == "" {
return validationError(fmt.Sprintf("%s is required", def.Name))
}
v := def.Validation
// String length validation
if v.MinLength != nil && len(value) < *v.MinLength {
return validationError(fmt.Sprintf("%s must be at least %d characters", def.Name, *v.MinLength))
}
if v.MaxLength != nil && len(value) > *v.MaxLength {
return validationError(fmt.Sprintf("%s must be at most %d characters", def.Name, *v.MaxLength))
}
// Number validation
if def.Type == AttributeTypeNumber && value != "" {
num, err := strconv.Atoi(value)
if err != nil {
return validationError(fmt.Sprintf("%s must be a number", def.Name))
}
if v.Min != nil && num < *v.Min {
return validationError(fmt.Sprintf("%s must be at least %d", def.Name, *v.Min))
}
if v.Max != nil && num > *v.Max {
return validationError(fmt.Sprintf("%s must be at most %d", def.Name, *v.Max))
}
}
// Pattern validation
if v.Pattern != nil && *v.Pattern != "" && value != "" {
re, err := regexp.Compile(*v.Pattern)
if err == nil && !re.MatchString(value) {
msg := def.Name + " format is invalid"
if v.Message != nil && *v.Message != "" {
msg = *v.Message
}
return validationError(msg)
}
}
// Select validation
if def.Type == AttributeTypeSelect && value != "" {
found := false
for _, opt := range def.Options {
if opt.Value == value {
found = true
break
}
}
if !found {
return validationError(fmt.Sprintf("%s: invalid option", def.Name))
}
}
// Multi-select validation (stored as JSON array)
if def.Type == AttributeTypeMultiSelect && value != "" {
var values []string
if err := json.Unmarshal([]byte(value), &values); err != nil {
// Try comma-separated fallback
values = strings.Split(value, ",")
}
for _, val := range values {
val = strings.TrimSpace(val)
found := false
for _, opt := range def.Options {
if opt.Value == val {
found = true
break
}
}
if !found {
return validationError(fmt.Sprintf("%s: invalid option %s", def.Name, val))
}
}
}
return nil
}
// validationError creates a validation error with a custom message
func validationError(msg string) error {
return infraerrors.BadRequest("ATTRIBUTE_VALIDATION_FAILED", msg)
}
func isValidAttributeType(t UserAttributeType) bool {
switch t {
case AttributeTypeText, AttributeTypeTextarea, AttributeTypeNumber,
AttributeTypeEmail, AttributeTypeURL, AttributeTypeDate,
AttributeTypeSelect, AttributeTypeMultiSelect:
return true
}
return false
}

View File

@@ -4,7 +4,7 @@ import (
"context"
"fmt"
infraerrors "github.com/Wei-Shaw/sub2api/internal/infrastructure/errors"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
@@ -14,6 +14,14 @@ var (
ErrInsufficientPerms = infraerrors.Forbidden("INSUFFICIENT_PERMISSIONS", "insufficient permissions")
)
// UserListFilters contains all filter options for listing users
type UserListFilters struct {
Status string // User status filter
Role string // User role filter
Search string // Search in email, username
Attributes map[int64]string // Custom attribute filters: attributeID -> value
}
type UserRepository interface {
Create(ctx context.Context, user *User) error
GetByID(ctx context.Context, id int64) (*User, error)
@@ -23,7 +31,7 @@ type UserRepository interface {
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, status, role, search string) ([]User, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error)
UpdateBalance(ctx context.Context, id int64, amount float64) error
DeductBalance(ctx context.Context, id int64, amount float64) error
@@ -36,7 +44,6 @@ type UserRepository interface {
type UpdateProfileRequest struct {
Email *string `json:"email"`
Username *string `json:"username"`
Wechat *string `json:"wechat"`
Concurrency *int `json:"concurrency"`
}
@@ -100,10 +107,6 @@ func (s *UserService) UpdateProfile(ctx context.Context, userID int64, req Updat
user.Username = *req.Username
}
if req.Wechat != nil {
user.Wechat = *req.Wechat
}
if req.Concurrency != nil {
user.Concurrency = *req.Concurrency
}

View File

@@ -73,6 +73,15 @@ func ProvideDeferredService(accountRepo AccountRepository, timingWheel *TimingWh
return svc
}
// ProvideConcurrencyService creates ConcurrencyService and starts slot cleanup worker.
func ProvideConcurrencyService(cache ConcurrencyCache, accountRepo AccountRepository, cfg *config.Config) *ConcurrencyService {
svc := NewConcurrencyService(cache)
if cfg != nil {
svc.StartSlotCleanupWorker(accountRepo, cfg.Gateway.Scheduling.SlotCleanupInterval)
}
return svc
}
// ProviderSet is the Wire provider set for all services
var ProviderSet = wire.NewSet(
// Core services
@@ -94,6 +103,7 @@ var ProviderSet = wire.NewSet(
NewOAuthService,
NewOpenAIOAuthService,
NewGeminiOAuthService,
NewGeminiQuotaService,
NewAntigravityOAuthService,
NewGeminiTokenProvider,
NewGeminiMessagesCompatService,
@@ -107,7 +117,7 @@ var ProviderSet = wire.NewSet(
ProvideEmailQueueService,
NewTurnstileService,
NewSubscriptionService,
NewConcurrencyService,
ProvideConcurrencyService,
NewIdentityService,
NewCRSSyncService,
ProvideUpdateService,
@@ -115,4 +125,5 @@ var ProviderSet = wire.NewSet(
ProvideTimingWheelService,
ProvideDeferredService,
ProvideAntigravityQuotaRefresher,
NewUserAttributeService,
)